From 658c300136f1397032bbf879de64337746288a88 Mon Sep 17 00:00:00 2001 From: himkwtn Date: Fri, 23 Aug 2024 14:16:57 -0700 Subject: [PATCH] ENH: create unit test for get_regularization shape validation --- pysindy/utils/base.py | 26 +++++---- test/utils/test_utils.py | 122 ++++++++++++++------------------------- 2 files changed, 57 insertions(+), 91 deletions(-) diff --git a/pysindy/utils/base.py b/pysindy/utils/base.py index 0184d48f..1ee96f1b 100644 --- a/pysindy/utils/base.py +++ b/pysindy/utils/base.py @@ -154,18 +154,18 @@ def reorder_constraints(arr, n_features, output_order="feature"): return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1)) -def validate_prox_and_reg_inputs(func): - @wraps(func) +def validate_prox_and_reg_inputs(func, regularization): def wrapper(x, regularization_weight): # Example validation: check if both a and b are positive integers - if isinstance(regularization_weight, np.ndarray) and ( - regularization_weight.shape != x.shape - and regularization_weight.shape != (1, 1) - ): + if regularization[:8] == 'weighted' and \ + (not isinstance(regularization_weight, np.ndarray) or (regularization_weight.shape != x.shape)): raise ValueError( - f"Invalid shape for 'regularization_weight': {regularization_weight.shape}. Must be the same shape as x: {x.shape}." + f"Invalid shape for 'regularization_weight': { + regularization_weight.shape if isinstance(regularization, np.ndarray) else '()'}. Must be the same shape as x: {x.shape}." ) - + elif regularization[:8] != 'weighted' and not isinstance(regularization_weight, (int, float)) \ + and (isinstance(regularization_weight, np.ndarray) and regularization_weight.shape not in [(1, 1), (1,)]): + raise ValueError("'regularization_weight' must be a scalar") # If validation passes, call the original function return func(x, regularization_weight) @@ -235,8 +235,9 @@ def prox_weighted_l2( "l2": prox_l2, "weighted_l2": prox_weighted_l2, } - if regularization.lower() in prox: - return validate_prox_and_reg_inputs(prox[regularization.lower()]) + regularization = regularization.lower() + if regularization in prox: + return validate_prox_and_reg_inputs(prox[regularization], regularization) else: raise NotImplementedError("{} has not been implemented".format(regularization)) @@ -291,8 +292,9 @@ def regualization_weighted_l2( "l2": regularization_l2, "weighted_l2": regualization_weighted_l2, } - if regularization.lower() in regularization_fn: - return validate_prox_and_reg_inputs(regularization_fn[regularization.lower()]) + regularization = regularization.lower() + if regularization in regularization_fn: + return validate_prox_and_reg_inputs(regularization_fn[regularization], regularization) else: raise NotImplementedError("{} has not been implemented".format(regularization)) diff --git a/test/utils/test_utils.py b/test/utils/test_utils.py index 3fa5d16c..ed4e0067 100644 --- a/test/utils/test_utils.py +++ b/test/utils/test_utils.py @@ -65,117 +65,81 @@ def test_validate_controls(): @pytest.mark.parametrize( - ["regularization", "expected"], [("l0", 6), ("l1", 20), ("l2", 76)] + ["regularization", "lam", "expected"], + [ + ("l0", 2, 4), + ("l1", 2, 14), + ("l2", 2, 58), + ("weighted_l0", np.array([[3, 2]]).T, 5), + ("weighted_l1", np.array([[3, 2]]).T, 16), + ("weighted_l2", np.array([[3, 2]]).T, 62), + ], ) -def test_get_regularization_1d(regularization, expected): - data = np.array([[-2, 3, 5]]).T - lam = np.array([[2]]) +def test_get_regularization(regularization, lam, expected): + data = np.array([[-2, 5]]).T reg = get_regularization(regularization) result = reg(data, lam) assert result == expected -@pytest.mark.parametrize( - ["regularization", "expected"], [("l0", 10), ("l1", 56), ("l2", 416)] -) -def test_get_regularization_2d(regularization, expected): - data = np.array([[-2, 3, 5], [7, 11, 0]]).T - lam = np.array([[2]]) - +@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"]) +@pytest.mark.parametrize("lam", [1, np.array([1]), np.array([[1]])]) +def test_get_regularization_shape(regularization, lam): + data = np.array([[-2, 5]]).T reg = get_regularization(regularization) result = reg(data, lam) - assert result == expected + assert result != None @pytest.mark.parametrize( - ["regularization", "expected"], - [("weighted_l0", 5.5), ("weighted_l1", 14.5), ("weighted_l2", 42.5)], + "regularization", ["weighted_l0", "weighted_l1", "weighted_l2"] ) -def test_get_weighted_regularization_1d(regularization, expected): - data = np.array([[-2, 3, 5]]).T - lam = np.array([[3, 2, 0.5]]).T - +@pytest.mark.parametrize("lam", [np.array([[1, 2]]).T]) +def test_get_weighted_regularization_shape(regularization, lam): + data = np.array([[-2, 5]]).T reg = get_regularization(regularization) result = reg(data, lam) - assert result == expected + assert result != None +@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"]) @pytest.mark.parametrize( - ["regularization", "expected"], - [("weighted_l0", 19.5), ("weighted_l1", 164.5), ("weighted_l2", 1664.5)], + "lam", [np.array([[1, 2]]), np.array([1, 2]), np.array([[1, 2]]).T] ) -def test_get_weighted_regularization_2d(regularization, expected): - data = np.array([[-2, 3, 5], [7, 11, 0]]).T - lam = np.array([[3, 2, 0.5], [1, 13, 17]]).T - +def test_get_regularization_bad_shape(regularization, lam): + data = np.array([[-2, 5]]).T reg = get_regularization(regularization) - result = reg(data, lam) - assert result == expected - - -@pytest.mark.parametrize( - ["regularization", "expected"], - [ - ("l0", np.array([[0, 3, 5]]).T), - ("l1", np.array([[0, 0, 2]]).T), - ("l2", np.array([[-2 / 7, 3 / 7, 5 / 7]]).T), - ], -) -def test_get_prox_1d(regularization, expected): - data = np.array([[-2, 3, 5]]).T - lam = np.array([[3]]) - - prox = get_prox(regularization) - result = prox(data, lam) - assert_array_equal(result, expected) + with pytest.raises(ValueError): + reg(data, lam) @pytest.mark.parametrize( - ["regularization", "expected"], - [ - ("l0", np.array([[0, 3, 5], [-7, 11, 0]]).T), - ("l1", np.array([[0, 0, 2], [-4, 8, 0]]).T), - ("l2", np.array([[-2 / 7, 3 / 7, 5 / 7], [-7 / 7, 11 / 7, 0 / 7]]).T), - ], + "regularization", ["weighted_l0", "weighted_l1", "weighted_l2"] ) -def test_get_prox_2d(regularization, expected): - data = np.array([[-2, 3, 5], [-7, 11, 0]]).T - lam = np.array([[3]]) - - prox = get_prox(regularization) - result = prox(data, lam) - assert_array_equal(result, expected) - - @pytest.mark.parametrize( - ["regularization", "expected"], - [ - ("l0", np.array([[0, 3, 5]]).T), - ("l1", np.array([[0, 1, 4.5]]).T), - ("l2", np.array([[-2 / 7, 3 / 5, 5 / 2]]).T), - ], + "lam", [np.array([[1, 2]]), np.array([1, 2, 3]), np.array([[1, 2, 3]]).T, 1] ) -def test_get_weighted_prox_1d(regularization, expected): - data = np.array([[-2, 3, 5]]).T - lam = np.array([[3, 2, 0.5]]).T - - prox = get_prox(regularization) - result = prox(data, lam) - assert_array_equal(result, expected) +def test_get_weighted_regularization_bad_shape(regularization, lam): + data = np.array([[-2, 5]]).T + reg = get_regularization(regularization) + with pytest.raises(ValueError): + reg(data, lam) @pytest.mark.parametrize( - ["regularization", "expected"], + ["regularization", "lam", "expected"], [ - ("l0", np.array([[0, 3, 5], [-7, 11, 0]]).T), - ("l1", np.array([[0, 1, 4.5], [-6, 0, 0]]).T), - ("l2", np.array([[-2 / 7, 3 / 5, 5 / 2], [-7 / 3, 11 / 27, 0 / 35]]).T), + ("l0", 3, np.array([[0, 5]]).T), + ("l1", 3, np.array([[0, 2]]).T), + ("l2", 3, np.array([[-2 / 7, 5 / 7]]).T), + ("weighted_l0", np.array([[3, 2]]).T, np.array([[0, 5]]).T), + ("weighted_l1", np.array([[3, 2]]).T, np.array([[0, 3]]).T), + ("weighted_l2", np.array([[3, 2]]).T, np.array([[-2 / 7, 5 / 5]]).T), ], ) -def test_get_weighted_prox_2d(regularization, expected): - data = np.array([[-2, 3, 5], [-7, 11, 0]]).T - lam = np.array([[3, 2, 0.5], [1, 13, 17]]).T +def test_get_prox(regularization, expected, lam): + data = np.array([[-2, 5]]).T prox = get_prox(regularization) result = prox(data, lam)