Skip to content

Commit

Permalink
Fix known bugs scalarizing NEMO
Browse files Browse the repository at this point in the history
  • Loading branch information
LonelyCat124 committed Jan 31, 2025
1 parent 956c164 commit 250c259
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
16 changes: 13 additions & 3 deletions src/psyclone/psyir/transformations/scalarization_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@

from psyclone.core import VariablesAccessInfo, Signature
from psyclone.psyGen import Kern
from psyclone.psyir.nodes import Call, CodeBlock, \
Loop, Reference, Routine
from psyclone.psyir.symbols import DataSymbol, RoutineSymbol
from psyclone.psyir.nodes import Call, CodeBlock, \
Loop, Reference, Routine, StructureReference
from psyclone.psyir.symbols import DataSymbol, DataTypeSymbol, RoutineSymbol
from psyclone.psyir.transformations.loop_trans import LoopTrans


Expand Down Expand Up @@ -117,6 +117,16 @@ def _is_local_array(signature: Signature,
base_symbol = var_accesses[signature].all_accesses[0].node.symbol
if not base_symbol.is_automatic:
return False
# If its a derived type then we don't scalarize.
if isinstance(var_accesses[signature].all_accesses[0].node,
StructureReference):
return False
# Find the containing routine
rout = var_accesses[signature].all_accesses[0].node.ancestor(Routine)
# If the array is the return symbol then its not a local
# array symbol
if base_symbol is rout.return_symbol:
return False

return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,27 @@


def test_scalararizationtrans_is_local_array(fortran_reader):
code = '''subroutine test(a)
use mymod, only: arr
code = '''function test(a) result(x)
use mymod, only: arr, atype
integer :: i
integer :: k
real, dimension(1:100) :: local
real, dimension(1:100) :: a
character(2), dimension(1:100) :: b
real, dimension(1:100) :: x
type(atype) :: custom
type(atype), dimension(1:100) :: custom2
do i = 1, 100
arr(i) = i
a(i) = i
local(i) = i
b(i) = b(i) // "c"
x(i) = i
custom%type(i) = i
custom2(i)%typeb(i) = i
end do
end subroutine'''
end function'''
psyir = fortran_reader.psyir_from_source(code)
node = psyir.children[0].children[0]
var_accesses = VariablesAccessInfo(nodes=node.loop_body)
Expand All @@ -82,6 +88,20 @@ def test_scalararizationtrans_is_local_array(fortran_reader):
assert not ScalarizationTrans._is_local_array(keys[4],
var_accesses)

# Test x - the return value is not classed as a local array.
assert var_accesses[keys[5]].var_name == "x"
assert not ScalarizationTrans._is_local_array(keys[5],
var_accesses)

# Test custom - we don't scalarize derived types.
assert var_accesses[keys[6]].var_name == "custom%type"
assert not ScalarizationTrans._is_local_array(keys[6],
var_accesses)
# Test custom2 - we don't scalarize derived types.
assert var_accesses[keys[7]].var_name == "custom2%typeb"
assert not ScalarizationTrans._is_local_array(keys[7],
var_accesses)

# Test filter behaviour same as used in the transformation
local_arrays = filter(
lambda sig: ScalarizationTrans._is_local_array(sig, var_accesses),
Expand Down

0 comments on commit 250c259

Please sign in to comment.