Skip to content

Commit

Permalink
Improving the bf16 tests for PT+TF. (#505)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Jul 26, 2024
1 parent 7d29f61 commit 3bfe613
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
6 changes: 3 additions & 3 deletions bindings/python/tests/test_pt_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def test_serialization(self):

def test_odd_dtype(self):
data = {
"test": torch.zeros((2, 2), dtype=torch.bfloat16),
"test2": torch.zeros((2, 2), dtype=torch.float16),
"test3": torch.zeros((2, 2), dtype=torch.bool),
"test": torch.randn((2, 2), dtype=torch.bfloat16),
"test2": torch.randn((2, 2), dtype=torch.float16),
"test3": torch.randn((2, 2), dtype=torch.bool),
}
local = "./tests/data/out_safe_pt_mmap_small.safetensors"

Expand Down
7 changes: 6 additions & 1 deletion bindings/python/tests/test_tf_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_deserialization_safe(self):

def test_bfloat16(self):
data = {
"test": tf.zeros((1024, 1024), dtype=tf.bfloat16),
"test": tf.randn((1024, 1024), dtype=tf.bfloat16),
}
save_file(data, self.sf_filename)
weights = {}
Expand All @@ -76,6 +76,11 @@ def test_bfloat16(self):
tv = data[k]
self.assertTrue(tf.experimental.numpy.allclose(v, tv))

weights = load_file(self.sf_filename)
for k, v in weights.items():
tv = data[k]
self.assertTrue(tf.experimental.numpy.allclose(v, tv))

def test_deserialization_safe_open(self):
weights = {}
with safe_open(self.sf_filename, framework="tf") as f:
Expand Down

0 comments on commit 3bfe613

Please sign in to comment.