Skip to content

Commit

Permalink
simplification through sympy
Browse files Browse the repository at this point in the history
Added symbolic math functionality by converting into sympy and back.
In particular, the simplify function simplifies an expression and its constants and converts back into the SRToolkit representation.
  • Loading branch information
brencej committed Jan 29, 2025
1 parent 9a2b805 commit 5c07395
Showing 1 changed file with 192 additions and 0 deletions.
192 changes: 192 additions & 0 deletions SRToolkit/utils/symbolic_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import numpy as np
from sympy import sympify, expand
from sympy.core import Mul, Add, Pow
from sympy import symbols as sp_symbols
import re
from SRToolkit.utils.expression_tree import Node

def simplify(list_of_tokens, constant, variables):
"""
Simplifies a mathematical expression by:
1. making use of sympy's simplification functions
2. simplifying constants, e.g. C*C + C -> C
Examples:
>>> expr = ("C", "+", "C" "*", "C", "+", "X_0", "*", "X_1", "/", "X_0")
>>> print(simplify(expr, "C", ["X_0", "X_1"]))
C+X_1
Args:
list_of_tokens (list): List of tokens representing the expression.
constant (str): The character representing numerical constants.
variables (list): List of characters representing variables.
Returns:
Node: simplified SRtoolkit expression tree.
"""
ex = simplify_expression("".join(list_of_tokens), constant, variables)
ex = sympify(denumerate_constants(str(ex), constant), evaluate=False)
return sympy_to_sr(ex)

def sympy_to_sr(expr):
"""
Converts a Sympy expression into an SRtoolkit tree node, explicitly handling left-associative division.
Args:
expr (sympy.Expr): The Sympy expression.
Returns:
Node: The root node of the SRtoolkit expression tree.
"""
if expr.is_Number or expr.is_Symbol:
return Node(str(expr))

if expr.is_Function:
func_name = expr.func.__name__
arg = sympy_to_sr(expr.args[0])
return Node(func_name, left=arg)

if isinstance(expr, Add):
args = expr.as_ordered_terms()
# Detect subtraction
if len(args) == 2 and args[1].is_Mul and args[1].args[0] == -1:
return Node('-', sympy_to_sr(-args[1]), sympy_to_sr(args[0]))
# Handle regular addition
root = Node('+', sympy_to_sr(args[1]), sympy_to_sr(args[0]))
for term in args[2:]:
root = Node('+', sympy_to_sr(term), root)
return root

if isinstance(expr, Mul):
# Process factors explicitly, ensuring left-to-right associativity
factors = list(expr.args)
root = sympy_to_sr(factors[0]) # Start with the first factor
for factor in factors[1:]:
if factor.is_Pow and factor.args[1] == -1: # Division
divisor = sympy_to_sr(factor.args[0])
root = Node('/', divisor, root) # Left-to-right division
else: # Multiplication
multiplicand = sympy_to_sr(factor)
root = Node('*', multiplicand, root)
return root

if isinstance(expr, Pow):
base, exp = expr.args
return Node('^', sympy_to_sr(exp), sympy_to_sr(base))

if expr.is_Rational and expr.q != 1:
# Handle rational division (e.g., 2/3)
return Node('/', sympy_to_sr(expr.q), sympy_to_sr(expr.p))

raise ValueError(f"Unsupported Sympy expression: {expr}")

def simplify_constants(eq, c, var):
""" Simplifies the constants in a Sympy expression. output[2][0][1] is the simplified expression.
Args:
eq: The Sympy expression.
c: The constant symbol.
var: List of symbols representing variables.
Returns:
- bool: True if the expression contains a variable.
- bool: True if the expression contains the constant.
- list: List of tuples containing the original and simplified expressions
"""
if len(eq.args) == 0:
if eq in var:
return True, False, [(eq, eq)]
elif str(eq)[0] == str(c):
return False, True, [(eq, eq)]
else:
return False, False, [(eq, eq)]
else:
has_var, has_c, subs = [], [], []
for a in eq.args:
a_rec = simplify_constants (a, c, var)
has_var += [a_rec[0]]; has_c += [a_rec[1]]; subs += [a_rec[2]]
if sum(has_var) == 0 and True in has_c:
return False, True, [(eq, c)]
else:
args = []
if isinstance(eq, (Add, Mul, Pow)):
has_free_c = False
if True in [has_c[i] and not has_var[i] for i in range(len(has_c))]:
has_free_c = True

for i in range(len(has_var)):
if has_var[i] or (not has_free_c and not has_c[i]):
if len(subs[i]) > 0:
args += [eq.args[i].subs(subs[i])]
else:
args += [eq.args[i]]
if has_free_c:
args += [c]

else:
for i in range(len(has_var)):
if len(subs[i]) > 0:
args += [eq.args[i].subs(subs[i])]
else:
args += [eq.args[i]]
return True in has_var, True in has_c, [(eq, eq.func(*args))]

def enumerate_constants(expr, constant):
""" Enumerates the constants in a Sympy expression.
Example: C*x**2 + C*x + C -> C0*x**2 + C1*x + C2
Input:
expr - Sympy expression
constant - constant symbol
Returns:
Sympy expression with enumerated constants
list of enumerated constants"""

char_list = np.array(list(str(expr)), dtype='<U16')
constind = np.where(char_list == constant)[0]
""" Rename all constants: c -> cn, where n is the index of the associated term"""
constants = [constant+str(i) for i in range(len(constind))]
char_list[constind] = constants
return sympify("".join(char_list)), tuple(constants)

def denumerate_constants(expr, constant):
""" Removes the enumeration of constants in a Sympy expression.
Args:
expr: Sympy expression
constant: constant symbol
Returns:
Sympy expression with denumerated constants
"""
return re.sub(f'{constant}\\d', constant, expr)

def simplify_expression (expr_str, constant, variables):
"""Simplifies a mathematical expression.
Args:
expr_str: String representing the expression.
constant: The character representing numerical constants.
variables: List of characters representing variables.
Returns:
expr: Sympy expression object in canonical form.
symbols_params: Tuple of enumerated constants.
"""
x = [sp_symbols(s.strip("'")) for s in variables]
c = sp_symbols(constant)

expr, _ = enumerate_constants(expr_str, constant)
expr = simplify_constants(expr, c, x)[2][0][1]
expr, _ = enumerate_constants(expr, constant)
expr = expand(expr)
expr = simplify_constants(expr, c, x)[2][0][1]
expr, _ = enumerate_constants(expr, constant)
return expr


if __name__ == "__main__":
expr = ("C", "+", "C" "*", "C", "+", "X_0", "*", "X_1", "/", "X_0")
print(simplify(expr, "C", ["X_0", "X_1"]))

0 comments on commit 5c07395

Please sign in to comment.