Skip to content

Commit

Permalink
update field validation and ploting
Browse files Browse the repository at this point in the history
  • Loading branch information
obouchaara committed Dec 15, 2023
1 parent b725edd commit 9b07413
Show file tree
Hide file tree
Showing 5 changed files with 333 additions and 65 deletions.
8 changes: 4 additions & 4 deletions notebooks/coord.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
{
"data": {
"text/plain": [
"SymbolicCartesianCoordSystem(origin=[0, 0, 0], basis_symbols=[x, y, z])"
"SymbolicCartesianCoordSystem(origin=(0, 0, 0), basis_symbols=(x, y, z))"
]
},
"metadata": {},
Expand All @@ -79,7 +79,7 @@
{
"data": {
"text/plain": [
"SymbolicCylindricalCoordSystem(origin=[0, 0, 0], basis_symbols=[r, theta, z])"
"SymbolicCylindricalCoordSystem(origin=(0, 0, 0), basis_symbols=(r, theta, z))"
]
},
"metadata": {},
Expand Down Expand Up @@ -116,7 +116,7 @@
{
"data": {
"text/plain": [
"SymbolicCylindricalCoordSystem(origin=[0, 0, 0], basis_symbols=[r, theta, z])"
"SymbolicCylindricalCoordSystem(origin=(0, 0, 0), basis_symbols=(r, theta, z))"
]
},
"metadata": {},
Expand All @@ -125,7 +125,7 @@
{
"data": {
"text/plain": [
"SymbolicCartesianCoordSystem(origin=[0, 0, 0], basis_symbols=[sqrt(x**2 + y**2), atan2(y, x), z])"
"SymbolicCartesianCoordSystem(origin=(0, 0, 0), basis_symbols=[sqrt(x**2 + y**2), atan2(y, x), z])"
]
},
"metadata": {},
Expand Down
4 changes: 2 additions & 2 deletions notebooks/field.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b72a02030aa490d8c1c41ae0accac63",
"model_id": "77c42c06012648c0a58085bf4d4797d0",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -133,7 +133,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7e5027294bb344f39c1fb2544b4af28f",
"model_id": "f30dc3d5d92942efbf597e3fc3568254",
"version_major": 2,
"version_minor": 0
},
Expand Down
209 changes: 189 additions & 20 deletions notebooks/field_2.ipynb

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions src/mechpy/core/symbolic/coord.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def __repr__(self):

class SymbolicCartesianCoordSystem(SymbolicCoordSystem):
def __init__(self, basis_symbols=None):
origin = sp.ImmutableDenseNDimArray([0, 0, 0])
basis_symbols = basis_symbols or [sp.symbols(_) for _ in ["x", "y", "z"]]
origin = (0, 0, 0)
basis_symbols = basis_symbols or sp.symbols("x y z")
super().__init__(origin, basis_symbols)

def get_basis_cylindrical_exprs(self) -> dict:
Expand Down Expand Up @@ -59,8 +59,8 @@ def get_spherical_coord(self, values):

class SymbolicCylindricalCoordSystem(SymbolicCoordSystem):
def __init__(self, basis_symbols=None):
origin = sp.ImmutableDenseNDimArray([0, 0, 0])
basis_symbols = basis_symbols or [sp.symbols(_) for _ in ["r", "theta", "z"]]
origin = (0, 0, 0)
basis_symbols = basis_symbols or sp.symbols("r theta z")
super().__init__(origin, basis_symbols)

def get_basis_cartesian_exprs(self, cartesian_basis_symbols=None) -> dict:
Expand Down Expand Up @@ -97,8 +97,8 @@ def get_cartesian_coords(self, values):

class SymbolicSphericalCoordSystem(SymbolicCoordSystem):
def __init__(self, basis_symbols=None):
origin = sp.ImmutableDenseNDimArray([0, 0, 0])
basis_symbols = basis_symbols or [sp.symbols(_) for _ in ["r", "theta", "phi"]]
origin = (0, 0, 0)
basis_symbols = basis_symbols or sp.symbols("r theta phi")
super().__init__(origin, basis_symbols)

def get_basis_cartesian_exprs(self, cartesian_basis_symbols=None) -> dict:
Expand Down
165 changes: 132 additions & 33 deletions src/mechpy/core/symbolic/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,124 @@ def __init__(self, data, coord_system, field_params=None):
if isinstance(coord_system, SymbolicCoordSystem):
self.data = data
self.coord_system = coord_system
self.field_params = field_params or []
self.field_params = list(field_params or [])
self.validate_field()
else:
raise ValueError("coord system must be a SymbolicCoordSystem")

def __repr__(self):
return f"{self.__class__.__name__}(\n{self.data},\n{self.coord_system.basis_symbols}\n)"
return f"{self.__class__.__name__}(\n{self.data},\n{self.coord_system.basis_symbols},\n{self.field_params})"

def validate_field(self):
self.validate_field_params()
self.validate_basis_symbols()

def validate_field_params(self):
if self.field_params:
field_param_symbols = set(self.field_params)
basis_symbols = set(self.coord_system.basis_symbols)

if not field_param_symbols.isdisjoint(basis_symbols):
raise ValueError(
"Field parameters must not overlap with coordinate system basis symbols."
)

def validate_basis_symbols(self):
"""
Validates that all free symbols in the field data are either part of the
coordinate system's basis symbols or the field parameters.
Raises a ValueError if there are any symbols in the field data that are
not included in either the basis symbols of the coordinate system or the
field parameters.
This ensures that the field data is properly defined with respect to the
coordinate system and any additional parameters.
"""
basis_symbols = set(self.coord_system.basis_symbols)
field_param_symbols = set(self.field_params)
valid_symbols = basis_symbols.union(field_param_symbols)

free_symbols = (
self.data.free_symbols
if isinstance(self.data, sp.Expr)
else set().union(*[element.free_symbols for element in self.data])
)

invalid_symbols = free_symbols - valid_symbols
if invalid_symbols:
raise ValueError(
"The field data contains symbols not in the basis or field parameters: "
+ ", ".join(str(symbol) for symbol in invalid_symbols)
)

def subs_field_params(self, param_values):
"""
Substitute the provided field parameters with specific values, and
remove them from self.field_params. Raise an error if a parameter in
param_values is not in self.field_params.
:param param_values: A dictionary mapping parameters to their values.
:return: None. The method updates self.data and self.field_params in place.
"""
if not isinstance(param_values, dict):
raise TypeError("param_values must be a dictionary")

# Perform the substitution for provided parameters
for param, value in param_values.items():
if param in self.field_params:
self.data = self.data.subs(param, value)
self.field_params.remove(param)
else:
raise ValueError(f"Parameter {param} not found in field parameters")

def to_cartesian(self):
"""
Converts the scalar field from its current coordinate system
(cylindrical or spherical) to the Cartesian coordinate system.
Returns:
SymbolicScalarField: A new instance of SymbolicScalarField in the
Cartesian coordinate system.
Raises:
ValueError: If the current coordinate system is not cylindrical or spherical.
"""
if not isinstance(
self.coord_system,
(SymbolicCylindricalCoordSystem, SymbolicSphericalCoordSystem),
):
raise ValueError(
"Conversion to Cartesian is only implemented for cylindrical and spherical coordinate systems."
)
expr_dict = self.coord_system.get_basis_cartesian_exprs()
cartesian_data = self.data.subs(expr_dict)
cartesian_coord_system = SymbolicCartesianCoordSystem()
return SymbolicScalarField(
cartesian_data, cartesian_coord_system, self.field_params
)

def to_cylindrical(self):
if isinstance(self.coord_system, SymbolicCartesianCoordSystem):
expr_dict = self.coord_system.get_basis_cylindrical_exprs()
else:
if not isinstance(self.coord_system, SymbolicCartesianCoordSystem):
raise NotImplementedError(
"Conversion from non-Cartesian systems is not implemented"
)

expr_dict = self.coord_system.get_basis_cylindrical_exprs()
cylindrical_data = self.data.subs(expr_dict)
cylindrical_coord_system = SymbolicCylindricalCoordSystem()
return self.__class__(cylindrical_data, cylindrical_coord_system)
return self.__class__(
cylindrical_data, cylindrical_coord_system, self.field_params
)

def to_spherical(self):
if not isinstance(self.coord_system, SymbolicCartesianCoordSystem):
raise NotImplementedError(
"Conversion from non-Cartesian systems is not implemented"
)
expr_dict = self.coord_system.get_basis_spherical_exprs()
spherical_data = self.data.subs(expr_dict)
spherical_coord_system = SymbolicSphericalCoordSystem()
return self.__class__(spherical_data, spherical_coord_system, self.field_params)


class SymbolicField3D(SymbolicField):
Expand All @@ -47,6 +147,24 @@ def __init__(self, data, coord_system, field_params=None):
else:
raise ValueError("Input data must be a SymPy Expr or SymPy Array")

def lambdify(self):
"""
Converts the symbolic field data into a lambda function for numerical evaluation.
If the field is not in Cartesian coordinates, it first converts it to Cartesian.
Returns:
function: A lambda function for numerical evaluation of the field.
"""
# Ensure the field is in Cartesian coordinates
if not isinstance(self.coord_system, SymbolicCartesianCoordSystem):
field_in_cartesian = self.to_cartesian()
else:
field_in_cartesian = self

basis_symbols = field_in_cartesian.coord_system.basis_symbols
data = field_in_cartesian.data
return sp.lambdify(basis_symbols, data, "numpy")


class SymbolicScalarField(SymbolicField3D):
shape = (3,)
Expand Down Expand Up @@ -74,24 +192,17 @@ def create_linear(cls, data, coord_system=None, field_params=None):
return cls(scalar_field, coord_system, field_params)

def plot(self, x_limits=[-100, 100], y_limits=[-100, 100], z_limits=[-100, 100]):
x, y, z = sp.symbols("x y z")
if self.field_params:
raise ValueError(
"Cannot plot field with unresolved parameters: "
+ ", ".join(str(p) for p in self.field_params)
)

x_vals = np.linspace(*x_limits, 100)
y_vals = np.linspace(*y_limits, 100)
X, Y = np.meshgrid(x_vals, y_vals)

if isinstance(self.coord_system, SymbolicCartesianCoordSystem):
data = self.data
elif isinstance(
self.coord_system,
(SymbolicCylindricalCoordSystem, SymbolicSphericalCoordSystem),
):
expr_dict = self.coord_system.get_basis_cartesian_exprs()
data = self.data.subs(expr_dict)
else:
raise ValueError("Unsupported coordinate system. The coordinate system.")

f = sp.lambdify((x, y, z), data, "numpy")
f = self.lambdify()

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
Expand Down Expand Up @@ -160,19 +271,7 @@ def plot(self, x_limits=[-100, 100], y_limits=[-100, 100], z_limits=[-100, 100])
z_vals = np.linspace(*z_limits, 10)
X, Y, Z = np.meshgrid(x_vals, y_vals, z_vals)

if isinstance(self.coord_system, SymbolicCartesianCoordSystem):
data = self.data
elif isinstance(
self.coord_system,
(SymbolicCylindricalCoordSystem, SymbolicSphericalCoordSystem),
):
expr_dict = self.coord_system.get_basis_cartesian_exprs()
data = self.data.subs(expr_dict)
else:
raise ValueError("Unsupported coordinate system. The coordinate system.")

# Convert the symbolic expressions to numerical functions
f = sp.lambdify(sp.symbols("x y z"), data, "numpy")
f = self.lambdify()

# Evaluate the function at each point in the grid
U, V, W = f(X, Y, Z)
Expand Down

0 comments on commit 9b07413

Please sign in to comment.