Skip to content

Commit

Permalink
Fix for parameter fitting when expression is given as an instance of …
Browse files Browse the repository at this point in the history
…Node
  • Loading branch information
smeznar committed Jan 27, 2025
1 parent ca5ea71 commit c213a97
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 5 additions & 1 deletion SRToolkit/evaluation/parameter_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ def estimate_parameters(self, expr: Union[List[str], Node]) -> Tuple[float, np.n
If the number of constants in the expression exceeds the maximum allowed, NaN and an empty array are returned.
If there are no constants in the expression, the RMSE is calculated directly without optimization.
"""
num_constants = sum([1 for t in expr if self.symbol_library.get_type(t) == "const"])
if isinstance(expr, Node):
expr_str = expr.to_list(notation="prefix")
num_constants = sum([1 for t in expr_str if self.symbol_library.get_type(t) == "const"])
else:
num_constants = sum([1 for t in expr if self.symbol_library.get_type(t) == "const"])
if 0 <= self.estimation_settings["max_constants"] < num_constants:
return np.nan, np.array([])

Expand Down
8 changes: 6 additions & 2 deletions SRToolkit/evaluation/sr_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def evaluate_expr(self, expr: Union[List[str], Node]) -> float:
f"Maximum number of evaluations ({self.max_evaluations}) reached. Stopping evaluation.")
return np.nan
else:
expr_str = "".join(expr)
if isinstance(expr, Node):
expr_list = expr.to_list(symbol_library=self.symbol_library)
else:
expr_list = expr
expr_str = "".join(expr_list)
if expr_str in self.models:
# print(f"Already evaluated {expr_str}")
# print(self.models[expr_str])
Expand All @@ -115,7 +119,7 @@ def evaluate_expr(self, expr: Union[List[str], Node]) -> float:
self.models[expr_str] = {
"rmse": rmse,
"parameters": parameters,
"expr": expr,
"expr": expr_list,
}
return rmse

Expand Down

0 comments on commit c213a97

Please sign in to comment.