Skip to content

Commit

Permalink
Add support for string trim operations (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulimo authored Nov 14, 2023
1 parent 440b6a0 commit b4cef38
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 1 deletion.
33 changes: 33 additions & 0 deletions src/FlowtideDotNet.Core/Compute/Internal/BuiltInStringFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ public static void AddStringFunctions(FunctionsRegister functionsRegister)

functionsRegister.RegisterScalarFunctionWithExpression(FunctionsString.Uri, FunctionsString.Lower, (x) => LowerImplementation(x));
functionsRegister.RegisterScalarFunctionWithExpression(FunctionsString.Uri, FunctionsString.Upper, (x) => UpperImplementation(x));
functionsRegister.RegisterScalarFunctionWithExpression(FunctionsString.Uri, FunctionsString.Trim, (x) => TrimImplementation(x));
functionsRegister.RegisterScalarFunctionWithExpression(FunctionsString.Uri, FunctionsString.LTrim, (x) => LTrimImplementation(x));
functionsRegister.RegisterScalarFunctionWithExpression(FunctionsString.Uri, FunctionsString.RTrim, (x) => RTrimImplementation(x));
}

private static FlxValue LowerImplementation(in FlxValue val)
Expand All @@ -68,5 +71,35 @@ private static FlxValue UpperImplementation(in FlxValue val)

return FlxValue.FromBytes(FlexBuffer.SingleValue(val.AsString.ToUpper()));
}

private static FlxValue TrimImplementation(in FlxValue val)
{
if (val.ValueType != FlexBuffers.Type.String)
{
return NullValue;
}

return FlxValue.FromBytes(FlexBuffer.SingleValue(val.AsString.Trim()));
}

private static FlxValue LTrimImplementation(in FlxValue val)
{
if (val.ValueType != FlexBuffers.Type.String)
{
return NullValue;
}

return FlxValue.FromBytes(FlexBuffer.SingleValue(val.AsString.TrimStart()));
}

private static FlxValue RTrimImplementation(in FlxValue val)
{
if (val.ValueType != FlexBuffers.Type.String)
{
return NullValue;
}

return FlxValue.FromBytes(FlexBuffer.SingleValue(val.AsString.TrimEnd()));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,8 @@ public static class FunctionsString
public const string Concat = "concat";
public const string Lower = "lower";
public const string Upper = "upper";
public const string Trim = "trim";
public const string LTrim = "ltrim";
public const string RTrim = "rtrim";
}
}
9 changes: 9 additions & 0 deletions src/FlowtideDotNet.Substrait/Sql/BaseExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,18 @@ public virtual TReturn Visit(Expression expression, TState state)
{
return VisitInList(inList, state);
}
if (expression is Expression.Trim trim)
{
return VisitTrim(trim, state);
}
throw new NotImplementedException($"The expression '{expression.GetType().Name}' is not supported in SQL");
}

protected virtual TReturn VisitTrim(Expression.Trim trim, TState state)
{
throw new NotImplementedException($"The expression '{trim.GetType().Name}' is not supported in SQL");
}

protected virtual TReturn VisitInList(Expression.InList inList, TState state)
{
throw new NotImplementedException($"The expression '{inList.GetType().Name}' is not supported in SQL");
Expand Down
44 changes: 44 additions & 0 deletions src/FlowtideDotNet.Substrait/Sql/Internal/BuiltInSqlFunctions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,50 @@ public static void AddBuiltInFunctions(SqlFunctionRegister sqlFunctionRegister)
}
});

sqlFunctionRegister.RegisterScalarFunction("ltrim", (f, visitor, emitData) =>
{
if (f.Args == null || f.Args.Count != 1)
{
throw new InvalidOperationException("ltrim must have exactly one argument");
}
if (f.Args[0] is FunctionArg.Unnamed unnamed && unnamed.FunctionArgExpression is FunctionArgExpression.FunctionExpression funcExpr)
{
var expr = visitor.Visit(funcExpr.Expression, emitData);
return new ScalarFunction()
{
ExtensionUri = FunctionsString.Uri,
ExtensionName = FunctionsString.LTrim,
Arguments = new List<Expressions.Expression>() { expr.Expr }
};
}
else
{
throw new NotImplementedException("ltrim does not support the input parameter");
}
});

