diff --git a/bindings/python/tests/test_pt_comparison.py b/bindings/python/tests/test_pt_comparison.py index eb1aa65a..0f3821b8 100644 --- a/bindings/python/tests/test_pt_comparison.py +++ b/bindings/python/tests/test_pt_comparison.py @@ -52,9 +52,9 @@ def test_serialization(self): def test_odd_dtype(self): data = { - "test": torch.randn((2, 2), dtype=torch.bfloat16), - "test2": torch.randn((2, 2), dtype=torch.float16), - "test3": torch.randn((2, 2), dtype=torch.bool), + "test": torch.zeros((2, 2), dtype=torch.bfloat16), + "test2": torch.zeros((2, 2), dtype=torch.float16), + "test3": torch.zeros((2, 2), dtype=torch.bool), } local = "./tests/data/out_safe_pt_mmap_small.safetensors" diff --git a/bindings/python/tests/test_tf_comparison.py b/bindings/python/tests/test_tf_comparison.py index 2c50bb76..ac41e6f6 100644 --- a/bindings/python/tests/test_tf_comparison.py +++ b/bindings/python/tests/test_tf_comparison.py @@ -64,7 +64,7 @@ def test_deserialization_safe(self): def test_bfloat16(self): data = { - "test": tf.randn((1024, 1024), dtype=tf.bfloat16), + "test": tf.zeros((1024, 1024), dtype=tf.bfloat16), } save_file(data, self.sf_filename) weights = {} @@ -76,11 +76,6 @@ def test_bfloat16(self): tv = data[k] self.assertTrue(tf.experimental.numpy.allclose(v, tv)) - weights = load_file(self.sf_filename) - for k, v in weights.items(): - tv = data[k] - self.assertTrue(tf.experimental.numpy.allclose(v, tv)) - def test_deserialization_safe_open(self): weights = {} with safe_open(self.sf_filename, framework="tf") as f: