From afa25dbf8e7841289fb151b0cff71925db850396 Mon Sep 17 00:00:00 2001 From: Jake Stevens-Haas <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Sun, 26 May 2024 12:51:49 -0700 Subject: [PATCH] cln: offload some of _solve_sparse_relax_and_split() to ConstrainedSR3 This pulls in the changes to ConstrainedSR3 from trapping_resolve --- pysindy/optimizers/constrained_sr3.py | 149 +++++++++++++++----------- pysindy/optimizers/trapping_sr3.py | 138 +++++++++--------------- 2 files changed, 139 insertions(+), 148 deletions(-) diff --git a/pysindy/optimizers/constrained_sr3.py b/pysindy/optimizers/constrained_sr3.py index f00445b9a..eb20cf159 100644 --- a/pysindy/optimizers/constrained_sr3.py +++ b/pysindy/optimizers/constrained_sr3.py @@ -1,4 +1,7 @@ import warnings +from copy import deepcopy +from typing import Optional +from typing import Tuple try: import cvxpy as cp @@ -39,6 +42,9 @@ class ConstrainedSR3(SR3): to learn parsimonious physics-informed models from data." IEEE Access 8 (2020): 169259-169271. + Zheng, Peng, et al. "A unified framework for sparse relaxed + regularized regression: Sr3." IEEE Access 7 (2018): 1404-1423. + Parameters ---------- threshold : float, optional (default 0.1) @@ -66,14 +72,10 @@ class ConstrainedSR3(SR3): max_iter : int, optional (default 30) Maximum iterations of the optimization algorithm. - fit_intercept : boolean, optional (default False) - Whether to calculate the intercept for this model. If set to false, no - intercept will be used in calculations. - constraint_lhs : numpy ndarray, optional (default None) Shape should be (n_constraints, n_features * n_targets), - The left hand side matrix C of Cw <= d. - There should be one row per constraint. + The left hand side matrix C of Cw <= d (Or Cw = d for equality + constraints). There should be one row per constraint. constraint_rhs : numpy ndarray, shape (n_constraints,), optional (default None) The right hand side vector d of Cw <= d. @@ -97,9 +99,6 @@ class ConstrainedSR3(SR3): is deprecated in sklearn versions >= 1.0 and will be removed. Note that this parameter is incompatible with the constraints! - copy_X : boolean, optional (default True) - If True, X will be copied; else, it may be overwritten. - initial_guess : np.ndarray, optional (default None) Shape should be (n_features) or (n_targets, n_features). Initial guess for coefficients ``coef_``, (v in the mathematical equations) @@ -128,6 +127,10 @@ class ConstrainedSR3(SR3): output should be verbose or not. Only relevant for optimizers that use the CVXPY package in some capabity. + unbias: bool (default False) + See base class for definition. Most options are incompatible + with unbiasing. + Attributes ---------- coef_ : array, shape (n_features,) or (n_targets, n_features) @@ -138,11 +141,15 @@ class ConstrainedSR3(SR3): Weight vector(s) that are not subjected to the regularization. This is the w in the objective function. - unbias : boolean - Whether to perform an extra step of unregularized linear regression - to unbias the coefficients for the identified support. - ``unbias`` is automatically set to False if a constraint is used and - is otherwise left uninitialized. + history_ : list + History of sparse coefficients. ``history_[k]`` contains the + sparse coefficients (v in the optimization objective function) + at iteration k. + + objective_history_ : list + History of the value of the objective at each step. Note that + the trapping SINDy problem is nonconvex, meaning that this value + may increase and decrease as the algorithm works. """ def __init__( @@ -158,17 +165,17 @@ def __init__( constraint_rhs=None, constraint_order="target", normalize_columns=False, - fit_intercept=False, copy_X=True, initial_guess=None, thresholds=None, equality_constraints=False, inequality_constraints=False, - constraint_separation_index=0, + constraint_separation_index: Optional[bool] = None, verbose=False, verbose_cvxpy=False, + unbias=False, ): - super(ConstrainedSR3, self).__init__( + super().__init__( threshold=threshold, nu=nu, tol=tol, @@ -178,10 +185,10 @@ def __init__( trimming_step_size=trimming_step_size, max_iter=max_iter, initial_guess=initial_guess, - fit_intercept=fit_intercept, copy_X=copy_X, normalize_columns=normalize_columns, verbose=verbose, + unbias=unbias, ) self.verbose_cvxpy = verbose_cvxpy @@ -189,7 +196,7 @@ def __init__( self.constraint_lhs = constraint_lhs self.constraint_rhs = constraint_rhs self.constraint_order = constraint_order - self.use_constraints = (constraint_lhs is not None) and ( + self.use_constraints = (constraint_lhs is not None) or ( constraint_rhs is not None ) @@ -203,15 +210,18 @@ def __init__( " but user did not specify if the constraints were equality or" " inequality constraints. Assuming equality constraints." ) - self.equality_constraints = True + equality_constraints = True if self.use_constraints: if constraint_order not in ("feature", "target"): raise ValueError( "constraint_order must be either 'feature' or 'target'" ) - - self.unbias = False + if unbias: + raise ValueError( + "Constraints are incompatible with an unbiasing step. Set" + " unbias=False" + ) if inequality_constraints and not cvxpy_flag: raise ValueError( @@ -235,6 +245,16 @@ def __init__( ) self.inequality_constraints = inequality_constraints self.equality_constraints = equality_constraints + if self.use_constraints and constraint_separation_index is None: + if self.inequality_constraints and not self.equality_constraints: + constraint_separation_index = len(constraint_lhs) + elif self.equality_constraints and not self.inequality_constraints: + constraint_separation_index = 0 + else: + raise ValueError( + "If passing both inequality and equality constraints, must specify" + " constraint_separation_index." + ) self.constraint_separation_index = constraint_separation_index def _update_full_coef_constraints(self, H, x_transpose_y, coef_sparse): @@ -251,62 +271,66 @@ def _update_full_coef_constraints(self, H, x_transpose_y, coef_sparse): rhs = rhs.reshape(g.shape) return inv1.dot(rhs) - def _update_coef_cvxpy(self, x, y, coef_sparse): - xi = cp.Variable(coef_sparse.shape[0] * coef_sparse.shape[1]) - cost = cp.sum_squares(x @ xi - y.flatten()) + def _create_var_and_part_cost( + self, var_len: int, x_expanded: np.ndarray, y: np.ndarray + ) -> Tuple[cp.Variable, cp.Expression]: + xi = cp.Variable(var_len) + cost = cp.sum_squares(x_expanded @ xi - y.flatten()) if self.thresholder.lower() == "l1": cost = cost + self.threshold * cp.norm1(xi) elif self.thresholder.lower() == "weighted_l1": cost = cost + cp.norm1(np.ravel(self.thresholds) @ xi) elif self.thresholder.lower() == "l2": - cost = cost + self.threshold * cp.norm2(xi) + cost = cost + self.threshold * cp.norm2(xi) ** 2 elif self.thresholder.lower() == "weighted_l2": - cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi) + cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi) ** 2 + return xi, cost + + def _update_coef_cvxpy(self, xi, cost, var_len, coef_prev, tol): if self.use_constraints: - if self.inequality_constraints and self.equality_constraints: - # Process inequality constraints then equality constraints - prob = cp.Problem( - cp.Minimize(cost), - [ - self.constraint_lhs[: self.constraint_separation_index, :] @ xi - <= self.constraint_rhs[: self.constraint_separation_index], - self.constraint_lhs[self.constraint_separation_index :, :] @ xi - == self.constraint_rhs[self.constraint_separation_index :], - ], - ) - elif self.inequality_constraints: - prob = cp.Problem( - cp.Minimize(cost), - [self.constraint_lhs @ xi <= self.constraint_rhs], + constraints = [] + if self.equality_constraints: + constraints.append( + self.constraint_lhs[self.constraint_separation_index :, :] @ xi + == self.constraint_rhs[self.constraint_separation_index :], ) - else: - prob = cp.Problem( - cp.Minimize(cost), - [self.constraint_lhs @ xi == self.constraint_rhs], + if self.inequality_constraints: + constraints.append( + self.constraint_lhs[: self.constraint_separation_index, :] @ xi + <= self.constraint_rhs[: self.constraint_separation_index] ) + prob = cp.Problem(cp.Minimize(cost), constraints) else: - prob = cp.Problem(cp.Minimize(cost)) + cp.Problem(cp.Minimize(cost)) - # default solver is OSQP here but switches to ECOS for L2 + prob_clone = deepcopy(prob) + # default solver is SCS/OSQP here but switches to ECOS for L2 try: prob.solve( max_iter=self.max_iter, - eps_abs=self.tol, - eps_rel=self.tol, + eps_abs=tol, + eps_rel=tol, verbose=self.verbose_cvxpy, ) # Annoying error coming from L2 norm switching to use the ECOS # solver, which uses "max_iters" instead of "max_iter", and # similar semantic changes for the other variables. - except TypeError: + except (TypeError, ValueError): try: - prob.solve(abstol=self.tol, reltol=self.tol, verbose=self.verbose_cvxpy) + prob = prob_clone + prob.solve(max_iters=self.max_iter, verbose=self.verbose_cvxpy) + xi = prob.variables()[0] except cp.error.SolverError: - print("Solver failed, setting coefs to zeros") - xi.value = np.zeros(coef_sparse.shape[0] * coef_sparse.shape[1]) + warnings.warn("Solver failed, setting coefs to zeros") + xi.value = np.zeros(var_len) except cp.error.SolverError: - print("Solver failed, setting coefs to zeros") - xi.value = np.zeros(coef_sparse.shape[0] * coef_sparse.shape[1]) + try: + prob = prob_clone + prob.solve(max_iter=self.max_iter, verbose=self.verbose_cvxpy) + xi = prob.variables()[0] + except cp.error.SolverError: + warnings.warn("Solver failed, setting coefs to zeros") + xi.value = np.zeros(var_len) if xi.value is None: warnings.warn( @@ -315,7 +339,7 @@ def _update_coef_cvxpy(self, x, y, coef_sparse): ConvergenceWarning, ) return None - coef_new = (xi.value).reshape(coef_sparse.shape) + coef_new = (xi.value).reshape(coef_prev.shape) return coef_new def _update_sparse_coef(self, coef_full): @@ -422,7 +446,11 @@ def _reduce(self, x, y): objective_history = [] if self.inequality_constraints: - coef_sparse = self._update_coef_cvxpy(x_expanded, y, coef_sparse) + var_len = coef_sparse.shape[0] * coef_sparse.shape[1] + xi, cost = self._create_var_and_part_cost(var_len, x_expanded, y) + coef_sparse = self._update_coef_cvxpy( + xi, cost, var_len, coef_sparse, self.tol + ) objective_history.append(self._objective(x, y, 0, coef_full, coef_sparse)) else: for k in range(self.max_iter): @@ -461,9 +489,8 @@ def _reduce(self, x, y): break else: warnings.warn( - "SR3._reduce did not converge after {} iterations.".format( - self.max_iter - ), + f"ConstrainedSR3 did not converge after {self.max_iter}" + " iterations.", ConvergenceWarning, ) if self.use_constraints and self.constraint_order.lower() == "target": diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index 4e83dbfdf..0e769aac8 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -14,6 +14,7 @@ from numpy.typing import NDArray from sklearn.exceptions import ConvergenceWarning +from ..feature_library.polynomial_library import n_poly_features from ..feature_library.polynomial_library import PolynomialLibrary from ..utils import reorder_constraints from .constrained_sr3 import ConstrainedSR3 @@ -310,8 +311,10 @@ def __init__( thresholder=thresholder, **kwargs, ) + self.method = "global" elif method == "local": super().__init__(thresholder=thresholder, **kwargs) + self.method = "local" else: raise ValueError(f"Can either use 'global' or 'local' method, not {method}") @@ -585,75 +588,23 @@ def _objective(self, x, y, coef_sparse, A, PW, k): ) return R2 + stability_term + L1 + alpha_term + beta_term - def _solve_sparse_relax_and_split(self, r, N, x_expanded, y, Pmatrix, A, coef_prev): + def _update_coef_sparse_rs( + self, r, N, var_len, x_expanded, y, Pmatrix, A, coef_prev + ): """Solve coefficient update with CVXPY if threshold != 0""" - xi = cp.Variable(N * r) - cost = cp.sum_squares(x_expanded @ xi - y.flatten()) - if self.thresholder.lower() == "l1": - cost = cost + self.threshold * cp.norm1(xi) - elif self.thresholder.lower() == "weighted_l1": - cost = cost + cp.norm1(np.ravel(self.thresholds) @ xi) - elif self.thresholder.lower() == "l2": - cost = cost + self.threshold * cp.norm2(xi) ** 2 - elif self.thresholder.lower() == "weighted_l2": - cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi) ** 2 + xi, cost = self._create_var_and_part_cost(var_len, x_expanded, y) cost = cost + cp.sum_squares(Pmatrix @ xi - A.flatten()) / self.eta # new terms minimizing quadratic piece ||P^Q @ xi||_2^2 - Q = np.reshape(self.PQ_, (r * r * r, N * r), "F") - cost = cost + cp.sum_squares(Q @ xi) / self.alpha - Q = np.reshape(self.PQ_, (r, r, r, N * r), "F") - Q_ep = Q + np.transpose(Q, [1, 2, 0, 3]) + np.transpose(Q, [2, 0, 1, 3]) - Q_ep = np.reshape(Q_ep, (r * r * r, N * r), "F") - cost = cost + cp.sum_squares(Q_ep @ xi) / self.beta - - # Constraints - if self.use_constraints: - if self.inequality_constraints: - prob = cp.Problem( - cp.Minimize(cost), - [self.constraint_lhs @ xi <= self.constraint_rhs], - ) - else: - prob = cp.Problem( - cp.Minimize(cost), - [self.constraint_lhs @ xi == self.constraint_rhs], - ) - else: - prob = cp.Problem(cp.Minimize(cost)) - - # default solver is OSQP here but switches to ECOS for L2 - try: - prob.solve( - eps_abs=self.eps_solver, - eps_rel=self.eps_solver, - verbose=self.verbose_cvxpy, - ) - # Annoying error coming from L2 norm switching to use the ECOS - # solver, which uses "max_iters" instead of "max_iter", and - # similar semantic changes for the other variables. - except TypeError: - try: - prob.solve( - abstol=self.eps_solver, - reltol=self.eps_solver, - verbose=self.verbose_cvxpy, - ) - except cp.error.SolverError: - print("Solver failed, setting coefs to zeros") - xi.value = np.zeros(N * r) - except cp.error.SolverError: - print("Solver failed, setting coefs to zeros") - xi.value = np.zeros(N * r) - - if xi.value is None: - warnings.warn( - "Infeasible solve, increase/decrease eta", - ConvergenceWarning, - ) - return None - coef_sparse = (xi.value).reshape(coef_prev.shape) - return coef_sparse + if self.method == "local": + Q = np.reshape(self.PQ_, (r * r * r, N * r), "F") + cost = cost + cp.sum_squares(Q @ xi) / self.alpha + Q = np.reshape(self.PQ_, (r, r, r, N * r), "F") + Q_ep = Q + np.transpose(Q, [1, 2, 0, 3]) + np.transpose(Q, [2, 0, 1, 3]) + Q_ep = np.reshape(Q_ep, (r * r * r, N * r), "F") + cost = cost + cp.sum_squares(Q_ep @ xi) / self.beta + + return self._update_coef_cvxpy(xi, cost, var_len, coef_prev, self.eps_solver) def _solve_m_relax_and_split(self, r, N, m_prev, m, A, coef_sparse, tk_previous): """ @@ -711,16 +662,23 @@ def _reduce(self, x, y): TrappingSR3 algorithm. Assumes initial guess for coefficients is stored in ``self.coef_``. """ - - n_samples, n_features = x.shape - self.n_features = n_features - r = y.shape[1] - N = n_features # int((r ** 2 + 3 * r) / 2.0) - if N > int((r**2 + 3 * r) / 2.0): - self._include_bias = True + self.A_history_ = [] + self.m_history_ = [] + self.p_history_ = [] + self.PW_history_ = [] + self.PWeigs_history_ = [] + self.history_ = [] + n_samples, n_tgts = y.shape + n_features = n_poly_features( + n_tgts, + 2, + include_bias=self._include_bias, + interaction_only=self._interaction_only, + ) + var_len = n_features * n_tgts if self.mod_matrix is None: - self.mod_matrix = np.eye(r) + self.mod_matrix = np.eye(n_tgts) # Define PL, PQ, PT and PM tensors, only relevant if the stability term in # trapping SINDy is turned on. @@ -731,7 +689,7 @@ def _reduce(self, x, y): self.PQ_, self.PT_, self.PM_, - ) = self._set_Ptensors(r) + ) = self._set_Ptensors(n_tgts) # Set initial coefficients if self.use_constraints and self.constraint_order.lower() == "target": @@ -760,9 +718,9 @@ def _reduce(self, x, y): if self.A0 is not None: A = self.A0 elif np.any(self.PM_ != 0.0): - A = np.diag(self.gamma * np.ones(r)) + A = np.diag(self.gamma * np.ones(n_tgts)) else: - A = np.diag(np.zeros(r)) + A = np.diag(np.zeros(n_tgts)) self.A_history_.append(A) # initial guess for m @@ -770,14 +728,14 @@ def _reduce(self, x, y): m = self.m0 else: np.random.seed(1) - m = (np.random.rand(r) - np.ones(r)) * 2 + m = (np.random.rand(n_tgts) - np.ones(n_tgts)) * 2 self.m_history_.append(m) # Precompute some objects for optimization - x_expanded = np.zeros((n_samples, r, n_features, r)) - for i in range(r): + x_expanded = np.zeros((n_samples, n_tgts, n_features, n_tgts)) + for i in range(n_tgts): x_expanded[:, i, :, i] = x - x_expanded = np.reshape(x_expanded, (n_samples * r, r * n_features)) + x_expanded = np.reshape(x_expanded, (n_samples * n_tgts, n_tgts * n_features)) xTx = np.dot(x_expanded.T, x_expanded) xTy = np.dot(x_expanded.T, y.flatten()) @@ -792,13 +750,13 @@ def _reduce(self, x, y): # update P tensor from the newest m mPM = np.tensordot(self.PM_, m, axes=([2], [0])) p = np.tensordot(self.mod_matrix, self.PL_ + mPM, axes=([1], [0])) - Pmatrix = p.reshape(r * r, r * n_features) + Pmatrix = p.reshape(n_tgts * n_tgts, n_tgts * n_features) # update w coef_prev = coef_sparse if (self.threshold > 0.0) or self.inequality_constraints: - coef_sparse = self._solve_sparse_relax_and_split( - r, n_features, x_expanded, y, Pmatrix, A, coef_prev + coef_sparse = self._update_coef_sparse_rs( + n_tgts, n_features, var_len, x_expanded, y, Pmatrix, A, coef_prev ) else: # if threshold = 0, there is analytic expression @@ -807,14 +765,20 @@ def _reduce(self, x, y): pTp = np.dot(Pmatrix.T, Pmatrix) # notice reshaping PQ here requires fortran-ordering PQ = np.tensordot(self.mod_matrix, self.PQ_, axes=([1], [0])) - PQ = np.reshape(PQ, (r * r * r, r * n_features), "F") + PQ = np.reshape( + PQ, (n_tgts * n_tgts * n_tgts, n_tgts * n_features), "F" + ) PQTPQ = np.dot(PQ.T, PQ) - PQ = np.reshape(self.PQ_, (r, r, r, r * n_features), "F") + PQ = np.reshape( + self.PQ_, (n_tgts, n_tgts, n_tgts, n_tgts * n_features), "F" + ) PQ = np.tensordot(self.mod_matrix, PQ, axes=([1], [0])) PQ_ep = ( PQ + np.transpose(PQ, [1, 2, 0, 3]) + np.transpose(PQ, [2, 0, 1, 3]) ) - PQ_ep = np.reshape(PQ_ep, (r * r * r, r * n_features), "F") + PQ_ep = np.reshape( + PQ_ep, (n_tgts * n_tgts * n_tgts, n_tgts * n_features), "F" + ) PQTPQ_ep = np.dot(PQ_ep.T, PQ_ep) H = xTx + pTp / self.eta + PQTPQ / self.alpha + PQTPQ_ep / self.beta P_transpose_A = np.dot(Pmatrix.T, A.flatten()) @@ -829,7 +793,7 @@ def _reduce(self, x, y): # Now solve optimization for m and A m_prev, m, A, tk_prev = self._solve_m_relax_and_split( - r, n_features, m_prev, m, A, coef_sparse, tk_prev + n_tgts, n_features, m_prev, m, A, coef_sparse, tk_prev ) # If problem over m becomes infeasible, break out of the loop