diff --git a/bindings/python/tests/test_pt_comparison.py b/bindings/python/tests/test_pt_comparison.py index 0f3821b8..eb1aa65a 100644 --- a/bindings/python/tests/test_pt_comparison.py +++ b/bindings/python/tests/test_pt_comparison.py @@ -52,9 +52,9 @@ def test_serialization(self): def test_odd_dtype(self): data = { - "test": torch.zeros((2, 2), dtype=torch.bfloat16), - "test2": torch.zeros((2, 2), dtype=torch.float16), - "test3": torch.zeros((2, 2), dtype=torch.bool), + "test": torch.randn((2, 2), dtype=torch.bfloat16), + "test2": torch.randn((2, 2), dtype=torch.float16), + "test3": torch.randn((2, 2), dtype=torch.bool), } local = "./tests/data/out_safe_pt_mmap_small.safetensors" diff --git a/bindings/python/tests/test_tf_comparison.py b/bindings/python/tests/test_tf_comparison.py index ac41e6f6..2c50bb76 100644 --- a/bindings/python/tests/test_tf_comparison.py +++ b/bindings/python/tests/test_tf_comparison.py @@ -64,7 +64,7 @@ def test_deserialization_safe(self): def test_bfloat16(self): data = { - "test": tf.zeros((1024, 1024), dtype=tf.bfloat16), + "test": tf.randn((1024, 1024), dtype=tf.bfloat16), } save_file(data, self.sf_filename) weights = {} @@ -76,6 +76,11 @@ def test_bfloat16(self): tv = data[k] self.assertTrue(tf.experimental.numpy.allclose(v, tv)) + weights = load_file(self.sf_filename) + for k, v in weights.items(): + tv = data[k] + self.assertTrue(tf.experimental.numpy.allclose(v, tv)) + def test_deserialization_safe_open(self): weights = {} with safe_open(self.sf_filename, framework="tf") as f: