From 29bd169d3e593ed6c89f85207bbc576b9c3f269a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 26 Jul 2024 12:03:50 +0200 Subject: [PATCH] Revert "Improving the bf16 tests for PT+TF. (#505)" This reverts commit 3bfe61321b1f6ad870685f92a600f87c10593ff2. --- bindings/python/tests/test_pt_comparison.py | 6 +++--- bindings/python/tests/test_tf_comparison.py | 7 +------ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/bindings/python/tests/test_pt_comparison.py b/bindings/python/tests/test_pt_comparison.py index eb1aa65a..0f3821b8 100644 --- a/bindings/python/tests/test_pt_comparison.py +++ b/bindings/python/tests/test_pt_comparison.py @@ -52,9 +52,9 @@ def test_serialization(self): def test_odd_dtype(self): data = { - "test": torch.randn((2, 2), dtype=torch.bfloat16), - "test2": torch.randn((2, 2), dtype=torch.float16), - "test3": torch.randn((2, 2), dtype=torch.bool), + "test": torch.zeros((2, 2), dtype=torch.bfloat16), + "test2": torch.zeros((2, 2), dtype=torch.float16), + "test3": torch.zeros((2, 2), dtype=torch.bool), } local = "./tests/data/out_safe_pt_mmap_small.safetensors" diff --git a/bindings/python/tests/test_tf_comparison.py b/bindings/python/tests/test_tf_comparison.py index 2c50bb76..ac41e6f6 100644 --- a/bindings/python/tests/test_tf_comparison.py +++ b/bindings/python/tests/test_tf_comparison.py @@ -64,7 +64,7 @@ def test_deserialization_safe(self): def test_bfloat16(self): data = { - "test": tf.randn((1024, 1024), dtype=tf.bfloat16), + "test": tf.zeros((1024, 1024), dtype=tf.bfloat16), } save_file(data, self.sf_filename) weights = {} @@ -76,11 +76,6 @@ def test_bfloat16(self): tv = data[k] self.assertTrue(tf.experimental.numpy.allclose(v, tv)) - weights = load_file(self.sf_filename) - for k, v in weights.items(): - tv = data[k] - self.assertTrue(tf.experimental.numpy.allclose(v, tv)) - def test_deserialization_safe_open(self): weights = {} with safe_open(self.sf_filename, framework="tf") as f: