Skip to content

Commit

Permalink
Stop recreating the hashmap all the time. (#363)
Browse files Browse the repository at this point in the history
* Stop recreating the hashmap all the time.

Fixes #361
Potentially superseeds #362

Co-Authored-By: Batuhan Taskaya <[email protected]>

* Adding the benches.

---------

Co-authored-by: Batuhan Taskaya <[email protected]>
  • Loading branch information
Narsil and isidentical authored Sep 18, 2023
1 parent e90f607 commit 6eb419f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
32 changes: 32 additions & 0 deletions bindings/python/benches/test_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ def create_gpt2(n_layers: int):
return tensors


def create_lora(n_layers: int):
tensors = {}
for i in range(n_layers):
tensors[f"lora.{i}.up.weight"] = torch.zeros((32, 32))
tensors[f"lora.{i}.down.weight"] = torch.zeros((32, 32))
return tensors


def test_pt_pt_load_cpu(benchmark):
# benchmark something
weights = create_gpt2(12)
Expand All @@ -56,6 +64,30 @@ def test_pt_sf_load_cpu(benchmark):
assert torch.allclose(v, tv)


def test_pt_pt_load_cpu_small(benchmark):
weights = create_lora(500)
with tempfile.NamedTemporaryFile(delete=False) as f:
torch.save(weights, f)
result = benchmark(torch.load, f.name)
os.unlink(f.name)

for k, v in weights.items():
tv = result[k]
assert torch.allclose(v, tv)


def test_pt_sf_load_cpu_small(benchmark):
weights = create_lora(500)
with tempfile.NamedTemporaryFile(delete=False) as f:
save_file(weights, f.name)
result = benchmark(load_file, f.name)
os.unlink(f.name)

for k, v in weights.items():
tv = result[k]
assert torch.allclose(v, tv)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires cuda")
def test_pt_pt_load_gpu(benchmark):
# benchmark something
Expand Down
6 changes: 4 additions & 2 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,12 @@ impl Open {
///
/// ```
pub fn get_tensor(&self, name: &str) -> PyResult<PyObject> {
let tensors = self.metadata.tensors();
let info = tensors.get(name).ok_or_else(|| {
let info = self.metadata.info(name).ok_or_else(|| {
SafetensorError::new_err(format!("File does not contain tensor {name}",))
})?;
// let info = tensors.get(name).ok_or_else(|| {
// SafetensorError::new_err(format!("File does not contain tensor {name}",))
// })?;

match &self.storage.as_ref() {
Storage::Mmap(mmap) => {
Expand Down
6 changes: 6 additions & 0 deletions safetensors/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,12 @@ impl Metadata {
Ok(start)
}

/// Gives back the tensor metadata
pub fn info(&self, name: &str) -> Option<&TensorInfo> {
let index = self.index_map.get(name)?;
self.tensors.get(*index)
}

/// Gives back the tensor metadata
pub fn tensors(&self) -> HashMap<String, &TensorInfo> {
self.index_map
Expand Down

0 comments on commit 6eb419f

Please sign in to comment.