-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added symbolic math functionality by converting into sympy and back. In particular, the simplify function simplifies an expression and its constants and converts back into the SRToolkit representation.
- Loading branch information
Showing
1 changed file
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import numpy as np | ||
from sympy import sympify, expand | ||
from sympy.core import Mul, Add, Pow | ||
from sympy import symbols as sp_symbols | ||
import re | ||
from SRToolkit.utils.expression_tree import Node | ||
|
||
def simplify(list_of_tokens, constant, variables): | ||
""" | ||
Simplifies a mathematical expression by: | ||
1. making use of sympy's simplification functions | ||
2. simplifying constants, e.g. C*C + C -> C | ||
Examples: | ||
>>> expr = ("C", "+", "C" "*", "C", "+", "X_0", "*", "X_1", "/", "X_0") | ||
>>> print(simplify(expr, "C", ["X_0", "X_1"])) | ||
C+X_1 | ||
Args: | ||
list_of_tokens (list): List of tokens representing the expression. | ||
constant (str): The character representing numerical constants. | ||
variables (list): List of characters representing variables. | ||
Returns: | ||
Node: simplified SRtoolkit expression tree. | ||
""" | ||
ex = simplify_expression("".join(list_of_tokens), constant, variables) | ||
ex = sympify(denumerate_constants(str(ex), constant), evaluate=False) | ||
return sympy_to_sr(ex) | ||
|
||
def sympy_to_sr(expr): | ||
""" | ||
Converts a Sympy expression into an SRtoolkit tree node, explicitly handling left-associative division. | ||
Args: | ||
expr (sympy.Expr): The Sympy expression. | ||
Returns: | ||
Node: The root node of the SRtoolkit expression tree. | ||
""" | ||
if expr.is_Number or expr.is_Symbol: | ||
return Node(str(expr)) | ||
|
||
if expr.is_Function: | ||
func_name = expr.func.__name__ | ||
arg = sympy_to_sr(expr.args[0]) | ||
return Node(func_name, left=arg) | ||
|
||
if isinstance(expr, Add): | ||
args = expr.as_ordered_terms() | ||
# Detect subtraction | ||
if len(args) == 2 and args[1].is_Mul and args[1].args[0] == -1: | ||
return Node('-', sympy_to_sr(-args[1]), sympy_to_sr(args[0])) | ||
# Handle regular addition | ||
root = Node('+', sympy_to_sr(args[1]), sympy_to_sr(args[0])) | ||
for term in args[2:]: | ||
root = Node('+', sympy_to_sr(term), root) | ||
return root | ||
|
||
if isinstance(expr, Mul): | ||
# Process factors explicitly, ensuring left-to-right associativity | ||
factors = list(expr.args) | ||
root = sympy_to_sr(factors[0]) # Start with the first factor | ||
for factor in factors[1:]: | ||
if factor.is_Pow and factor.args[1] == -1: # Division | ||
divisor = sympy_to_sr(factor.args[0]) | ||
root = Node('/', divisor, root) # Left-to-right division | ||
else: # Multiplication | ||
multiplicand = sympy_to_sr(factor) | ||
root = Node('*', multiplicand, root) | ||
return root | ||
|
||
if isinstance(expr, Pow): | ||
base, exp = expr.args | ||
return Node('^', sympy_to_sr(exp), sympy_to_sr(base)) | ||
|
||
if expr.is_Rational and expr.q != 1: | ||
# Handle rational division (e.g., 2/3) | ||
return Node('/', sympy_to_sr(expr.q), sympy_to_sr(expr.p)) | ||
|
||
raise ValueError(f"Unsupported Sympy expression: {expr}") | ||
|
||
def simplify_constants(eq, c, var): | ||
""" Simplifies the constants in a Sympy expression. output[2][0][1] is the simplified expression. | ||
Args: | ||
eq: The Sympy expression. | ||
c: The constant symbol. | ||
var: List of symbols representing variables. | ||
Returns: | ||
- bool: True if the expression contains a variable. | ||
- bool: True if the expression contains the constant. | ||
- list: List of tuples containing the original and simplified expressions | ||
""" | ||
if len(eq.args) == 0: | ||
if eq in var: | ||
return True, False, [(eq, eq)] | ||
elif str(eq)[0] == str(c): | ||
return False, True, [(eq, eq)] | ||
else: | ||
return False, False, [(eq, eq)] | ||
else: | ||
has_var, has_c, subs = [], [], [] | ||
for a in eq.args: | ||
a_rec = simplify_constants (a, c, var) | ||
has_var += [a_rec[0]]; has_c += [a_rec[1]]; subs += [a_rec[2]] | ||
if sum(has_var) == 0 and True in has_c: | ||
return False, True, [(eq, c)] | ||
else: | ||
args = [] | ||
if isinstance(eq, (Add, Mul, Pow)): | ||
has_free_c = False | ||
if True in [has_c[i] and not has_var[i] for i in range(len(has_c))]: | ||
has_free_c = True | ||
|
||
for i in range(len(has_var)): | ||
if has_var[i] or (not has_free_c and not has_c[i]): | ||
if len(subs[i]) > 0: | ||
args += [eq.args[i].subs(subs[i])] | ||
else: | ||
args += [eq.args[i]] | ||
if has_free_c: | ||
args += [c] | ||
|
||
else: | ||
for i in range(len(has_var)): | ||
if len(subs[i]) > 0: | ||
args += [eq.args[i].subs(subs[i])] | ||
else: | ||
args += [eq.args[i]] | ||
return True in has_var, True in has_c, [(eq, eq.func(*args))] | ||
|
||
def enumerate_constants(expr, constant): | ||
""" Enumerates the constants in a Sympy expression. | ||
Example: C*x**2 + C*x + C -> C0*x**2 + C1*x + C2 | ||
Input: | ||
expr - Sympy expression | ||
constant - constant symbol | ||
Returns: | ||
Sympy expression with enumerated constants | ||
list of enumerated constants""" | ||
|
||
char_list = np.array(list(str(expr)), dtype='<U16') | ||
constind = np.where(char_list == constant)[0] | ||
""" Rename all constants: c -> cn, where n is the index of the associated term""" | ||
constants = [constant+str(i) for i in range(len(constind))] | ||
char_list[constind] = constants | ||
return sympify("".join(char_list)), tuple(constants) | ||
|
||
def denumerate_constants(expr, constant): | ||
""" Removes the enumeration of constants in a Sympy expression. | ||
Args: | ||
expr: Sympy expression | ||
constant: constant symbol | ||
Returns: | ||
Sympy expression with denumerated constants | ||
""" | ||
return re.sub(f'{constant}\\d', constant, expr) | ||
|
||
def simplify_expression (expr_str, constant, variables): | ||
"""Simplifies a mathematical expression. | ||
Args: | ||
expr_str: String representing the expression. | ||
constant: The character representing numerical constants. | ||
variables: List of characters representing variables. | ||
Returns: | ||
expr: Sympy expression object in canonical form. | ||
symbols_params: Tuple of enumerated constants. | ||
""" | ||
x = [sp_symbols(s.strip("'")) for s in variables] | ||
c = sp_symbols(constant) | ||
|
||
expr, _ = enumerate_constants(expr_str, constant) | ||
expr = simplify_constants(expr, c, x)[2][0][1] | ||
expr, _ = enumerate_constants(expr, constant) | ||
expr = expand(expr) | ||
expr = simplify_constants(expr, c, x)[2][0][1] | ||
expr, _ = enumerate_constants(expr, constant) | ||
return expr | ||
|
||
|
||
if __name__ == "__main__": | ||
expr = ("C", "+", "C" "*", "C", "+", "X_0", "*", "X_1", "/", "X_0") | ||
print(simplify(expr, "C", ["X_0", "X_1"])) |