Skip to content

Commit

Permalink
ENH: create unit test for get_regularization shape validation
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Aug 23, 2024
1 parent 89ead0f commit 658c300
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 91 deletions.
26 changes: 14 additions & 12 deletions pysindy/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,18 @@ def reorder_constraints(arr, n_features, output_order="feature"):
return arr.reshape(starting_shape).transpose([0, 2, 1]).reshape((n_constraints, -1))


def validate_prox_and_reg_inputs(func):
@wraps(func)
def validate_prox_and_reg_inputs(func, regularization):
def wrapper(x, regularization_weight):
# Example validation: check if both a and b are positive integers
if isinstance(regularization_weight, np.ndarray) and (
regularization_weight.shape != x.shape
and regularization_weight.shape != (1, 1)
):
if regularization[:8] == 'weighted' and \
(not isinstance(regularization_weight, np.ndarray) or (regularization_weight.shape != x.shape)):
raise ValueError(
f"Invalid shape for 'regularization_weight': {regularization_weight.shape}. Must be the same shape as x: {x.shape}."
f"Invalid shape for 'regularization_weight': {
regularization_weight.shape if isinstance(regularization, np.ndarray) else '()'}. Must be the same shape as x: {x.shape}."
)

elif regularization[:8] != 'weighted' and not isinstance(regularization_weight, (int, float)) \
and (isinstance(regularization_weight, np.ndarray) and regularization_weight.shape not in [(1, 1), (1,)]):
raise ValueError("'regularization_weight' must be a scalar")
# If validation passes, call the original function
return func(x, regularization_weight)

Expand Down Expand Up @@ -235,8 +235,9 @@ def prox_weighted_l2(
"l2": prox_l2,
"weighted_l2": prox_weighted_l2,
}
if regularization.lower() in prox:
return validate_prox_and_reg_inputs(prox[regularization.lower()])
regularization = regularization.lower()
if regularization in prox:
return validate_prox_and_reg_inputs(prox[regularization], regularization)
else:
raise NotImplementedError("{} has not been implemented".format(regularization))

Expand Down Expand Up @@ -291,8 +292,9 @@ def regualization_weighted_l2(
"l2": regularization_l2,
"weighted_l2": regualization_weighted_l2,
}
if regularization.lower() in regularization_fn:
return validate_prox_and_reg_inputs(regularization_fn[regularization.lower()])
regularization = regularization.lower()
if regularization in regularization_fn:
return validate_prox_and_reg_inputs(regularization_fn[regularization], regularization)
else:
raise NotImplementedError("{} has not been implemented".format(regularization))

Expand Down
122 changes: 43 additions & 79 deletions test/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,117 +65,81 @@ def test_validate_controls():


@pytest.mark.parametrize(
["regularization", "expected"], [("l0", 6), ("l1", 20), ("l2", 76)]
["regularization", "lam", "expected"],
[
("l0", 2, 4),
("l1", 2, 14),
("l2", 2, 58),
("weighted_l0", np.array([[3, 2]]).T, 5),
("weighted_l1", np.array([[3, 2]]).T, 16),
("weighted_l2", np.array([[3, 2]]).T, 62),
],
)
def test_get_regularization_1d(regularization, expected):
data = np.array([[-2, 3, 5]]).T
lam = np.array([[2]])
def test_get_regularization(regularization, lam, expected):
data = np.array([[-2, 5]]).T

reg = get_regularization(regularization)
result = reg(data, lam)
assert result == expected


@pytest.mark.parametrize(
["regularization", "expected"], [("l0", 10), ("l1", 56), ("l2", 416)]
)
def test_get_regularization_2d(regularization, expected):
data = np.array([[-2, 3, 5], [7, 11, 0]]).T
lam = np.array([[2]])

@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
@pytest.mark.parametrize("lam", [1, np.array([1]), np.array([[1]])])
def test_get_regularization_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
result = reg(data, lam)
assert result == expected
assert result != None


@pytest.mark.parametrize(
["regularization", "expected"],
[("weighted_l0", 5.5), ("weighted_l1", 14.5), ("weighted_l2", 42.5)],
"regularization", ["weighted_l0", "weighted_l1", "weighted_l2"]
)
def test_get_weighted_regularization_1d(regularization, expected):
data = np.array([[-2, 3, 5]]).T
lam = np.array([[3, 2, 0.5]]).T

