Skip to content

Commit

Permalink
Small changes to evaluator and expression simplification to make it b…
Browse files Browse the repository at this point in the history
…etter to use
  • Loading branch information
smeznar committed Jan 30, 2025
1 parent f186df9 commit 58c444a
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
14 changes: 13 additions & 1 deletion SRToolkit/evaluation/sr_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
get_results(top_k): Returns the results of the evaluation.
"""
self.models = dict()
self.invalid = list()
self.metadata = metadata
self.symbol_library = symbol_library
self.max_evaluations = max_evaluations
Expand Down Expand Up @@ -116,7 +117,18 @@ def evaluate_expr(self, expr: Union[List[str], Node], simplify_expr=False) -> fl
return np.nan
else:
if simplify_expr:
expr = simplify(expr, self.symbol_library)
try:
expr = simplify(expr, self.symbol_library)
except Exception as e:
if isinstance(expr, Node):
expr_list = expr.to_list(symbol_library=self.symbol_library)
else:
expr_list = expr
print(f"Unable to simplify: {''.join(expr_list)}, problems with subexpression {e}")

self.invalid.append(expr_list)
return np.inf


if isinstance(expr, Node):
expr_list = expr.to_list(symbol_library=self.symbol_library)
Expand Down
11 changes: 6 additions & 5 deletions SRToolkit/utils/expression_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
from sympy import sympify, expand, Expr, Basic
from sympy.core import Mul, Add, Pow
from sympy.core import Mul, Add, Pow, Symbol
from sympy import symbols as sp_symbols
import re

Expand Down Expand Up @@ -106,7 +106,7 @@ def _sympy_to_sr(expr: Union[Expr, Basic]) -> Node:
if expr.is_Rational and expr.q != 1:
return Node('/', _sympy_to_sr(expr.q), _sympy_to_sr(expr.p))

raise ValueError(f"Unsupported Sympy expression: {expr}")
raise ValueError(f"{expr}")


def _simplify_constants(eq, c, var):
Expand All @@ -125,8 +125,8 @@ def _simplify_constants(eq, c, var):
if len(eq.args) == 0:
if eq in var:
return True, False, [(eq, eq)]
elif str(eq)[0] == str(c):
return False, True, [(eq, eq)]
elif eq in eq.free_symbols:
return False, True, [(eq, c)]
else:
return False, False, [(eq, eq)]
else:
Expand Down Expand Up @@ -215,7 +215,8 @@ def _simplify_expression (expr_str, constant, variables):
expr, _ = _enumerate_constants(expr, constant)
expr = expand(expr)
expr = _simplify_constants(expr, c, x)[2][0][1]
expr, _ = _enumerate_constants(expr, constant)
# expr, _ = _enumerate_constants(expr, constant)
# expr = _simplify_constants(expr, c, x)[2][0][1]
return expr


Expand Down
8 changes: 6 additions & 2 deletions examples/SR_evaluation_minimal_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ def read_expressions(filename):
start_time = time.time()

# Evaluate expressions one by one
for expr in expressions:
evaluator.evaluate_expr(expr)
for i, expr in enumerate(expressions):
# print(f"{i}: Expr: {''.join(expr)}")
evaluator.evaluate_expr(expr, simplify_expr=True)

print(f"Total time: {time.time() - start_time}")

# Get and print the results
print(evaluator.get_results())

# Simplified: 396.6958358287811
# Non simplified: 156.62710046768188

0 comments on commit 58c444a

Please sign in to comment.