Skip to content

Commit

Permalink
Remove casting when assigning non-nil value
Browse files Browse the repository at this point in the history
  • Loading branch information
heshanpadmasiri committed Jan 19, 2024
1 parent 4d3bf1d commit 70dc076
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@
import static org.wso2.ballerinalang.compiler.desugar.ASTBuilderUtil.createVariable;
import static org.wso2.ballerinalang.compiler.desugar.ASTBuilderUtil.createVariableRef;
import static org.wso2.ballerinalang.compiler.util.CompilerUtils.getMajorVersion;
import static org.wso2.ballerinalang.compiler.util.CompilerUtils.isAssignmentToOptionalField;
import static org.wso2.ballerinalang.compiler.util.Names.GENERATED_INIT_SUFFIX;
import static org.wso2.ballerinalang.compiler.util.Names.GEN_VAR_PREFIX;
import static org.wso2.ballerinalang.compiler.util.Names.IGNORE;
Expand Down Expand Up @@ -2468,32 +2469,25 @@ private void createSimpleVarDefStmt(BLangSimpleVariable simpleVariable, BLangBlo

@Override
public void visit(BLangAssignment assignNode) {
boolean isOptionalFieldAssignment = isOptionalBasicTypeFieldAssignment(assignNode);
boolean addNilToCastingType = shouldWidenExpressionTypeWithNil(assignNode);
assignNode.varRef = rewriteExpr(assignNode.varRef);
assignNode.expr = rewriteExpr(assignNode.expr);
BType castingType = assignNode.varRef.getBType();
if (isOptionalFieldAssignment) {
if (addNilToCastingType) {
castingType = types.addNilForNillableAccessType(castingType);
}
assignNode.expr = types.addConversionExprIfRequired(rewriteExpr(assignNode.expr), castingType);
result = assignNode;
}

private boolean isOptionalBasicTypeFieldAssignment(BLangAssignment assignNode) {
BLangNode varRef = assignNode.varRef;
if (varRef.getKind() != NodeKind.FIELD_BASED_ACCESS_EXPR) {
private static boolean shouldWidenExpressionTypeWithNil(BLangAssignment assignNode) {
if (!assignNode.expr.getBType().isNullable() || !isAssignmentToOptionalField(assignNode)) {
return false;
}
BLangFieldBasedAccess fieldAccessNode = (BLangFieldBasedAccess) varRef;
BType targetType = Types.getImpliedType(fieldAccessNode.expr.getBType());
if (targetType.tag != TypeTags.RECORD) {
return false;
}
BRecordType recordType = (BRecordType) targetType;
// If we are assigning to an optional field we have a field based access on a record
BLangFieldBasedAccess fieldAccessNode = (BLangFieldBasedAccess) assignNode.varRef;
BRecordType recordType = (BRecordType) Types.getImpliedType(fieldAccessNode.expr.getBType());
BField field = recordType.fields.get(fieldAccessNode.field.value);
if (field == null || !Symbols.isOptional(field.symbol)) {
return false;
}
BType fieldType = Types.getImpliedType(field.getType());
return TypeTags.isSimpleBasicType(fieldType.tag);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@
import org.wso2.ballerinalang.compiler.tree.expressions.BLangConstant;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangErrorVarRef;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangExpression;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangFieldBasedAccess;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangInvocation;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangLambdaFunction;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangLetExpression;
Expand Down Expand Up @@ -246,6 +245,7 @@
import static org.ballerinalang.model.tree.NodeKind.RECORD_LITERAL_EXPR;
import static org.ballerinalang.model.tree.NodeKind.REG_EXP_CAPTURING_GROUP;
import static org.ballerinalang.model.tree.NodeKind.REG_EXP_CHARACTER_CLASS;
import static org.wso2.ballerinalang.compiler.util.CompilerUtils.isAssignmentToOptionalField;

/**
* @since 0.94
Expand Down Expand Up @@ -2311,15 +2311,20 @@ public void visit(BLangAssignment assignNode, AnalyzerData data) {
validateFunctionVarRef(varRef, data);

checkInvalidTypeDef(varRef);
if (varRef.getKind() == NodeKind.FIELD_BASED_ACCESS_EXPR && data.expType.tag != TypeTags.SEMANTIC_ERROR) {
BLangFieldBasedAccess fieldBasedAccessVarRef = (BLangFieldBasedAccess) varRef;
int varRefTypeTag = Types.getImpliedType(fieldBasedAccessVarRef.expr.getBType()).tag;
if (varRefTypeTag == TypeTags.RECORD && Symbols.isOptional(fieldBasedAccessVarRef.symbol)) {
data.expType = types.addNilForNillableAccessType(data.expType);
}
BType actualExpectedType = null;
// For optional field assignments we add nil to the expected type before doing type checking in order to get
// the type in error messages correct. But we don't need an implicit conversion since desugar will add a
// cast if needed.
if (data.expType != symTable.semanticError && isAssignmentToOptionalField(assignNode)) {
actualExpectedType = data.expType;
data.expType = types.addNilForNillableAccessType(actualExpectedType);
}

data.typeChecker.checkExpr(assignNode.expr, data.env, data.expType, data.prevEnvs, data.commonAnalyzerData);
BLangExpression expr = assignNode.expr;
data.typeChecker.checkExpr(expr, data.env, data.expType, data.prevEnvs, data.commonAnalyzerData);
if (actualExpectedType != null && expr.impConversionExpr != null) {
data.typeChecker.resetImpConversionExpr(expr, expr.getBType(), actualExpectedType);
}

validateWorkerAnnAttachments(assignNode.expr, data);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6325,7 +6325,7 @@ protected void visitCheckAndCheckPanicExpr(BLangCheckedExpr checkedExpr, Analyze
data.resultType = types.checkType(checkedExpr, actualType, data.expType);
}

private void resetImpConversionExpr(BLangExpression expr, BType actualType, BType targetType) {
protected void resetImpConversionExpr(BLangExpression expr, BType actualType, BType targetType) {
expr.impConversionExpr = null;
types.setImplicitCastExpr(expr, actualType, targetType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,16 @@

import org.ballerinalang.compiler.CompilerOptionName;
import org.ballerinalang.model.elements.PackageID;
import org.ballerinalang.model.tree.NodeKind;
import org.wso2.ballerinalang.compiler.semantics.analyzer.Types;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.Symbols;
import org.wso2.ballerinalang.compiler.semantics.model.types.BField;
import org.wso2.ballerinalang.compiler.semantics.model.types.BRecordType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.tree.BLangFunction;
import org.wso2.ballerinalang.compiler.tree.BLangNode;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangFieldBasedAccess;
import org.wso2.ballerinalang.compiler.tree.statements.BLangAssignment;

import static org.wso2.ballerinalang.compiler.util.Constants.MAIN_FUNCTION_NAME;

Expand Down Expand Up @@ -69,4 +77,17 @@ public static String getPackageIDStringWithMajorVersion(PackageID packageID) {
return org + packageID.name + Names.VERSION_SEPARATOR.value + getMajorVersion(packageID.version.value);
}

public static boolean isAssignmentToOptionalField(BLangAssignment assignNode) {
BLangNode varRef = assignNode.varRef;
if (varRef.getKind() != NodeKind.FIELD_BASED_ACCESS_EXPR) {
return false;
}
BLangFieldBasedAccess fieldAccessNode = (BLangFieldBasedAccess) varRef;
BType targetType = Types.getImpliedType(fieldAccessNode.expr.getBType());
if (targetType.tag != TypeTags.RECORD) {
return false;
}
BField field = ((BRecordType) targetType).fields.get(fieldAccessNode.field.value);
return field != null && Symbols.isOptional(field.symbol);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ public Object[][] recordOptionalFieldAccessFunctions2() {
{ "testUnavailableFinalAccessInNestedAccess" },
{ "testAvailableFinalAccessInNestedAccess" },
{ "testUnavailableIntermediateAccessInNestedAccess" },
{ "testNilValuedFinalAccessInNestedAccess" }
{ "testNilValuedFinalAccessInNestedAccess" },
{ "testSubtypeAssignment" },
{ "testUnionAssignment" },
{ "testNullableAssignment" }
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@ setOptionalField function() -> () {
%0(RETURN) ();
%1(LOCAL) R3;
%2(TEMP) typeDesc<any|error>;
%4(TEMP) int|();
%5(TEMP) int;
%7(TEMP) string;
%4(TEMP) int;
%6(TEMP) string;
%7(TEMP) int|();
%8(TEMP) ();

bb0 {
%2 = newType R3;
%1 = NewMap %2{};
%5 = ConstLoad 2;
%4 = <int|()> %5;
%7 = ConstLoad x;
%1[%7] = %4;
%4 = ConstLoad 2;
%6 = ConstLoad x;
%1[%6] = %4;
%8 = ConstLoad 0;
%7 = <int|()> %8;
%6 = ConstLoad x;
%1[%6] = %7;
%0 = ConstLoad 0;
GOTO bb1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ function setNillableField() {
function setOptionalField() {
R3 r3 = {};
r3.x = 2;
r3.x = ();
}
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,88 @@ public function testNestedOptionalFieldAccessOnIntersectionTypes() {
assertEquality((), v5);
}

type MyInt int;
type Rxx record {
int val;
};
type Rxy record {
int val;
float otherVal;
};
type Rx record {|
int a?;
MyInt b?;
Rxx c?;
|};

function testSubtypeAssignment() {
int:Signed16 init = 2;
Rx r = {a:init, b:init};
int:Signed32 val = 5;
r.a = val;
assertEquality(val, r.a);
r.a = ();
assertFalse(r.hasKey("a"));
r.b = val;
assertEquality(val, r.b);
r.b = ();
assertFalse(r.hasKey("b"));
Rxy c = {val: 5, otherVal: 10.0};
r.c = c;
assertEquality(c, r.c);
r.c = ();
assertFalse(r.hasKey("b"));
}

function testNullableAssignment() {
Rx r = {};
int? val = 12;
r.a = val;
assertEquality(val, r.a);
int? b = ();
r.a = b;
assertFalse(r.hasKey("a"));
Rxy? c = {val: 5, otherVal: 10.0};
r.c = c;
assertEquality(c, r.c);
c = ();
r.c = c;
assertFalse(r.hasKey("c"));
}

type MyUnion int|float|string;
type Ry record {|
int|float|string a?;
MyUnion b?;
int|Rxx c?;
|};

function testUnionAssignment() {
int init = 5;
Ry r = {a:init, b:init};
int:Signed32 val = 5;
r.a = val;
assertEquality(val, r.a);
int|float val2 = 10;
r.a = val2;
assertEquality(val2, r.a);
r.a = ();
assertFalse(r.hasKey("a"));

r.b = val;
assertEquality(val, r.b);
r.b = val2;
assertEquality(val2, r.b);
r.b = ();
assertFalse(r.hasKey("b"));

Rxx c = {val: 5};
r.c = c;
assertEquality(c, r.c);
r.c = ();
assertFalse(r.hasKey("b"));
}

function assertTrue(anydata actual) {
assertEquality(true, actual);
}
Expand Down

0 comments on commit 70dc076

Please sign in to comment.