diff --git a/docs/docs/operators/topn.md b/docs/docs/operators/topn.md new file mode 100644 index 000000000..8813d77aa --- /dev/null +++ b/docs/docs/operators/topn.md @@ -0,0 +1,22 @@ +--- +sidebar_position: 8 +--- + +# Top N Operator + +The *Top N Operator* implements the *Top N Relation* defined in [substrait](https://substrait.io/relations/physical_relations/#top-n-operation). +It returns the top N rows in a query based on user provided sort fields. +It stores all events in a B+ tree based on the giving ordering. For each event, it also has to check if the event is in the top, and if so +create negation event if another event is no longer in the top. + +## Metrics + +The *Top N Operator* has the following metrics: + +| Metric Name | Type | Description | +| ------------- | --------- | ----------------------------------------------------- | +| busy | Gauge | Value 0-1 on how busy the operator is. | +| backpressure | Gauge | Value 0-1 on how much backpressure the operator has. | +| health | Gauge | Value 0 or 1, if the operator is healthy or not. | +| events | Counter | How many events the operator outputs. | + diff --git a/docs/docs/sql/select/topn.md b/docs/docs/sql/select/topn.md new file mode 100644 index 000000000..fbc931ec6 --- /dev/null +++ b/docs/docs/sql/select/topn.md @@ -0,0 +1,17 @@ +--- +sidebar_position: 2 +--- + +# Top N + +The Top N operator returns only the top N results from a query. +An ordering should be provided as well. + +Example: + +``` +SELECT TOP (10) + userkey +FROM users +ORDER BY userkey +``` \ No newline at end of file diff --git a/src/FlowtideDotNet.Core/Compute/Internal/SortFieldCompareCreator.cs b/src/FlowtideDotNet.Core/Compute/Internal/SortFieldCompareCreator.cs new file mode 100644 index 000000000..8f2c59c9f --- /dev/null +++ b/src/FlowtideDotNet.Core/Compute/Internal/SortFieldCompareCreator.cs @@ -0,0 +1,214 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using FlexBuffers; +using FlowtideDotNet.Substrait.Expressions; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using System.Threading.Tasks; + +namespace FlowtideDotNet.Core.Compute.Internal +{ + internal static class SortFieldCompareCreator + { + + // These methods are collected through reflection + internal static int CompareAscendingNullsFirstImplementation(FlxValue a, FlxValue b) + { + if (a.IsNull) + { + if (b.IsNull) + { + return 0; + } + else + { + return -1; + } + } + else if (b.IsNull) + { + return 1; + } + return FlxValueComparer.CompareTo(a, b); + } + + internal static int CompareAscendingNullsLastImplementation(FlxValue a, FlxValue b) + { + if (a.IsNull) + { + if (b.IsNull) + { + return 0; + } + else + { + return 1; + } + } + else if (b.IsNull) + { + return -1; + } + return FlxValueComparer.CompareTo(a, b); + } + + internal static int CompareDescendingNullsFirstImplementation(FlxValue a, FlxValue b) + { + if (a.IsNull) + { + if (b.IsNull) + { + return 0; + } + else + { + return -1; + } + } + else if (b.IsNull) + { + return 1; + } + return FlxValueComparer.CompareTo(b, a); + } + + internal static int CompareDescendingNullsLastImplementation(FlxValue a, FlxValue b) + { + if (a.IsNull) + { + if (b.IsNull) + { + return 0; + } + else + { + return 1; + } + } + else if (b.IsNull) + { + return -1; + } + return FlxValueComparer.CompareTo(b, a); + } + + private static System.Linq.Expressions.MethodCallExpression CompareAscendingNullsFirst(System.Linq.Expressions.Expression a, System.Linq.Expressions.Expression b) + { + MethodInfo? compareMethod = typeof(SortFieldCompareCreator).GetMethod("CompareAscendingNullsFirstImplementation", BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static); + Debug.Assert(compareMethod != null); + return System.Linq.Expressions.Expression.Call(compareMethod, a, b); + } + + private static System.Linq.Expressions.MethodCallExpression CompareAscendingNullsLast(System.Linq.Expressions.Expression a, System.Linq.Expressions.Expression b) + { + MethodInfo? compareMethod = typeof(SortFieldCompareCreator).GetMethod("CompareAscendingNullsLastImplementation", BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static); + Debug.Assert(compareMethod != null); + return System.Linq.Expressions.Expression.Call(compareMethod, a, b); + } + + private static System.Linq.Expressions.MethodCallExpression CompareDescendingNullsFirst(System.Linq.Expressions.Expression a, System.Linq.Expressions.Expression b) + { + MethodInfo? compareMethod = typeof(SortFieldCompareCreator).GetMethod("CompareDescendingNullsFirstImplementation", BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static); + Debug.Assert(compareMethod != null); + return System.Linq.Expressions.Expression.Call(compareMethod, a, b); + } + + private static System.Linq.Expressions.MethodCallExpression CompareDescendingNullsLast(System.Linq.Expressions.Expression a, System.Linq.Expressions.Expression b) + { + MethodInfo? compareMethod = typeof(SortFieldCompareCreator).GetMethod("CompareDescendingNullsLastImplementation", BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Static); + Debug.Assert(compareMethod != null); + return System.Linq.Expressions.Expression.Call(compareMethod, a, b); + } + + public static Func CreateComparer(List sortFields, FunctionsRegister functionsRegister) + { + var visitor = new FlowtideExpressionVisitor(functionsRegister, typeof(T)); + ParameterExpression left = System.Linq.Expressions.Expression.Parameter(typeof(T)); + ParameterExpression right = System.Linq.Expressions.Expression.Parameter(typeof(T)); + + var leftParameterInfo = new ParametersInfo(new List { left }, new List()); + var rightParameterInfo = new ParametersInfo(new List { right }, new List()); + + List comparisons = new List(); + for (int i = 0; i < sortFields.Count; i++) + { + var sortField = sortFields[i]; + var leftExpression = visitor.Visit(sortField.Expression, leftParameterInfo); + var rightExpression = visitor.Visit(sortField.Expression, rightParameterInfo); + + Debug.Assert(leftExpression != null); + Debug.Assert(rightExpression != null); + + MethodCallExpression? compareExpression = null; + if (sortField.SortDirection == SortDirection.SortDirectionAscNullsFirst) + { + compareExpression = CompareAscendingNullsFirst(leftExpression, rightExpression); + } + else if (sortField.SortDirection == SortDirection.SortDirectionAscNullsLast) + { + compareExpression = CompareAscendingNullsLast(leftExpression, rightExpression); + } + else if (sortField.SortDirection == SortDirection.SortDirectionDescNullsFirst) + { + compareExpression = CompareDescendingNullsFirst(leftExpression, rightExpression); + } + else if (sortField.SortDirection == SortDirection.SortDirectionDescNullsLast) + { + compareExpression = CompareDescendingNullsLast(leftExpression, rightExpression); + } + else if (sortField.SortDirection == SortDirection.SortDirectionUnspecified) + { + // Default is ascending with nulls first + compareExpression = CompareAscendingNullsFirst(leftExpression, rightExpression); + } + else + { + throw new NotSupportedException($"The sort order {sortField.SortDirection} is not supported"); + } + comparisons.Add(compareExpression); + } + + if (comparisons.Count == 1) + { + var lambda = System.Linq.Expressions.Expression.Lambda>(comparisons[0], left, right); + return lambda.Compile(); + } + else if (comparisons.Count > 1) + { + var tmpVar = System.Linq.Expressions.Expression.Variable(typeof(int)); + var compare = comparisons[comparisons.Count - 1]; + for (int i = comparisons.Count - 2; i >= 0; i--) + { + var res = comparisons[i]; + var assignOp = System.Linq.Expressions.Expression.Assign(tmpVar, res); + + var conditionTest = System.Linq.Expressions.Expression.Equal(tmpVar, System.Linq.Expressions.Expression.Constant(0)); + var condition = System.Linq.Expressions.Expression.Condition(conditionTest, compare, tmpVar); + var block = System.Linq.Expressions.Expression.Block(new ParameterExpression[] { tmpVar }, assignOp, condition); + compare = block; + } + var lambda = System.Linq.Expressions.Expression.Lambda>(compare, left, right); + return lambda.Compile(); + } + else + { + throw new InvalidOperationException("No sort fields specified"); + } + } + } +} diff --git a/src/FlowtideDotNet.Core/Engine/SubstraitVisitor.cs b/src/FlowtideDotNet.Core/Engine/SubstraitVisitor.cs index 2fd0c2470..82c5187e9 100644 --- a/src/FlowtideDotNet.Core/Engine/SubstraitVisitor.cs +++ b/src/FlowtideDotNet.Core/Engine/SubstraitVisitor.cs @@ -32,6 +32,7 @@ using FlowtideDotNet.Substrait.Expressions; using FlowtideDotNet.Core.Operators.TimestampProvider; using FlowtideDotNet.Core.Operators.Buffer; +using FlowtideDotNet.Core.Operators.TopN; namespace FlowtideDotNet.Core.Engine { @@ -533,5 +534,23 @@ public override IStreamVertex VisitBufferRelation(BufferRelation bufferRelation, dataflowStreamBuilder.AddPropagatorBlock(id.ToString(), op); return op; } + + public override IStreamVertex VisitTopNRelation(TopNRelation topNRelation, ITargetBlock? state) + { + var id = _operatorId++; + var op = new TopNOperator(topNRelation, functionsRegister, new ExecutionDataflowBlockOptions() { BoundedCapacity = queueSize, MaxDegreeOfParallelism = 1 }); + if (state != null) + { + op.LinkTo(state); + } + topNRelation.Input.Accept(this, op); + dataflowStreamBuilder.AddPropagatorBlock(id.ToString(), op); + return op; + } + + public override IStreamVertex VisitFetchRelation(FetchRelation fetchRelation, ITargetBlock? state) + { + throw new NotSupportedException("Fetch operation (top or limit) is not supported without an order by"); + } } } diff --git a/src/FlowtideDotNet.Core/Operators/TopN/TopNComparer.cs b/src/FlowtideDotNet.Core/Operators/TopN/TopNComparer.cs new file mode 100644 index 000000000..426adf6e0 --- /dev/null +++ b/src/FlowtideDotNet.Core/Operators/TopN/TopNComparer.cs @@ -0,0 +1,40 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace FlowtideDotNet.Core.Operators.TopN +{ + internal class TopNComparer : IComparer + { + private readonly Func compareFunction; + + public TopNComparer(Func compareFunction) + { + this.compareFunction = compareFunction; + } + public int Compare(RowEvent x, RowEvent y) + { + var result = compareFunction(x, y); + + if (result != 0) + { + return result; + } + return RowEvent.Compare(x, y); + } + } +} diff --git a/src/FlowtideDotNet.Core/Operators/TopN/TopNOperator.cs b/src/FlowtideDotNet.Core/Operators/TopN/TopNOperator.cs new file mode 100644 index 000000000..f76ef87ea --- /dev/null +++ b/src/FlowtideDotNet.Core/Operators/TopN/TopNOperator.cs @@ -0,0 +1,255 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using FlowtideDotNet.Base.Metrics; +using FlowtideDotNet.Base.Vertices.Unary; +using FlowtideDotNet.Core.Compute; +using FlowtideDotNet.Core.Compute.Internal; +using FlowtideDotNet.Core.Storage; +using FlowtideDotNet.Storage.Serializers; +using FlowtideDotNet.Storage.StateManager; +using FlowtideDotNet.Storage.Tree; +using FlowtideDotNet.Substrait.Relations; +using System.Diagnostics; +using System.Diagnostics.Tracing; +using System.Threading.Tasks.Dataflow; + +namespace FlowtideDotNet.Core.Operators.TopN +{ + internal class TopNOperator : UnaryVertex + { + private readonly TopNComparer _comparer; + private readonly TopNRelation relation; + private IBPlusTree? _tree; + private ICounter? _eventsOutCounter; + + public TopNOperator(TopNRelation relation, FunctionsRegister functionsRegister, ExecutionDataflowBlockOptions executionDataflowBlockOptions) : base(executionDataflowBlockOptions) + { + var compareFunc = SortFieldCompareCreator.CreateComparer(relation.Sorts, functionsRegister); + _comparer = new TopNComparer(compareFunc); + this.relation = relation; + } + + public override string DisplayName => $"Top ({relation.Count})"; + + public override Task Compact() + { + return Task.CompletedTask; + } + + public override Task DeleteAsync() + { + return Task.CompletedTask; + } + + public override async Task OnCheckpoint() + { + Debug.Assert(_tree != null); + await _tree.Commit(); + return default; + } + + public override async IAsyncEnumerable OnRecieve(StreamEventBatch msg, long time) + { + Debug.Assert(_tree != null); + var iterator = _tree.CreateIterator(); + List output = new List(); + foreach(var e in msg.Events) + { + // Insert the value into the tree + var (op, _) = await _tree.RMW(e, e.Weight, (input, current, exists) => + { + if (exists) + { + var newWeight = current + input; + if (newWeight == 0) + { + return (0, GenericWriteOperation.Delete); + } + return (current + e.Weight, GenericWriteOperation.Upsert); + } + else + { + return (input, GenericWriteOperation.Upsert); + } + }); + + // Iterate over the tree, find the Nth value, check if this value is greater or smaller than that + await GetOutputValues(e, output, iterator, op); + } + + if (output.Count > 0) + { + Debug.Assert(_eventsOutCounter != null, nameof(_eventsOutCounter)); + _eventsOutCounter.Add(output.Count); + yield return new StreamEventBatch(output); + } + } + + private async Task GetOutputValues(RowEvent ev, List output, IBPlusTreeIterator iterator, GenericWriteOperation op) + { + await iterator.SeekFirst(); + int count = 0; + var enumerator = iterator.GetAsyncEnumerator(); + int bumpCount = -1; + int bumpWeightModifier = -1; + int pageIndex = -1; + while (await enumerator.MoveNextAsync()) + { + var page = enumerator.Current; + var index = page.Keys.BinarySearch(ev, _comparer); + var loopEndIndex = index; + if (loopEndIndex < 0) + { + loopEndIndex = ~loopEndIndex; + } + // Loop through all elements to get the count + for (int i = 0; i < loopEndIndex; i++) + { + // Count all the weights in the page, since one row could have many duplicates + count += page.Values[i]; + if (count >= relation.Count) + { + break; + } + } + // Check if all N rows have already been accounted, then no output will be given. + if (count >= relation.Count) + { + break; + } + // Check if we did not find the row we inserted, if so, continue to the next page + if (index < 0 && loopEndIndex == page.Values.Count) + { + continue; + } + + // If we reached here, the row was found and it should output values. + + // Set the index where the element is + pageIndex = loopEndIndex; + // If it is an upsert, output the event + if (op == GenericWriteOperation.Upsert) + { + if (index >= 0) + { + // Check if this value already outputs enough rows to satisfy the count + if ((count + page.Values[index] - ev.Weight) >= relation.Count) + { + // Break and do nothing, the count is already satisfied. + break; + } + var outputWeight = Math.Min(relation.Count - count, ev.Weight); + output.Add(new RowEvent(outputWeight, 0, ev.RowData)); + bumpCount = outputWeight; + bumpWeightModifier = -1; + break; + } + else + { + throw new InvalidOperationException("Got an upsert for a value that does not exist in the tree"); + } + } + else if (op == GenericWriteOperation.Delete) + { + output.Add(ev); + bumpCount = ev.Weight * -1; + bumpWeightModifier = 1; + break; + } + else + { + throw new NotSupportedException(); + } + } + + if (bumpCount > 0) + { + var stopCount = relation.Count - bumpCount; + if (bumpWeightModifier < 0) + { + // if we should remove elements, we look at an element infront of the count. + // If it is a delete, the element has already been removed from the tree, so we should not add with 1. + stopCount += 1; + } + // Iterate until the stop count where we should start adding or removing events from the output. + int bumpStartIndex = -1; + do + { + var page = enumerator.Current; + for (int i = pageIndex; i < page.Values.Count; i++) + { + count += page.Values[i]; + if (count > stopCount) + { + bumpStartIndex = i; + break; + } + } + if (count > stopCount) + { + break; + } + pageIndex = 0; + } while (await enumerator.MoveNextAsync()); + + // Check if we have an index where we should start bumping from + if (bumpStartIndex >= 0) + { + var page = enumerator.Current; + while (bumpCount > 0) + { + for (int i = bumpStartIndex; i < page.Values.Count; i++) + { + // Take the min value of the bump count and the weight of the row + var weightToRemove = Math.Min(bumpCount, page.Values[i]); + output.Add(new RowEvent(weightToRemove * bumpWeightModifier, 0, page.Keys[i].RowData)); + bumpCount -= weightToRemove; + if (bumpCount == 0) + { + break; + } + } + if (bumpCount == 0) + { + break; + } + if (await enumerator.MoveNextAsync()) + { + page = enumerator.Current; + bumpStartIndex = 0; + } + else + { + break; + } + } + } + } + } + + protected override async Task InitializeOrRestore(object? state, IStateManagerClient stateManagerClient) + { + if (_eventsOutCounter == null) + { + _eventsOutCounter = Metrics.CreateCounter("events"); + } + // Create tree that will hold all rows + _tree = await stateManagerClient.GetOrCreateTree("topn", new FlowtideDotNet.Storage.Tree.BPlusTreeOptions() + { + Comparer = _comparer, + KeySerializer = new StreamEventBPlusTreeSerializer(), + ValueSerializer = new IntSerializer() + }); + } + } +} diff --git a/src/FlowtideDotNet.Core/Optimizer/OptimizerBaseVisitor.cs b/src/FlowtideDotNet.Core/Optimizer/OptimizerBaseVisitor.cs index aed319c24..b7f980c9d 100644 --- a/src/FlowtideDotNet.Core/Optimizer/OptimizerBaseVisitor.cs +++ b/src/FlowtideDotNet.Core/Optimizer/OptimizerBaseVisitor.cs @@ -123,5 +123,23 @@ public override Relation VisitBufferRelation(BufferRelation bufferRelation, obje bufferRelation.Input = Visit(bufferRelation.Input, state); return bufferRelation; } + + public override Relation VisitFetchRelation(FetchRelation fetchRelation, object state) + { + fetchRelation.Input = Visit(fetchRelation.Input, state); + return fetchRelation; + } + + public override Relation VisitSortRelation(SortRelation sortRelation, object state) + { + sortRelation.Input = Visit(sortRelation.Input, state); + return sortRelation; + } + + public override Relation VisitTopNRelation(TopNRelation topNRelation, object state) + { + topNRelation.Input = Visit(topNRelation.Input, state); + return topNRelation; + } } } diff --git a/src/FlowtideDotNet.Storage/Tree/IBPlusTreePageIterator.cs b/src/FlowtideDotNet.Storage/Tree/IBPlusTreePageIterator.cs index cd503a029..c6289b1db 100644 --- a/src/FlowtideDotNet.Storage/Tree/IBPlusTreePageIterator.cs +++ b/src/FlowtideDotNet.Storage/Tree/IBPlusTreePageIterator.cs @@ -19,5 +19,9 @@ public interface IBPlusTreePageIterator : IEnumerable> /// /// ValueTask SavePage(); + + List Keys { get; } + + List Values { get; } } } diff --git a/src/FlowtideDotNet.Storage/Tree/Internal/BPlusTreePageIterator.cs b/src/FlowtideDotNet.Storage/Tree/Internal/BPlusTreePageIterator.cs index da09ad769..9c27c7367 100644 --- a/src/FlowtideDotNet.Storage/Tree/Internal/BPlusTreePageIterator.cs +++ b/src/FlowtideDotNet.Storage/Tree/Internal/BPlusTreePageIterator.cs @@ -66,6 +66,10 @@ public BPlusTreePageIterator(in LeafNode leaf, in int index, in BPlusTree< this.tree = tree; } + public List Keys => leaf.keys; + + public List Values => leaf.values; + public ValueTask SavePage() { var isFull = tree.m_stateClient.AddOrUpdate(leaf.Id, leaf); diff --git a/src/FlowtideDotNet.Substrait/CustomProto/custom_proto.proto b/src/FlowtideDotNet.Substrait/CustomProto/custom_proto.proto index 0e22c2da4..117444dfc 100644 --- a/src/FlowtideDotNet.Substrait/CustomProto/custom_proto.proto +++ b/src/FlowtideDotNet.Substrait/CustomProto/custom_proto.proto @@ -25,4 +25,10 @@ message ReferenceRelation { message BufferRelation { +} + +message TopNRelation { + repeated substrait.SortField sorts = 1; + int32 offset = 2; + int32 count = 3; } \ No newline at end of file diff --git a/src/FlowtideDotNet.Substrait/Expressions/SortField.cs b/src/FlowtideDotNet.Substrait/Expressions/SortField.cs new file mode 100644 index 000000000..594403071 --- /dev/null +++ b/src/FlowtideDotNet.Substrait/Expressions/SortField.cs @@ -0,0 +1,37 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace FlowtideDotNet.Substrait.Expressions +{ + public enum SortDirection + { + SortDirectionUnspecified = 0, + SortDirectionAscNullsFirst = 1, + SortDirectionAscNullsLast = 2, + SortDirectionDescNullsFirst = 3, + SortDirectionDescNullsLast = 4, + SortDirectionClustered = 5 + } + + public class SortField + { + public required Expression Expression { get; set; } + + public SortDirection SortDirection { get; set; } + } +} diff --git a/src/FlowtideDotNet.Substrait/Modifier/ModifierVisitor.cs b/src/FlowtideDotNet.Substrait/Modifier/ModifierVisitor.cs index cbef07183..e9632f679 100644 --- a/src/FlowtideDotNet.Substrait/Modifier/ModifierVisitor.cs +++ b/src/FlowtideDotNet.Substrait/Modifier/ModifierVisitor.cs @@ -133,5 +133,23 @@ public override Relation VisitBufferRelation(BufferRelation bufferRelation, obje bufferRelation.Input = Visit(bufferRelation.Input, state); return bufferRelation; } + + public override Relation VisitFetchRelation(FetchRelation fetchRelation, object? state) + { + fetchRelation.Input = Visit(fetchRelation.Input, state); + return fetchRelation; + } + + public override Relation VisitSortRelation(SortRelation sortRelation, object? state) + { + sortRelation.Input = Visit(sortRelation.Input, state); + return sortRelation; + } + + public override Relation VisitTopNRelation(TopNRelation topNRelation, object? state) + { + topNRelation.Input = Visit(topNRelation.Input, state); + return topNRelation; + } } } diff --git a/src/FlowtideDotNet.Substrait/Relations/FetchRelation.cs b/src/FlowtideDotNet.Substrait/Relations/FetchRelation.cs new file mode 100644 index 000000000..a7b51d020 --- /dev/null +++ b/src/FlowtideDotNet.Substrait/Relations/FetchRelation.cs @@ -0,0 +1,46 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace FlowtideDotNet.Substrait.Relations +{ + public class FetchRelation : Relation + { + public override int OutputLength + { + get + { + if (EmitSet) + { + return Emit!.Count; + } + return Input.OutputLength; + } + } + + public required Relation Input { get; set; } + + public int Offset { get; set; } + + public int Count { get; set; } + + public override TReturn Accept(RelationVisitor visitor, TState state) + { + return visitor.VisitFetchRelation(this, state); + } + } +} diff --git a/src/FlowtideDotNet.Substrait/Relations/RelationVisitor.cs b/src/FlowtideDotNet.Substrait/Relations/RelationVisitor.cs index 9ee839381..1d9e5f3a5 100644 --- a/src/FlowtideDotNet.Substrait/Relations/RelationVisitor.cs +++ b/src/FlowtideDotNet.Substrait/Relations/RelationVisitor.cs @@ -103,5 +103,20 @@ public virtual TReturn VisitBufferRelation(BufferRelation bufferRelation, TState { throw new NotImplementedException("Buffer relation is not implemented"); } + + public virtual TReturn VisitTopNRelation(TopNRelation topNRelation, TState state) + { + throw new NotImplementedException("TopN relation is not implemented"); + } + + public virtual TReturn VisitSortRelation(SortRelation sortRelation, TState state) + { + throw new NotImplementedException("Sort relation is not implemented"); + } + + public virtual TReturn VisitFetchRelation(FetchRelation fetchRelation, TState state) + { + throw new NotImplementedException("Fetch relation is not implemented"); + } } } diff --git a/src/FlowtideDotNet.Substrait/Relations/SortRelation.cs b/src/FlowtideDotNet.Substrait/Relations/SortRelation.cs new file mode 100644 index 000000000..039b3a50a --- /dev/null +++ b/src/FlowtideDotNet.Substrait/Relations/SortRelation.cs @@ -0,0 +1,45 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using FlowtideDotNet.Substrait.Expressions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace FlowtideDotNet.Substrait.Relations +{ + public class SortRelation : Relation + { + public override int OutputLength + { + get + { + if (EmitSet) + { + return Emit!.Count; + } + return Input.OutputLength; + } + } + + public required Relation Input { get; set; } + + public required List Sorts { get; set; } + + public override TReturn Accept(RelationVisitor visitor, TState state) + { + return visitor.VisitSortRelation(this, state); + } + } +} diff --git a/src/FlowtideDotNet.Substrait/Relations/TopNRelation.cs b/src/FlowtideDotNet.Substrait/Relations/TopNRelation.cs new file mode 100644 index 000000000..0ed5ceabf --- /dev/null +++ b/src/FlowtideDotNet.Substrait/Relations/TopNRelation.cs @@ -0,0 +1,49 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using FlowtideDotNet.Substrait.Expressions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace FlowtideDotNet.Substrait.Relations +{ + public class TopNRelation : Relation + { + public override int OutputLength + { + get + { + if (EmitSet) + { + return Emit!.Count; + } + return Input.OutputLength; + } + } + + public required Relation Input { get; set; } + + public required List Sorts { get; set; } + + public int Offset { get; set; } + + public int Count { get; set; } + + public override TReturn Accept(RelationVisitor visitor, TState state) + { + return visitor.VisitTopNRelation(this, state); + } + } +} diff --git a/src/FlowtideDotNet.Substrait/Sql/Internal/SqlSubstraitVisitor.cs b/src/FlowtideDotNet.Substrait/Sql/Internal/SqlSubstraitVisitor.cs index 30c04ee02..30fbb86e6 100644 --- a/src/FlowtideDotNet.Substrait/Sql/Internal/SqlSubstraitVisitor.cs +++ b/src/FlowtideDotNet.Substrait/Sql/Internal/SqlSubstraitVisitor.cs @@ -169,7 +169,107 @@ public SqlSubstraitVisitor(SqlPlanBuilder sqlPlanBuilder, SqlFunctionRegister sq tablesMetadata.AddTable(alias, cteEmitData.GetNames()); } } - return Visit(query.Body, state); + var node = Visit(query.Body, state); + + if (node == null) + { + throw new InvalidOperationException("Could not create a plan from the query"); + } + if (query.OrderBy != null) + { + var exprVisitor = new SqlExpressionVisitor(sqlFunctionRegister); + List sortFields = new List(); + foreach (var o in query.OrderBy) + { + var expr = exprVisitor.Visit(o.Expression, node.EmitData); + var sortDirection = GetSortDirection(o); + + sortFields.Add(new Expressions.SortField() + { + Expression = expr.Expr, + SortDirection = sortDirection + }); + } + + if (node.Relation is FetchRelation fetch) + { + var rel = new TopNRelation() + { + Input = fetch.Input, + Sorts = sortFields, + Count = fetch.Count, + Offset = fetch.Offset + }; + // Add the order by before the fetch, since the fetch can come from the TOP N in the select. + node = new RelationData(rel, node.EmitData); + } + } + return node; + } + + private static Expressions.SortDirection GetSortDirection(OrderByExpression o) + { + Expressions.SortDirection sortDirection; + + // Find the sort direction of this field + if (o.Asc != null) + { + if (o.Asc.Value) + { + if (o.NullsFirst != null) + { + if (o.NullsFirst.Value) + { + sortDirection = Expressions.SortDirection.SortDirectionAscNullsFirst; + } + else + { + sortDirection = Expressions.SortDirection.SortDirectionAscNullsLast; + } + } + else + { + sortDirection = Expressions.SortDirection.SortDirectionAscNullsFirst; + } + } + else + { + if (o.NullsFirst != null) + { + if (o.NullsFirst.Value) + { + sortDirection = Expressions.SortDirection.SortDirectionDescNullsFirst; + } + else + { + sortDirection = Expressions.SortDirection.SortDirectionDescNullsLast; + } + } + else + { + sortDirection = Expressions.SortDirection.SortDirectionDescNullsLast; + } + } + } + else + { + if (o.NullsFirst != null) + { + if (o.NullsFirst.Value) + { + sortDirection = Expressions.SortDirection.SortDirectionAscNullsFirst; + } + else + { + sortDirection = Expressions.SortDirection.SortDirectionAscNullsLast; + } + } + else + { + sortDirection = Expressions.SortDirection.SortDirectionAscNullsFirst; + } + } + return sortDirection; } protected override RelationData? VisitSelect(Select select, object? state) @@ -228,6 +328,24 @@ public SqlSubstraitVisitor(SqlPlanBuilder sqlPlanBuilder, SqlFunctionRegister sq outNode = VisitProjection(select.Projection, outNode); } + if (select.Top != null) + { + if (outNode == null) + { + throw new InvalidOperationException("TOP statement is not supported without a FROM statement"); + } + var literal = select.Top.Quantity?.AsLiteral()?.Value?.AsNumber(); + if (literal == null) + { + throw new NotSupportedException("Only numeric literal values are supported in the TOP statement"); + } + outNode = new RelationData(new FetchRelation() + { + Input = outNode.Relation, + Count = int.Parse(literal.Value) + }, outNode.EmitData); + } + return outNode; } diff --git a/src/FlowtideDotNet.Substrait/SubstraitSerializer.cs b/src/FlowtideDotNet.Substrait/SubstraitSerializer.cs index 43176f8ad..b4d63179a 100644 --- a/src/FlowtideDotNet.Substrait/SubstraitSerializer.cs +++ b/src/FlowtideDotNet.Substrait/SubstraitSerializer.cs @@ -861,6 +861,69 @@ public override Protobuf.Rel VisitWriteRelation(WriteRelation writeRelation, Ser Write = writeRel }; } + + public override Rel VisitTopNRelation(TopNRelation topNRelation, SerializerVisitorState state) + { + var rel = new Protobuf.ExtensionSingleRel(); + var topRel = new CustomProtobuf.TopNRelation(); + topRel.Offset = topNRelation.Offset; + topRel.Count = topNRelation.Count; + + var exprVisitor = new SerializerExpressionVisitor(); + + foreach (var sortField in topNRelation.Sorts) + { + Protobuf.SortField.Types.SortDirection sortDir; + switch (sortField.SortDirection) + { + case SortDirection.SortDirectionUnspecified: + sortDir = Protobuf.SortField.Types.SortDirection.Unspecified; + break; + case SortDirection.SortDirectionAscNullsFirst: + sortDir = Protobuf.SortField.Types.SortDirection.AscNullsFirst; + break; + case SortDirection.SortDirectionAscNullsLast: + sortDir = Protobuf.SortField.Types.SortDirection.AscNullsLast; + break; + case SortDirection.SortDirectionDescNullsFirst: + sortDir = Protobuf.SortField.Types.SortDirection.DescNullsFirst; + break; + case SortDirection.SortDirectionDescNullsLast: + sortDir = Protobuf.SortField.Types.SortDirection.DescNullsLast; + break; + case SortDirection.SortDirectionClustered: + sortDir = Protobuf.SortField.Types.SortDirection.Clustered; + break; + default: + throw new NotImplementedException(); + } + + topRel.Sorts.Add(new Protobuf.SortField() + { + Direction = sortDir, + Expr = exprVisitor.Visit(sortField.Expression, state) + }); + } + + rel.Detail = new Google.Protobuf.WellKnownTypes.Any() + { + TypeUrl = "flowtide/flowtide.TopNRelation", + Value = topRel.ToByteString() + }; + + if (topNRelation.EmitSet) + { + rel.Common = new Protobuf.RelCommon(); + rel.Common.Emit = new Protobuf.RelCommon.Types.Emit(); + rel.Common.Emit.OutputMapping.AddRange(topNRelation.Emit); + } + rel.Input = Visit(topNRelation.Input, state); + + return new Protobuf.Rel() + { + ExtensionSingle = rel + }; + } } public static Protobuf.Plan Serialize(Plan plan) @@ -885,7 +948,8 @@ public static string SerializeToJson(Plan plan) CustomProtobuf.IterationReferenceReadRelation.Descriptor, CustomProtobuf.IterationRelation.Descriptor, CustomProtobuf.NormalizationRelation.Descriptor, - CustomProtobuf.ReferenceRelation.Descriptor); + CustomProtobuf.ReferenceRelation.Descriptor, + CustomProtobuf.TopNRelation.Descriptor); var settings = new Google.Protobuf.JsonFormatter.Settings(true, typeRegistry) .WithIndentation(); var formatter = new Google.Protobuf.JsonFormatter(settings); diff --git a/tests/FlowtideDotNet.AcceptanceTests/TopNTests.cs b/tests/FlowtideDotNet.AcceptanceTests/TopNTests.cs new file mode 100644 index 000000000..00505d322 --- /dev/null +++ b/tests/FlowtideDotNet.AcceptanceTests/TopNTests.cs @@ -0,0 +1,249 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Xunit.Abstractions; + +namespace FlowtideDotNet.AcceptanceTests +{ + public class TopNTests : FlowtideAcceptanceBase + { + public TopNTests(ITestOutputHelper testOutputHelper) : base(testOutputHelper) + { + } + + [Fact] + public async Task TestTopWithoutOrderBy() + { + GenerateData(); + var ex = await Assert.ThrowsAsync(async () => + { + await StartStream(@" + INSERT INTO output + SELECT TOP 1 userkey + FROM users"); + }); + + Assert.Equal("Fetch operation (top or limit) is not supported without an order by", ex.Message); + } + + [Fact] + public async Task TestTop1Asc() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 1 userkey + FROM users + ORDER BY userkey"); + + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderBy(x => x.UserKey).Take(1).Select(x => new { x.UserKey })); + } + + [Fact] + public async Task TestTop1Desc() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 1 userkey + FROM users + ORDER BY userkey DESC"); + + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderByDescending(x => x.UserKey).Take(1).Select(x => new { x.UserKey })); + } + + [Fact] + public async Task TestTop10Asc() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 10 userkey + FROM users + ORDER BY userkey"); + + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderBy(x => x.UserKey).Take(10).Select(x => new { x.UserKey })); + } + + [Fact] + public async Task TestTop10Desc() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 10 userkey + FROM users + ORDER BY userkey DESC"); + + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderByDescending(x => x.UserKey).Take(10).Select(x => new { x.UserKey })); + } + + [Fact] + public async Task TestTop63Desc() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 63 userkey + FROM users + ORDER BY userkey DESC"); + + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderByDescending(x => x.UserKey).Take(63).Select(x => new { x.UserKey })); + } + + [Fact] + public async Task TestTop63Asc() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 63 userkey + FROM users + ORDER BY userkey"); + + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderBy(x => x.UserKey).Take(63).Select(x => new { x.UserKey })); + } + + [Fact] + public async Task TestTop1AscNullsFirstWithDuplicates() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 1 companyId + FROM users + ORDER BY companyId ASC NULLS FIRST"); + + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderBy(x => x.CompanyId).Take(1).Select(x => new { x.CompanyId })); + } + + [Fact] + public async Task TestTop1AscNullsLastWithDuplicates() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 1 companyId + FROM users + ORDER BY companyId ASC NULLS LAST"); + + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderByDescending(x => x.CompanyId != null).ThenBy(x => x.CompanyId).Take(1).Select(x => new { x.CompanyId })); + } + + [Fact] + public async Task TestTop1DescNullsFirstWithDuplicates() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 1 companyId + FROM users + ORDER BY companyId DESC NULLS FIRST"); + + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderBy(x => x.CompanyId != null).ThenByDescending(x => x.CompanyId).Take(1).Select(x => new { x.CompanyId })); + } + + [Fact] + public async Task TestTop1DescNullsLastWithDuplicates() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 1 companyId + FROM users + ORDER BY companyId DESC NULLS LAST"); + + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderByDescending(x => x.CompanyId != null).ThenByDescending(x => x.CompanyId).Take(1).Select(x => new { x.CompanyId })); + } + + [Fact] + public async Task TestTop1AscDeleteFirstRow() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 1 userkey + FROM users + ORDER BY userkey"); + + await WaitForUpdate(); + + var firstUser = Users[0]; + DeleteUser(firstUser); + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderBy(x => x.UserKey).Take(1).Select(x => new { x.UserKey })); + } + + [Fact] + public async Task TestTop10AscDeleteFirstRow() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 10 userkey + FROM users + ORDER BY userkey"); + + await WaitForUpdate(); + + var firstUser = Users[0]; + DeleteUser(firstUser); + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderBy(x => x.UserKey).Take(10).Select(x => new { x.UserKey })); + } + + [Fact] + public async Task TestTop1AscNullsFirstWithDuplicatesDeleteNull() + { + GenerateData(); + await StartStream(@" + INSERT INTO output + SELECT TOP 1 companyId + FROM users + ORDER BY companyId ASC NULLS FIRST"); + + await WaitForUpdate(); + var firstWithNull = Users.First(x => x.CompanyId == null); + DeleteUser(firstWithNull); + await WaitForUpdate(); + + AssertCurrentDataEqual(Users.OrderBy(x => x.CompanyId).Take(1).Select(x => new { x.CompanyId })); + } + } +} diff --git a/tests/FlowtideDotNet.Substrait.Tests/SerializeTests.cs b/tests/FlowtideDotNet.Substrait.Tests/SerializeTests.cs index d09c624c8..2871eba82 100644 --- a/tests/FlowtideDotNet.Substrait.Tests/SerializeTests.cs +++ b/tests/FlowtideDotNet.Substrait.Tests/SerializeTests.cs @@ -42,5 +42,34 @@ select lower(a) FROM table1 t1 var formatter = new Google.Protobuf.JsonFormatter(settings); var json = formatter.Format(protoPlan); } + + [Fact] + public void SerializeTopNRelation() + { + SqlPlanBuilder sqlPlanBuilder = new SqlPlanBuilder(); + sqlPlanBuilder.Sql(@" + create table table1 (a any); + + insert into out + select TOP (1) a FROM table1 t1 ORDER BY a + "); + var plan = sqlPlanBuilder.GetPlan(); + + var ex = Record.Exception(() => + { + var protoPlan = SubstraitSerializer.Serialize(plan); + + var typeRegistry = Google.Protobuf.Reflection.TypeRegistry.FromMessages( + CustomProtobuf.IterationReferenceReadRelation.Descriptor, + CustomProtobuf.IterationRelation.Descriptor, + CustomProtobuf.NormalizationRelation.Descriptor, + CustomProtobuf.ReferenceRelation.Descriptor); + var settings = new Google.Protobuf.JsonFormatter.Settings(true, typeRegistry) + .WithIndentation(); + var formatter = new Google.Protobuf.JsonFormatter(settings); + var json = formatter.Format(protoPlan); + }); + Assert.Null(ex); + } } }