Skip to content

Commit

Permalink
[CALCITE-5855] Implement frame exclusion in window functions
Browse files Browse the repository at this point in the history
  • Loading branch information
itiels authored and mihaibudiu committed Jun 18, 2024
1 parent 97d62ac commit 4a1da22
Show file tree
Hide file tree
Showing 23 changed files with 532 additions and 43 deletions.
26 changes: 25 additions & 1 deletion core/src/main/codegen/templates/Parser.jj
Original file line number Diff line number Diff line change
Expand Up @@ -2736,6 +2736,7 @@ SqlWindow WindowSpecification() :
final SqlNodeList orderList;
final SqlLiteral isRows;
final SqlNode lowerBound, upperBound;
final SqlLiteral exclude;
final Span s, s1, s2;
final SqlLiteral allowPartial;
}
Expand Down Expand Up @@ -2768,9 +2769,11 @@ SqlWindow WindowSpecification() :
lowerBound = WindowRange()
{ upperBound = null; }
)
exclude = WindowExclusion()
|
{
isRows = SqlLiteral.createBoolean(false, SqlParserPos.ZERO);
exclude = SqlWindow.createExcludeNoOthers(getPos());
lowerBound = upperBound = null;
}
)
Expand All @@ -2787,7 +2790,7 @@ SqlWindow WindowSpecification() :
<RPAREN>
{
return SqlWindow.create(null, id, partitionList, orderList,
isRows, lowerBound, upperBound, allowPartial, s.end(this));
isRows, lowerBound, upperBound, allowPartial, exclude, s.end(this));
}
}

Expand Down Expand Up @@ -2826,6 +2829,27 @@ SqlNode WindowRange() :
)
}

/** Parses an exclusion clause for WINDOW FRAME. */
SqlLiteral WindowExclusion() :
{
}
{
(
<EXCLUDE>
(
<CURRENT> <ROW> { return SqlWindow.createExcludeCurrentRow(getPos()); }
|
<NO> <OTHERS> { return SqlWindow.createExcludeNoOthers(getPos()); }
|
<GROUP> { return SqlWindow.createExcludeGroup(getPos()); }
|
<TIES> { return SqlWindow.createExcludeTies(getPos()); }
)
|
{ return SqlWindow.createExcludeNoOthers(SqlParserPos.ZERO); }
)
}

/** Parses a QUALIFY clause for SELECT. */
SqlNode Qualify() :
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowExclusion;
import org.apache.calcite.runtime.SortedMultiMap;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.validate.SqlConformance;
Expand Down Expand Up @@ -301,7 +302,7 @@ private static void sampleOfTheGeneratedWindowedAggregate() {
}

declareAndResetState(typeFactory, builder, result, windowIdx, aggs,
outputPhysType, outputRow);
outputPhysType, outputRow, group.exclude);

