diff --git a/src/FlowtideDotNet.Core/Compute/Internal/BuiltInStringFunctions.cs b/src/FlowtideDotNet.Core/Compute/Internal/BuiltInStringFunctions.cs index 4f39f91bf..ce975c0a3 100644 --- a/src/FlowtideDotNet.Core/Compute/Internal/BuiltInStringFunctions.cs +++ b/src/FlowtideDotNet.Core/Compute/Internal/BuiltInStringFunctions.cs @@ -47,6 +47,9 @@ public static void AddStringFunctions(FunctionsRegister functionsRegister) functionsRegister.RegisterScalarFunctionWithExpression(FunctionsString.Uri, FunctionsString.Lower, (x) => LowerImplementation(x)); functionsRegister.RegisterScalarFunctionWithExpression(FunctionsString.Uri, FunctionsString.Upper, (x) => UpperImplementation(x)); + functionsRegister.RegisterScalarFunctionWithExpression(FunctionsString.Uri, FunctionsString.Trim, (x) => TrimImplementation(x)); + functionsRegister.RegisterScalarFunctionWithExpression(FunctionsString.Uri, FunctionsString.LTrim, (x) => LTrimImplementation(x)); + functionsRegister.RegisterScalarFunctionWithExpression(FunctionsString.Uri, FunctionsString.RTrim, (x) => RTrimImplementation(x)); } private static FlxValue LowerImplementation(in FlxValue val) @@ -68,5 +71,35 @@ private static FlxValue UpperImplementation(in FlxValue val) return FlxValue.FromBytes(FlexBuffer.SingleValue(val.AsString.ToUpper())); } + + private static FlxValue TrimImplementation(in FlxValue val) + { + if (val.ValueType != FlexBuffers.Type.String) + { + return NullValue; + } + + return FlxValue.FromBytes(FlexBuffer.SingleValue(val.AsString.Trim())); + } + + private static FlxValue LTrimImplementation(in FlxValue val) + { + if (val.ValueType != FlexBuffers.Type.String) + { + return NullValue; + } + + return FlxValue.FromBytes(FlexBuffer.SingleValue(val.AsString.TrimStart())); + } + + private static FlxValue RTrimImplementation(in FlxValue val) + { + if (val.ValueType != FlexBuffers.Type.String) + { + return NullValue; + } + + return FlxValue.FromBytes(FlexBuffer.SingleValue(val.AsString.TrimEnd())); + } } } diff --git a/src/FlowtideDotNet.Substrait/FunctionExtensions/FunctionsString.cs b/src/FlowtideDotNet.Substrait/FunctionExtensions/FunctionsString.cs index 6e5ed975f..513de7a23 100644 --- a/src/FlowtideDotNet.Substrait/FunctionExtensions/FunctionsString.cs +++ b/src/FlowtideDotNet.Substrait/FunctionExtensions/FunctionsString.cs @@ -24,5 +24,8 @@ public static class FunctionsString public const string Concat = "concat"; public const string Lower = "lower"; public const string Upper = "upper"; + public const string Trim = "trim"; + public const string LTrim = "ltrim"; + public const string RTrim = "rtrim"; } } diff --git a/src/FlowtideDotNet.Substrait/Sql/BaseExpressionVisitor.cs b/src/FlowtideDotNet.Substrait/Sql/BaseExpressionVisitor.cs index 88854584e..7546a95ab 100644 --- a/src/FlowtideDotNet.Substrait/Sql/BaseExpressionVisitor.cs +++ b/src/FlowtideDotNet.Substrait/Sql/BaseExpressionVisitor.cs @@ -73,9 +73,18 @@ public virtual TReturn Visit(Expression expression, TState state) { return VisitInList(inList, state); } + if (expression is Expression.Trim trim) + { + return VisitTrim(trim, state); + } throw new NotImplementedException($"The expression '{expression.GetType().Name}' is not supported in SQL"); } + protected virtual TReturn VisitTrim(Expression.Trim trim, TState state) + { + throw new NotImplementedException($"The expression '{trim.GetType().Name}' is not supported in SQL"); + } + protected virtual TReturn VisitInList(Expression.InList inList, TState state) { throw new NotImplementedException($"The expression '{inList.GetType().Name}' is not supported in SQL"); diff --git a/src/FlowtideDotNet.Substrait/Sql/Internal/BuiltInSqlFunctions.cs b/src/FlowtideDotNet.Substrait/Sql/Internal/BuiltInSqlFunctions.cs index d8b97fc80..0e3cb2bf3 100644 --- a/src/FlowtideDotNet.Substrait/Sql/Internal/BuiltInSqlFunctions.cs +++ b/src/FlowtideDotNet.Substrait/Sql/Internal/BuiltInSqlFunctions.cs @@ -219,6 +219,50 @@ public static void AddBuiltInFunctions(SqlFunctionRegister sqlFunctionRegister) } }); + sqlFunctionRegister.RegisterScalarFunction("ltrim", (f, visitor, emitData) => + { + if (f.Args == null || f.Args.Count != 1) + { + throw new InvalidOperationException("ltrim must have exactly one argument"); + } + if (f.Args[0] is FunctionArg.Unnamed unnamed && unnamed.FunctionArgExpression is FunctionArgExpression.FunctionExpression funcExpr) + { + var expr = visitor.Visit(funcExpr.Expression, emitData); + return new ScalarFunction() + { + ExtensionUri = FunctionsString.Uri, + ExtensionName = FunctionsString.LTrim, + Arguments = new List() { expr.Expr } + }; + } + else + { + throw new NotImplementedException("ltrim does not support the input parameter"); + } + }); + + sqlFunctionRegister.RegisterScalarFunction("rtrim", (f, visitor, emitData) => + { + if (f.Args == null || f.Args.Count != 1) + { + throw new InvalidOperationException("rtrim must have exactly one argument"); + } + if (f.Args[0] is FunctionArg.Unnamed unnamed && unnamed.FunctionArgExpression is FunctionArgExpression.FunctionExpression funcExpr) + { + var expr = visitor.Visit(funcExpr.Expression, emitData); + return new ScalarFunction() + { + ExtensionUri = FunctionsString.Uri, + ExtensionName = FunctionsString.RTrim, + Arguments = new List() { expr.Expr } + }; + } + else + { + throw new NotImplementedException("rtrim does not support the input parameter"); + } + }); + sqlFunctionRegister.RegisterScalarFunction("strftime", (f, visitor, emitData) => { if (f.Args == null || f.Args.Count != 2) diff --git a/src/FlowtideDotNet.Substrait/Sql/Internal/ContainsAggregateVisitor.cs b/src/FlowtideDotNet.Substrait/Sql/Internal/ContainsAggregateVisitor.cs index 4fc964ff4..d367fc51b 100644 --- a/src/FlowtideDotNet.Substrait/Sql/Internal/ContainsAggregateVisitor.cs +++ b/src/FlowtideDotNet.Substrait/Sql/Internal/ContainsAggregateVisitor.cs @@ -154,5 +154,10 @@ protected override bool VisitInList(Expression.InList inList, object state) } return containsAggregate; } + + protected override bool VisitTrim(Expression.Trim trim, object state) + { + return Visit(trim.Expression, state); + } } } diff --git a/src/FlowtideDotNet.Substrait/Sql/SqlExpressionVisitor.cs b/src/FlowtideDotNet.Substrait/Sql/SqlExpressionVisitor.cs index 866817735..12d794d3f 100644 --- a/src/FlowtideDotNet.Substrait/Sql/SqlExpressionVisitor.cs +++ b/src/FlowtideDotNet.Substrait/Sql/SqlExpressionVisitor.cs @@ -263,6 +263,53 @@ protected override ExpressionData VisitBinaryOperation(SqlParser.Ast.Expression. } } + protected override ExpressionData VisitTrim(Trim trim, EmitData state) + { + var expr = Visit(trim.Expression, state); + + if (trim.TrimWhere == TrimWhereField.Both || trim.TrimWhere == TrimWhereField.None) + { + return new ExpressionData( + new ScalarFunction() + { + ExtensionUri = FunctionsString.Uri, + ExtensionName = FunctionsString.Trim, + Arguments = new List() { expr.Expr } + }, + "$trim" + ); + } + else if (trim.TrimWhere == TrimWhereField.Trailing) + { + return new ExpressionData( + new ScalarFunction() + { + ExtensionUri = FunctionsString.Uri, + ExtensionName = FunctionsString.RTrim, + Arguments = new List() { expr.Expr } + }, + "$trim" + ); + } + else if (trim.TrimWhere == TrimWhereField.Leading) + { + return new ExpressionData( + new ScalarFunction() + { + ExtensionUri = FunctionsString.Uri, + ExtensionName = FunctionsString.LTrim, + Arguments = new List() { expr.Expr } + }, + "$trim" + ); + } + else + { + throw new NotSupportedException(); + } + + } + protected override ExpressionData VisitCompoundIdentifier(SqlParser.Ast.Expression.CompoundIdentifier compoundIdentifier, EmitData state) { var removedQuotaIdentifier = new SqlParser.Ast.Expression.CompoundIdentifier(new Sequence(compoundIdentifier.Idents.Select(x => new Ident(x.Value)))); diff --git a/tests/FlowtideDotNet.AcceptanceTests/Entities/User.cs b/tests/FlowtideDotNet.AcceptanceTests/Entities/User.cs index 8646fdeee..509f3ac46 100644 --- a/tests/FlowtideDotNet.AcceptanceTests/Entities/User.cs +++ b/tests/FlowtideDotNet.AcceptanceTests/Entities/User.cs @@ -37,5 +37,7 @@ public class User public int? Visits { get; set; } public int? ManagerKey { get; set; } + + public string? TrimmableNullableString { get; set; } } } diff --git a/tests/FlowtideDotNet.AcceptanceTests/Internal/DatasetGenerator.cs b/tests/FlowtideDotNet.AcceptanceTests/Internal/DatasetGenerator.cs index 39bdefb53..fceeea0b9 100644 --- a/tests/FlowtideDotNet.AcceptanceTests/Internal/DatasetGenerator.cs +++ b/tests/FlowtideDotNet.AcceptanceTests/Internal/DatasetGenerator.cs @@ -94,7 +94,8 @@ private void GenerateUsers(int count) availableManagers.Add(u.UserKey); return managerKey; } - }); + }) + .RuleFor(x => x.TrimmableNullableString, (f, u) => u.NullableString != null ? $" {u.NullableString} " : null); var newUsers = testUsers.Generate(count); diff --git a/tests/FlowtideDotNet.AcceptanceTests/StringFunctionTests.cs b/tests/FlowtideDotNet.AcceptanceTests/StringFunctionTests.cs index bfc34c84a..1dfb212bb 100644 --- a/tests/FlowtideDotNet.AcceptanceTests/StringFunctionTests.cs +++ b/tests/FlowtideDotNet.AcceptanceTests/StringFunctionTests.cs @@ -60,5 +60,32 @@ public async Task SelectWithUpper() await WaitForUpdate(); AssertCurrentDataEqual(Users.Select(x => new { Name = x.FirstName.ToUpper() })); } + + [Fact] + public async Task SelectWithTrim() + { + GenerateData(); + await StartStream("INSERT INTO output SELECT trim(TrimmableNullableString) as Name FROM users"); + await WaitForUpdate(); + AssertCurrentDataEqual(Users.Select(x => new { Name = x.TrimmableNullableString?.Trim() })); + } + + [Fact] + public async Task SelectWithLTrim() + { + GenerateData(); + await StartStream("INSERT INTO output SELECT ltrim(TrimmableNullableString) as Name FROM users"); + await WaitForUpdate(); + AssertCurrentDataEqual(Users.Select(x => new { Name = x.TrimmableNullableString?.TrimStart() })); + } + + [Fact] + public async Task SelectWithRTrim() + { + GenerateData(); + await StartStream("INSERT INTO output SELECT rtrim(TrimmableNullableString) as Name FROM users"); + await WaitForUpdate(); + AssertCurrentDataEqual(Users.Select(x => new { Name = x.TrimmableNullableString?.TrimEnd() })); + } } }