|
| 1 | +"""Remove component tensors. |
| 2 | +
|
| 3 | +This module contains classes and functions to remove component tensors. |
| 4 | +""" |
| 5 | +# Copyright (C) 2008-2016 Martin Sandve Alnæs |
| 6 | +# |
| 7 | +# This file is part of UFL (https://www.fenicsproject.org) |
| 8 | +# |
| 9 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 10 | + |
| 11 | +from ufl.classes import ( |
| 12 | + ComponentTensor, |
| 13 | + Form, |
| 14 | + Index, |
| 15 | + MultiIndex, |
| 16 | + Zero, |
| 17 | +) |
| 18 | +from ufl.corealg.map_dag import map_expr_dag |
| 19 | +from ufl.corealg.multifunction import MultiFunction, memoized_handler |
| 20 | + |
| 21 | + |
| 22 | +class IndexReplacer(MultiFunction): |
| 23 | + """Replace Indices.""" |
| 24 | + |
| 25 | + def __init__(self, fimap: dict): |
| 26 | + """Initialise. |
| 27 | +
|
| 28 | + Args: |
| 29 | + fimap: map for index replacements. |
| 30 | +
|
| 31 | + """ |
| 32 | + MultiFunction.__init__(self) |
| 33 | + self.fimap = fimap |
| 34 | + self._object_cache = {} |
| 35 | + |
| 36 | + expr = MultiFunction.reuse_if_untouched |
| 37 | + |
| 38 | + @memoized_handler |
| 39 | + def zero(self, o): |
| 40 | + """Handle Zero.""" |
| 41 | + free_indices = [] |
| 42 | + index_dimensions = [] |
| 43 | + for i, d in zip(o.ufl_free_indices, o.ufl_index_dimensions): |
| 44 | + if Index(i) in self.fimap: |
| 45 | + ind_j = self.fimap[Index(i)] |
| 46 | + if isinstance(ind_j, Index): |
| 47 | + free_indices.append(ind_j.count()) |
| 48 | + index_dimensions.append(d) |
| 49 | + else: |
| 50 | + free_indices.append(i) |
| 51 | + index_dimensions.append(d) |
| 52 | + return Zero( |
| 53 | + shape=o.ufl_shape, |
| 54 | + free_indices=tuple(free_indices), |
| 55 | + index_dimensions=tuple(index_dimensions), |
| 56 | + ) |
| 57 | + |
| 58 | + @memoized_handler |
| 59 | + def multi_index(self, o): |
| 60 | + """Handle MultiIndex.""" |
| 61 | + return MultiIndex(tuple(self.fimap.get(i, i) for i in o.indices())) |
| 62 | + |
| 63 | + |
| 64 | +class IndexRemover(MultiFunction): |
| 65 | + """Remove Indexed.""" |
| 66 | + |
| 67 | + def __init__(self): |
| 68 | + """Initialise.""" |
| 69 | + MultiFunction.__init__(self) |
| 70 | + self._object_cache = {} |
| 71 | + |
| 72 | + expr = MultiFunction.reuse_if_untouched |
| 73 | + |
| 74 | + @memoized_handler |
| 75 | + def _zero_simplify(self, o): |
| 76 | + """Apply simplification for Zero().""" |
| 77 | + (operand,) = o.ufl_operands |
| 78 | + operand = map_expr_dag(self, operand) |
| 79 | + if isinstance(operand, Zero): |
| 80 | + return Zero( |
| 81 | + shape=o.ufl_shape, |
| 82 | + free_indices=o.ufl_free_indices, |
| 83 | + index_dimensions=o.ufl_index_dimensions, |
| 84 | + ) |
| 85 | + return o._ufl_expr_reconstruct_(operand) |
| 86 | + |
| 87 | + @memoized_handler |
| 88 | + def indexed(self, o): |
| 89 | + """Simplify Indexed.""" |
| 90 | + o1, i1 = o.ufl_operands |
| 91 | + if isinstance(o1, ComponentTensor): |
| 92 | + # Simplify Indexed ComponentTensor |
| 93 | + o2, i2 = o1.ufl_operands |
| 94 | + assert len(i2) == len(i1) |
| 95 | + fimap = dict(zip(i2, i1)) |
| 96 | + rule = IndexReplacer(fimap) |
| 97 | + v = map_expr_dag(self, o2) |
| 98 | + return map_expr_dag(rule, v) |
| 99 | + |
| 100 | + expr = map_expr_dag(self, o1) |
| 101 | + if expr is o1: |
| 102 | + # Reuse if untouched |
| 103 | + return o |
| 104 | + return o._ufl_expr_reconstruct_(expr, i1) |
| 105 | + |
| 106 | + # Do something nicer |
| 107 | + positive_restricted = _zero_simplify |
| 108 | + negative_restricted = _zero_simplify |
| 109 | + reference_grad = _zero_simplify |
| 110 | + reference_value = _zero_simplify |
| 111 | + |
| 112 | + |
| 113 | +def remove_component_tensors(o): |
| 114 | + """Remove component tensors.""" |
| 115 | + if isinstance(o, Form): |
| 116 | + integrals = [] |
| 117 | + for integral in o.integrals(): |
| 118 | + integrand = remove_component_tensors(integral.integrand()) |
| 119 | + if not isinstance(integrand, Zero): |
| 120 | + integrals.append(integral.reconstruct(integrand=integrand)) |
| 121 | + return o._ufl_expr_reconstruct_(integrals) |
| 122 | + else: |
| 123 | + rule = IndexRemover() |
| 124 | + return map_expr_dag(rule, o) |
0 commit comments