Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH(sr3): refactor threshold and thresholder property #548

Merged
merged 50 commits into from
Sep 17, 2024

Conversation

himkwtn
Copy link
Collaborator

@himkwtn himkwtn commented Aug 20, 2024

Threshold parameter is specific to l0 norm. This PR aims to generalize SR3 to other norm by moving threshold calculation to a separate function and using generic regularization weight instead.

  • rename to threshold to reg_weight_lam and thresholder to regularizer
  • update all related test cases
  • fix all child classes

@Jacob-Stevens-Haas
Copy link
Member

Jacob-Stevens-Haas commented Sep 3, 2024

Thanks Watcharin! We can talk about this today, but this PR should target master. If it's awaiting #394, it should be made into a draft PR

Copy link
Member

@Jacob-Stevens-Haas Jacob-Stevens-Haas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, now check the rest of examples/ for python files with SR3(threshold=...)

Don't forget the ipynb files! Run examples/publish_notebook.py on the example scripts

Comment on lines +156 to +138
if regularizer == "l1":
cost = cp.sum_squares(x[:, i] - x @ xi) + cp.sum(
cp.multiply(lam, cp.abs(xi))
)
elif regularizer == "weighted_l1":
cost = cp.sum_squares(x[:, i] - x @ xi) + cp.sum(
cp.multiply(lam[:, i], cp.abs(xi))
)
elif regularizer == "l2":
cost = cp.sum_squares(x[:, i] - x @ xi) + cp.sum(
cp.multiply(lam, xi**2)
)
elif regularizer == "weighted_l2":
cost = cp.sum_squares(x[:, i] - x @ xi) + cp.sum(
cp.multiply(lam[:, i], xi**2)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we find a way to use _calculate_penalty?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is slightly different from ConstrainedSR3. I'm not sure if we can combine them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it doesn't work, override the superclass _calculate_penalty()

Base automatically changed from bug-394-regularization to master September 10, 2024 20:31
# Conflicts:
#	examples/1_feature_overview/example.ipynb
#	examples/1_feature_overview/example.py
#	pysindy/optimizers/constrained_sr3.py
#	pysindy/optimizers/sr3.py
#	pysindy/optimizers/stable_linear_sr3.py
#	test/test_optimizers.py
Copy link

codecov bot commented Sep 11, 2024

Codecov Report

Attention: Patch coverage is 96.07843% with 2 lines in your changes missing coverage. Please review.

Project coverage is 94.64%. Comparing base (3d011fc) to head (52e7f2b).
Report is 1 commits behind head on master.

Files with missing lines Patch % Lines
pysindy/optimizers/trapping_sr3.py 71.42% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #548      +/-   ##
==========================================
+ Coverage   94.23%   94.64%   +0.40%     
==========================================
  Files          37       37              
  Lines        4045     4018      -27     
==========================================
- Hits         3812     3803       -9     
+ Misses        233      215      -18     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@Jacob-Stevens-Haas Jacob-Stevens-Haas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost done with this! Mostly just a linting issue

Comment on lines +156 to +138
if regularizer == "l1":
cost = cp.sum_squares(x[:, i] - x @ xi) + cp.sum(
cp.multiply(lam, cp.abs(xi))
)
elif regularizer == "weighted_l1":
cost = cp.sum_squares(x[:, i] - x @ xi) + cp.sum(
cp.multiply(lam[:, i], cp.abs(xi))
)
elif regularizer == "l2":
cost = cp.sum_squares(x[:, i] - x @ xi) + cp.sum(
cp.multiply(lam, xi**2)
)
elif regularizer == "weighted_l2":
cost = cp.sum_squares(x[:, i] - x @ xi) + cp.sum(
cp.multiply(lam[:, i], xi**2)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it doesn't work, override the superclass _calculate_penalty()

@Jacob-Stevens-Haas Jacob-Stevens-Haas merged commit 137b19d into master Sep 17, 2024
8 checks passed
@Jacob-Stevens-Haas Jacob-Stevens-Haas deleted the refactor-sr3 branch September 17, 2024 20:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants