Skip to content

Commit

Permalink
#1960 WIP adding tests and code to catch reductions in WHERE [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
arporter committed Oct 19, 2023
1 parent 70da6a7 commit 69f3de4
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
18 changes: 17 additions & 1 deletion src/psyclone/psyir/frontend/fparser2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3597,6 +3597,22 @@ def _where_construct_handler(self, node, parent):
# the NEMO style where the fact that `a` is an array is made
# explicit using the colon notation, e.g. `a(:, :) < 0.0`.

# We do not yet support intrinsics that perform reductions within
# the mask expression.
intr_nodes = walk(logical_expr,
Fortran2003.Intrinsic_Function_Reference)
for intr in intr_nodes:
if (intr.children[0].string in
Fortran2003.Intrinsic_Name.array_reduction_names):
# These intrinsics are only a problem if they return an
# array rather than a scalar.
arg_specs = walk(intr.children[1] Fortran2003.Actual_Arg_Spec)
# TODO WORKING HERE
raise NotImplementedError(
f"WHERE constructs which contain an array reduction "
f"intrinsic in their logical expression are not supported "
f"but found '{logical_expr}'")

# For this initial processing of the logical-array expression we
# use a temporary parent as we haven't yet constructed the PSyIR
# for the loop nest and innermost IfBlock. Once we have a valid
Expand All @@ -3610,7 +3626,7 @@ def _where_construct_handler(self, node, parent):
# because the code doesn't use explicit array syntax. At least one
# variable in the logical-array expression must be an array for
# this to be a valid WHERE().
# TODO #717. Look-up the shape of the array in the SymbolTable.
# TODO #1799. Look-up the shape of the array in the SymbolTable.
raise NotImplementedError(
f"Only WHERE constructs using explicit array notation (e.g. "
f"my_array(:,:)) are supported but found '{logical_expr}'.")
Expand Down
77 changes: 77 additions & 0 deletions src/psyclone/tests/psyir/frontend/fparser2_where_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,83 @@ def test_where_array_subsections():
assert assign.lhs.children[2].name == "widx2"


def test_where_body_containing_sum_with_dim(fortran_reader, fortran_writer):
'''
Since a SUM(x, dim=y) performs a reduction but produces an array, we need
to replace it with a temporary in order to translate the WHERE into
canonical form. We can't do that without being able to declare a suitable
temporary and that requires TODO #1799.
'''
code = '''\
subroutine my_sub(picefr)
use some_mod
REAL(wp), INTENT(in), DIMENSION(:,:) :: picefr
REAL(wp), DIMENSION(jpi,jpj,jpl) :: zevap_ice
REAL(wp), ALLOCATABLE, SAVE, DIMENSION(:,:,:) :: a_i_last_couple
WHERE( picefr(:,:) > 1.e-10 )
zevap_ice(:,:,1) = frcv(jpr_ievp)%z3(:,:,1) * &
SUM( a_i_last_couple, dim=3 ) / picefr(:,:)
ELSEWHERE
zevap_ice(:,:,1) = 0.0
END WHERE
end subroutine my_sub
'''
psyir = fortran_reader.psyir_from_source(code)
routine = psyir.walk(Routine)[0]
assert isinstance(routine[0], CodeBlock)
output = fortran_writer(psyir)
assert "SUM( a_i_last_couple, dim=3 ) / picefr(:,:)" in output


def test_where_containing_sum_no_dim(fortran_reader, fortran_writer):
'''
Since a SUM without a dim argument always produces a scalar we can
translate a WHERE containing it into canonical form.
'''
code = '''\
subroutine my_sub(picefr)
use some_mod
REAL(wp), INTENT(in), DIMENSION(:,:) :: picefr
REAL(wp), DIMENSION(jpi,jpj,jpl) :: zevap_ice
REAL(wp), ALLOCATABLE, SAVE, DIMENSION(:,:,:) :: a_i_last_couple
WHERE( picefr(:,:) > SUM(picefr) )
zevap_ice(:,:,1) = frcv(jpr_ievp)%z3(:,:,1) * &
SUM( a_i_last_couple ) / picefr(:,:)
ELSEWHERE
zevap_ice(:,:,1) = 0.0
END WHERE
end subroutine my_sub
'''
psyir = fortran_reader.psyir_from_source(code)
routine = psyir.walk(Routine)[0]
assert isinstance(routine[0], CodeBlock)
output = fortran_writer(psyir)
assert "SUM( a_i_last_couple, dim=3 ) / picefr(:,:)" in output


def test_where_mask_containing_sum_with_dim(fortran_reader):
'''Since a SUM(x, dim=y) appearing in a mask expression performs
a reduction but produces an array, we need to replace it with a
temporary in order to translate the WHERE into canonical form
(since the mask expression determines the number of nested loops
required). We can't do that without being able to declare a
suitable temporary and that requires TODO #1799.
'''
code = '''\
subroutine my_sub(v2)
REAL(wp), INTENT(in), DIMENSION(:,:) :: v2
REAL(wp), ALLOCATABLE, SAVE, DIMENSION(:) :: v3
where(sum(v2(:,:), dim=2) > 0.0)
v3(:) = 1.0
end where
end subroutine my_sub
'''
psyir = fortran_reader.psyir_from_source(code)
routine = psyir.walk(Routine)[0]
assert isinstance(routine[0], CodeBlock)


@pytest.mark.usefixtures("parser")
@pytest.mark.parametrize("rhs", ["depth", "maxval(depth(:))"])
@pytest.mark.xfail(reason="#717 need to distinguish scalar and array "
Expand Down

0 comments on commit 69f3de4

Please sign in to comment.