Skip to content

Commit

Permalink
#2125 get full coverage of datatypes
Browse files Browse the repository at this point in the history
  • Loading branch information
arporter committed Jul 5, 2024
1 parent 817ea50 commit 0f73b67
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/psyclone/psyir/symbols/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ def update_symbols_from(self, table):
try:
new_type = table.lookup(component.datatype.name)
except KeyError:
pass
new_type = component.datatype
else:
component.datatype.update_symbols_from(table)
new_type = component.datatype
Expand Down
50 changes: 50 additions & 0 deletions src/psyclone/tests/psyir/symbols/datatype_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,35 @@ def test_unsupported_fortran_type_copy(fortran_reader):
v2copy.partial_datatype.shape[0].lower.symbol)


def test_unsupported_fortran_type_update_symbols():
'''Test that update_symbols_from() correctly updates Symbols in
the _partial_datatype.
'''
decl = "type(some_type), dimension(nelem) :: var"
stype = DataTypeSymbol("some_type", UnresolvedType())
nelem = DataSymbol("nelem", INTEGER_TYPE)
ptype = ArrayType(stype, [Reference(nelem)])
utype = UnsupportedFortranType(decl, partial_datatype=ptype)
table = SymbolTable()
utype.update_symbols_from(table)
assert utype.partial_datatype.shape[0].upper.symbol is nelem
newnelem = nelem.copy()
table.add(newnelem)
utype.update_symbols_from(table)
assert utype.partial_datatype.shape[0].upper.symbol is newnelem
wp = DataSymbol("wp", INTEGER_TYPE)
ptype2 = ScalarType(ScalarType.Intrinsic.REAL, wp)
decl2 = "real(kind=wp), pointer :: var"
stype2 = UnsupportedFortranType(decl2, partial_datatype=ptype2)
stype2.update_symbols_from(table)
assert stype2.partial_datatype.precision is wp
newp = wp.copy()
table.add(newp)
stype2.update_symbols_from(table)
assert stype2.partial_datatype.precision is newp


# StructureType tests

def test_structure_type():
Expand Down Expand Up @@ -964,3 +993,24 @@ def test_structuretype_eq():
("peggy", REAL_TYPE, Symbol.Visibility.PRIVATE,
Literal("1.0", REAL_TYPE)),
("roger", INTEGER_TYPE, Symbol.Visibility.PUBLIC, None)])


def test_structuretype_update_symbols():
'''Test that update_symbols_from() correctly updates any Symbols referred
to within a StructureType.
'''
tsymbol = DataTypeSymbol("my_type", UnresolvedType())
stype = StructureType.create([
("fred", INTEGER_TYPE, Symbol.Visibility.PUBLIC, None),
("george", REAL_TYPE, Symbol.Visibility.PRIVATE,
Literal("1.0", REAL_TYPE)),
("barry", tsymbol, Symbol.Visibility.PUBLIC, None)])
table = SymbolTable()
assert stype.components["barry"].datatype is tsymbol
stype.update_symbols_from(table)
assert stype.components["barry"].datatype is tsymbol
newtsymbol = DataTypeSymbol("my_type", UnresolvedType())
table.add(newtsymbol)
stype.update_symbols_from(table)
assert stype.components["barry"].datatype is newtsymbol

0 comments on commit 0f73b67

Please sign in to comment.