Skip to content

Commit

Permalink
#1960 correct/improve lower-bound handling in where processing
Browse files Browse the repository at this point in the history
  • Loading branch information
arporter committed Jan 31, 2024
1 parent aae4ec5 commit 7fa7249
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 33 deletions.
60 changes: 39 additions & 21 deletions src/psyclone/psyir/frontend/fparser2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3662,32 +3662,50 @@ def _array_syntax_to_indexed(self, parent, loop_vars):
# Replace the PSyIR Ranges with appropriate index expressions.
range_idx = 0
for idx, child in enumerate(array.indices):
if isinstance(child, Range):
symbol = table.lookup(loop_vars[range_idx])
if isinstance(shape[range_idx], ArrayType.Extent):
# We don't know the bounds of this array so we have
# to query using LBOUND.
if not isinstance(child, Range):
continue
# We need the lower bound of the appropriate dimension of this
# array as we will index relative to it. Note that the 'shape'
# of the datatype only gives us extents, not the lower bounds
# of the declaration or slice.
if isinstance(shape[range_idx], ArrayType.Extent):
# We don't know the bounds of this array so we have
# to query using LBOUND.
lbound = IntrinsicCall.create(
IntrinsicCall.Intrinsic.LBOUND,
[base_ref.copy(),
("dim", Literal(str(idx+1), INTEGER_TYPE))])
else:
if array.is_full_range(idx):
# The access to this index is to the full range of
# the array.
# TODO #949 - ideally we would try to find the lower
# bound of the array by interrogating `array.symbol.
# datatype` but the fparser2 frontend doesn't currently
# support array declarations with explicit lower bounds
lbound = IntrinsicCall.create(
IntrinsicCall.Intrinsic.LBOUND,
[base_ref.copy(),
("dim", Literal(str(idx+1), INTEGER_TYPE))])
else:
lbound = shape[range_idx].lower.copy()
# Create the index expression.
if isinstance(lbound, Literal) and lbound.value == "1":
# Lower bound is just unity so we can use the loop-idx
# directly.
expr2 = Reference(symbol)
else:
# We don't know what the lower bound is so have to
# have an expression:
# idx-expr = array-lower-bound + loop-idx - 1
expr = BinaryOperation.create(
add_op, lbound, Reference(symbol))
expr2 = BinaryOperation.create(sub_op, expr,
one.copy())
array.children[idx] = expr2
range_idx += 1
# We need the lower bound of this access.
lbound = child.start.copy()

# Create the index expression.
symbol = table.lookup(loop_vars[range_idx])
if isinstance(lbound, Literal) and lbound.value == "1":
# Lower bound is just unity so we can use the loop-idx
# directly.
expr2 = Reference(symbol)
else:
# We don't know what the lower bound is so have to
# have an expression:
# idx-expr = array-lower-bound + loop-idx - 1
expr = BinaryOperation.create(
add_op, lbound, Reference(symbol))
expr2 = BinaryOperation.create(sub_op, expr, one.copy())
array.children[idx] = expr2
range_idx += 1

