#!/usr/bin/env python3

PATH_TO_SOFTPROMPT_1 = "softprompt1.zip"
PATH_TO_SOFTPROMPT_2 = "softprompt2.zip"
OUTPUT_PATH = "softprompt_combined.zip"
META = {
    "name": "Soft prompt name",
    "author": "Soft prompt author",
    "description": "Soft prompt description",
    "supported": "Generic 6B",
}

################################################################################

import numpy as np
import json
import zipfile

def bf16_to_fp32(tensor):
    if tensor.dtype == "V2":
        tensor.dtype = np.uint16
        tensor = np.uint32(tensor) << 16
        tensor.dtype = np.float32
    return tensor

with np.load(PATH_TO_SOFTPROMPT_1) as z1, np.load(PATH_TO_SOFTPROMPT_2) as z2:
    t1 = z1["tensor"]
    t2 = z2["tensor"]

if t1.dtype != t2.dtype:
    t1 = np.float32(bf16_to_fp32(t1))
    t2 = np.float32(bf16_to_fp32(t2))

tensor = np.concatenate((t1, t2), axis=-2)

with zipfile.ZipFile(OUTPUT_PATH, "w", compression=zipfile.ZIP_LZMA) as z:
    with z.open("tensor.npy", "w") as f:
        np.save(f, tensor, allow_pickle=False)
with zipfile.ZipFile(OUTPUT_PATH, "a", compression=zipfile.ZIP_STORED) as z:
    with z.open("meta.json", "w") as f:
        f.write(json.dumps(META, indent=2).encode("utf-8"))