// There are assumptions that minX==0. If ever change this, look for
// frameRowCount, bounds checking, etc
Expand Down Expand Up @@ -403,8 +404,14 @@ private static void sampleOfTheGeneratedWindowedAggregate() {
group.lowerBound.isUnbounded() && group.lowerBound.isPreceding()
? Expressions.constant(false)
: Expressions.notEqual(startX, prevStart);

Expression isExcluding =
Expressions.constant(group.exclude != RexWindowExclusion.EXCLUDE_NO_OTHER);

// If there's exclude clause we need to recompute the window every time, as rows can affect
// differently for the same frame.
Expression needRecomputeWindow =
Expressions.orElse(lowerBoundCanChange,
Expressions.orElse(Expressions.orElse(isExcluding, lowerBoundCanChange),
Expressions.lessThan(endX, prevEnd));

BlockStatement resetWindowState = builder6.toBlock();
Expand Down Expand Up @@ -458,15 +465,17 @@ private static void sampleOfTheGeneratedWindowedAggregate() {
};

implementAdd(aggs, builder7, resultContextBuilder, rexArguments, jDecl);

BlockStatement forBlock = builder7.toBlock();

// Don't run the aggregate function if current row is excluded
Statement exclude = buildExcludeGuard(group, comparator_, i_, jDecl, rows_, forBlock);
if (!forBlock.statements.isEmpty()) {
// For instance, row_number does not use for loop to compute the value
Statement forAggLoop =
Expressions.for_(Arrays.asList(jDecl),
Expressions.lessThanOrEqual(jDecl.parameter, endX),
Expressions.preIncrementAssign(jDecl.parameter),
forBlock);
exclude);
if (!hasRows.equals(Expressions.constant(true))) {
forAggLoop = Expressions.ifThen(hasRows, forAggLoop);
}
Expand Down Expand Up @@ -520,6 +529,29 @@ private static void sampleOfTheGeneratedWindowedAggregate() {
return implementor.result(inputPhysType, builder.toBlock());
}

private Statement buildExcludeGuard(Group group, Expression comparator,
ParameterExpression currentRow,
DeclarationStatement jDecl, Expression rows, BlockStatement forBlock) {
if (group.exclude == RexWindowExclusion.EXCLUDE_CURRENT_ROW) {
return Expressions.ifThen(Expressions.notEqual(currentRow, jDecl.parameter), forBlock);
} else if (group.exclude == RexWindowExclusion.EXCLUDE_GROUP) {
return Expressions.ifThen(
Expressions.notEqual(Expressions.constant(0),
Expressions.call(comparator, BuiltInMethod.COMPARATOR_COMPARE.method,
Expressions.arrayIndex(rows, currentRow),
Expressions.arrayIndex(rows, jDecl.parameter))), forBlock);
} else if (group.exclude == RexWindowExclusion.EXCLUDE_TIES) {
return Expressions.ifThen(
Expressions.or(Expressions.equal(currentRow, jDecl.parameter),
Expressions.notEqual(Expressions.constant(0),
Expressions.call(comparator, BuiltInMethod.COMPARATOR_COMPARE.method,
Expressions.arrayIndex(rows, currentRow),
Expressions.arrayIndex(rows, jDecl.parameter)))), forBlock);
} else {
return forBlock;
}
}

private static Function<BlockBuilder, WinAggFrameResultContext>
getBlockBuilderWinAggFrameResultContextFunction(
final JavaTypeFactory typeFactory, final SqlConformance conformance,
Expand Down Expand Up @@ -741,7 +773,7 @@ private static Pair<Expression, Expression> getPartitionIterator(
private void declareAndResetState(final JavaTypeFactory typeFactory,
BlockBuilder builder, final Result result, int windowIdx,
List<AggImpState> aggs, PhysType outputPhysType,
List<Expression> outputRow) {
List<Expression> outputRow, RexWindowExclusion exclusion) {
for (final AggImpState agg : aggs) {
agg.context =
new WinAggContext() {
Expand Down Expand Up @@ -782,6 +814,10 @@ private void declareAndResetState(final JavaTypeFactory typeFactory,
@Override public List<? extends Type> keyTypes() {
throw new UnsupportedOperationException();
}

@Override public RexWindowExclusion getExclude() {
return exclusion;
}
};
String aggName = "a" + agg.aggIdx;
if (CalciteSystemProperty.DEBUG.value()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPatternFieldRef;
import org.apache.calcite.rex.RexWindowExclusion;
import org.apache.calcite.runtime.FlatLists;
import org.apache.calcite.runtime.SqlFunctions;
import org.apache.calcite.schema.FunctionContext;
Expand Down Expand Up @@ -1441,7 +1442,7 @@ static class CountWinImplementor extends StrictWinAggImplementor {
break;
}
}
if (!hasNullable) {
if (!hasNullable && info.getExclude() == RexWindowExclusion.EXCLUDE_NO_OTHER) {
justFrameRowCount = true;
return Collections.emptyList();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
*/
package org.apache.calcite.adapter.enumerable;

import org.apache.calcite.rex.RexWindowExclusion;

/**
* Marker interface to allow
* {@link org.apache.calcite.adapter.enumerable.AggImplementor}
* to tell if it is used in regular or windowed context.
*/
public interface WinAggContext extends AggContext {
/** The exclude clause of the group of the window function. */
RexWindowExclusion getExclude();
}
15 changes: 12 additions & 3 deletions core/src/main/java/org/apache/calcite/rel/core/Window.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowExclusion;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
Expand Down Expand Up @@ -223,8 +224,8 @@ public List<RexLiteral> getConstants() {
/**
* Group of windowed aggregate calls that have the same window specification.
*
* <p>The specification is defined by an upper and lower bound, and
* also has zero or more partitioning columns.
* <p>The specification is defined by an upper and lower bound, exclusion clause,
* and also has zero or more partitioning columns.
*
* <p>A window is either logical or physical. A physical window is measured
* in terms of row count. A logical window is measured in terms of rows
Expand All @@ -247,6 +248,7 @@ public static class Group {
public final boolean isRows;
public final RexWindowBound lowerBound;
public final RexWindowBound upperBound;
public final RexWindowExclusion exclude;
public final RelCollation orderKeys;
private final String digest;

Expand All @@ -262,12 +264,14 @@ public Group(
boolean isRows,
RexWindowBound lowerBound,
RexWindowBound upperBound,
RexWindowExclusion exclude,
RelCollation orderKeys,
List<RexWinAggCall> aggCalls) {
this.keys = Objects.requireNonNull(keys, "keys");
this.isRows = isRows;
this.lowerBound = Objects.requireNonNull(lowerBound, "lowerBound");
this.upperBound = Objects.requireNonNull(upperBound, "upperBound");
this.exclude = exclude;
this.orderKeys = Objects.requireNonNull(orderKeys, "orderKeys");
this.aggCalls = ImmutableList.copyOf(aggCalls);
this.digest = computeString();
Expand Down Expand Up @@ -319,6 +323,9 @@ private String computeString(@UnderInitialization Group this) {
buf.append(lowerBound);
buf.append(" and ");
buf.append(upperBound);
if (exclude != RexWindowExclusion.EXCLUDE_NO_OTHER) {
buf.append(" ").append(exclude);
}
}
if (!aggCalls.isEmpty()) {
if (buf.length() > i) {
Expand Down Expand Up @@ -358,7 +365,9 @@ public RelCollation collation() {
public boolean isAlwaysNonEmpty() {
int lowerKey = lowerBound.getOrderKey();
int upperKey = upperBound.getOrderKey();
return lowerKey > -1 && lowerKey <= upperKey;
return lowerKey > -1 && lowerKey <= upperKey
&& (exclude == RexWindowExclusion.EXCLUDE_NO_OTHER
|| exclude == RexWindowExclusion.EXCLUDE_TIES);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.apache.calcite.rex.RexWindow;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowBounds;
import org.apache.calcite.rex.RexWindowExclusion;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlIdentifier;
Expand Down Expand Up @@ -772,11 +773,18 @@ public RexNode toRex(RelOptCluster cluster, Object o) {
upperBound = null;
physical = false;
}
final RexWindowExclusion exclude;
if (window.get("exclude") != null) {
exclude = toRexWindowExclusion((Map) window.get("exclude"));
} else {
exclude = RexWindowExclusion.EXCLUDE_NO_OTHER;
}
final boolean distinct = get((Map<String, Object>) map, "distinct");
return rexBuilder.makeOver(type, operator, rexOperands, partitionKeys,
ImmutableList.copyOf(orderKeys),
requireNonNull(lowerBound, "lowerBound"),
requireNonNull(upperBound, "upperBound"),
requireNonNull(exclude, "exclude"),
physical,
true, false, distinct, false);
} else {
Expand Down Expand Up @@ -969,6 +977,25 @@ private void addRexFieldCollationList(List<RexFieldCollation> list,
}
}

private @Nullable RexWindowExclusion toRexWindowExclusion(@Nullable Map<String, Object> map) {
if (map == null) {
return null;
}
final String type = get(map, "type");
switch (type) {
case "CURRENT_ROW":
return RexWindowExclusion.EXCLUDE_CURRENT_ROW;
case "GROUP":
return RexWindowExclusion.EXCLUDE_GROUP;
case "TIES":
return RexWindowExclusion.EXCLUDE_TIES;
case "NO OTHERS":
return RexWindowExclusion.EXCLUDE_NO_OTHER;
default:
throw new UnsupportedOperationException(
"cannot convert " + type + " to rex window exclusion");
}
}
private @Nullable RexWindowBound toRexWindowBound(RelInput input,
@Nullable Map<String, Object> map) {
if (map == null) {
Expand All @@ -988,7 +1015,7 @@ private void addRexFieldCollationList(List<RexFieldCollation> list,
case "FOLLOWING":
return RexWindowBounds.following(toRex(input, get(map, "offset")));
default:
throw new UnsupportedOperationException("cannot convert type to rex window bound " + type);
throw new UnsupportedOperationException("cannot convert " + type + " to rex window bound");
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexWindow;
import org.apache.calcite.rex.RexWindowBound;
import org.apache.calcite.rex.RexWindowExclusion;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Litmus;
Expand Down Expand Up @@ -68,8 +69,8 @@ public final class LogicalWindow extends Window {
*
* @param cluster Cluster
* @param traitSet Trait set
* @param hints Hints for this node
* @param input Input relational expression
* @param hints Hints for this node
* @param input Input relational expression
* @param constants List of constants that are additional inputs
* @param rowType Output row type
* @param groups Window groups
Expand Down Expand Up @@ -199,6 +200,7 @@ public static RelNode create(RelOptCluster cluster,
windowKey.isRows,
windowKey.lowerBound.accept(toInputRefs),
windowKey.upperBound.accept(toInputRefs),
windowKey.exclude,
windowKey.orderKeys,
aggCalls));
}
Expand Down Expand Up @@ -334,22 +336,25 @@ private static class WindowKey {
private final boolean isRows;
private final RexWindowBound lowerBound;
private final RexWindowBound upperBound;
private final RexWindowExclusion exclude;

WindowKey(
ImmutableBitSet groupSet,
RelCollation orderKeys,
boolean isRows,
RexWindowBound lowerBound,
RexWindowBound upperBound) {
RexWindowBound upperBound,
RexWindowExclusion exclude) {
this.groupSet = groupSet;
this.orderKeys = orderKeys;
this.isRows = isRows;
this.lowerBound = lowerBound;
this.upperBound = upperBound;
this.exclude = exclude;
}

@Override public int hashCode() {
return Objects.hash(groupSet, orderKeys, isRows, lowerBound, upperBound);
return Objects.hash(groupSet, orderKeys, isRows, lowerBound, upperBound, exclude);
}

@Override public boolean equals(@Nullable Object obj) {
Expand All @@ -359,6 +364,7 @@ private static class WindowKey {
&& orderKeys.equals(((WindowKey) obj).orderKeys)
&& Objects.equals(lowerBound, ((WindowKey) obj).lowerBound)
&& Objects.equals(upperBound, ((WindowKey) obj).upperBound)
&& exclude == ((WindowKey) obj).exclude
&& isRows == ((WindowKey) obj).isRows;
}
}
Expand Down Expand Up @@ -390,7 +396,7 @@ private static void addWindows(
WindowKey windowKey =
new WindowKey(
groupSet, orderKeys, aggWindow.isRows(),
aggWindow.getLowerBound(), aggWindow.getUpperBound());
aggWindow.getLowerBound(), aggWindow.getUpperBound(), aggWindow.getExclude());
windowMap.put(windowKey, over);
}

Expand Down
Loading

0 comments on commit 4a1da22

Please sign in to comment.