Skip to content

Commit d328fe5

Browse files
committed
Remove component tensors
1 parent 7d7c676 commit d328fe5

6 files changed

+157
-2
lines changed

test/test_algorithms.py

+15
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
FacetNormal,
1313
FunctionSpace,
1414
Mesh,
15+
SpatialCoordinate,
1516
TestFunction,
1617
TrialFunction,
1718
adjoint,
@@ -21,9 +22,11 @@
2122
dx,
2223
grad,
2324
inner,
25+
sin,
2426
triangle,
2527
)
2628
from ufl.algorithms import (
29+
compute_form_data,
2730
expand_derivatives,
2831
expand_indices,
2932
extract_arguments,
@@ -182,3 +185,15 @@ def test_adjoint(domain):
182185
d = adjoint(b)
183186
d_arg_degrees = [arg.ufl_element().embedded_superdegree for arg in extract_arguments(d)]
184187
assert d_arg_degrees == [2, 1]
188+
189+
190+
def test_remove_component_tensors(domain):
191+
x = SpatialCoordinate(domain)
192+
u = sin(x[0])
193+
194+
f = div(grad(div(grad(u))))
195+
form = f * dx
196+
197+
fd = compute_form_data(form)
198+
199+
assert "ComponentTensor" not in repr(fd.preprocessed_form)

test/test_derivative.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ def test_vector_coefficient_scalar_derivatives(self):
665665
integrand = inner(f, g)
666666

667667
i0, i1, i2, i3, i4 = [Index(count=c) for c in range(5)]
668-
expected = as_tensor(df[i1] * dv, (i1,))[i0] * g[i0]
668+
expected = as_tensor(df[i1], (i1,))[i0] * dv * g[i0]
669669

670670
F = integrand * dx
671671
J = derivative(F, u, dv, cd)
@@ -693,7 +693,7 @@ def test_vector_coefficient_derivatives(self):
693693
integrand = inner(f, g)
694694

695695
i0, i1, i2, i3, i4 = [Index(count=c) for c in range(5)]
696-
expected = as_tensor(df[i2, i1] * dv[i1], (i2,))[i0] * g[i0]
696+
expected = as_tensor(df[i2, i1], (i2,))[i0] * dv[i1] * g[i0]
697697

698698
F = integrand * dx
699699
J = derivative(F, u, dv, cd)

ufl/algebra.py

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ufl.core.operator import Operator
1414
from ufl.core.ufl_type import ufl_type
1515
from ufl.index_combination_utils import merge_unique_indices
16+
from ufl.indexed import Indexed
1617
from ufl.precedence import parstr
1718
from ufl.sorting import sorted_expr
1819

@@ -89,6 +90,11 @@ def __init__(self, a, b):
8990
"""Initialise."""
9091
Operator.__init__(self)
9192

93+
def _simplify_indexed(self, multiindex):
94+
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
95+
a, b = self.ufl_operands
96+
return Sum(Indexed(a, multiindex), Indexed(b, multiindex))
97+
9298
def evaluate(self, x, mapping, component, index_values):
9399
"""Evaluate."""
94100
return sum(o.evaluate(x, mapping, component, index_values) for o in self.ufl_operands)

ufl/algorithms/compute_form_data.py

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ufl.algorithms.formdata import FormData
3434
from ufl.algorithms.formtransformations import compute_form_arities
3535
from ufl.algorithms.remove_complex_nodes import remove_complex_nodes
36+
from ufl.algorithms.remove_component_tensors import remove_component_tensors
3637
from ufl.classes import Coefficient, Form, FunctionSpace, GeometricFacetQuantity
3738
from ufl.corealg.traversal import traverse_unique_terminals
3839
from ufl.domain import extract_unique_domain
@@ -328,6 +329,9 @@ def compute_form_data(
328329

329330
form = apply_coordinate_derivatives(form)
330331

332+
# Remove component tensors
333+
form = remove_component_tensors(form)
334+
331335
# Propagate restrictions to terminals
332336
if do_apply_restrictions:
333337
form = apply_restrictions(form, apply_default=do_apply_default_restrictions)
+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""Remove component tensors.
2+
3+
This module contains classes and functions to remove component tensors.
4+
"""
5+
# Copyright (C) 2008-2016 Martin Sandve Alnæs
6+
#
7+
# This file is part of UFL (https://www.fenicsproject.org)
8+
#
9+
# SPDX-License-Identifier: LGPL-3.0-or-later
10+
11+
from ufl.classes import (
12+
ComponentTensor,
13+
Form,
14+
Index,
15+
MultiIndex,
16+
Zero,
17+
)
18+
from ufl.corealg.map_dag import map_expr_dag
19+
from ufl.corealg.multifunction import MultiFunction, memoized_handler
20+
21+
22+
class IndexReplacer(MultiFunction):
23+
"""Replace Indices."""
24+
25+
def __init__(self, fimap: dict):
26+
"""Initialise.
27+
28+
Args:
29+
fimap: map for index replacements.
30+
31+
"""
32+
MultiFunction.__init__(self)
33+
self.fimap = fimap
34+
self._object_cache = {}
35+
36+
expr = MultiFunction.reuse_if_untouched
37+
38+
@memoized_handler
39+
def zero(self, o):
40+
"""Handle Zero."""
41+
free_indices = []
42+
index_dimensions = []
43+
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions):
44+
if Index(i) in self.fimap:
45+
ind_j = self.fimap[Index(i)]
46+
if isinstance(ind_j, Index):
47+
free_indices.append(ind_j.count())
48+
index_dimensions.append(d)
49+
else:
50+
free_indices.append(i)
51+
index_dimensions.append(d)
52+
return Zero(
53+
shape=o.ufl_shape,
54+
free_indices=tuple(free_indices),
55+
index_dimensions=tuple(index_dimensions),
56+
)
57+
58+
@memoized_handler
59+
def multi_index(self, o):
60+
"""Handle MultiIndex."""
61+
return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices()))
62+
63+
64+
class IndexRemover(MultiFunction):
65+
"""Remove Indexed."""
66+
67+
def __init__(self):
68+
"""Initialise."""
69+
MultiFunction.__init__(self)
70+
self._object_cache = {}
71+
72+
expr = MultiFunction.reuse_if_untouched
73+
74+
@memoized_handler
75+
def _zero_simplify(self, o):
76+
"""Apply simplification for Zero()."""
77+
(operand,) = o.ufl_operands
78+
operand = map_expr_dag(self, operand)
79+
if isinstance(operand, Zero):
80+
return Zero(
81+
shape=o.ufl_shape,
82+
free_indices=o.ufl_free_indices,
83+
index_dimensions=o.ufl_index_dimensions,
84+
)
85+
return o._ufl_expr_reconstruct_(operand)
86+
87+
@memoized_handler
88+
def indexed(self, o):
89+
"""Simplify Indexed."""
90+
o1, i1 = o.ufl_operands
91+
if isinstance(o1, ComponentTensor):
92+
# Simplify Indexed ComponentTensor
93+
o2, i2 = o1.ufl_operands
94+
assert len(i2) == len(i1)
95+
fimap = dict(zip(i2, i1))
96+
rule = IndexReplacer(fimap)
97+
v = map_expr_dag(self, o2)
98+
return map_expr_dag(rule, v)
99+
100+
expr = map_expr_dag(self, o1)
101+
if expr is o1:
102+
# Reuse if untouched
103+
return o
104+
return o._ufl_expr_reconstruct_(expr, i1)
105+
106+
# Do something nicer
107+
positive_restricted = _zero_simplify
108+
negative_restricted = _zero_simplify
109+
reference_grad = _zero_simplify
110+
reference_value = _zero_simplify
111+
112+
113+
def remove_component_tensors(o):
114+
"""Remove component tensors."""
115+
if isinstance(o, Form):
116+
integrals = []
117+
for integral in o.integrals():
118+
integrand = remove_component_tensors(integral.integrand())
119+
if not isinstance(integrand, Zero):
120+
integrals.append(integral.reconstruct(integrand=integrand))
121+
return o._ufl_expr_reconstruct_(integrals)
122+
else:
123+
rule = IndexRemover()
124+
return map_expr_dag(rule, o)

ufl/indexsum.py

+6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from ufl.core.multiindex import MultiIndex
1212
from ufl.core.operator import Operator
1313
from ufl.core.ufl_type import ufl_type
14+
from ufl.indexed import Indexed
1415
from ufl.precedence import parstr
1516

1617
# --- Sum over an index ---
@@ -69,6 +70,11 @@ def ufl_shape(self):
6970
"""Get UFL shape."""
7071
return self.ufl_operands[0].ufl_shape
7172

73+
def _simplify_indexed(self, multiindex):
74+
"""Return a simplified Expr used in the constructor of Indexed(self, multiindex)."""
75+
A, i = self.ufl_operands
76+
return IndexSum(Indexed(A, multiindex), i)
77+
7278
def evaluate(self, x, mapping, component, index_values):
7379
"""Evaluate."""
7480
(i,) = self.ufl_operands[1]

0 commit comments

Comments
 (0)