Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support parsing lambda function #866

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ RT_SQR_PRTHS: ']';
SINGLE_QUOTE: '\'';
DOUBLE_QUOTE: '"';
BACKTICK: '`';
ARROW: '->';

// Operators. Bit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ valueExpression
| timestampFunction # timestampFunctionCall
| LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr
| LT_SQR_PRTHS subSearch RT_SQR_PRTHS # scalarSubqueryExpr
| ident ARROW expression # lambda
| LT_PRTHS ident (COMMA ident)+ RT_PRTHS ARROW expression # lambda
;

primaryExpression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.sql.ast.expression.EqualTo;
import org.opensearch.sql.ast.expression.Field;
import org.opensearch.sql.ast.expression.FieldList;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.tree.FieldSummary;
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Function;
Expand Down Expand Up @@ -183,6 +184,10 @@ public T visitFunction(Function node, C context) {
return visitChildren(node, context);
}

public T visitLambdaFunction(LambdaFunction node, C context) {
return visitChildren(node, context);
}

public T visitIsEmpty(IsEmpty node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ast.expression;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;

/**
* Expression node of lambda function. Params include function name (@funcName) and function
* arguments (@funcArgs)
*/
@Getter
@EqualsAndHashCode(callSuper = false)
@RequiredArgsConstructor
public class LambdaFunction extends UnresolvedExpression {
private final UnresolvedExpression function;
private final List<QualifiedName> funcArgs;

@Override
public List<UnresolvedExpression> getChild() {
List<UnresolvedExpression> children = new ArrayList<>();
children.add(function);
children.addAll(funcArgs);
return children;
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> nodeVisitor, C context) {
return nodeVisitor.visitLambdaFunction(this, context);
}

@Override
public String toString() {
return String.format(
"(%s) -> %s",
funcArgs.stream().map(Object::toString).collect(Collectors.joining(", ")),
function.toString()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,25 @@

package org.opensearch.sql.ppl;

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute;
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$;
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation;
import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$;
import org.apache.spark.sql.catalyst.expressions.CaseWhen;
import org.apache.spark.sql.catalyst.expressions.CurrentRow$;
import org.apache.spark.sql.catalyst.expressions.Exists$;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GreaterThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.In$;
import org.apache.spark.sql.catalyst.expressions.InSubquery$;
import org.apache.spark.sql.catalyst.expressions.LessThan;
import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual;
import org.apache.spark.sql.catalyst.expressions.ListQuery$;
import org.apache.spark.sql.catalyst.expressions.MakeInterval$;
import org.apache.spark.sql.catalyst.expressions.NamedExpression;
import org.apache.spark.sql.catalyst.expressions.Predicate;
import org.apache.spark.sql.catalyst.expressions.RowFrame$;
import org.apache.spark.sql.catalyst.expressions.ScalaUDF;
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery$;
import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame;
import org.apache.spark.sql.catalyst.expressions.WindowExpression;
import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition;
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable;
import org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable$;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.types.DataTypes;
import org.opensearch.sql.ast.AbstractNodeVisitor;
Expand All @@ -38,7 +35,6 @@
import org.opensearch.sql.ast.expression.BinaryExpression;
import org.opensearch.sql.ast.expression.Case;
import org.opensearch.sql.ast.expression.Compare;
import org.opensearch.sql.ast.expression.DataType;
import org.opensearch.sql.ast.expression.FieldsMapping;
import org.opensearch.sql.ast.expression.Function;
import org.opensearch.sql.ast.expression.In;
Expand All @@ -47,6 +43,7 @@
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
import org.opensearch.sql.ast.expression.Or;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.expression.QualifiedName;
import org.opensearch.sql.ast.expression.Span;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
Expand All @@ -61,14 +58,14 @@
import org.opensearch.sql.ast.tree.FillNull;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.RareTopN;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.ast.tree.UnresolvedPlan;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.expression.function.SerializableUdf;
import org.opensearch.sql.ppl.utils.AggregatorTransformer;
import org.opensearch.sql.ppl.utils.BuiltinFunctionTransformer;
import org.opensearch.sql.ppl.utils.ComparatorTransformer;
import org.opensearch.sql.ppl.utils.JavaToScalaTransformer;
import scala.Option;
import scala.PartialFunction;
import scala.Tuple2;
import scala.collection.Seq;

Expand Down Expand Up @@ -432,6 +429,25 @@ public Expression visitCidr(org.opensearch.sql.ast.expression.Cidr node, Catalys
return context.getNamedParseExpressions().push(udf);
}

@Override
public Expression visitLambdaFunction(LambdaFunction node, CatalystPlanContext context) {
PartialFunction<Expression, Expression> transformer = JavaToScalaTransformer.toPartialFunction(
expr -> expr instanceof UnresolvedAttribute,
expr -> {
UnresolvedAttribute attr = (UnresolvedAttribute) expr;
return new UnresolvedNamedLambdaVariable(attr.nameParts());
}
);
Expression functionResult = node.getFunction().accept(this, context).transformUp(transformer);
context.popNamedParseExpressions();
List<NamedExpression> argsResult = node.getFuncArgs().stream()
.map(arg -> UnresolvedNamedLambdaVariable$.MODULE$.apply(seq(arg.getParts())))
.collect(Collectors.toList());
org.apache.spark.sql.catalyst.expressions.LambdaFunction lambdaFunction = new org.apache.spark.sql.catalyst.expressions.LambdaFunction(functionResult, seq(argsResult), false);
context.getNamedParseExpressions().push(lambdaFunction);
return lambdaFunction;
}

private List<Expression> visitExpressionList(List<UnresolvedExpression> expressionList, CatalystPlanContext context) {
return expressionList.isEmpty()
? emptyList()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.opensearch.sql.ast.expression.Interval;
import org.opensearch.sql.ast.expression.IntervalUnit;
import org.opensearch.sql.ast.expression.IsEmpty;
import org.opensearch.sql.ast.expression.LambdaFunction;
import org.opensearch.sql.ast.expression.Let;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.Not;
Expand All @@ -43,8 +44,6 @@
import org.opensearch.sql.ast.expression.subquery.ExistsSubquery;
import org.opensearch.sql.ast.expression.subquery.InSubquery;
import org.opensearch.sql.ast.expression.subquery.ScalarSubquery;
import org.opensearch.sql.ast.tree.Trendline;
import org.opensearch.sql.common.antlr.SyntaxCheckException;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.ppl.utils.ArgumentFactory;

Expand Down Expand Up @@ -429,6 +428,15 @@ public UnresolvedExpression visitTimestampFunctionCall(
ctx.timestampFunction().timestampFunctionName().getText(), timestampFunctionArguments(ctx));
}

@Override
public UnresolvedExpression visitLambda(OpenSearchPPLParser.LambdaContext ctx) {

List<QualifiedName> arguments = ctx.ident().stream().map(x -> this.visitIdentifiers(Collections.singletonList(x))).collect(
Collectors.toList());
UnresolvedExpression function = visitExpression(ctx.expression());
return new LambdaFunction(function, arguments);
}

private List<UnresolvedExpression> timestampFunctionArguments(
OpenSearchPPLParser.TimestampFunctionCallContext ctx) {
List<UnresolvedExpression> args =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.ppl.utils;

import scala.PartialFunction;
import scala.runtime.AbstractPartialFunction;

public interface JavaToScalaTransformer {
static <T> PartialFunction<T, T> toPartialFunction(
java.util.function.Predicate<T> isDefinedAt,
java.util.function.Function<T, T> apply) {
return new AbstractPartialFunction<T, T>() {
@Override
public boolean isDefinedAt(T t) {
return isDefinedAt.test(t);
}

@Override
public T apply(T t) {
if (isDefinedAt.test(t)) return apply.apply(t);
else return t;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ package org.opensearch.flint.spark.ppl
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.common.antlr.SyntaxCheckException
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, GreaterThan, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, GreaterThan, LambdaFunction, Literal, NamedExpression, SortOrder, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
Expand Down Expand Up @@ -396,4 +397,21 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
|""".stripMargin),
context)
}

test("test lambda function") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser, """source=t | eval lambda = (x -> x > 0) """.stripMargin),
context)
val table = UnresolvedRelation(Seq("t"))
val lambda = LambdaFunction(
GreaterThan(UnresolvedNamedLambdaVariable(seq("x")), Literal(0)),
Seq(UnresolvedNamedLambdaVariable(seq("x"))))
val alias = Alias(lambda, "lambda")()
val evalProject = Project(Seq(UnresolvedStar(None), alias), table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, evalProject)
comparePlans(expectedPlan, logPlan, false)
}
}
Loading