diff --git a/bindings/python/benches/test_pt.py b/bindings/python/benches/test_pt.py index ea677051..e9c06fcf 100644 --- a/bindings/python/benches/test_pt.py +++ b/bindings/python/benches/test_pt.py @@ -30,6 +30,14 @@ def create_gpt2(n_layers: int): return tensors +def create_lora(n_layers: int): + tensors = {} + for i in range(n_layers): + tensors[f"lora.{i}.up.weight"] = torch.zeros((32, 32)) + tensors[f"lora.{i}.down.weight"] = torch.zeros((32, 32)) + return tensors + + def test_pt_pt_load_cpu(benchmark): # benchmark something weights = create_gpt2(12) @@ -56,6 +64,30 @@ def test_pt_sf_load_cpu(benchmark): assert torch.allclose(v, tv) +def test_pt_pt_load_cpu_small(benchmark): + weights = create_lora(500) + with tempfile.NamedTemporaryFile(delete=False) as f: + torch.save(weights, f) + result = benchmark(torch.load, f.name) + os.unlink(f.name) + + for k, v in weights.items(): + tv = result[k] + assert torch.allclose(v, tv) + + +def test_pt_sf_load_cpu_small(benchmark): + weights = create_lora(500) + with tempfile.NamedTemporaryFile(delete=False) as f: + save_file(weights, f.name) + result = benchmark(load_file, f.name) + os.unlink(f.name) + + for k, v in weights.items(): + tv = result[k] + assert torch.allclose(v, tv) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda") def test_pt_pt_load_gpu(benchmark): # benchmark something diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 1a5289d5..d088c277 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -452,10 +452,12 @@ impl Open { /// /// ``` pub fn get_tensor(&self, name: &str) -> PyResult { - let tensors = self.metadata.tensors(); - let info = tensors.get(name).ok_or_else(|| { + let info = self.metadata.info(name).ok_or_else(|| { SafetensorError::new_err(format!("File does not contain tensor {name}",)) })?; + // let info = tensors.get(name).ok_or_else(|| { + // SafetensorError::new_err(format!("File does not contain tensor {name}",)) + // })?; match &self.storage.as_ref() { Storage::Mmap(mmap) => { diff --git a/safetensors/src/tensor.rs b/safetensors/src/tensor.rs index 45b9d405..f326f21b 100644 --- a/safetensors/src/tensor.rs +++ b/safetensors/src/tensor.rs @@ -513,6 +513,12 @@ impl Metadata { Ok(start) } + /// Gives back the tensor metadata + pub fn info(&self, name: &str) -> Option<&TensorInfo> { + let index = self.index_map.get(name)?; + self.tensors.get(*index) + } + /// Gives back the tensor metadata pub fn tensors(&self) -> HashMap { self.index_map