def _where_construct_handler(self, node, parent):
'''
Expand Down
107 changes: 95 additions & 12 deletions src/psyclone/tests/psyir/frontend/fparser2_where_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ def test_where_within_loop(fortran_reader):
assert isinstance(where_loop.loop_body[0], IfBlock)
assign = where_loop.loop_body[0].if_body[0]
assert isinstance(assign, Assignment)
assert assign.lhs.indices[0].debug_string() == "widx1"
assert (assign.lhs.indices[0].debug_string() ==
"LBOUND(var2, dim=1) + widx1 - 1")
assert assign.lhs.indices[1].debug_string() == "jl"
assert where_loop.start_expr.value == "1"
assert where_loop.stop_expr.debug_string() == "SIZE(var, dim=1)"
Expand Down Expand Up @@ -356,7 +357,10 @@ def test_basic_where():
ifblock = loops[2].loop_body[0]
assert isinstance(ifblock, IfBlock)
assert "was_where" in ifblock.annotations
assert ifblock.condition.debug_string() == "dry(widx1,widx2,widx3)"
assert (ifblock.condition.debug_string() ==
"dry(LBOUND(dry, dim=1) + widx1 - 1,"
"LBOUND(dry, dim=2) + widx2 - 1,"
"LBOUND(dry, dim=3) + widx3 - 1)")


@pytest.mark.usefixtures("parser")
Expand All @@ -381,9 +385,85 @@ def test_where_array_subsections():
# Check that the array reference is indexed correctly
assign = ifblock.if_body[0]
assert isinstance(assign, Assignment)
assert isinstance(assign.lhs.children[0], Reference)
assert assign.lhs.children[0].debug_string() == "widx1"
assert assign.lhs.children[2].debug_string() == "widx2"
assert isinstance(assign.lhs.children[0], BinaryOperation)
assert (assign.lhs.children[0].debug_string() ==
"LBOUND(z1_st, dim=1) + widx1 - 1")
assert (assign.lhs.children[2].debug_string() ==
"LBOUND(z1_st, dim=3) + widx2 - 1")


def test_where_mask_starting_value(fortran_reader, fortran_writer):
'''
Check handling of a case where the mask array is indexed from values other
than unity.
# TODO #949 - we can't currently take advantage of any knowledge of the
# declared lower bounds of arrays because the fparser2 frontend doesn't yet
# capture this information (we get an UnknownFortranType).
'''
code = '''\
program my_sub
use some_mod
real, dimension(-5:5,-5:5) :: picefr
real, DIMENSION(11,11,jpl) :: zevap_ice
real, dimension(-2:8,jpl,-3:7) :: snow
real, dimension(-22:0,jpl,-32:0) :: slush
WHERE( picefr(:,:) > 1.e-10 )
zevap_ice(:,:,1) = snow(:,3,:) * frcv(jpr_ievp)%z3(:,:,1) / picefr(:,:)
ELSEWHERE
zevap_ice(:,:,1) = snow(:,map(jpl),:) + slush(-22:-11,jpl,-32:-21)
END WHERE
end program my_sub
'''
psyir = fortran_reader.psyir_from_source(code)
output = fortran_writer(psyir)
print(output)
expected = '''\
do widx2 = 1, SIZE(picefr, dim=2), 1
do widx1 = 1, SIZE(picefr, dim=1), 1
if (picefr(LBOUND(picefr, dim=1) + widx1 - 1,\
LBOUND(picefr, dim=2) + widx2 - 1) > 1.e-10) then
zevap_ice(LBOUND(zevap_ice, dim=1) + widx1 - 1,\
LBOUND(zevap_ice, dim=2) + widx2 - 1,1) = \
snow(LBOUND(snow, dim=1) + widx1 - 1,3,LBOUND(snow, dim=3) + widx2 - 1) * \
frcv(jpr_ievp)%z3(LBOUND(frcv(jpr_ievp)%z3, dim=1) + widx1 - 1,\
LBOUND(frcv(jpr_ievp)%z3, dim=2) + widx2 - 1,1) / \
picefr(LBOUND(picefr, dim=1) + widx1 - 1,LBOUND(picefr, dim=2) + widx2 - 1)
else
zevap_ice(LBOUND(zevap_ice, dim=1) + widx1 - 1,\
LBOUND(zevap_ice, dim=2) + widx2 - 1,1) = \
snow(LBOUND(snow, dim=1) + widx1 - 1,map(jpl),\
LBOUND(snow, dim=3) + widx2 - 1) + slush(-22 + widx1 - 1,jpl,-32 + widx2 - 1)
'''
assert expected in output


def test_where_mask_is_slice(fortran_reader, fortran_writer):
'''
Check that the correct loop bounds and index expressions are created
when the mask expression uses a slice with specified bounds.
'''
code = '''\
program my_sub
use some_mod
WHERE( picefr(2:4,jstart:jstop) > 1.e-10 )
zevap_ice(:,:,1) = frcv(jpr_ievp)%z3(:,:,1) / picefr(:,:)
ELSEWHERE
zevap_ice(:,:,1) = 0.0
END WHERE
end program my_sub
'''
psyir = fortran_reader.psyir_from_source(code)
out = fortran_writer(psyir)
# Check that created loops have the correct number of iterations
assert "do widx2 = 1, jstop - jstart + 1, 1" in out
assert "do widx1 = 1, 4 - 2 + 1, 1" in out
# Check that the indexing into the mask expression uses the lower bounds
# specified in the original slice.
assert "if (picefr(2 + widx1 - 1,jstart + widx2 - 1) > 1.e-10)" in out
assert ("zevap_ice(LBOUND(zevap_ice, dim=1) + widx1 - 1,"
"LBOUND(zevap_ice, dim=2) + widx2 - 1,1) = 0.0" in out)


def test_where_body_containing_sum_with_dim(fortran_reader, fortran_writer):
Expand Down Expand Up @@ -536,7 +616,9 @@ def test_elsewhere():
assert isinstance(ifblock.condition, BinaryOperation)
assert ifblock.condition.operator == BinaryOperation.Operator.GT
assert (ifblock.condition.debug_string() ==
"ptsu(widx1,widx2,widx3) > 10._wp")
"ptsu(LBOUND(ptsu, dim=1) + widx1 - 1,"
"LBOUND(ptsu, dim=2) + widx2 - 1,"
"LBOUND(ptsu, dim=3) + widx3 - 1) > 10._wp")
# Check that this IF block has an else body which contains another IF
assert ifblock.else_body is not None
ifblock2 = ifblock.else_body[0]
Expand All @@ -545,7 +627,9 @@ def test_elsewhere():
assert isinstance(ifblock2.condition, BinaryOperation)
assert ifblock2.condition.operator == BinaryOperation.Operator.LT
assert (ifblock2.condition.debug_string() ==
"ptsu(widx1,widx2,widx3) < 0.0_wp")
"ptsu(LBOUND(ptsu, dim=1) + widx1 - 1,"
"LBOUND(ptsu, dim=2) + widx2 - 1,"
"LBOUND(ptsu, dim=3) + widx3 - 1) < 0.0_wp")
# Check that this IF block too has an else body
assert isinstance(ifblock2.else_body[0], Assignment)
# Check that we have three assignments of the correct form and with the
Expand All @@ -554,10 +638,10 @@ def test_elsewhere():
assert len(assigns) == 3
for assign in assigns:
assert isinstance(assign.lhs, ArrayReference)
refs = assign.lhs.walk(Reference)
assert len(refs) == 4
assert (assign.lhs.debug_string() ==
"z1_st(widx1,widx2,widx3)")
"z1_st(LBOUND(z1_st, dim=1) + widx1 - 1,"
"LBOUND(z1_st, dim=2) + widx2 - 1,"
"LBOUND(z1_st, dim=3) + widx3 - 1)")
assert isinstance(assign.parent.parent, IfBlock)

assert isinstance(assigns[0].rhs, BinaryOperation)
Expand Down Expand Up @@ -694,5 +778,4 @@ def test_where_derived_type(fortran_reader, fortran_writer, code, size_arg):
# All ArrayMember accesses should now use the `widx1` loop variable
array_members = loops[0].walk(ArrayMember)
for member in array_members:
idx_refs = member.indices[0].walk(Reference)
assert all([idx.name == "widx1" for idx in idx_refs])
assert "+ widx1 - 1" in member.indices[0].debug_string()

0 comments on commit 7fa7249

Please sign in to comment.