Skip to content

Commit

Permalink
Push down complex expressions to a projection from join condition (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulimo authored Nov 14, 2023
1 parent b4cef38 commit c5ea1be
Show file tree
Hide file tree
Showing 9 changed files with 541 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,27 @@ public override Relation VisitProjectRelation(ProjectRelation projectRelation, o
// Remap the expression emits also to reflect the changes
Dictionary<int, int> oldToNew = new Dictionary<int, int>();
List<int> emit = new List<int>();

Dictionary<int, int> inputEmitToInternal = new Dictionary<int, int>();
if (input.EmitSet)
{
for (int i = 0; i < input.Emit.Count; i++)
{
inputEmitToInternal.Add(i, input.Emit[i]);
}
}
else
{
for (int i = 0; i < input.OutputLength; i++)
{
inputEmitToInternal.Add(i, i);
}
}

int count = 0;
foreach(var field in usedFields.OrderBy(x => x))
{
emit.Add(field);
emit.Add(inputEmitToInternal[field]);
oldToNew.Add(field, count);
count++;
}
Expand Down
28 changes: 0 additions & 28 deletions src/FlowtideDotNet.Core/Optimizer/JoinProjectionPushDownVisitor.cs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Licensed under the Apache License, Version 2.0 (the "License")
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using FlowtideDotNet.Core.Optimizer.EmitPushdown;
using FlowtideDotNet.Substrait.Expressions;
using FlowtideDotNet.Substrait.Expressions.Literals;
using FlowtideDotNet.Substrait.FunctionExtensions;
using FlowtideDotNet.Substrait.Relations;
using SqlParser;

namespace FlowtideDotNet.Core.Optimizer.JoinProjectionPushdown
{
/// <summary>
/// Finds expressions that contain projections can be pushed down, this helps performance of the join queries.
/// </summary>
internal class JoinProjectionPushDownVisitor : OptimizerBaseVisitor
{
/// <summary>
/// Goes through an expression and tries to find complex expressions that only use left or right side of the data.
/// In that case it can be pushed down to a projection infront of a join
/// </summary>
/// <param name="expression"></param>
/// <param name="leftSize"></param>
/// <param name="rightSize"></param>
/// <param name="leftSideExpressions"></param>
/// <param name="rightSideExpressions"></param>
/// <returns></returns>
private Expression Check(
Expression expression,
int leftSize,
int rightSize,
List<Expression> leftSideExpressions,
List<Expression> rightSideExpressions,
ref int newIdCounter,
List<int> leftSideIds,
List<int> rightSideIds)
{
if (expression is DirectFieldReference)
{
return expression;
}
var visitor = new JoinExpressionVisitor(leftSize);
visitor.Visit(expression, default);

if (visitor.unknownCase || visitor.fieldInLeft && visitor.fieldInRight)
{
if (expression is ScalarFunction scalar)
{
for (int i = 0; i < scalar.Arguments.Count; i++)
{
var arg = scalar.Arguments[i];
scalar.Arguments[i] = Check(arg, leftSize, rightSize, leftSideExpressions, rightSideExpressions, ref newIdCounter, leftSideIds, rightSideIds);
}
return scalar;
}
return expression;
}
if (visitor.fieldInLeft)
{
leftSideExpressions.Add(expression);
var fieldId = newIdCounter;
newIdCounter++;
leftSideIds.Add(fieldId);
return new DirectFieldReference()
{
ReferenceSegment = new StructReferenceSegment()
{
Field = fieldId
}
};
}
else if (visitor.fieldInRight)
{
rightSideExpressions.Add(expression);
var fieldId = newIdCounter;
newIdCounter++;
rightSideIds.Add(fieldId);
return new DirectFieldReference()
{
ReferenceSegment = new StructReferenceSegment()
{
Field = fieldId
}
};
}
return expression;
}

public override Relation VisitJoinRelation(JoinRelation joinRelation, object state)
{
joinRelation.Left = Visit(joinRelation.Left, state);
joinRelation.Right = Visit(joinRelation.Right, state);

int counter = joinRelation.Left.OutputLength + joinRelation.Right.OutputLength;
int leftSizeBefore = joinRelation.Left.OutputLength;
int rightSizeBefore = joinRelation.Right.OutputLength;
List<Expression> leftExpressions = new List<Expression>();
List<Expression> rightExpressions = new List<Expression>();
List<int> leftSideEmits = new List<int>();
List<int> rightSideEmits = new List<int>();
joinRelation.Expression = Check(
joinRelation.Expression,
joinRelation.Left.OutputLength,
joinRelation.Right.OutputLength,
leftExpressions,
rightExpressions,
ref counter,
leftSideEmits,
rightSideEmits);

if (leftExpressions.Count == 0 && rightExpressions.Count == 0)
{
return joinRelation;
}

// Create mapping from old emit to new emit
Dictionary<int, int> oldEmitToNew = new Dictionary<int, int>();
for (int i = 0; i < joinRelation.Left.OutputLength; i++)
{
oldEmitToNew.Add(i, i);
}
for (int i = 0; i < leftSideEmits.Count; i++)
{
oldEmitToNew.Add(leftSideEmits[i], oldEmitToNew.Count);
}
for (int i = 0; i < joinRelation.Right.OutputLength; i++)
{
oldEmitToNew.Add(i + joinRelation.Left.OutputLength, oldEmitToNew.Count);
}
for (int i = 0; i < rightSideEmits.Count; i++)
{
oldEmitToNew.Add(rightSideEmits[i], oldEmitToNew.Count);
}
var replaceVisitor = new ExpressionFieldReplaceVisitor(oldEmitToNew);
replaceVisitor.Visit(joinRelation.Expression, default);

if (joinRelation.EmitSet)
{
for (int i = 0; i < joinRelation.Emit.Count; i++)
{
if (oldEmitToNew.TryGetValue(joinRelation.Emit[i], out var newEmit))
{
joinRelation.Emit[i] = newEmit;
}
else
{
throw new NotImplementedException("Emit optimizer does not support this case yet");
}
}
}
else
{
// We must create an emit
List<int> newEmit = new List<int>();
for (int i = 0; i < leftSizeBefore + rightSizeBefore; i++)
{
if (oldEmitToNew.TryGetValue(i, out var newId))
{
newEmit.Add(newId);
}
else
{
throw new NotImplementedException("Emit optimizer does not support this case yet");
}
}
joinRelation.Emit = newEmit;
}

if (leftExpressions.Count > 0)
{
// FIeld usage on left side does not need to be updated
joinRelation.Left = new ProjectRelation()
{
Expressions = leftExpressions,
Input = joinRelation.Left
};
}
if (rightExpressions.Count > 0)
{
// Update field id on right side to remove length of left side
Dictionary<int, int> rightSideOldToNew = new Dictionary<int, int>();
for (int i = 0; i < rightSizeBefore; i++)
{
rightSideOldToNew.Add(leftSizeBefore + i, i);
}
replaceVisitor = new ExpressionFieldReplaceVisitor(rightSideOldToNew);
for (int i = 0; i < rightExpressions.Count; i++)
{
replaceVisitor.Visit(rightExpressions[i], default);
}
joinRelation.Right = new ProjectRelation()
{
Expressions = rightExpressions,
Input = joinRelation.Right
};
}


return joinRelation;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Licensed under the Apache License, Version 2.0 (the "License")
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

using FlowtideDotNet.Core.Optimizer.FIlterPushdown;
using FlowtideDotNet.Substrait;
using FlowtideDotNet.Substrait.Relations;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace FlowtideDotNet.Core.Optimizer.JoinProjectionPushdown
{
internal static class JoinProjectionPushdown
{
private static readonly object _emptyObject = new object();

public static Plan Optimize(Plan plan)
{
for (int i = 0; i < plan.Relations.Count; i++)
{
var relation = plan.Relations[i];
var visitor = new JoinProjectionPushDownVisitor();
relation = visitor.Visit(relation, _emptyObject);

plan.Relations[i] = relation;
}

return plan;
}
}
}
2 changes: 2 additions & 0 deletions src/FlowtideDotNet.Core/Optimizer/MergeJoinFindVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public override Relation VisitJoinRelation(JoinRelation joinRelation, object sta
{
return new MergeJoinRelation()
{
Emit = joinRelation.Emit,
Left = joinRelation.Left,
Right = joinRelation.Right,
LeftKeys = new List<FieldReference>()
Expand Down Expand Up @@ -100,6 +101,7 @@ public override Relation VisitJoinRelation(JoinRelation joinRelation, object sta
{
return new MergeJoinRelation()
{
Emit = joinRelation.Emit,
Left = joinRelation.Left,
Right = joinRelation.Right,
LeftKeys = leftKeys,
Expand Down
9 changes: 9 additions & 0 deletions src/FlowtideDotNet.Core/Optimizer/PlanOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ public static Plan Optimize(Plan plan, PlanOptimizerSettings? settings = null)
var filterIntoRead = new FilterIntoReadOptimizer();
relation = filterIntoRead.Visit(relation, null);

plan.Relations[i] = relation;
}

plan = JoinProjectionPushdown.JoinProjectionPushdown.Optimize(plan);

for (int i = 0; i < plan.Relations.Count; i++)
{
var relation = plan.Relations[i];

if (!settings.NoMergeJoin)
{
var mergeJoinOptimize = new MergeJoinFindVisitor();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ public void AssertCurrentDataEqual<T>(IEnumerable<T> data)

Assert.Equal(expectedData.Count, _actualData!.Count);

bool fail = false;
for (int i = 0; i < expectedData.Count; i++)
{
var expectedRow = expectedData[i];
Expand All @@ -201,10 +202,30 @@ public void AssertCurrentDataEqual<T>(IEnumerable<T> data)
var actualRowJson = FlxValue.FromMemory(actualRow).ToJson;
if (!expectedRowJson.Equals(actualRowJson))
{
Assert.Fail($"Expected:{Environment.NewLine}{expectedRowJson}{Environment.NewLine}but got:{Environment.NewLine}{actualRowJson}");
fail = true;
}
}
}

if (fail)
{
List<string> expected = new List<string>();
List<string> actual = new List<string>();

for (int i = 0; i < expectedData.Count; i++)
{
var expectedRow = expectedData[i];
var actualRow = _actualData[i];
expected.Add(FlxValue.FromMemory(expectedRow).ToJson);
actual.Add(FlxValue.FromMemory(actualRow).ToJson);
}
expected.Sort();
actual.Sort();
for (int i = 0; i < expected.Count; i++)
{
Assert.Equal(expected[i], actual[i]);
}
}
}

public List<FlxVector> GetActualRowsAsVectors()
Expand Down
Loading

0 comments on commit c5ea1be

Please sign in to comment.