Skip to content

Commit

Permalink
feat(pysindy): add flush option for model.print
Browse files Browse the repository at this point in the history
  • Loading branch information
himkwtn committed Jul 8, 2024
1 parent 0eefaf8 commit 4ae8fdc
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions pysindy/pysindy.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def equations(self, precision=3):
precision=precision,
)

def print(self, lhs=None, precision=3):
def print(self, lhs=None, precision=3, flush=False):
"""Print the SINDy model equations.
Parameters
Expand All @@ -362,25 +362,27 @@ def print(self, lhs=None, precision=3):
precision: int, optional (default 3)
Precision to be used when printing out model coefficients.
flush: bool, optional (default = False)
If flush is true, the output stream is forcibly flushed.
"""
eqns = self.equations(precision)
if sindy_pi_flag and isinstance(self.optimizer, SINDyPI):
feature_names = self.get_feature_names()
else:
feature_names = self.feature_names
for i, eqn in enumerate(eqns):
names = None
if self.discrete_time:
names = "(" + feature_names[i] + ")"
print(names + "[k+1] = " + eqn)
names = f"({feature_names[i]})[k+1]"
elif lhs is None:
if not sindy_pi_flag or not isinstance(self.optimizer, SINDyPI):
names = "(" + feature_names[i] + ")"
print(names + "' = " + eqn)
names = f"({feature_names[i]})'"
else:
names = feature_names[i]
print(names + " = " + eqn)
names = f"({feature_names[i]})"
else:
print(lhs[i] + " = " + eqn)
names = f"{lhs[i]}"
print(f"{names} = {eqn}", flush=flush)

def score(self, x, t=None, x_dot=None, u=None, metric=r2_score, **metric_kws):
"""
Expand Down Expand Up @@ -657,12 +659,12 @@ def check_stop_condition(xi):
"variables were not used when the model was fit"
)
for i in range(1, t):
x[i] = self.predict(x[i - 1 : i])
x[i] = self.predict(x[i - 1: i])
if check_stop_condition(x[i]):
return x[: i + 1]
else:
for i in range(1, t):
x[i] = self.predict(x[i - 1 : i], u=u[i - 1, np.newaxis])
x[i] = self.predict(x[i - 1: i], u=u[i - 1, np.newaxis])
if check_stop_condition(x[i]):
return x[: i + 1]
return x
Expand Down

0 comments on commit 4ae8fdc

Please sign in to comment.