Skip to content

Commit d1b850e

Browse files
committed
Simplify Grad(CellwiseConstant)
1 parent 768f403 commit d1b850e

File tree

3 files changed

+45
-27
lines changed

3 files changed

+45
-27
lines changed

test/test_algorithms.py

+17
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,20 @@ def test_remove_component_tensors(domain):
197197
fd = compute_form_data(form)
198198

199199
assert "ComponentTensor" not in repr(fd.preprocessed_form)
200+
201+
202+
def test_grad_cellwise_constant(domain):
203+
element = FiniteElement("Lagrange", triangle, 3, (), identity_pullback, H1)
204+
space = FunctionSpace(domain, element)
205+
u = Coefficient(space)
206+
207+
# Applying four derivatives to a cubic should simplify to zero
208+
f = div(grad(div(grad(u))))
209+
form = f * dx
210+
211+
fd = compute_form_data(
212+
form,
213+
do_apply_function_pullbacks=True,
214+
)
215+
assert fd.preprocessed_form.empty()
216+
assert fd.num_coefficients == 0

ufl/algorithms/apply_derivatives.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from ufl.action import Action
1414
from ufl.algorithms.analysis import extract_arguments
15+
from ufl.algorithms.estimate_degrees import SumDegreeEstimator
1516
from ufl.algorithms.map_integrands import map_integrand_dags
1617
from ufl.algorithms.replace_derivative_nodes import replace_derivative_nodes
1718
from ufl.argument import BaseArgument
@@ -562,6 +563,14 @@ def __init__(self, geometric_dimension):
562563
"""Initialise."""
563564
GenericDerivativeRuleset.__init__(self, var_shape=(geometric_dimension,))
564565
self._Id = Identity(geometric_dimension)
566+
self.degree_estimator = SumDegreeEstimator(1, {})
567+
568+
def is_cellwise_constant(self, o):
569+
"""More precise checks for cellwise constants."""
570+
if is_cellwise_constant(o):
571+
return True
572+
degree = map_expr_dag(self.degree_estimator, o)
573+
return degree == 0
565574

566575
# --- Specialized rules for geometric quantities
567576

@@ -572,7 +581,7 @@ def geometric_quantity(self, o):
572581
otherwise transform derivatives to reference derivatives.
573582
Override for specific types if other behaviour is needed.
574583
"""
575-
if is_cellwise_constant(o):
584+
if self.is_cellwise_constant(o):
576585
return self.independent_terminal(o)
577586
else:
578587
domain = extract_unique_domain(o)
@@ -583,7 +592,7 @@ def geometric_quantity(self, o):
583592
def jacobian_inverse(self, o):
584593
"""Differentiate a jacobian_inverse."""
585594
# grad(K) == K_ji rgrad(K)_rj
586-
if is_cellwise_constant(o):
595+
if self.is_cellwise_constant(o):
587596
return self.independent_terminal(o)
588597
if not o._ufl_is_terminal_:
589598
raise ValueError("ReferenceValue can only wrap a terminal")
@@ -654,8 +663,10 @@ def reference_value(self, o):
654663
def reference_grad(self, o):
655664
"""Differentiate a reference_grad."""
656665
# grad(o) == grad(rgrad(rv(f))) -> K_ji*rgrad(rgrad(rv(f)))_rj
657-
f = o.ufl_operands[0]
666+
if self.is_cellwise_constant(o):
667+
return self.independent_terminal(o)
658668

669+
f = o.ufl_operands[0]
659670
valid_operand = f._ufl_is_in_reference_frame_ or isinstance(
660671
f, (JacobianInverse, SpatialCoordinate, Jacobian, JacobianDeterminant)
661672
)
@@ -676,7 +687,6 @@ def grad(self, o):
676687
# Check that o is a "differential terminal"
677688
if not isinstance(o.ufl_operands[0], (Grad, Terminal)):
678689
raise ValueError("Expecting only grads applied to a terminal.")
679-
680690
return Grad(o)
681691

682692
def _grad(self, o):

ufl/algorithms/remove_component_tensors.py

+14-23
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,7 @@
88
#
99
# SPDX-License-Identifier: LGPL-3.0-or-later
1010

11-
from ufl.algorithms.estimate_degrees import SumDegreeEstimator
12-
from ufl.classes import (
13-
ComponentTensor,
14-
Form,
15-
Index,
16-
MultiIndex,
17-
Zero,
18-
)
11+
from ufl.classes import ComponentTensor, Form, Index, MultiIndex, Zero
1912
from ufl.corealg.map_dag import map_expr_dag
2013
from ufl.corealg.multifunction import MultiFunction, memoized_handler
2114

@@ -42,13 +35,10 @@ def zero(self, o):
4235
free_indices = []
4336
index_dimensions = []
4437
for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions):
45-
if Index(i) in self.fimap:
46-
ind_j = self.fimap[Index(i)]
47-
if isinstance(ind_j, Index):
48-
free_indices.append(ind_j.count())
49-
index_dimensions.append(d)
50-
else:
51-
free_indices.append(i)
38+
k = Index(i)
39+
j = self.fimap.get(k, k)
40+
if isinstance(j, Index):
41+
free_indices.append(j.count())
5242
index_dimensions.append(d)
5343
return Zero(
5444
shape=o.ufl_shape,
@@ -69,26 +59,24 @@ def __init__(self):
6959
"""Initialise."""
7060
MultiFunction.__init__(self)
7161
self._object_cache = {}
72-
self.degree_estimator = SumDegreeEstimator(1, {})
7362

7463
expr = MultiFunction.reuse_if_untouched
7564

7665
@memoized_handler
77-
def reference_grad(self, o):
78-
"""Simplify ReferenceGrad(Constant)."""
66+
def _unary_operator(self, o):
67+
"""Simplify UnaryOperator(Zero)."""
7968
(operand,) = o.ufl_operands
80-
operand = map_expr_dag(self, operand)
81-
degree = map_expr_dag(self.degree_estimator, operand)
82-
if degree == 0:
69+
f = map_expr_dag(self, operand)
70+
if isinstance(f, Zero):
8371
return Zero(
8472
shape=o.ufl_shape,
8573
free_indices=o.ufl_free_indices,
8674
index_dimensions=o.ufl_index_dimensions,
8775
)
88-
if operand is o.ufl_operands[0]:
76+
if f is operand:
8977
# Reuse if untouched
9078
return o
91-
return o._ufl_expr_reconstruct_(operand)
79+
return o._ufl_expr_reconstruct_(f)
9280

9381
@memoized_handler
9482
def indexed(self, o):
@@ -111,6 +99,9 @@ def indexed(self, o):
11199
return o
112100
return o._ufl_expr_reconstruct_(expr, i1)
113101

102+
reference_grad = _unary_operator
103+
reference_value = _unary_operator
104+
114105

115106
def remove_component_tensors(o):
116107
"""Remove component tensors."""

0 commit comments

Comments
 (0)