diff --git a/shortfin/tests/api/array_ops_test.py b/shortfin/tests/api/array_ops_test.py index 164dfb479..3e778d421 100644 --- a/shortfin/tests/api/array_ops_test.py +++ b/shortfin/tests/api/array_ops_test.py @@ -100,6 +100,7 @@ def test_argmax_axis0(device): @pytest.mark.parametrize( "dtype", [ + sfnp.bfloat16, sfnp.float16, sfnp.float32, ], @@ -114,6 +115,7 @@ def test_argmax_dtypes(device, dtype): @pytest.mark.parametrize( "dtype", [ + sfnp.bfloat16, sfnp.float16, sfnp.float32, ], @@ -138,6 +140,7 @@ def test_fill_randn_default_generator(device, dtype): @pytest.mark.parametrize( "dtype", [ + sfnp.bfloat16, sfnp.float16, sfnp.float32, ], @@ -180,6 +183,7 @@ def test_fill_randn_explicit_generator(device, dtype): sfnp.int16, sfnp.int32, sfnp.int64, + sfnp.bfloat16, sfnp.float16, sfnp.float32, sfnp.float64, @@ -208,12 +212,16 @@ def round_half_away_from_zero(n): @pytest.mark.parametrize( "dtype,sfnp_func,ref_round_func", [ + (sfnp.bfloat16, sfnp.round, round_half_away_from_zero), (sfnp.float16, sfnp.round, round_half_away_from_zero), (sfnp.float32, sfnp.round, round_half_away_from_zero), + (sfnp.bfloat16, sfnp.ceil, math.ceil), (sfnp.float16, sfnp.ceil, math.ceil), (sfnp.float32, sfnp.ceil, math.ceil), + (sfnp.bfloat16, sfnp.floor, math.floor), (sfnp.float16, sfnp.floor, math.floor), (sfnp.float32, sfnp.floor, math.floor), + (sfnp.bfloat16, sfnp.trunc, math.trunc), (sfnp.float16, sfnp.trunc, math.trunc), (sfnp.float32, sfnp.trunc, math.trunc), ], @@ -309,6 +317,8 @@ def test_elementwise_forms(device): @pytest.mark.parametrize( "lhs_dtype,rhs_dtype,promoted_dtype", [ + (sfnp.float32, sfnp.bfloat16, sfnp.float32), + (sfnp.bfloat16, sfnp.float32, sfnp.float32), (sfnp.float32, sfnp.float16, sfnp.float32), (sfnp.float16, sfnp.float32, sfnp.float32), (sfnp.float32, sfnp.float64, sfnp.float64), @@ -347,6 +357,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): (sfnp.uint16, sfnp.add, 44.0), (sfnp.uint32, sfnp.add, 44.0), (sfnp.uint64, sfnp.add, 44.0), + (sfnp.bfloat16, sfnp.add, 44.0), (sfnp.float16, sfnp.add, 44.0), (sfnp.float32, sfnp.add, 44.0), (sfnp.float64, sfnp.add, 44.0), @@ -359,6 +370,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): (sfnp.uint16, sfnp.divide, 21.0), (sfnp.uint32, sfnp.divide, 21.0), (sfnp.uint64, sfnp.divide, 21.0), + (sfnp.bfloat16, sfnp.divide, 21.0), (sfnp.float16, sfnp.divide, 21.0), (sfnp.float32, sfnp.divide, 21.0), (sfnp.float64, sfnp.divide, 21.0), @@ -371,6 +383,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): (sfnp.uint16, sfnp.multiply, 84.0), (sfnp.uint32, sfnp.multiply, 84.0), (sfnp.uint64, sfnp.multiply, 84.0), + (sfnp.bfloat16, sfnp.multiply, 84.0), (sfnp.float16, sfnp.multiply, 84.0), (sfnp.float32, sfnp.multiply, 84.0), (sfnp.float64, sfnp.multiply, 84.0), @@ -383,6 +396,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype): (sfnp.uint16, sfnp.subtract, 40.0), (sfnp.uint32, sfnp.subtract, 40.0), (sfnp.uint64, sfnp.subtract, 40.0), + (sfnp.bfloat16, sfnp.subtract, 40.0), (sfnp.float16, sfnp.subtract, 40.0), (sfnp.float32, sfnp.subtract, 40.0), (sfnp.float64, sfnp.subtract, 40.0), @@ -418,6 +432,7 @@ def test_elementwise_array_correctness(device, dtype, op, check_value): sfnp.uint32, sfnp.uint64, sfnp.float32, + sfnp.bfloat16, sfnp.float16, sfnp.float32, sfnp.float64,