Skip to content

Commit

Permalink
Merge pull request #41743 from nipunayf/fix-narrow-scope
Browse files Browse the repository at this point in the history
Look up closure symbols with the original symbol
  • Loading branch information
KavinduZoysa authored Dec 18, 2023
2 parents 28e7a41 + a0d5760 commit fe783fc
Show file tree
Hide file tree
Showing 20 changed files with 333 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1242,8 +1242,7 @@ private void updateClosureVariable(BVarSymbol varSymbol, BLangInvokableNode encI
!flagSet.contains(Flag.ATTACHED) && varSymbol.owner.tag != SymTag.PACKAGE;
if (!varSymbol.closure && isClosure) {
SymbolEnv encInvokableEnv = findEnclosingInvokableEnv(env, encInvokable);
BSymbol resolvedSymbol =
symResolver.lookupClosureVarSymbol(encInvokableEnv, varSymbol.name, SymTag.VARIABLE);
BSymbol resolvedSymbol = symResolver.lookupClosureVarSymbol(encInvokableEnv, varSymbol);
if (resolvedSymbol != symTable.notFoundSymbol) {
varSymbol.closure = true;
((BLangFunction) encInvokable).closureVarSymbols.add(new ClosureVarSymbol(varSymbol, pos));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5876,8 +5876,7 @@ private SymbolEnv findEnclosingInvokableEnv(SymbolEnv env, BLangInvokableNode en
private void updateClosureVariable(BVarSymbol varSymbol, BLangInvokableNode encInvokable, Location pos) {
if (!varSymbol.closure) {
SymbolEnv encInvokableEnv = findEnclosingInvokableEnv(env, encInvokable);
BSymbol resolvedSymbol =
symResolver.lookupClosureVarSymbol(encInvokableEnv, varSymbol.name, SymTag.VARIABLE);
BSymbol resolvedSymbol = symResolver.lookupClosureVarSymbol(encInvokableEnv, varSymbol);
if (resolvedSymbol != symTable.notFoundSymbol) {
varSymbol.closure = true;
((BLangFunction) encInvokable).closureVarSymbols.add(new ClosureVarSymbol(varSymbol, pos));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.ballerinalang.model.TreeBuilder;
import org.ballerinalang.model.clauses.OrderKeyNode;
import org.ballerinalang.model.elements.Flag;
import org.ballerinalang.model.symbols.SymbolKind;
import org.ballerinalang.model.tree.IdentifierNode;
import org.ballerinalang.model.tree.NodeKind;
import org.ballerinalang.model.tree.OperatorKind;
Expand Down Expand Up @@ -1903,19 +1904,23 @@ public void visit(BLangErrorVarRef varRefExpr) {
@Override
public void visit(BLangSimpleVarRef bLangSimpleVarRef) {
BSymbol symbol = bLangSimpleVarRef.symbol;
if (symbol == null) {
result = bLangSimpleVarRef;
return;
}
if (symbol.kind == SymbolKind.VARIABLE || symbol.kind == SymbolKind.FUNCTION) {
BVarSymbol originalSymbol = ((BVarSymbol) symbol).originalSymbol;
if (originalSymbol != null) {
symbol = originalSymbol;
}
}
BSymbol resolvedSymbol = symResolver.lookupClosureVarSymbol(env, symbol);
String identifier = bLangSimpleVarRef.variableName == null ? String.valueOf(bLangSimpleVarRef.varSymbol.name) :
String.valueOf(bLangSimpleVarRef.variableName);
BSymbol resolvedSymbol = symResolver.lookupClosureVarSymbol(env,
Names.fromString(identifier), SymTag.VARIABLE);

// check whether the symbol and resolved symbol are the same.
// because, lookup using name produce unexpected results if there's variable shadowing.
if (symbol != null && symbol != resolvedSymbol && !FRAME_PARAMETER_NAME.equals(identifier)) {
if (symbol instanceof BVarSymbol) {
BVarSymbol originalSymbol = ((BVarSymbol) symbol).originalSymbol;
if (originalSymbol != null) {
symbol = originalSymbol;
}
}
if (symbol != resolvedSymbol && !FRAME_PARAMETER_NAME.equals(identifier)) {
if ((withinLambdaOrArrowFunc || queryEnv == null || !queryEnv.scope.entries.containsKey(symbol.name))
&& !identifiers.containsKey(identifier)) {
Location pos = currentQueryLambdaBody.pos;
Expand Down Expand Up @@ -1960,15 +1965,8 @@ public void visit(BLangSimpleVarRef bLangSimpleVarRef) {
bLangSimpleVarRef.symbol = symbol;
bLangSimpleVarRef.varSymbol = symbol;
}
} else if (resolvedSymbol != symTable.notFoundSymbol && symbol != null) {
} else if (!resolvedSymbol.closure && resolvedSymbol != symTable.notFoundSymbol) {
resolvedSymbol.closure = true;
// When there's a type guard, there can be a enclSymbol before type narrowing.
// So, we have to mark that as a closure as well.
BSymbol enclSymbol = symResolver.lookupClosureVarSymbol(env.enclEnv,
Names.fromString(identifier), SymTag.VARIABLE);
if (enclSymbol != null && enclSymbol != symTable.notFoundSymbol) {
enclSymbol.closure = true;
}
}
result = bLangSimpleVarRef;
}
Expand Down Expand Up @@ -2591,8 +2589,8 @@ public void visit(BLangCollectClause collectClause) {

void updateIdentifiers(SymbolEnv env) {
for (Map.Entry<String, BSymbol> identifier : identifiers.entrySet()) {
BSymbol symbol = symResolver.lookupClosureVarSymbol(env, Names.fromString(identifier.getKey()),
SymTag.SEQUENCE);
BSymbol symbol =
symResolver.lookupSymbolInGivenScope(env, Names.fromString(identifier.getKey()), SymTag.SEQUENCE);
if (symbol != symTable.notFoundSymbol && !identifier.getValue().closure) {
identifiers.put(identifier.getKey(), symbol);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -815,18 +815,13 @@ public BSymbol lookupLangLibMethod(BType type, Name name, SymbolEnv env) {
* Recursively analyse the symbol env to find the closure variable symbol that is being resolved.
*
* @param env symbol env to analyse and find the closure variable.
* @param name name of the symbol to lookup
* @param expSymTag symbol tag
* @param symbol symbol to lookup
* @return closure symbol wrapper along with the resolved count
*/
public BSymbol lookupClosureVarSymbol(SymbolEnv env, Name name, long expSymTag) {
ScopeEntry entry = env.scope.lookup(name);
public BSymbol lookupClosureVarSymbol(SymbolEnv env, BSymbol symbol) {
ScopeEntry entry = env.scope.lookup(symbol.name);
while (entry != NOT_FOUND_ENTRY) {
if (symTable.rootPkgSymbol.pkgID.equals(entry.symbol.pkgID) &&
(entry.symbol.tag & SymTag.VARIABLE_NAME) == SymTag.VARIABLE_NAME) {
return entry.symbol;
}
if ((entry.symbol.tag & expSymTag) == expSymTag && !isFieldRefFromWithinARecord(entry.symbol, env)) {
if (entry.symbol == symbol) {
return entry.symbol;
}
entry = entry.next;
Expand All @@ -836,7 +831,7 @@ public BSymbol lookupClosureVarSymbol(SymbolEnv env, Name name, long expSymTag)
return symTable.notFoundSymbol;
}

return lookupClosureVarSymbol(env.enclEnv, name, expSymTag);
return lookupClosureVarSymbol(env.enclEnv, symbol);
}

public BSymbol lookupMainSpaceSymbolInPackage(Location pos,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6615,8 +6615,7 @@ protected void markAndRegisterClosureVariable(BSymbol symbol, Location pos, Symb
BLangFunction currentFunc = (BLangFunction) encInvokable;
if ((currentFunc != null) && !currentFunc.attachedFunction &&
!(currentFunc.symbol.receiverSymbol == symbol)) {
BSymbol resolvedSymbol = symResolver.lookupClosureVarSymbol(oceData.capturedClosureEnv, symbol.name,
SymTag.VARIABLE);
BSymbol resolvedSymbol = symResolver.lookupClosureVarSymbol(oceData.capturedClosureEnv, symbol);
if (resolvedSymbol != symTable.notFoundSymbol && !resolvedSymbol.closure) {
if (resolvedSymbol.owner.getKind() != SymbolKind.PACKAGE) {
updateObjectCtorClosureSymbols(pos, currentFunc, resolvedSymbol, classDef, data);
Expand Down Expand Up @@ -6647,8 +6646,7 @@ protected void markAndRegisterClosureVariable(BSymbol symbol, Location pos, Symb
return;
}
SymbolEnv encInvokableEnv = findEnclosingInvokableEnv(env, encInvokable);
BSymbol resolvedSymbol = symResolver.lookupClosureVarSymbol(encInvokableEnv, symbol.name,
SymTag.VARIABLE);
BSymbol resolvedSymbol = symResolver.lookupClosureVarSymbol(encInvokableEnv, symbol);
BLangClassDefinition classDef = (BLangClassDefinition) node;
if (resolvedSymbol != symTable.notFoundSymbol) {
if (resolvedSymbol.owner.getKind() == SymbolKind.PACKAGE) {
Expand Down Expand Up @@ -6678,8 +6676,7 @@ private boolean searchClosureVariableInExpressions(BSymbol symbol, Location pos,
if (encInvokable != null && encInvokable.flagSet.contains(Flag.LAMBDA)
&& !isFunctionArgument(symbol, encInvokable.requiredParams)) {
SymbolEnv encInvokableEnv = findEnclosingInvokableEnv(env, encInvokable);
BSymbol resolvedSymbol =
symResolver.lookupClosureVarSymbol(encInvokableEnv, symbol.name, SymTag.VARIABLE);
BSymbol resolvedSymbol = symResolver.lookupClosureVarSymbol(encInvokableEnv, symbol);
if (resolvedSymbol != symTable.notFoundSymbol && !encInvokable.flagSet.contains(Flag.ATTACHED)) {
resolvedSymbol.closure = true;
((BLangFunction) encInvokable).closureVarSymbols.add(new ClosureVarSymbol(resolvedSymbol, pos));
Expand All @@ -6690,8 +6687,7 @@ private boolean searchClosureVariableInExpressions(BSymbol symbol, Location pos,
if (bLangNode.getKind() == NodeKind.ARROW_EXPR
&& !isFunctionArgument(symbol, ((BLangArrowFunction) bLangNode).params)) {
SymbolEnv encInvokableEnv = findEnclosingInvokableEnv(env, encInvokable);
BSymbol resolvedSymbol =
symResolver.lookupClosureVarSymbol(encInvokableEnv, symbol.name, SymTag.VARIABLE);
BSymbol resolvedSymbol = symResolver.lookupClosureVarSymbol(encInvokableEnv, symbol);
if (resolvedSymbol != symTable.notFoundSymbol) {
resolvedSymbol.closure = true;
((BLangArrowFunction) bLangNode).closureVarSymbols.add(new ClosureVarSymbol(resolvedSymbol, pos));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ public void testUnionType(String function) {
}

@DataProvider(name = "unionTestFunctions")
public Object[][] unionTestFunctions() {
return new Object[][]{
{"testUnionPositive"},
{"testUnionNegative"},
{"testUnionRuntimeToString"}
public Object[] unionTestFunctions() {
return new Object[]{
"testUnionPositive",
"testUnionNegative",
"testUnionRuntimeToString",
"testTernaryWithQueryForModuleImportedVariable"
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ public void testTypeNarrowingWithClosure() {
30, 25);
BAssertUtil.validateError(compileResult, index++, "operator '+' not defined for '(int|string)' and 'int'",
31, 17);
BAssertUtil.validateError(compileResult, index++, "incompatible types: expected 'int', found '(int|string)'",
44, 17);
BAssertUtil.validateError(compileResult, index++,
"operator '+' not defined for '(int|string)' and '(int|boolean|error)'", 56, 21);
Assert.assertEquals(compileResult.getErrorCount(), index);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

/**
Expand Down Expand Up @@ -143,6 +144,19 @@ public void bitwiseAndTest() {
Assert.assertEquals(((Byte) returns.get(3)).longValue(), b & d);
}

@Test(description = "Test binary expression with query", dataProvider = "binaryExpressionWithQueryDataProvider")
public void binaryExpressionWithQuery(String fnName) {
BRunUtil.invoke(result, fnName, new Object[]{});
}

@DataProvider(name = "binaryExpressionWithQueryDataProvider")
public Object[] binaryExpressionWithQueryData() {
return new Object[] {
"binaryAndWithQuery",
"binaryOrWithQuery"
};
}

@AfterClass
public void tearDown() {
result = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ public void testElvisExprWithUnionWithFiniteTypeContainingNull() {
BRunUtil.invoke(compileResult, "testElvisExprWithUnionWithFiniteTypeContainingNull");
}

@Test
public void testElvisExprWithQuery() {
BRunUtil.invoke(compileResult, "testElvisExprWithQuery");
}

@Test(description = "Negative test cases.")
public void testElvisOperatorNegative() {
int index = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,41 @@ public void testTernaryInModuleLevel() {
BRunUtil.invoke(compileResult, "testTernaryInModuleLevel");
}

@Test
public void testTernaryWithQueryWithLocalVariable() {
BRunUtil.invoke(compileResult, "testTernaryWithQueryWithLocalVariable");
}

@Test
public void testTernaryWithQueryWithFunctionParameter() {
BRunUtil.invoke(compileResult, "testTernaryWithQueryWithFunctionParameter");
}

@Test
public void testTernaryWithQueryWithTypeDef() {
BRunUtil.invoke(compileResult, "testTernaryWithQueryWithTypeDef");
}

@Test
public void testTernaryWithQueryWithModuleVariable() {
BRunUtil.invoke(compileResult, "testTernaryWithQueryWithModuleVariable");
}

@Test
public void testTernaryWithQueryForTwoVariables() {
BRunUtil.invoke(compileResult, "testTernaryWithQueryForTwoVariables");
}

@Test
public void testTernaryWithQueryWithFunctionPointers() {
BRunUtil.invoke(compileResult, "testTernaryWithQueryWithFunctionPointers");
}

@Test
public void testTernaryWithQueryWithFunctionAsClosure() {
BRunUtil.invoke(compileResult, "testTernaryWithQueryWithFunctionAsClosure");
}

@Test(description = "Test type narrowing for ternary expression")
public void testTernaryTypeNarrow() {
CompileResult compileResult = BCompileUtil.compile("test-src/expressions/ternary/ternary_expr_type_narrow.bal");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ public void testNarrowTypeInListBindingPattern3() {
BRunUtil.invoke(result, "testNarrowTypeInListBindingPattern3");
}

@Test
public void testMatchClauseWithQuery() {
BRunUtil.invoke(result, "testMatchClauseWithQuery");
}

@Test(dataProvider = "dataToTestMatchClauseWithTypeGuard", description = "Test match clause with type guard")
public void testMatchClauseWithTypeGuard(String functionName) {
BRunUtil.invoke(result, functionName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ public void testMultipleReceiveAction() {
BAssertUtil.validateError(result, 0, "multiple receive action not yet supported", 23, 25);
}

@Test
public void testWorkerWithQuery() {
BRunUtil.invoke(result, "testWorkerWithQuery", new Object[0]);
}

@AfterClass
public void tearDown() {
result = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ function testUnionRuntimeToString() {
<string> checkpanic err.detail()["message"]);
}

function testTernaryWithQueryForModuleImportedVariable() {
int|int[] thenResult = foo:IntOrNull is int ?
from var _ in [1, 2]
where foo:IntOrNull + 2 == 5
select 2 : 2;
assertEquals([2,2], thenResult);

int|int[] elseResult = foo:IntOrNull is () ? 2 :
from var _ in [1, 2]
where foo:IntOrNull + 2 == 5
select 2;
assertEquals([2,2], elseResult);
}

function assertTrue(anydata actual) {
return assertEquals(true, actual);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ public enum BazQux {
BAZ,
QUX
}

public int? IntOrNull = 3;
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,26 @@ function testTypeNarrowingWithClosure2() returns int|string {
};
return x;
}

function testBasicClosureWithInvalidTypeNarrowing() {
int|string a = "32";
var fn = function () {
int b;
if a is int {
b = a;
}
};
}

function testMultiLevelClosureWithInvalidTypeNarrowing() {
int|string a = "32";
var fn1 = function() {
int|boolean|error b = 32;
var fn2 = function(int|string|boolean c) {
int d;
if a is int && b is int && c is int {
d = a + b + c;
}
};
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,26 @@ function bitwiseAnd(int a, int b, byte c, byte d) returns [int, byte, byte, byte
res [3] = b & d;
return res;
}

function binaryAndWithQuery() {
int? i = 3;
boolean result = i is int && (from var _ in [1, 2]
where i + 2 == 5
select 2) == [2, 2];
assertTrue(result);
}

function binaryOrWithQuery() {
int? i = 3;
boolean result = i is () || (from var _ in [1, 2]
where i + 2 == 5
select 2) == [2, 2];
assertTrue(result);
}

function assertTrue(boolean actual) {
if actual {
return;
}
panic error(string `expected 'true', found '${actual.toString()}'`);
}
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,15 @@ function testElvisExprWithUnionWithFiniteTypeContainingNull() {
assertEquals(23, f);
}

function testElvisExprWithQuery() {
int? i = ();
int|int[] res = i ?:
from var _ in [1, 2]
where i == ()
select 2;
assertEquals([2,2], res);
}

const ASSERTION_ERROR_REASON = "AssertionError";

function assertTrue(anydata actual) {
Expand Down
Loading

0 comments on commit fe783fc

Please sign in to comment.