From 6e14f8802b69119a4da28a01f6ce883a2babbe13 Mon Sep 17 00:00:00 2001 From: turboderp Date: Mon, 11 Sep 2023 04:16:42 +0200 Subject: [PATCH] Unsharding utility --- util/unshard.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 util/unshard.py diff --git a/util/unshard.py b/util/unshard.py new file mode 100644 index 00000000..dd9f0366 --- /dev/null +++ b/util/unshard.py @@ -0,0 +1,32 @@ +import argparse, json, math, os, glob +from safetensors import safe_open +from safetensors.torch import save_file + +parser = argparse.ArgumentParser(description = "Combine sharded .safetensors files") +parser.add_argument("output_file", type = str, help = "Path to output file") +args = parser.parse_args() + +output_file = args.output_file +output_base, _ = os.path.splitext(output_file) + +# Combine + +output_dict = {} +input_files = glob.glob(output_base + "-*.safetensors") + +for input_file in input_files: + print(f" -- Scanning tensors in {input_file}") + with safe_open(input_file, framework = "pt", device = "cpu") as f: + for key in f.keys(): + print(f" -- Reading: {key}") + output_dict[key] = f.get_tensor(key) + +# Write output + +print(f" -- Writing: {output_file}") +save_file(output_dict, output_file) + +# Done + +print(f" -- Done") +