Skip to content

Commit

Permalink
Unsharding utility
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Sep 11, 2023
1 parent 49c8d9e commit 6e14f88
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions util/unshard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import argparse, json, math, os, glob
from safetensors import safe_open
from safetensors.torch import save_file

parser = argparse.ArgumentParser(description = "Combine sharded .safetensors files")
parser.add_argument("output_file", type = str, help = "Path to output file")
args = parser.parse_args()

output_file = args.output_file
output_base, _ = os.path.splitext(output_file)

# Combine

output_dict = {}
input_files = glob.glob(output_base + "-*.safetensors")

for input_file in input_files:
print(f" -- Scanning tensors in {input_file}")
with safe_open(input_file, framework = "pt", device = "cpu") as f:
for key in f.keys():
print(f" -- Reading: {key}")
output_dict[key] = f.get_tensor(key)

# Write output

print(f" -- Writing: {output_file}")
save_file(output_dict, output_file)

# Done

print(f" -- Done")

0 comments on commit 6e14f88

Please sign in to comment.