From 22351a6ee60d378cb5e94e219331eebd16ec38d1 Mon Sep 17 00:00:00 2001 From: LakshanWeerasinghe Date: Mon, 20 Nov 2023 21:45:37 +0530 Subject: [PATCH] Add on conflict clause to query pipeline --- .../compiler/desugar/QueryDesugar.java | 62 ++++++++-- .../lang.query/src/main/ballerina/helpers.bal | 111 +++++++++++++++--- .../lang.query/src/main/ballerina/types.bal | 63 ++++++++++ 3 files changed, 210 insertions(+), 26 deletions(-) diff --git a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/QueryDesugar.java b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/QueryDesugar.java index f898371b3cb3..79eeee302ae6 100644 --- a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/QueryDesugar.java +++ b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/QueryDesugar.java @@ -217,6 +217,7 @@ public class QueryDesugar extends BLangNodeVisitor { private static final Name QUERY_CREATE_GROUP_BY_FUNCTION = new Name("createGroupByFunction"); private static final Name QUERY_CREATE_COLLECT_FUNCTION = new Name("createCollectFunction"); private static final Name QUERY_CREATE_SELECT_FUNCTION = new Name("createSelectFunction"); + private static final Name QUERY_CREATE_ON_CONFLICT_FUNCTION = new Name("createOnConflictFunction"); private static final Name QUERY_CREATE_DO_FUNCTION = new Name("createDoFunction"); private static final Name QUERY_CREATE_LIMIT_FUNCTION = new Name("createLimitFunction"); private static final Name QUERY_ADD_STREAM_FUNCTION = new Name("addStreamFunction"); @@ -226,8 +227,12 @@ public class QueryDesugar extends BLangNodeVisitor { private static final Name QUERY_TO_STRING_FUNCTION = new Name("toString"); private static final Name QUERY_TO_XML_FUNCTION = new Name("toXML"); private static final Name QUERY_ADD_TO_TABLE_FUNCTION = new Name("addToTable"); + private static final Name QUERY_ADD_TO_TABLE_FOR_ON_CONFLICT_FUNCTION = new Name("addToTableForOnConflict"); private static final Name QUERY_ADD_TO_MAP_FUNCTION = new Name("addToMap"); + private static final Name QUERY_ADD_TO_MAP_FOR_ON_CONFLICT_FUNCTION = new Name("addToMapForOnConflict"); private static final Name QUERY_GET_STREAM_FROM_PIPELINE_FUNCTION = new Name("getStreamFromPipeline"); + private static final Name QUERY_GET_STREAM_FOR_ON_CONFLICT_FROM_PIPELINE_FUNCTION = + new Name("getStreamForOnConflictFromPipeline"); private static final Name QUERY_GET_QUERY_ERROR_ROOT_CAUSE_FUNCTION = new Name("getQueryErrorRootCause"); private static final String FRAME_PARAMETER_NAME = "$frame$"; private static final Name QUERY_BODY_DISTINCT_ERROR_NAME = new Name("Error"); @@ -295,23 +300,27 @@ BLangStatementExpression desugar(BLangQueryExpr queryExpr, SymbolEnv env, if (queryExpr.isStream) { resultType = streamRef.getBType(); } else if (queryExpr.isTable) { - onConflictExpr = (onConflictExpr == null) - ? ASTBuilderUtil.createLiteral(pos, symTable.nilType, Names.NIL_VALUE) - : onConflictExpr; BLangVariableReference tableRef = addTableConstructor(queryExpr, queryBlock); - result = getStreamFunctionVariableRef(queryBlock, - QUERY_ADD_TO_TABLE_FUNCTION, Lists.of(streamRef, tableRef, onConflictExpr, isReadonly), pos); + if (onConflictExpr == null) { + result = getStreamFunctionVariableRef(queryBlock, + QUERY_ADD_TO_TABLE_FUNCTION, Lists.of(streamRef, tableRef, isReadonly), pos); + } else { + result = getStreamFunctionVariableRef(queryBlock, + QUERY_ADD_TO_TABLE_FOR_ON_CONFLICT_FUNCTION, Lists.of(streamRef, tableRef, isReadonly), pos); + } resultType = tableRef.getBType(); onConflictExpr = null; } else if (queryExpr.isMap) { - onConflictExpr = (onConflictExpr == null) - ? ASTBuilderUtil.createLiteral(pos, symTable.nilType, Names.NIL_VALUE) - : onConflictExpr; BMapType mapType = getMapType(queryExpr.getBType()); BLangRecordLiteral.BLangMapLiteral mapLiteral = new BLangRecordLiteral.BLangMapLiteral(queryExpr.pos, mapType, new ArrayList<>()); - result = getStreamFunctionVariableRef(queryBlock, - QUERY_ADD_TO_MAP_FUNCTION, Lists.of(streamRef, mapLiteral, onConflictExpr, isReadonly), pos); + if (onConflictExpr == null) { + result = getStreamFunctionVariableRef(queryBlock, + QUERY_ADD_TO_MAP_FUNCTION, Lists.of(streamRef, mapLiteral, isReadonly), pos); + } else { + result = getStreamFunctionVariableRef(queryBlock, + QUERY_ADD_TO_MAP_FOR_ON_CONFLICT_FUNCTION, Lists.of(streamRef, mapLiteral, isReadonly), pos); + } onConflictExpr = null; } else if (queryExpr.getFinalClause().getKind() == NodeKind.COLLECT) { result = getStreamFunctionVariableRef(queryBlock, COLLECT_QUERY_FUNCTION, Lists.of(streamRef), pos); @@ -559,6 +568,8 @@ BLangVariableReference buildStream(List clauses, BType resultType, Sy case ON_CONFLICT: final BLangOnConflictClause onConflict = (BLangOnConflictClause) clause; onConflictExpr = onConflict.expression; + BLangVariableReference onConflictRef = addOnConflictFunction(block, onConflict, stmtsToBePropagated); + addStreamFunction(block, initPipeline, onConflictRef); break; } } @@ -924,6 +935,28 @@ BLangVariableReference addSelectFunction(BLangBlockStmt blockStmt, BLangSelectCl return getStreamFunctionVariableRef(blockStmt, QUERY_CREATE_SELECT_FUNCTION, Lists.of(lambda), pos); } + /** + * Desugar onConflictClause to below and return a reference to created onConflict _StreamFunction. + * _StreamFunction onConflictFunc = createOnConflictFunction + * @param blockStmt + * @param onConflictClause + * @param stmtsToBePropagated + * @return + */ + BLangVariableReference addOnConflictFunction(BLangBlockStmt blockStmt, BLangOnConflictClause onConflictClause, + List stmtsToBePropagated) { + Location pos = onConflictClause.pos; + BLangLambdaFunction lambda = createPassthroughLambda(pos); + BLangBlockFunctionBody body = (BLangBlockFunctionBody) lambda.function.body; + body.stmts.addAll(0, stmtsToBePropagated); + BVarSymbol oldFrameSymbol = lambda.function.requiredParams.get(0).symbol; + BLangSimpleVarRef frame = ASTBuilderUtil.createVariableRef(pos, oldFrameSymbol); + // $frame#[$error$] = on-conflict-expr; + BLangStatement assignment = getAddToFrameStmt(pos, frame, "$error$", onConflictClause.expression); + body.stmts.add(body.stmts.size() - 1, assignment); + lambda = rewrite(lambda); + return getStreamFunctionVariableRef(blockStmt, QUERY_CREATE_ON_CONFLICT_FUNCTION, Lists.of(lambda), pos); + } /** * Desugar doClause to below and return a reference to created do _StreamFunction. * _StreamFunction doFunc = createDoFunction(function(_Frame frame) { @@ -993,8 +1026,12 @@ void addStreamFunction(BLangBlockStmt blockStmt, BLangVariableReference pipeline */ BLangVariableReference addGetStreamFromPipeline(BLangBlockStmt blockStmt, BLangVariableReference pipelineRef) { Location pos = pipelineRef.pos; + if (onConflictExpr == null) { + return getStreamFunctionVariableRef(blockStmt, + QUERY_GET_STREAM_FROM_PIPELINE_FUNCTION, null, Lists.of(pipelineRef), pos); + } return getStreamFunctionVariableRef(blockStmt, - QUERY_GET_STREAM_FROM_PIPELINE_FUNCTION, null, Lists.of(pipelineRef), pos); + QUERY_GET_STREAM_FOR_ON_CONFLICT_FROM_PIPELINE_FUNCTION, null, Lists.of(pipelineRef), pos); } /** @@ -2066,6 +2103,9 @@ public void visit(BLangErrorConstructorExpr errorConstructorExpr) { if (errorConstructorExpr.namedArgs != null) { rewrite(errorConstructorExpr.namedArgs); } + if (errorConstructorExpr.positionalArgs != null) { + rewrite(errorConstructorExpr.positionalArgs); + } errorConstructorExpr.errorDetail = rewrite(errorConstructorExpr.errorDetail); result = errorConstructorExpr; } diff --git a/langlib/lang.query/src/main/ballerina/helpers.bal b/langlib/lang.query/src/main/ballerina/helpers.bal index b9602f0d258c..425caef45d10 100644 --- a/langlib/lang.query/src/main/ballerina/helpers.bal +++ b/langlib/lang.query/src/main/ballerina/helpers.bal @@ -73,6 +73,11 @@ function createSelectFunction(function(_Frame _frame) returns _Frame|error? sele return new _SelectFunction(selectFunc); } +function createOnConflictFunction(function(_Frame _frame) returns _Frame|error? onConflictFunc) + returns _StreamFunction { + return new _OnConflictFunction(onConflictFunc); +} + function createCollectFunction(string[] nonGroupingKeys, function(_Frame _frame) returns _Frame|error? collectFunc) returns _StreamFunction { return new _CollectFunction(nonGroupingKeys, collectFunc); } @@ -93,6 +98,10 @@ function getStreamFromPipeline(_StreamPipeline pipeline) returns stream { + return pipeline.getStreamForOnConflict(); +} + function toArray(stream strm, Type[] arr, boolean isReadOnly) returns Type[]|error { if isReadOnly { // In this case arr will be an immutable array. Therefore, we will create a new mutable array and pass it to the @@ -159,7 +168,8 @@ function toString(stream strm) returns string|error { return result; } -function addToTable(stream strm, table> tbl, error? err, boolean isReadOnly) returns table>|error { +function addToTable(stream strm, table> tbl, boolean isReadOnly) + returns table>|error { if isReadOnly { // TODO: Properly fix readonly scenario - Issue lang/#36721 // In this case tbl will be an immutable table. Therefore, we will create a new mutable table. Next, we will @@ -168,52 +178,85 @@ function addToTable(stream strm, table> tbl, err // and make it immutable with createImmutableTable(). table> tempTbl = table []; table> tbl2 = createTableWithKeySpecifier(tbl, typeof(tempTbl)); - table> tempTable = check createTable(strm, tbl2, err); + table> tempTable = check createTable(strm, tbl2); return createImmutableTable(tbl, tempTable.toArray()); } - return createTable(strm, tbl, err); + return createTable(strm, tbl); } -function createTable(stream strm, table> tbl, error? err) returns table>|error { +function createTable(stream strm, table> tbl) returns table>|error { record {| Type value; |}|CompletionType v = strm.next(); while (v is record {| Type value; |}) { error? e = trap tbl.add(> checkpanic v.value); if (e is error) { - if (err is error) { - return err; - } tbl.put(> checkpanic v.value); } v = strm.next(); } + if v is error { + return v; + } + return tbl; +} + +function addToTableForOnConflict(stream strm, table> tbl, boolean isReadOnly) + returns table>|error { + if isReadOnly { + // TODO: Properly fix readonly scenario - Issue lang/#36721 + // In this case tbl will be an immutable table. Therefore, we will create a new mutable table. Next, we will + // pass the newly created table into createTableWithKeySpecifier() to add the key specifier details from the + // original table variable (tbl). Then the newly created table variable will be populated using createTable() + // and make it immutable with createImmutableTable(). + table> tempTbl = table []; + table> tbl2 = createTableWithKeySpecifier(tbl, typeof(tempTbl)); + table> tempTable = check createTableForOnConflict(strm, tbl2); + return createImmutableTable(tbl, tempTable.toArray()); + } + return createTableForOnConflict(strm, tbl); +} + +function createTableForOnConflict(stream strm, table> tbl) + returns table>|error { + record {| Type value; |}|CompletionType v = strm.next(); + while (v is record {| Type value; |}) { + record {|Type v; error? err;|}|error value = trap ( checkpanic v.value); + if value is error { + return value; + } + error? e = trap tbl.add(> checkpanic value.v); + error? err = value.err; + if e is error && err is error { + return err; + } + if e is error && err is () { + tbl.put(> checkpanic value.v); + } + v = strm.next(); + } if (v is error) { return v; } return tbl; } -function addToMap(stream strm, map mp, error? err, boolean isReadOnly) returns map|error { -// Here, `err` is used to get the expression of on-conflict clause +function addToMap(stream strm, map mp, boolean isReadOnly) returns map|error { if isReadOnly { // In this case mp will be an immutable map. Therefore, we will create a new mutable map and pass it to the // createMap() (because we can't update immutable map). Then it will populate the members into it and the // resultant map will be passed into createImmutableValue() to make it immutable. map mp2 = {}; - createImmutableValue(check createMap(strm, mp2, err)); + createImmutableValue(check createMap(strm, mp2)); return mp2; } - return createMap(strm, mp, err); + return createMap(strm, mp); } -function createMap(stream strm, map mp, error? err) returns map|error { +function createMap(stream strm, map mp) returns map|error { record {| Type value; |}|CompletionType v = strm.next(); while (v is record {| Type value; |}) { [string, Type]|error value = trap (<[string, Type]> checkpanic v.value); if value !is error { string key = value[0]; - if mp.hasKey(key) && err is error { - return err; - } mp[key] = value[1]; } else { return value; @@ -227,6 +270,44 @@ function createMap(stream strm, map mp, error? err) return mp; } +function addToMapForOnConflict(stream strm, map mp, boolean isReadOnly) + returns map|error { + if isReadOnly { + // In this case mp will be an immutable map. Therefore, we will create a new mutable map and pass it to the + // createMap() (because we can't update immutable map). Then it will populate the members into it and the + // resultant map will be passed into createImmutableValue() to make it immutable. + map mp2 = {}; + createImmutableValue(check createMapForOnConflict(strm, mp2)); + return mp2; + } + return createMapForOnConflict(strm, mp); +} + +function createMapForOnConflict(stream strm, map mp) returns map|error { + record {| Type value; |}|CompletionType v = strm.next(); + while (v is record {| Type value; |}) { + record {|Type v; error? err;|}|error value = trap ( checkpanic v.value); + if value is error { + return value; + } + [string, Type]|error keyValue = trap (<[string, Type]> checkpanic value.v); + if keyValue is error { + return keyValue; + } + string key = keyValue[0]; + error? err = value.err; + if mp.hasKey(key) && err is error { + return err; + } + mp[key] = keyValue[1]; + v = strm.next(); + } + if (v is error) { + return v; + } + return mp; +} + function consumeStream(stream strm) returns any|error { any|error? v = strm.next(); while (!(v is () || v is error)) { diff --git a/langlib/lang.query/src/main/ballerina/types.bal b/langlib/lang.query/src/main/ballerina/types.bal index 831969bbe5cd..fc91889a785e 100644 --- a/langlib/lang.query/src/main/ballerina/types.bal +++ b/langlib/lang.query/src/main/ballerina/types.bal @@ -115,6 +115,12 @@ class _StreamPipeline { var strm = internal:construct(self.constraintTd, self.completionTd, itrObj); return strm; } + + public function getStreamForOnConflict() returns stream { + OnConflictIterHelper itrObj = new (self, self.constraintTd); + var strm = internal:construct(self.constraintTd, self.completionTd, itrObj); + return strm; + } } class _InitFunction { @@ -794,6 +800,40 @@ class _SelectFunction { } } +class _OnConflictFunction { + *_StreamFunction; + + # Desugared function to do; + # on conflict error("Duplicate key") + public function (_Frame _frame) returns _Frame|error? onConflictFunc; + + function init(function (_Frame _frame) returns _Frame|error? onConflictFunc) { + self.onConflictFunc = onConflictFunc; + self.prevFunc = (); + } + + public function process() returns _Frame|error? { + _StreamFunction pf = <_StreamFunction>self.prevFunc; + function (_Frame _frame) returns _Frame|error? f = self.onConflictFunc; + _Frame|error? pFrame = pf.process(); + if (pFrame is _Frame) { + _Frame|error? cFrame = f(pFrame); + if (cFrame is error) { + return prepareQueryBodyError(cFrame); + } + return cFrame; + } + return pFrame; + } + + public function reset() { + _StreamFunction? pf = self.prevFunc; + if (pf is _StreamFunction) { + pf.reset(); + } + } +} + class _DoFunction { *_StreamFunction; @@ -931,6 +971,29 @@ class IterHelper { } } +class OnConflictIterHelper { + public _StreamPipeline pipeline; + public typedesc outputType; + + function init(_StreamPipeline pipeline, typedesc outputType) { + self.pipeline = pipeline; + self.outputType = outputType; + } + + public isolated function next() returns record {|Type value;|}|error? { + _StreamPipeline p = self.pipeline; + _Frame|error? f = p.next(); + if (f is _Frame) { + Type v = f["$value$"]; + error? err = f["$error$"]; + record {|Type v; error? err;|} value = {v, err}; + return internal:setNarrowType(self.outputType, {value: value}); + } else { + return f; + } + } +} + class _OrderTreeNode { any? key = (); _Frame[]? frames = ();