Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
KyleHerndon committed Feb 12, 2025
1 parent 6f547a0 commit 3ce36bb
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions shortfin/tests/api/array_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_argmax_axis0(device):
@pytest.mark.parametrize(
"dtype",
[
sfnp.bfloat16,
sfnp.float16,
sfnp.float32,
],
Expand All @@ -114,6 +115,7 @@ def test_argmax_dtypes(device, dtype):
@pytest.mark.parametrize(
"dtype",
[
sfnp.bfloat16,
sfnp.float16,
sfnp.float32,
],
Expand All @@ -138,6 +140,7 @@ def test_fill_randn_default_generator(device, dtype):
@pytest.mark.parametrize(
"dtype",
[
sfnp.bfloat16,
sfnp.float16,
sfnp.float32,
],
Expand Down Expand Up @@ -180,6 +183,7 @@ def test_fill_randn_explicit_generator(device, dtype):
sfnp.int16,
sfnp.int32,
sfnp.int64,
sfnp.bfloat16,
sfnp.float16,
sfnp.float32,
sfnp.float64,
Expand Down Expand Up @@ -208,12 +212,16 @@ def round_half_away_from_zero(n):
@pytest.mark.parametrize(
"dtype,sfnp_func,ref_round_func",
[
(sfnp.bfloat16, sfnp.round, round_half_away_from_zero),
(sfnp.float16, sfnp.round, round_half_away_from_zero),
(sfnp.float32, sfnp.round, round_half_away_from_zero),
(sfnp.bfloat16, sfnp.ceil, math.ceil),
(sfnp.float16, sfnp.ceil, math.ceil),
(sfnp.float32, sfnp.ceil, math.ceil),
(sfnp.bfloat16, sfnp.floor, math.floor),
(sfnp.float16, sfnp.floor, math.floor),
(sfnp.float32, sfnp.floor, math.floor),
(sfnp.bfloat16, sfnp.trunc, math.trunc),
(sfnp.float16, sfnp.trunc, math.trunc),
(sfnp.float32, sfnp.trunc, math.trunc),
],
Expand Down Expand Up @@ -309,6 +317,8 @@ def test_elementwise_forms(device):
@pytest.mark.parametrize(
"lhs_dtype,rhs_dtype,promoted_dtype",
[
(sfnp.float32, sfnp.bfloat16, sfnp.float32),
(sfnp.bfloat16, sfnp.float32, sfnp.float32),
(sfnp.float32, sfnp.float16, sfnp.float32),
(sfnp.float16, sfnp.float32, sfnp.float32),
(sfnp.float32, sfnp.float64, sfnp.float64),
Expand Down Expand Up @@ -347,6 +357,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype):
(sfnp.uint16, sfnp.add, 44.0),
(sfnp.uint32, sfnp.add, 44.0),
(sfnp.uint64, sfnp.add, 44.0),
(sfnp.bfloat16, sfnp.add, 44.0),
(sfnp.float16, sfnp.add, 44.0),
(sfnp.float32, sfnp.add, 44.0),
(sfnp.float64, sfnp.add, 44.0),
Expand All @@ -359,6 +370,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype):
(sfnp.uint16, sfnp.divide, 21.0),
(sfnp.uint32, sfnp.divide, 21.0),
(sfnp.uint64, sfnp.divide, 21.0),
(sfnp.bfloat16, sfnp.divide, 21.0),
(sfnp.float16, sfnp.divide, 21.0),
(sfnp.float32, sfnp.divide, 21.0),
(sfnp.float64, sfnp.divide, 21.0),
Expand All @@ -371,6 +383,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype):
(sfnp.uint16, sfnp.multiply, 84.0),
(sfnp.uint32, sfnp.multiply, 84.0),
(sfnp.uint64, sfnp.multiply, 84.0),
(sfnp.bfloat16, sfnp.multiply, 84.0),
(sfnp.float16, sfnp.multiply, 84.0),
(sfnp.float32, sfnp.multiply, 84.0),
(sfnp.float64, sfnp.multiply, 84.0),
Expand All @@ -383,6 +396,7 @@ def test_elementwise_promotion(device, lhs_dtype, rhs_dtype, promoted_dtype):
(sfnp.uint16, sfnp.subtract, 40.0),
(sfnp.uint32, sfnp.subtract, 40.0),
(sfnp.uint64, sfnp.subtract, 40.0),
(sfnp.bfloat16, sfnp.subtract, 40.0),
(sfnp.float16, sfnp.subtract, 40.0),
(sfnp.float32, sfnp.subtract, 40.0),
(sfnp.float64, sfnp.subtract, 40.0),
Expand Down Expand Up @@ -418,6 +432,7 @@ def test_elementwise_array_correctness(device, dtype, op, check_value):
sfnp.uint32,
sfnp.uint64,
sfnp.float32,
sfnp.bfloat16,
sfnp.float16,
sfnp.float32,
sfnp.float64,
Expand Down

0 comments on commit 3ce36bb

Please sign in to comment.