sqlFunctionRegister.RegisterScalarFunction("rtrim", (f, visitor, emitData) =>
{
if (f.Args == null || f.Args.Count != 1)
{
throw new InvalidOperationException("rtrim must have exactly one argument");
}
if (f.Args[0] is FunctionArg.Unnamed unnamed && unnamed.FunctionArgExpression is FunctionArgExpression.FunctionExpression funcExpr)
{
var expr = visitor.Visit(funcExpr.Expression, emitData);
return new ScalarFunction()
{
ExtensionUri = FunctionsString.Uri,
ExtensionName = FunctionsString.RTrim,
Arguments = new List<Expressions.Expression>() { expr.Expr }
};
}
else
{
throw new NotImplementedException("rtrim does not support the input parameter");
}
});

sqlFunctionRegister.RegisterScalarFunction("strftime", (f, visitor, emitData) =>
{
if (f.Args == null || f.Args.Count != 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,10 @@ protected override bool VisitInList(Expression.InList inList, object state)
}
return containsAggregate;
}

protected override bool VisitTrim(Expression.Trim trim, object state)
{
return Visit(trim.Expression, state);
}
}
}
47 changes: 47 additions & 0 deletions src/FlowtideDotNet.Substrait/Sql/SqlExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,53 @@ protected override ExpressionData VisitBinaryOperation(SqlParser.Ast.Expression.
}
}

protected override ExpressionData VisitTrim(Trim trim, EmitData state)
{
var expr = Visit(trim.Expression, state);

if (trim.TrimWhere == TrimWhereField.Both || trim.TrimWhere == TrimWhereField.None)
{
return new ExpressionData(
new ScalarFunction()
{
ExtensionUri = FunctionsString.Uri,
ExtensionName = FunctionsString.Trim,
Arguments = new List<Expressions.Expression>() { expr.Expr }
},
"$trim"
);
}
else if (trim.TrimWhere == TrimWhereField.Trailing)
{
return new ExpressionData(
new ScalarFunction()
{
ExtensionUri = FunctionsString.Uri,
ExtensionName = FunctionsString.RTrim,
Arguments = new List<Expressions.Expression>() { expr.Expr }
},
"$trim"
);
}
else if (trim.TrimWhere == TrimWhereField.Leading)
{
return new ExpressionData(
new ScalarFunction()
{
ExtensionUri = FunctionsString.Uri,
ExtensionName = FunctionsString.LTrim,
Arguments = new List<Expressions.Expression>() { expr.Expr }
},
"$trim"
);
}
else
{
throw new NotSupportedException();
}

}

protected override ExpressionData VisitCompoundIdentifier(SqlParser.Ast.Expression.CompoundIdentifier compoundIdentifier, EmitData state)
{
var removedQuotaIdentifier = new SqlParser.Ast.Expression.CompoundIdentifier(new Sequence<Ident>(compoundIdentifier.Idents.Select(x => new Ident(x.Value))));
Expand Down
2 changes: 2 additions & 0 deletions tests/FlowtideDotNet.AcceptanceTests/Entities/User.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,7 @@ public class User
public int? Visits { get; set; }

public int? ManagerKey { get; set; }

public string? TrimmableNullableString { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ private void GenerateUsers(int count)
availableManagers.Add(u.UserKey);
return managerKey;
}
});
})
.RuleFor(x => x.TrimmableNullableString, (f, u) => u.NullableString != null ? $" {u.NullableString} " : null);


var newUsers = testUsers.Generate(count);
Expand Down
27 changes: 27 additions & 0 deletions tests/FlowtideDotNet.AcceptanceTests/StringFunctionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,32 @@ public async Task SelectWithUpper()
await WaitForUpdate();
AssertCurrentDataEqual(Users.Select(x => new { Name = x.FirstName.ToUpper() }));
}

[Fact]
public async Task SelectWithTrim()
{
GenerateData();
await StartStream("INSERT INTO output SELECT trim(TrimmableNullableString) as Name FROM users");
await WaitForUpdate();
AssertCurrentDataEqual(Users.Select(x => new { Name = x.TrimmableNullableString?.Trim() }));
}

[Fact]
public async Task SelectWithLTrim()
{
GenerateData();
await StartStream("INSERT INTO output SELECT ltrim(TrimmableNullableString) as Name FROM users");
await WaitForUpdate();
AssertCurrentDataEqual(Users.Select(x => new { Name = x.TrimmableNullableString?.TrimStart() }));
}

[Fact]
public async Task SelectWithRTrim()
{
GenerateData();
await StartStream("INSERT INTO output SELECT rtrim(TrimmableNullableString) as Name FROM users");
await WaitForUpdate();
AssertCurrentDataEqual(Users.Select(x => new { Name = x.TrimmableNullableString?.TrimEnd() }));
}
}
}

0 comments on commit b4cef38

Please sign in to comment.