@pytest.mark.parametrize("lam", [np.array([[1, 2]]).T])
def test_get_weighted_regularization_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
result = reg(data, lam)
assert result == expected
assert result != None


@pytest.mark.parametrize("regularization", ["l0", "l1", "l2"])
@pytest.mark.parametrize(
["regularization", "expected"],
[("weighted_l0", 19.5), ("weighted_l1", 164.5), ("weighted_l2", 1664.5)],
"lam", [np.array([[1, 2]]), np.array([1, 2]), np.array([[1, 2]]).T]
)
def test_get_weighted_regularization_2d(regularization, expected):
data = np.array([[-2, 3, 5], [7, 11, 0]]).T
lam = np.array([[3, 2, 0.5], [1, 13, 17]]).T

def test_get_regularization_bad_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
result = reg(data, lam)
assert result == expected


@pytest.mark.parametrize(
["regularization", "expected"],
[
("l0", np.array([[0, 3, 5]]).T),
("l1", np.array([[0, 0, 2]]).T),
("l2", np.array([[-2 / 7, 3 / 7, 5 / 7]]).T),
],
)
def test_get_prox_1d(regularization, expected):
data = np.array([[-2, 3, 5]]).T
lam = np.array([[3]])

prox = get_prox(regularization)
result = prox(data, lam)
assert_array_equal(result, expected)
with pytest.raises(ValueError):
reg(data, lam)


@pytest.mark.parametrize(
["regularization", "expected"],
[
("l0", np.array([[0, 3, 5], [-7, 11, 0]]).T),
("l1", np.array([[0, 0, 2], [-4, 8, 0]]).T),
("l2", np.array([[-2 / 7, 3 / 7, 5 / 7], [-7 / 7, 11 / 7, 0 / 7]]).T),
],
"regularization", ["weighted_l0", "weighted_l1", "weighted_l2"]
)
def test_get_prox_2d(regularization, expected):
data = np.array([[-2, 3, 5], [-7, 11, 0]]).T
lam = np.array([[3]])

prox = get_prox(regularization)
result = prox(data, lam)
assert_array_equal(result, expected)


@pytest.mark.parametrize(
["regularization", "expected"],
[
("l0", np.array([[0, 3, 5]]).T),
("l1", np.array([[0, 1, 4.5]]).T),
("l2", np.array([[-2 / 7, 3 / 5, 5 / 2]]).T),
],
"lam", [np.array([[1, 2]]), np.array([1, 2, 3]), np.array([[1, 2, 3]]).T, 1]
)
def test_get_weighted_prox_1d(regularization, expected):
data = np.array([[-2, 3, 5]]).T
lam = np.array([[3, 2, 0.5]]).T

prox = get_prox(regularization)
result = prox(data, lam)
assert_array_equal(result, expected)
def test_get_weighted_regularization_bad_shape(regularization, lam):
data = np.array([[-2, 5]]).T
reg = get_regularization(regularization)
with pytest.raises(ValueError):
reg(data, lam)


@pytest.mark.parametrize(
["regularization", "expected"],
["regularization", "lam", "expected"],
[
("l0", np.array([[0, 3, 5], [-7, 11, 0]]).T),
("l1", np.array([[0, 1, 4.5], [-6, 0, 0]]).T),
("l2", np.array([[-2 / 7, 3 / 5, 5 / 2], [-7 / 3, 11 / 27, 0 / 35]]).T),
("l0", 3, np.array([[0, 5]]).T),
("l1", 3, np.array([[0, 2]]).T),
("l2", 3, np.array([[-2 / 7, 5 / 7]]).T),
("weighted_l0", np.array([[3, 2]]).T, np.array([[0, 5]]).T),
("weighted_l1", np.array([[3, 2]]).T, np.array([[0, 3]]).T),
("weighted_l2", np.array([[3, 2]]).T, np.array([[-2 / 7, 5 / 5]]).T),
],
)
def test_get_weighted_prox_2d(regularization, expected):
data = np.array([[-2, 3, 5], [-7, 11, 0]]).T
lam = np.array([[3, 2, 0.5], [1, 13, 17]]).T
def test_get_prox(regularization, expected, lam):
data = np.array([[-2, 5]]).T

prox = get_prox(regularization)
result = prox(data, lam)
Expand Down

0 comments on commit 658c300

Please sign in to comment.