-
-
Notifications
You must be signed in to change notification settings - Fork 6
/
convert.py
28 lines (23 loc) · 829 Bytes
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import argparse
import os
import torch
def main():
parser = argparse.ArgumentParser()
parser.add_argument("src", help="root directory", default="llama-7b-hf")
args = parser.parse_args()
for f in os.listdir(args.src):
if not f.endswith(".bin"):
continue
print(f"Loading {f}")
sd = torch.load(os.path.join(args.src, f))
for key, tensor in sd.items():
print("Saving", key, tensor.shape, tensor.dtype)
path = os.path.sep.join(key.split("."))
os.makedirs(os.path.join(args.src, os.path.dirname(path)), exist_ok=True)
np_array = tensor.numpy()
with open(os.path.join(args.src, path), "w") as fp:
np_array.tofile(fp)
del np_array
del sd
if __name__ == "__main__":
main()