From c5ea1beadc6c26a4d3ae947530d7c280aef30b1e Mon Sep 17 00:00:00 2001 From: Ulimo Date: Tue, 14 Nov 2023 20:49:07 +0100 Subject: [PATCH] Push down complex expressions to a projection from join condition (#120) --- .../EmitPushdown/EmitPushdownVisitor.cs | 19 +- .../JoinProjectionPushDownVisitor.cs | 28 --- .../JoinProjectionPushDownVisitor.cs | 211 +++++++++++++++++ .../JoinProjectionPushdown.cs | 42 ++++ .../Optimizer/MergeJoinFindVisitor.cs | 2 + .../Optimizer/PlanOptimizer.cs | 9 + .../Internal/FlowtideTestStream.cs | 23 +- .../JoinTests.cs | 24 ++ .../JoinProjectionPushdownTests.cs | 213 ++++++++++++++++++ 9 files changed, 541 insertions(+), 30 deletions(-) delete mode 100644 src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushDownVisitor.cs create mode 100644 src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushdown/JoinProjectionPushDownVisitor.cs create mode 100644 src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushdown/JoinProjectionPushdown.cs create mode 100644 tests/FlowtideDotNet.Core.Tests/OptimizerTests/JoinProjectionPushdownTests.cs diff --git a/src/FlowtideDotNet.Core/Optimizer/EmitPushdown/EmitPushdownVisitor.cs b/src/FlowtideDotNet.Core/Optimizer/EmitPushdown/EmitPushdownVisitor.cs index af1c4fcce..fe75d9315 100644 --- a/src/FlowtideDotNet.Core/Optimizer/EmitPushdown/EmitPushdownVisitor.cs +++ b/src/FlowtideDotNet.Core/Optimizer/EmitPushdown/EmitPushdownVisitor.cs @@ -247,10 +247,27 @@ public override Relation VisitProjectRelation(ProjectRelation projectRelation, o // Remap the expression emits also to reflect the changes Dictionary oldToNew = new Dictionary(); List emit = new List(); + + Dictionary inputEmitToInternal = new Dictionary(); + if (input.EmitSet) + { + for (int i = 0; i < input.Emit.Count; i++) + { + inputEmitToInternal.Add(i, input.Emit[i]); + } + } + else + { + for (int i = 0; i < input.OutputLength; i++) + { + inputEmitToInternal.Add(i, i); + } + } + int count = 0; foreach(var field in usedFields.OrderBy(x => x)) { - emit.Add(field); + emit.Add(inputEmitToInternal[field]); oldToNew.Add(field, count); count++; } diff --git a/src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushDownVisitor.cs b/src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushDownVisitor.cs deleted file mode 100644 index c5daa44b9..000000000 --- a/src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushDownVisitor.cs +++ /dev/null @@ -1,28 +0,0 @@ -// Licensed under the Apache License, Version 2.0 (the "License") -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -using FlowtideDotNet.Substrait.Relations; - -namespace FlowtideDotNet.Core.Optimizer -{ - /// - /// Finds expressions that contain projections can be pushed down, this helps performance of the join queries. - /// - internal class JoinProjectionPushDownVisitor : RelationVisitor - { - public override Relation VisitJoinRelation(JoinRelation joinRelation, object state) - { - - return base.VisitJoinRelation(joinRelation, state); - } - } -} diff --git a/src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushdown/JoinProjectionPushDownVisitor.cs b/src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushdown/JoinProjectionPushDownVisitor.cs new file mode 100644 index 000000000..107c7d18b --- /dev/null +++ b/src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushdown/JoinProjectionPushDownVisitor.cs @@ -0,0 +1,211 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using FlowtideDotNet.Core.Optimizer.EmitPushdown; +using FlowtideDotNet.Substrait.Expressions; +using FlowtideDotNet.Substrait.Expressions.Literals; +using FlowtideDotNet.Substrait.FunctionExtensions; +using FlowtideDotNet.Substrait.Relations; +using SqlParser; + +namespace FlowtideDotNet.Core.Optimizer.JoinProjectionPushdown +{ + /// + /// Finds expressions that contain projections can be pushed down, this helps performance of the join queries. + /// + internal class JoinProjectionPushDownVisitor : OptimizerBaseVisitor + { + /// + /// Goes through an expression and tries to find complex expressions that only use left or right side of the data. + /// In that case it can be pushed down to a projection infront of a join + /// + /// + /// + /// + /// + /// + /// + private Expression Check( + Expression expression, + int leftSize, + int rightSize, + List leftSideExpressions, + List rightSideExpressions, + ref int newIdCounter, + List leftSideIds, + List rightSideIds) + { + if (expression is DirectFieldReference) + { + return expression; + } + var visitor = new JoinExpressionVisitor(leftSize); + visitor.Visit(expression, default); + + if (visitor.unknownCase || visitor.fieldInLeft && visitor.fieldInRight) + { + if (expression is ScalarFunction scalar) + { + for (int i = 0; i < scalar.Arguments.Count; i++) + { + var arg = scalar.Arguments[i]; + scalar.Arguments[i] = Check(arg, leftSize, rightSize, leftSideExpressions, rightSideExpressions, ref newIdCounter, leftSideIds, rightSideIds); + } + return scalar; + } + return expression; + } + if (visitor.fieldInLeft) + { + leftSideExpressions.Add(expression); + var fieldId = newIdCounter; + newIdCounter++; + leftSideIds.Add(fieldId); + return new DirectFieldReference() + { + ReferenceSegment = new StructReferenceSegment() + { + Field = fieldId + } + }; + } + else if (visitor.fieldInRight) + { + rightSideExpressions.Add(expression); + var fieldId = newIdCounter; + newIdCounter++; + rightSideIds.Add(fieldId); + return new DirectFieldReference() + { + ReferenceSegment = new StructReferenceSegment() + { + Field = fieldId + } + }; + } + return expression; + } + + public override Relation VisitJoinRelation(JoinRelation joinRelation, object state) + { + joinRelation.Left = Visit(joinRelation.Left, state); + joinRelation.Right = Visit(joinRelation.Right, state); + + int counter = joinRelation.Left.OutputLength + joinRelation.Right.OutputLength; + int leftSizeBefore = joinRelation.Left.OutputLength; + int rightSizeBefore = joinRelation.Right.OutputLength; + List leftExpressions = new List(); + List rightExpressions = new List(); + List leftSideEmits = new List(); + List rightSideEmits = new List(); + joinRelation.Expression = Check( + joinRelation.Expression, + joinRelation.Left.OutputLength, + joinRelation.Right.OutputLength, + leftExpressions, + rightExpressions, + ref counter, + leftSideEmits, + rightSideEmits); + + if (leftExpressions.Count == 0 && rightExpressions.Count == 0) + { + return joinRelation; + } + + // Create mapping from old emit to new emit + Dictionary oldEmitToNew = new Dictionary(); + for (int i = 0; i < joinRelation.Left.OutputLength; i++) + { + oldEmitToNew.Add(i, i); + } + for (int i = 0; i < leftSideEmits.Count; i++) + { + oldEmitToNew.Add(leftSideEmits[i], oldEmitToNew.Count); + } + for (int i = 0; i < joinRelation.Right.OutputLength; i++) + { + oldEmitToNew.Add(i + joinRelation.Left.OutputLength, oldEmitToNew.Count); + } + for (int i = 0; i < rightSideEmits.Count; i++) + { + oldEmitToNew.Add(rightSideEmits[i], oldEmitToNew.Count); + } + var replaceVisitor = new ExpressionFieldReplaceVisitor(oldEmitToNew); + replaceVisitor.Visit(joinRelation.Expression, default); + + if (joinRelation.EmitSet) + { + for (int i = 0; i < joinRelation.Emit.Count; i++) + { + if (oldEmitToNew.TryGetValue(joinRelation.Emit[i], out var newEmit)) + { + joinRelation.Emit[i] = newEmit; + } + else + { + throw new NotImplementedException("Emit optimizer does not support this case yet"); + } + } + } + else + { + // We must create an emit + List newEmit = new List(); + for (int i = 0; i < leftSizeBefore + rightSizeBefore; i++) + { + if (oldEmitToNew.TryGetValue(i, out var newId)) + { + newEmit.Add(newId); + } + else + { + throw new NotImplementedException("Emit optimizer does not support this case yet"); + } + } + joinRelation.Emit = newEmit; + } + + if (leftExpressions.Count > 0) + { + // FIeld usage on left side does not need to be updated + joinRelation.Left = new ProjectRelation() + { + Expressions = leftExpressions, + Input = joinRelation.Left + }; + } + if (rightExpressions.Count > 0) + { + // Update field id on right side to remove length of left side + Dictionary rightSideOldToNew = new Dictionary(); + for (int i = 0; i < rightSizeBefore; i++) + { + rightSideOldToNew.Add(leftSizeBefore + i, i); + } + replaceVisitor = new ExpressionFieldReplaceVisitor(rightSideOldToNew); + for (int i = 0; i < rightExpressions.Count; i++) + { + replaceVisitor.Visit(rightExpressions[i], default); + } + joinRelation.Right = new ProjectRelation() + { + Expressions = rightExpressions, + Input = joinRelation.Right + }; + } + + + return joinRelation; + } + } +} diff --git a/src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushdown/JoinProjectionPushdown.cs b/src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushdown/JoinProjectionPushdown.cs new file mode 100644 index 000000000..1bcf72b52 --- /dev/null +++ b/src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushdown/JoinProjectionPushdown.cs @@ -0,0 +1,42 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using FlowtideDotNet.Core.Optimizer.FIlterPushdown; +using FlowtideDotNet.Substrait; +using FlowtideDotNet.Substrait.Relations; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace FlowtideDotNet.Core.Optimizer.JoinProjectionPushdown +{ + internal static class JoinProjectionPushdown + { + private static readonly object _emptyObject = new object(); + + public static Plan Optimize(Plan plan) + { + for (int i = 0; i < plan.Relations.Count; i++) + { + var relation = plan.Relations[i]; + var visitor = new JoinProjectionPushDownVisitor(); + relation = visitor.Visit(relation, _emptyObject); + + plan.Relations[i] = relation; + } + + return plan; + } + } +} diff --git a/src/FlowtideDotNet.Core/Optimizer/MergeJoinFindVisitor.cs b/src/FlowtideDotNet.Core/Optimizer/MergeJoinFindVisitor.cs index bedabdd13..a6f380fd9 100644 --- a/src/FlowtideDotNet.Core/Optimizer/MergeJoinFindVisitor.cs +++ b/src/FlowtideDotNet.Core/Optimizer/MergeJoinFindVisitor.cs @@ -62,6 +62,7 @@ public override Relation VisitJoinRelation(JoinRelation joinRelation, object sta { return new MergeJoinRelation() { + Emit = joinRelation.Emit, Left = joinRelation.Left, Right = joinRelation.Right, LeftKeys = new List() @@ -100,6 +101,7 @@ public override Relation VisitJoinRelation(JoinRelation joinRelation, object sta { return new MergeJoinRelation() { + Emit = joinRelation.Emit, Left = joinRelation.Left, Right = joinRelation.Right, LeftKeys = leftKeys, diff --git a/src/FlowtideDotNet.Core/Optimizer/PlanOptimizer.cs b/src/FlowtideDotNet.Core/Optimizer/PlanOptimizer.cs index e4c754461..9dfe04992 100644 --- a/src/FlowtideDotNet.Core/Optimizer/PlanOptimizer.cs +++ b/src/FlowtideDotNet.Core/Optimizer/PlanOptimizer.cs @@ -33,6 +33,15 @@ public static Plan Optimize(Plan plan, PlanOptimizerSettings? settings = null) var filterIntoRead = new FilterIntoReadOptimizer(); relation = filterIntoRead.Visit(relation, null); + plan.Relations[i] = relation; + } + + plan = JoinProjectionPushdown.JoinProjectionPushdown.Optimize(plan); + + for (int i = 0; i < plan.Relations.Count; i++) + { + var relation = plan.Relations[i]; + if (!settings.NoMergeJoin) { var mergeJoinOptimize = new MergeJoinFindVisitor(); diff --git a/tests/FlowtideDotNet.AcceptanceTests/Internal/FlowtideTestStream.cs b/tests/FlowtideDotNet.AcceptanceTests/Internal/FlowtideTestStream.cs index cfd70f826..593759ee1 100644 --- a/tests/FlowtideDotNet.AcceptanceTests/Internal/FlowtideTestStream.cs +++ b/tests/FlowtideDotNet.AcceptanceTests/Internal/FlowtideTestStream.cs @@ -190,6 +190,7 @@ public void AssertCurrentDataEqual(IEnumerable data) Assert.Equal(expectedData.Count, _actualData!.Count); + bool fail = false; for (int i = 0; i < expectedData.Count; i++) { var expectedRow = expectedData[i]; @@ -201,10 +202,30 @@ public void AssertCurrentDataEqual(IEnumerable data) var actualRowJson = FlxValue.FromMemory(actualRow).ToJson; if (!expectedRowJson.Equals(actualRowJson)) { - Assert.Fail($"Expected:{Environment.NewLine}{expectedRowJson}{Environment.NewLine}but got:{Environment.NewLine}{actualRowJson}"); + fail = true; } } } + + if (fail) + { + List expected = new List(); + List actual = new List(); + + for (int i = 0; i < expectedData.Count; i++) + { + var expectedRow = expectedData[i]; + var actualRow = _actualData[i]; + expected.Add(FlxValue.FromMemory(expectedRow).ToJson); + actual.Add(FlxValue.FromMemory(actualRow).ToJson); + } + expected.Sort(); + actual.Sort(); + for (int i = 0; i < expected.Count; i++) + { + Assert.Equal(expected[i], actual[i]); + } + } } public List GetActualRowsAsVectors() diff --git a/tests/FlowtideDotNet.AcceptanceTests/JoinTests.cs b/tests/FlowtideDotNet.AcceptanceTests/JoinTests.cs index 4e252bd30..cf86f25d6 100644 --- a/tests/FlowtideDotNet.AcceptanceTests/JoinTests.cs +++ b/tests/FlowtideDotNet.AcceptanceTests/JoinTests.cs @@ -183,5 +183,29 @@ INNER JOIN users u AssertCurrentDataEqual(Orders.Join(Users, x => x.UserKey, x => x.UserKey, (l, r) => new { l.OrderKey, r.FirstName, r.LastName })); } + + [Fact] + public async Task LeftJoinMergeJoinWithPushdown() + { + GenerateData(100); + await StartStream(@" + INSERT INTO output + SELECT + u.userkey, c.name + FROM users u + LEFT JOIN companies c + ON trim(u.companyid) = trim(c.companyid)"); + await WaitForUpdate(); + + AssertCurrentDataEqual( + from user in Users + join company in Companies on user.CompanyId equals company.CompanyId into gj + from subcompany in gj.DefaultIfEmpty() + select new + { + user.UserKey, + companyName = subcompany?.Name ?? default(string) + }); + } } } diff --git a/tests/FlowtideDotNet.Core.Tests/OptimizerTests/JoinProjectionPushdownTests.cs b/tests/FlowtideDotNet.Core.Tests/OptimizerTests/JoinProjectionPushdownTests.cs new file mode 100644 index 000000000..e4d6ffaf1 --- /dev/null +++ b/tests/FlowtideDotNet.Core.Tests/OptimizerTests/JoinProjectionPushdownTests.cs @@ -0,0 +1,213 @@ +// Licensed under the Apache License, Version 2.0 (the "License") +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using FlowtideDotNet.Core.Optimizer; +using FlowtideDotNet.Core.Optimizer.JoinProjectionPushdown; +using FlowtideDotNet.Substrait; +using FlowtideDotNet.Substrait.Expressions; +using FlowtideDotNet.Substrait.FunctionExtensions; +using FlowtideDotNet.Substrait.Relations; +using FlowtideDotNet.Substrait.Type; +using FluentAssertions; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace FlowtideDotNet.Core.Tests.OptimizerTests +{ + public class JoinProjectionPushdownTests + { + [Fact] + public void TestPushdownTrimOnBothSides() + { + var plan = new Plan() + { + Relations = new List() + { + new JoinRelation() + { + Type = JoinType.Inner, + Expression = new ScalarFunction() + { + ExtensionUri = FunctionsComparison.Uri, + ExtensionName = FunctionsComparison.Equal, + Arguments = new List() + { + new ScalarFunction() + { + ExtensionUri = FunctionsString.Uri, + ExtensionName = FunctionsString.Trim, + Arguments = new List() + { + new DirectFieldReference() + { + ReferenceSegment = new StructReferenceSegment() + { + Field = 0 + } + } + } + }, + new ScalarFunction() + { + ExtensionUri = FunctionsString.Uri, + ExtensionName = FunctionsString.Trim, + Arguments = new List() + { + new DirectFieldReference() + { + ReferenceSegment = new StructReferenceSegment() + { + Field = 1 + } + } + } + } + } + }, + Left = new ReadRelation() + { + BaseSchema = new Substrait.Type.NamedStruct() + { + Names = new List(){ "a" }, + Struct = new Substrait.Type.Struct() { Types = new List() { new AnyType() } } + }, + NamedTable = new Substrait.Type.NamedTable + { + Names = new List(){ "a" } + } + }, + Right = new ReadRelation() + { + BaseSchema = new Substrait.Type.NamedStruct() + { + Names = new List(){ "a" }, + Struct = new Substrait.Type.Struct() { Types = new List() { new AnyType() } } + }, + NamedTable = new Substrait.Type.NamedTable + { + Names = new List(){ "b" } + } + }, + } + } + }; + + plan = JoinProjectionPushdown.Optimize(plan); + + var expected = new Plan() + { + Relations = new List() + { + new JoinRelation() + { + Emit = new List() { 0, 2 }, + Type = JoinType.Inner, + Expression = new ScalarFunction() + { + ExtensionUri = FunctionsComparison.Uri, + ExtensionName = FunctionsComparison.Equal, + Arguments = new List() + { + new DirectFieldReference() + { + ReferenceSegment = new StructReferenceSegment() + { + Field = 1 + } + }, + new DirectFieldReference() + { + ReferenceSegment = new StructReferenceSegment() + { + Field = 3 + } + } + } + }, + Left = new ProjectRelation() + { + Expressions = new List() + { + new ScalarFunction() + { + ExtensionUri = FunctionsString.Uri, + ExtensionName = FunctionsString.Trim, + Arguments = new List() + { + new DirectFieldReference() + { + ReferenceSegment = new StructReferenceSegment() + { + Field = 0 + } + } + } + } + }, + Input = new ReadRelation() + { + BaseSchema = new Substrait.Type.NamedStruct() + { + Names = new List(){ "a" }, + Struct = new Substrait.Type.Struct() { Types = new List() { new AnyType() } } + }, + NamedTable = new Substrait.Type.NamedTable + { + Names = new List(){ "a" } + } + } + }, + Right = new ProjectRelation() + { + Expressions = new List() + { + new ScalarFunction() + { + ExtensionUri = FunctionsString.Uri, + ExtensionName = FunctionsString.Trim, + Arguments = new List() + { + new DirectFieldReference() + { + ReferenceSegment = new StructReferenceSegment() + { + Field = 0 + } + } + } + } + }, + Input = new ReadRelation() + { + BaseSchema = new Substrait.Type.NamedStruct() + { + Names = new List(){ "a" }, + Struct = new Substrait.Type.Struct() { Types = new List() { new AnyType() } } + }, + NamedTable = new Substrait.Type.NamedTable + { + Names = new List(){ "b" } + } + } + } + } + } + }; + + plan.Should().BeEquivalentTo(expected, + opt => opt.AllowingInfiniteRecursion().IncludingNestedObjects().ThrowingOnMissingMembers().RespectingRuntimeTypes()); + } + } +}