diff --git a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/Desugar.java b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/Desugar.java index 85178f938f34..234c8aae715f 100644 --- a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/Desugar.java +++ b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/Desugar.java @@ -341,6 +341,7 @@ import static org.wso2.ballerinalang.compiler.desugar.ASTBuilderUtil.createVariable; import static org.wso2.ballerinalang.compiler.desugar.ASTBuilderUtil.createVariableRef; import static org.wso2.ballerinalang.compiler.util.CompilerUtils.getMajorVersion; +import static org.wso2.ballerinalang.compiler.util.CompilerUtils.isAssignmentToOptionalField; import static org.wso2.ballerinalang.compiler.util.Names.GENERATED_INIT_SUFFIX; import static org.wso2.ballerinalang.compiler.util.Names.GEN_VAR_PREFIX; import static org.wso2.ballerinalang.compiler.util.Names.IGNORE; @@ -2476,17 +2477,31 @@ private void createSimpleVarDefStmt(BLangSimpleVariable simpleVariable, BLangBlo @Override public void visit(BLangAssignment assignNode) { - boolean fieldAccessLVExpr = assignNode.varRef.getKind() == NodeKind.FIELD_BASED_ACCESS_EXPR; + // We rewrite the varRef of the BLangAssignment to a IndexBasedAssignment if it is a FieldBasedAssignment. + // Therefore we must do the shouldWidenExpressionTypeWithNil check before that. + boolean addNilToCastingType = shouldWidenExpressionTypeWithNil(assignNode); assignNode.varRef = rewriteExpr(assignNode.varRef); assignNode.expr = rewriteExpr(assignNode.expr); BType castingType = assignNode.varRef.getBType(); - if (fieldAccessLVExpr) { + if (addNilToCastingType) { castingType = types.addNilForNillableAccessType(castingType); } assignNode.expr = types.addConversionExprIfRequired(rewriteExpr(assignNode.expr), castingType); result = assignNode; } + private static boolean shouldWidenExpressionTypeWithNil(BLangAssignment assignNode) { + if (!assignNode.expr.getBType().isNullable() || !isAssignmentToOptionalField(assignNode)) { + return false; + } + // If we are assigning to an optional field we have a field based access on a record + BLangFieldBasedAccess fieldAccessNode = (BLangFieldBasedAccess) assignNode.varRef; + BRecordType recordType = (BRecordType) Types.getImpliedType(fieldAccessNode.expr.getBType()); + BField field = recordType.fields.get(fieldAccessNode.field.value); + BType fieldType = Types.getImpliedType(field.getType()); + return TypeTags.isSimpleBasicType(fieldType.tag); + } + @Override public void visit(BLangTupleDestructure tupleDestructure) { // case 1: diff --git a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/semantics/analyzer/SemanticAnalyzer.java b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/semantics/analyzer/SemanticAnalyzer.java index 7cb3e4f838d4..cf9cd9884090 100644 --- a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/semantics/analyzer/SemanticAnalyzer.java +++ b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/semantics/analyzer/SemanticAnalyzer.java @@ -118,7 +118,6 @@ import org.wso2.ballerinalang.compiler.tree.expressions.BLangConstant; import org.wso2.ballerinalang.compiler.tree.expressions.BLangErrorVarRef; import org.wso2.ballerinalang.compiler.tree.expressions.BLangExpression; -import org.wso2.ballerinalang.compiler.tree.expressions.BLangFieldBasedAccess; import org.wso2.ballerinalang.compiler.tree.expressions.BLangInvocation; import org.wso2.ballerinalang.compiler.tree.expressions.BLangLambdaFunction; import org.wso2.ballerinalang.compiler.tree.expressions.BLangLetExpression; @@ -246,6 +245,7 @@ import static org.ballerinalang.model.tree.NodeKind.RECORD_LITERAL_EXPR; import static org.ballerinalang.model.tree.NodeKind.REG_EXP_CAPTURING_GROUP; import static org.ballerinalang.model.tree.NodeKind.REG_EXP_CHARACTER_CLASS; +import static org.wso2.ballerinalang.compiler.util.CompilerUtils.isAssignmentToOptionalField; /** * @since 0.94 @@ -2318,15 +2318,20 @@ public void visit(BLangAssignment assignNode, AnalyzerData data) { validateFunctionVarRef(varRef, data); checkInvalidTypeDef(varRef); - if (varRef.getKind() == NodeKind.FIELD_BASED_ACCESS_EXPR && data.expType.tag != TypeTags.SEMANTIC_ERROR) { - BLangFieldBasedAccess fieldBasedAccessVarRef = (BLangFieldBasedAccess) varRef; - int varRefTypeTag = Types.getImpliedType(fieldBasedAccessVarRef.expr.getBType()).tag; - if (varRefTypeTag == TypeTags.RECORD && Symbols.isOptional(fieldBasedAccessVarRef.symbol)) { - data.expType = types.addNilForNillableAccessType(data.expType); - } + BType actualExpectedType = null; + // For optional field assignments we add nil to the expected type before doing type checking in order to get + // the type in error messages correct. But we don't need an implicit conversion since desugar will add a + // cast if needed. + if (data.expType != symTable.semanticError && isAssignmentToOptionalField(assignNode)) { + actualExpectedType = data.expType; + data.expType = types.addNilForNillableAccessType(actualExpectedType); } - data.typeChecker.checkExpr(assignNode.expr, data.env, data.expType, data.prevEnvs, data.commonAnalyzerData); + BLangExpression expr = assignNode.expr; + data.typeChecker.checkExpr(expr, data.env, data.expType, data.prevEnvs, data.commonAnalyzerData); + if (actualExpectedType != null && expr.impConversionExpr != null) { + data.typeChecker.resetImpConversionExpr(expr, expr.getBType(), actualExpectedType); + } validateWorkerAnnAttachments(assignNode.expr, data); diff --git a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/semantics/analyzer/TypeChecker.java b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/semantics/analyzer/TypeChecker.java index 0c89d01d0684..e029c63930c6 100644 --- a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/semantics/analyzer/TypeChecker.java +++ b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/semantics/analyzer/TypeChecker.java @@ -6506,7 +6506,7 @@ protected void visitCheckAndCheckPanicExpr(BLangCheckedExpr checkedExpr, Analyze data.resultType = types.checkType(checkedExpr, actualType, data.expType); } - private void resetImpConversionExpr(BLangExpression expr, BType actualType, BType targetType) { + protected void resetImpConversionExpr(BLangExpression expr, BType actualType, BType targetType) { expr.impConversionExpr = null; types.setImplicitCastExpr(expr, actualType, targetType); } diff --git a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/util/CompilerUtils.java b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/util/CompilerUtils.java index 94bc63139e90..c78ed10ae522 100644 --- a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/util/CompilerUtils.java +++ b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/util/CompilerUtils.java @@ -19,10 +19,18 @@ import org.ballerinalang.compiler.CompilerOptionName; import org.ballerinalang.model.elements.PackageID; +import org.ballerinalang.model.tree.NodeKind; +import org.wso2.ballerinalang.compiler.semantics.analyzer.Types; import org.wso2.ballerinalang.compiler.semantics.model.symbols.BSymbol; import org.wso2.ballerinalang.compiler.semantics.model.symbols.Symbols; +import org.wso2.ballerinalang.compiler.semantics.model.types.BField; +import org.wso2.ballerinalang.compiler.semantics.model.types.BRecordType; +import org.wso2.ballerinalang.compiler.semantics.model.types.BType; import org.wso2.ballerinalang.compiler.tree.BLangFunction; +import org.wso2.ballerinalang.compiler.tree.BLangNode; import org.wso2.ballerinalang.compiler.tree.BLangSimpleVariable; +import org.wso2.ballerinalang.compiler.tree.expressions.BLangFieldBasedAccess; +import org.wso2.ballerinalang.compiler.tree.statements.BLangAssignment; import java.util.List; @@ -82,4 +90,17 @@ public static boolean isInParameterList(BSymbol symbol, List functions = Arrays.asList("setRequiredField", "setNillableField", "setOptionalField"); + result.getExpectedBIR().functions.stream().filter(function -> functions.contains(function.name.value)) + .forEach(this::assertFunctions); + } + + private void assertFunctions(BIRNode.BIRFunction function) { + String actual = BIREmitter.emitFunction(function, 0); + String expected = null; + try { + expected = readFile(function.name.value); + } catch (IOException e) { + Assert.fail("Failed to read the expected BIR file for function: " + function.name.value, e); + } + Assert.assertEquals(actual, expected); + } + + private String readFile(String name) throws IOException { + // The files in the bir-dump folder are named with the function name and contain the expected bir dump for + // the function + Path filePath = Paths.get("src", "test", "resources", "test-src", "bir", "bir-dump", name).toAbsolutePath(); + if (Files.exists(filePath)) { + StringBuilder contentBuilder = new StringBuilder(); + + Stream stream = Files.lines(filePath, StandardCharsets.UTF_8); + stream.forEach(s -> contentBuilder.append(s).append("\n")); + + return contentBuilder.toString().trim(); + } + Assert.fail("Expected BIR file not found for function: " + name); + return null; + } + + @AfterClass + public void tearDown() { + result = null; + } +} diff --git a/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/expressions/access/OptionalFieldAccessTest.java b/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/expressions/access/OptionalFieldAccessTest.java index 5ba1e5aa73b8..df415d6c4102 100644 --- a/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/expressions/access/OptionalFieldAccessTest.java +++ b/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/expressions/access/OptionalFieldAccessTest.java @@ -110,7 +110,10 @@ public Object[][] recordOptionalFieldAccessFunctions2() { { "testUnavailableFinalAccessInNestedAccess" }, { "testAvailableFinalAccessInNestedAccess" }, { "testUnavailableIntermediateAccessInNestedAccess" }, - { "testNilValuedFinalAccessInNestedAccess" } + { "testNilValuedFinalAccessInNestedAccess" }, + { "testSubtypeAssignment" }, + { "testUnionAssignment" }, + { "testNullableAssignment" } }; } @@ -167,6 +170,20 @@ public void testOptionalFieldAccessOnMethodCall() { BRunUtil.invoke(result, "testOptionalFieldAccessOnMethodCall"); } + @Test(dataProvider = "optionalFieldRemovalFunctions") + public void testOptionalFieldRemoval(String function) { + BRunUtil.invoke(result, function); + } + + @DataProvider(name = "optionalFieldRemovalFunctions") + public Object[][] optionalFieldRemovalFunctions() { + return new Object[][]{ + {"testOptionalFieldRemovalBasicType"}, + {"testOptionalFieldRemovalIndirect"}, + {"testOptionalFieldRemovalComplex"} + }; + } + @Test public void testNestedOptionalFieldAccessOnIntersectionTypes() { BRunUtil.invoke(result, "testNestedOptionalFieldAccessOnIntersectionTypes"); diff --git a/tests/jballerina-unit-test/src/test/resources/test-src/bir/bir-dump/setNillableField b/tests/jballerina-unit-test/src/test/resources/test-src/bir/bir-dump/setNillableField new file mode 100644 index 000000000000..d216c36953cb --- /dev/null +++ b/tests/jballerina-unit-test/src/test/resources/test-src/bir/bir-dump/setNillableField @@ -0,0 +1,32 @@ +setNillableField function() -> () { + %0(RETURN) (); + %1(LOCAL) R2; + %2(TEMP) typeDesc; + %4(TEMP) string; + %5(TEMP) int|(); + %6(TEMP) int; + %12(TEMP) (); + + bb0 { + %2 = newType R2; + %4 = ConstLoad x; + %6 = ConstLoad 1; + %5 = %6; + %1 = NewMap %2{%4:%5}; + %6 = ConstLoad 2; + %5 = %6; + %4 = ConstLoad x; + %1[%4] = %5; + %12 = ConstLoad 0; + %5 = %12; + %4 = ConstLoad x; + %1[%4] = %5; + %0 = ConstLoad 0; + GOTO bb1; + } + bb1 { + return; + } + + +} diff --git a/tests/jballerina-unit-test/src/test/resources/test-src/bir/bir-dump/setOptionalField b/tests/jballerina-unit-test/src/test/resources/test-src/bir/bir-dump/setOptionalField new file mode 100644 index 000000000000..b792e069ca63 --- /dev/null +++ b/tests/jballerina-unit-test/src/test/resources/test-src/bir/bir-dump/setOptionalField @@ -0,0 +1,28 @@ +setOptionalField function() -> () { + %0(RETURN) (); + %1(LOCAL) R3; + %2(TEMP) typeDesc; + %4(TEMP) int; + %6(TEMP) string; + %7(TEMP) int|(); + %8(TEMP) (); + + bb0 { + %2 = newType R3; + %1 = NewMap %2{}; + %4 = ConstLoad 2; + %6 = ConstLoad x; + %1[%6] = %4; + %8 = ConstLoad 0; + %7 = %8; + %6 = ConstLoad x; + %1[%6] = %7; + %0 = ConstLoad 0; + GOTO bb1; + } + bb1 { + return; + } + + +} diff --git a/tests/jballerina-unit-test/src/test/resources/test-src/bir/bir-dump/setRequiredField b/tests/jballerina-unit-test/src/test/resources/test-src/bir/bir-dump/setRequiredField new file mode 100644 index 000000000000..24101ec4fd7a --- /dev/null +++ b/tests/jballerina-unit-test/src/test/resources/test-src/bir/bir-dump/setRequiredField @@ -0,0 +1,24 @@ +setRequiredField function() -> () { + %0(RETURN) (); + %1(LOCAL) R1; + %2(TEMP) typeDesc}>; + %4(TEMP) string; + %5(TEMP) int; + + bb0 { + %2 = newType R1; + %4 = ConstLoad x; + %5 = ConstLoad 1; + %1 = NewMap %2{%4:%5}; + %5 = ConstLoad 2; + %4 = ConstLoad x; + %1[%4] = %5; + %0 = ConstLoad 0; + GOTO bb1; + } + bb1 { + return; + } + + +} diff --git a/tests/jballerina-unit-test/src/test/resources/test-src/bir/record_desugar.bal b/tests/jballerina-unit-test/src/test/resources/test-src/bir/record_desugar.bal new file mode 100644 index 000000000000..2e0e71087812 --- /dev/null +++ b/tests/jballerina-unit-test/src/test/resources/test-src/bir/record_desugar.bal @@ -0,0 +1,28 @@ +type R1 record {| + int x; +|}; + +type R2 record {| + int? x; +|}; + +type R3 record {| + int x?; +|}; + +function setRequiredField() { + R1 r1 = {x: 1}; + r1.x = 2; +} + +function setNillableField() { + R2 r2 = {x: 1}; + r2.x = 2; + r2.x = (); +} + +function setOptionalField() { + R3 r3 = {}; + r3.x = 2; + r3.x = (); +} diff --git a/tests/jballerina-unit-test/src/test/resources/test-src/expressions/access/optional_field_access.bal b/tests/jballerina-unit-test/src/test/resources/test-src/expressions/access/optional_field_access.bal index bf3daab13bbe..ad35feb9ab00 100644 --- a/tests/jballerina-unit-test/src/test/resources/test-src/expressions/access/optional_field_access.bal +++ b/tests/jballerina-unit-test/src/test/resources/test-src/expressions/access/optional_field_access.bal @@ -40,6 +40,23 @@ type Bar record { decimal c; }; +type MyString string; + +type R1 record { + int a?; + float b?; + string c?; + boolean d?; + decimal e?; + string:Char f?; + MyString g?; +}; + +type R2 record { + int a; + R1 r1?; +}; + function testOptionalFieldAccessOnRequiredRecordField() returns boolean { string s = "Anne"; Employee e = { name: s, id: 100 }; @@ -47,6 +64,41 @@ function testOptionalFieldAccessOnRequiredRecordField() returns boolean { return name == s; } +function testOptionalFieldRemovalBasicType() { + R1 r = {a: 1, b: 2.0, c: "test", d: true, e: 3.0, f:"c", g: "test"}; + r.a = (); + r.b = (); + r.c = (); + r.d = (); + r.e = (); + r.f = (); + r.g = (); + assertFalse(r.hasKey("a")); + assertFalse(r.hasKey("b")); + assertFalse(r.hasKey("c")); + assertFalse(r.hasKey("d")); + assertFalse(r.hasKey("e")); + assertFalse(r.hasKey("f")); + assertFalse(r.hasKey("g")); +} + +function testOptionalFieldRemovalIndirect() { + R2 r = {a: 1, r1: {a: 1, b: 2.0, c: "test"}}; + r.r1.a = (); + r.r1.b = (); + r.r1.c = (); + R1 r1 = r.r1; + assertFalse(r1.hasKey("a")); + assertFalse(r1.hasKey("b")); + assertFalse(r1.hasKey("c")); +} + +function testOptionalFieldRemovalComplex() { + R2 r = {a: 1, r1: {a: 1, b: 2.0, c: "test"}}; + r.r1 = (); + assertFalse(r.hasKey("r1")); +} + function testOptionalFieldAccessOnRequiredRecordFieldInRecordUnion() returns boolean { Foo f = { a: 1, b: true }; Foo|Bar fb = f; @@ -570,10 +622,96 @@ public function testNestedOptionalFieldAccessOnIntersectionTypes() { assertEquality((), v5); } +type MyInt int; +type Rxx record { + int val; +}; +type Rxy record { + int val; + float otherVal; +}; +type Rx record {| + int a?; + MyInt b?; + Rxx c?; +|}; + +function testSubtypeAssignment() { + int:Signed16 init = 2; + Rx r = {a:init, b:init}; + int:Signed32 val = 5; + r.a = val; + assertEquality(val, r.a); + r.a = (); + assertFalse(r.hasKey("a")); + r.b = val; + assertEquality(val, r.b); + r.b = (); + assertFalse(r.hasKey("b")); + Rxy c = {val: 5, otherVal: 10.0}; + r.c = c; + assertEquality(c, r.c); + r.c = (); + assertFalse(r.hasKey("b")); +} + +function testNullableAssignment() { + Rx r = {}; + int? val = 12; + r.a = val; + assertEquality(val, r.a); + int? b = (); + r.a = b; + assertFalse(r.hasKey("a")); + Rxy? c = {val: 5, otherVal: 10.0}; + r.c = c; + assertEquality(c, r.c); + c = (); + r.c = c; + assertFalse(r.hasKey("c")); +} + +type MyUnion int|float|string; +type Ry record {| + int|float|string a?; + MyUnion b?; + int|Rxx c?; +|}; + +function testUnionAssignment() { + int init = 5; + Ry r = {a:init, b:init}; + int:Signed32 val = 5; + r.a = val; + assertEquality(val, r.a); + int|float val2 = 10; + r.a = val2; + assertEquality(val2, r.a); + r.a = (); + assertFalse(r.hasKey("a")); + + r.b = val; + assertEquality(val, r.b); + r.b = val2; + assertEquality(val2, r.b); + r.b = (); + assertFalse(r.hasKey("b")); + + Rxx c = {val: 5}; + r.c = c; + assertEquality(c, r.c); + r.c = (); + assertFalse(r.hasKey("b")); +} + function assertTrue(anydata actual) { assertEquality(true, actual); } +function assertFalse(anydata actual) { + assertEquality(false, actual); +} + function assertEquality(anydata expected, anydata actual) { if expected == actual { return;