From 96e686af9bf540567408bcd67e250f6167b631c0 Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Mon, 5 Aug 2024 19:14:47 +0300 Subject: [PATCH 01/58] initial commit --- cpp/src/arrow/flight/sql/client.h | 2 +- ...he.Arrow.Flight.Sql.IntegrationTest.csproj | 15 + .../Program.cs | 304 ++++++++ csharp/Apache.Arrow.sln | 6 + csharp/examples/Examples.sln | 6 + .../Client/FlightSqlClient.cs | 719 ++++++++++++++++++ .../Apache.Arrow.Flight.Sql/DoPutResult.cs | 16 + .../FlightCallOptions.cs | 10 + .../PreparedStatement.cs | 24 + .../src/Apache.Arrow.Flight.Sql/Savepoint.cs | 13 + .../src/Apache.Arrow.Flight.Sql/TableRef.cs | 8 + .../Apache.Arrow.Flight.Sql/Transaction.cs | 8 + .../Ipc/ArrowTypeFlatbufferBuilder.cs | 2 +- 13 files changed, 1131 insertions(+), 2 deletions(-) create mode 100644 csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Apache.Arrow.Flight.Sql.IntegrationTest.csproj create mode 100644 csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Savepoint.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index c37c640e653a4..7c4807f882a72 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -434,7 +434,7 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { PreparedStatement(FlightSqlClient* client, std::string handle, std::shared_ptr dataset_schema, std::shared_ptr parameter_schema); - + /// \brief Default destructor for the PreparedStatement class. /// The destructor will call the Close method from the class in order, /// to send a request to close the PreparedStatement. diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Apache.Arrow.Flight.Sql.IntegrationTest.csproj b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Apache.Arrow.Flight.Sql.IntegrationTest.csproj new file mode 100644 index 0000000000000..0c5b923d72880 --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Apache.Arrow.Flight.Sql.IntegrationTest.csproj @@ -0,0 +1,15 @@ + + + + Exe + net6.0 + enable + enable + + + + + + + + diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs new file mode 100644 index 0000000000000..83f3021d12882 --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs @@ -0,0 +1,304 @@ +/*using Apache.Arrow; +using Apache.Arrow.Flight; +using Apache.Arrow.Flight.Client; +using Grpc.Core; +using Grpc.Net.Client; + +namespace FlightClientExample +{ + public class Program + { + public static async Task Main(string[] args) + { + string host = args.Length > 0 ? args[0] : "localhost"; + string port = args.Length > 1 ? args[1] : "5000"; + + // Create client + // (In production systems, you should use https not http) + var address = $"http://{host}:{port}"; + Console.WriteLine($"Connecting to: {address}"); + var channel = GrpcChannel.ForAddress(address); + var client = new FlightClient(channel); + + var recordBatches = new[] { CreateTestBatch(0, 2000), CreateTestBatch(50, 9000) }; + + // Particular flights are identified by a descriptor. This might be a name, + // a SQL query, or a path. Here, just using the name "test". + var descriptor = FlightDescriptor.CreatePathDescriptor("//SYSDB/Info"); //.CreateCommandDescriptor("SELECT * FROM SYSDB.`Info` "); + + // Upload data with StartPut + // var batchStreamingCall = client.StartPut(descriptor); + // foreach (var batch in recordBatches) + // { + // await batchStreamingCall.RequestStream.WriteAsync(batch); + // } + // + // // Signal we are done sending record batches + // await batchStreamingCall.RequestStream.CompleteAsync(); + // // Retrieve final response + // await batchStreamingCall.ResponseStream.MoveNext(); + // Console.WriteLine(batchStreamingCall.ResponseStream.Current.ApplicationMetadata.ToStringUtf8()); + // Console.WriteLine($"Wrote {recordBatches.Length} batches to server."); + + // Request information: + //var schema = await client.GetSchema(descriptor).ResponseAsync; + //Console.WriteLine($"Schema saved as: \n {schema}"); + + var info = await client.GetInfo(descriptor).ResponseAsync; + Console.WriteLine($"Info provided: \n {info.TotalRecords}"); + + Console.WriteLine($"Available flights:"); + // var flights_call = client.ListFlights(); + // + // while (await flights_call.ResponseStream.MoveNext()) + // { + // Console.WriteLine(" " + flights_call.ResponseStream.Current); + // } + + // // Download data + // await foreach (var batch in StreamRecordBatches(info)) + // { + // Console.WriteLine($"Read batch from flight server: \n {batch}"); + // } + + // See available commands on this server + // var action_stream = client.ListActions(); + // Console.WriteLine("Actions:"); + // while (await action_stream.ResponseStream.MoveNext()) + // { + // var action = action_stream.ResponseStream.Current; + // Console.WriteLine($" {action.Type}: {action.Description}"); + // } + // + // // Send clear command to drop all data from the server. + // var clear_result = client.DoAction(new FlightAction("clear")); + // await clear_result.ResponseStream.MoveNext(default); + } + + public static async IAsyncEnumerable StreamRecordBatches( + FlightInfo info + ) + { + // There might be multiple endpoints hosting part of the data. In simple services, + // the only endpoint might be the same server we initially queried. + foreach (var endpoint in info.Endpoints) + { + // We may have multiple locations to choose from. Here we choose the first. + var download_channel = GrpcChannel.ForAddress(endpoint.Locations.First().Uri); + var download_client = new FlightClient(download_channel); + + var stream = download_client.GetStream(endpoint.Ticket); + + while (await stream.ResponseStream.MoveNext()) + { + yield return stream.ResponseStream.Current; + } + } + } + + public static RecordBatch CreateTestBatch(int start, int length) + { + return new RecordBatch.Builder() + .Append("Column A", false, + col => col.Int32(array => array.AppendRange(Enumerable.Range(start, start + length)))) + .Append("Column B", false, + col => col.Float(array => + array.AppendRange(Enumerable.Range(start, start + length) + .Select(x => Convert.ToSingle(x * 2))))) + .Append("Column C", false, + col => col.String(array => + array.AppendRange(Enumerable.Range(start, start + length).Select(x => $"Item {x + 1}")))) + .Append("Column D", false, + col => col.Boolean(array => + array.AppendRange(Enumerable.Range(start, start + length).Select(x => x % 2 == 0)))) + .Build(); + } + } +}*/ + +using Apache.Arrow.Flight.Client; +using Apache.Arrow.Flight.Sql.Client; +using Apache.Arrow.Types; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; +using Grpc.Net.Client; + +namespace Apache.Arrow.Flight.Sql.IntegrationTest; + +class Program +{ + static async Task Main(string[] args) + { + var httpHandler = new SocketsHttpHandler + { + PooledConnectionIdleTimeout = TimeSpan.FromMinutes(1), + KeepAlivePingDelay = TimeSpan.FromSeconds(60), + KeepAlivePingTimeout = TimeSpan.FromSeconds(30), + EnableMultipleHttp2Connections = true + }; + // Initialize the gRPC channel to connect to the Flight server + using var channel = GrpcChannel.ForAddress("http://localhost:5000", + new GrpcChannelOptions { HttpHandler = httpHandler, Credentials = ChannelCredentials.Insecure }); + + // Initialize the Flight client + var flightClient = new FlightClient(channel); + var sqlClient = new FlightSqlClient(flightClient); + + // Define the SQL query + string query = "SELECT * FROM SYSDB.`Info`"; + + try + { + // ExecuteAsync + Console.WriteLine("ExecuteAsync:"); + var flightInfo = await sqlClient.ExecuteAsync(new FlightCallOptions(), query); + // Handle the ExecuteAsync result + Console.WriteLine($@"Query executed successfully. Records count: {flightInfo.TotalRecords}"); + + // ExecuteUpdate + Console.WriteLine("ExecuteUpdate:"); + string updateQuery = "UPDATE SYSDB.`Info` SET Key = 1, Val=10 WHERE Id=1"; + long affectedRows = await sqlClient.ExecuteUpdateAsync(new FlightCallOptions(), updateQuery); + // Handle the ExecuteUpdate result + Console.WriteLine($@"Number of affected d rows: {affectedRows}"); + + // GetExecuteSchema + Console.WriteLine("GetExecuteSchema:"); + var schemaResult = await sqlClient.GetExecuteSchemaAsync(new FlightCallOptions(), query); + // Process the schemaResult as needed + Console.WriteLine($"Schema retrieved successfully:{schemaResult}"); + + // ExecuteIngest + + // GetCatalogs + Console.WriteLine("GetCatalogs:"); + var catalogsInfo = await sqlClient.GetCatalogs(new FlightCallOptions()); + // Print catalog details + Console.WriteLine("Catalogs retrieved:"); + foreach (var endpoint in catalogsInfo.Endpoints) + { + var ticket = endpoint.Ticket; + Console.WriteLine($"- Ticket: {ticket}"); + } + + // GetCatalogsSchema + // Console.WriteLine("GetCatalogsSchema:"); + // Schema schemaCatalogResult = await sqlClient.GetCatalogsSchema(new FlightCallOptions()); + // // Print schema details + // Console.WriteLine("Catalogs Schema retrieved:"); + // Console.WriteLine(schemaCatalogResult); + + // GetDbSchemasAsync + // Console.WriteLine("GetDbSchemasAsync:"); + // FlightInfo flightInfoDbSchemas = + // await sqlClient.GetDbSchemasAsync(new FlightCallOptions(), "default_catalog", "public"); + // // Process the FlightInfoDbSchemas + // Console.WriteLine("Database schemas retrieved:"); + // Console.WriteLine(flightInfoDbSchemas); + + + // GetDbSchemasSchemaAsync + // Console.WriteLine("GetDbSchemasSchemaAsync:"); + // Schema schema = await sqlClient.GetDbSchemasSchemaAsync(new FlightCallOptions()); + // // Process the Schema + // Console.WriteLine("Database schemas schema retrieved:"); + // Console.WriteLine(schema); + + // DoPut + // Console.WriteLine("DoPut:"); + // await PutExample(sqlClient, query); + + // GetPrimaryKeys + // Console.WriteLine("GetPrimaryKeys:"); + // var tableRef = new TableRef + // { + // DbSchema = "SYSDB", + // Table = "Info" + // }; + // var getPrimaryKeysInfo = await sqlClient.GetPrimaryKeys(new FlightCallOptions(), tableRef); + // Console.WriteLine("Primary keys information retrieved successfully."); + + + // Call GetTablesAsync method + IEnumerable tables = await sqlClient.GetTablesAsync( + new FlightCallOptions(), + catalog: "", + dbSchemaFilterPattern: "public", + tableFilterPattern: "SYSDB", + includeSchema: true, + tableTypes: new List { "TABLE", "VIEW" }); + + // Process and print the results + foreach (var table in tables) + { + Console.WriteLine($"Table URI: {table.Descriptor.Paths}"); + foreach (var endpoint in table.Endpoints) + { + Console.WriteLine($"Endpoint Ticket: {endpoint.Ticket}"); + } + } + } + catch (Exception ex) + { + Console.WriteLine($"Error executing query: {ex.Message}"); + } + + } + + static async Task PutExample(FlightSqlClient client, string query) + { + // TODO: Talk with Jeremy about the implementation: DoPut - seems that needed to resolve missing part + var options = new FlightCallOptions(); + var body = new ActionCreatePreparedStatementRequest { Query = query }.PackAndSerialize(); + var action = new FlightAction(SqlAction.CreateRequest, body); + await foreach (FlightResult flightResult in client.DoActionAsync(action)) + { + var preparedStatementResponse = + FlightSqlUtils.ParseAndUnpack(flightResult.Body); + + var command = new CommandPreparedStatementUpdate + { + PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + // Define schema + var fields = new List + { + new("id", Int32Type.Default, nullable: false), new("name", StringType.Default, nullable: false) + }; + var metadata = new List> { new("db_name", "SYSDB"), new("table_name", "Info") }; + var schema = new Schema(fields, metadata); + var doPutResult = await client.DoPut(options, descriptor, schema).ConfigureAwait(false); + + // Example data to write + var col1 = new Int32Array.Builder().AppendRange(new[] { 8, 9, 10, 11 }).Build(); + var col2 = new StringArray.Builder().AppendRange(new[] { "a", "b", "c", "d" }).Build(); + var col3 = new StringArray.Builder().AppendRange(new[] { "x", "y", "z", "q" }).Build(); + var batch = new RecordBatch(schema, new IArrowArray[] { col1, col2, col3 }, 4); + + await doPutResult.Writer.WriteAsync(batch); + await doPutResult.Writer.CompleteAsync(); + + // Handle metadata response (if any) + while (await doPutResult.Reader.MoveNext()) + { + var receivedMetadata = doPutResult.Reader.Current.ApplicationMetadata; + if (receivedMetadata != null) + { + Console.WriteLine("Received metadata: " + receivedMetadata.ToStringUtf8()); + } + } + } + } +} + +internal static class FlightDescriptorExtensions +{ + public static byte[] PackAndSerialize(this IMessage command) + { + return Any.Pack(command).Serialize().ToByteArray(); + } +} diff --git a/csharp/Apache.Arrow.sln b/csharp/Apache.Arrow.sln index 7e7f7c6331e88..1de7202780060 100644 --- a/csharp/Apache.Arrow.sln +++ b/csharp/Apache.Arrow.sln @@ -27,6 +27,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.Tes EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql", "src\Apache.Arrow.Flight.Sql\Apache.Arrow.Flight.Sql.csproj", "{2ADE087A-B424-4895-8CC5-10170D10BA62}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.IntegrationTest", "Apache.Arrow.Flight.Sql.IntegrationTest\Apache.Arrow.Flight.Sql.IntegrationTest.csproj", "{45416D7D-F12B-4524-B641-AD0E1A33B3B0}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -81,6 +83,10 @@ Global {2ADE087A-B424-4895-8CC5-10170D10BA62}.Debug|Any CPU.Build.0 = Debug|Any CPU {2ADE087A-B424-4895-8CC5-10170D10BA62}.Release|Any CPU.ActiveCfg = Release|Any CPU {2ADE087A-B424-4895-8CC5-10170D10BA62}.Release|Any CPU.Build.0 = Release|Any CPU + {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/csharp/examples/Examples.sln b/csharp/examples/Examples.sln index c0a4199ca5605..1485c9bfbf3b1 100644 --- a/csharp/examples/Examples.sln +++ b/csharp/examples/Examples.sln @@ -7,6 +7,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "FluentBuilderExample", "Flu EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow", "..\src\Apache.Arrow\Apache.Arrow.csproj", "{1FE1DE95-FF6E-4895-82E7-909713C53524}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "FlightClientExample", "FlightClientExample\FlightClientExample.csproj", "{CBB46C39-530D-465A-9367-4E771595209A}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -21,6 +23,10 @@ Global {1FE1DE95-FF6E-4895-82E7-909713C53524}.Debug|Any CPU.Build.0 = Debug|Any CPU {1FE1DE95-FF6E-4895-82E7-909713C53524}.Release|Any CPU.ActiveCfg = Release|Any CPU {1FE1DE95-FF6E-4895-82E7-909713C53524}.Release|Any CPU.Build.0 = Release|Any CPU + {CBB46C39-530D-465A-9367-4E771595209A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {CBB46C39-530D-465A-9367-4E771595209A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {CBB46C39-530D-465A-9367-4E771595209A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {CBB46C39-530D-465A-9367-4E771595209A}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs new file mode 100644 index 0000000000000..68ddc91481490 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -0,0 +1,719 @@ +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Client; +using Grpc.Core; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using Arrow.Flight.Protocol.Sql; + +namespace Apache.Arrow.Flight.Sql.Client; + +public class FlightSqlClient +{ + private readonly FlightClient _client; + + public FlightSqlClient(FlightClient client) + { + _client = client ?? throw new ArgumentNullException(nameof(client)); + } + + public static Transaction NoTransaction() => new(null); + + /// + /// Execute a SQL query on the server. + /// + /// RPC-layer hints for this call. + /// The UTF8-encoded SQL query to be executed. + /// A transaction to associate this query with. + /// The FlightInfo describing where to access the dataset. + public async Task ExecuteAsync(FlightCallOptions options, string query, Transaction? transaction = null) + { + // todo: return FlightInfo + transaction ??= NoTransaction(); + + FlightInfo? flightInfo = null; + + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (string.IsNullOrEmpty(query)) + { + throw new ArgumentException($"Query cannot be null or empty: {nameof(query)}"); + } + + try + { + Console.WriteLine($@"Executing query: {query}"); + var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query }; + var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); + var call = _client.DoAction(action, options.Headers); + + // Process the response + await foreach (var result in call.ResponseStream.ReadAllAsync()) + { + var preparedStatementResponse = + FlightSqlUtils.ParseAndUnpack(result.Body); + var commandSqlCall = new CommandPreparedStatementQuery + { + PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle + }; + byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); + flightInfo = await GetFlightInfoAsync(options, descriptor); + var doGetResult = DoGetAsync(options, flightInfo.Endpoints[0].Ticket); + await foreach (var recordBatch in doGetResult) + { + Console.WriteLine(recordBatch); + } + } + + return flightInfo!; + } + catch (RpcException ex) + { + // Handle gRPC exceptions + Console.WriteLine($@"gRPC Error: {ex.Status}"); + throw new InvalidOperationException("Failed to execute query", ex); + } + catch (Exception ex) + { + // Handle other exceptions + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Executes an update query on the server. + /// + /// RPC-layer hints for this call. + /// The UTF8-encoded SQL query to be executed. + /// A transaction to associate this query with. Defaults to no transaction if not provided. + /// The number of rows affected by the operation. + public async Task ExecuteUpdateAsync(FlightCallOptions options, string query, Transaction? transaction = null) + { + transaction ??= NoTransaction(); + + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (string.IsNullOrEmpty(query)) + { + throw new ArgumentException("Query cannot be null or empty", nameof(query)); + } + + try + { + // Step 1: Create statement query + Console.WriteLine($@"Executing query: {query}"); + var updateRequestCommand = new ActionCreatePreparedStatementRequest { Query = query }; + byte[] serializedUpdateRequestCommand = updateRequestCommand.PackAndSerialize(); + var action = new FlightAction(SqlAction.CreateRequest, serializedUpdateRequestCommand); + var call = _client.DoAction(action, options.Headers); + long affectedRows = 0; + await foreach (var result in call.ResponseStream.ReadAllAsync()) + { + var preparedStatementResponse = + FlightSqlUtils.ParseAndUnpack(result.Body); + + var command = new CommandPreparedStatementQuery + { + PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + + var doGetResult = DoGetAsync(options, flightInfo.Endpoints[0].Ticket); + await foreach (var recordBatch in doGetResult) + { + Console.WriteLine(recordBatch); + Interlocked.Increment(ref affectedRows); + } + } + + return affectedRows; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to execute update query", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Asynchronously retrieves flight information for a given flight descriptor. + /// + /// RPC-layer hints for this call. + /// The descriptor of the dataset request, whether a named dataset or a command. + /// A task that represents the asynchronous operation. The task result contains the FlightInfo describing where to access the dataset. + public async Task GetFlightInfoAsync(FlightCallOptions options, FlightDescriptor descriptor) + { + if (descriptor is null) + { + throw new ArgumentNullException(nameof(descriptor)); + } + + try + { + var flightInfoCall = _client.GetInfo(descriptor, options.Headers); + var flightInfo = await flightInfoCall.ResponseAsync.ConfigureAwait(false); + return flightInfo; + } + catch (RpcException ex) + { + // Handle gRPC exceptions + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get flight info", ex); + } + catch (Exception ex) + { + // Handle other exceptions + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Asynchronously retrieves flight information for a given flight descriptor. + /// + /// The descriptor of the dataset request, whether a named dataset or a command. + /// A task that represents the asynchronous operation. The task result contains the FlightInfo describing where to access the dataset. + public Task GetFlightInfoAsync(FlightDescriptor descriptor) + { + var options = new FlightCallOptions(); + return GetFlightInfoAsync(options, descriptor); + } + + /// + /// Perform the indicated action, returning an iterator to the stream of results, if any. + /// + /// Per-RPC options + /// The action to be performed + /// An async enumerable of results + public async IAsyncEnumerable DoActionAsync(FlightCallOptions options, FlightAction action) + { + if (options is null) + throw new ArgumentNullException(nameof(options)); + + if (action is null) + throw new ArgumentNullException(nameof(action)); + + var call = _client.DoAction(action, options.Headers); + + await foreach (var result in call.ResponseStream.ReadAllAsync()) + { + yield return result; + } + } + + /// + /// Perform the indicated action with default options, returning an iterator to the stream of results, if any. + /// + /// The action to be performed + /// An async enumerable of results + public async IAsyncEnumerable DoActionAsync(FlightAction action) + { + await foreach (var result in DoActionAsync(new FlightCallOptions(), action)) + { + yield return result; + } + } + + /// + /// Get the result set schema from the server for the given query. + /// + /// Per-RPC options + /// The UTF8-encoded SQL query + /// A transaction to associate this query with + /// The SchemaResult describing the schema of the result set + public async Task GetExecuteSchemaAsync(FlightCallOptions options, string query, + Transaction? transaction = null) + { + transaction ??= NoTransaction(); + + if (options is null) + throw new ArgumentNullException(nameof(options)); + + if (string.IsNullOrEmpty(query)) + throw new ArgumentException($"Query cannot be null or empty: {nameof(query)}"); + + FlightInfo schemaResult = null!; + try + { + var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query }; + var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); + var call = _client.DoAction(action, options.Headers); + + // Process the response + await foreach (var result in call.ResponseStream.ReadAllAsync()) + { + var preparedStatementResponse = + FlightSqlUtils.ParseAndUnpack(result.Body); + var commandSqlCall = new CommandPreparedStatementQuery + { + PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle + }; + byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); + schemaResult = await GetFlightInfoAsync(options, descriptor); + } + + return schemaResult.Schema; + } + catch (RpcException ex) + { + // Handle gRPC exceptions + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get execute schema", ex); + } + catch (Exception ex) + { + // Handle other exceptions + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Request a list of catalogs. + /// + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetCatalogs(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var command = new CommandGetCatalogs(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var catalogsInfo = await GetFlightInfoAsync(options, descriptor); + return catalogsInfo; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status}"); + throw new InvalidOperationException("Failed to get catalogs", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Get the catalogs schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the catalogs. + public async Task GetCatalogsSchema(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var commandGetCatalogsSchema = new CommandGetCatalogs(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetCatalogsSchema.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(options, descriptor); + return schemaResult; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status}"); + throw new InvalidOperationException("Failed to get catalogs schema", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Asynchronously retrieves schema information for a given flight descriptor. + /// + /// RPC-layer hints for this call. + /// The descriptor of the dataset request, whether a named dataset or a command. + /// A task that represents the asynchronous operation. The task result contains the SchemaResult describing the dataset schema. + public async Task GetSchemaAsync(FlightCallOptions options, FlightDescriptor descriptor) + { + if (descriptor is null) + { + throw new ArgumentNullException(nameof(descriptor)); + } + + try + { + var schemaResultCall = _client.GetSchema(descriptor, options.Headers); + var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); + return schemaResult; + } + catch (RpcException ex) + { + // Handle gRPC exceptions + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get schema", ex); + } + catch (Exception ex) + { + // Handle other exceptions + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Asynchronously retrieves schema information for a given flight descriptor. + /// + /// The descriptor of the dataset request, whether a named dataset or a command. + /// A task that represents the asynchronous operation. The task result contains the SchemaResult describing the dataset schema. + public Task GetSchemaAsync(FlightDescriptor descriptor) + { + var options = new FlightCallOptions(); + return GetSchemaAsync(options, descriptor); + } + + /// + /// Request a list of database schemas. + /// + /// RPC-layer hints for this call. + /// The catalog. + /// The schema filter pattern. + /// The FlightInfo describing where to access the dataset. + public async Task GetDbSchemasAsync(FlightCallOptions options, string? catalog = null, + string? dbSchemaFilterPattern = null) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var command = new CommandGetDbSchemas(); + + if (catalog != null) + { + command.Catalog = catalog; + } + + if (dbSchemaFilterPattern != null) + { + command.DbSchemaFilterPattern = dbSchemaFilterPattern; + } + + byte[] serializedAndPackedCommand = command.PackAndSerialize(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(serializedAndPackedCommand); + var flightInfoCall = GetFlightInfoAsync(options, descriptor); + var flightInfo = await flightInfoCall.ConfigureAwait(false); + + return flightInfo; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get database schemas", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Get the database schemas schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the database schemas. + public async Task GetDbSchemasSchemaAsync(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var command = new CommandGetDbSchemas(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + + var schemaResultCall = _client.GetSchema(descriptor, options.Headers); + var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); + + return schemaResult; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get database schemas schema", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Given a flight ticket and schema, request to be sent the stream. Returns record batch stream reader. + /// + /// Per-RPC options + /// The flight ticket to use + /// The returned RecordBatchReader + public async IAsyncEnumerable DoGetAsync(FlightCallOptions options, FlightTicket ticket) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (ticket == null) + { + throw new ArgumentNullException(nameof(ticket)); + } + + var call = _client.GetStream(ticket, options.Headers); + await foreach (var recordBatch in call.ResponseStream.ReadAllAsync()) + { + yield return recordBatch; + } + } + + /// + /// Upload data to a Flight described by the given descriptor. The caller must call Close() on the returned stream + /// once they are done writing. + /// + /// RPC-layer hints for this call. + /// The descriptor of the stream. + /// The schema for the data to upload. + /// A Task representing the asynchronous operation. The task result contains a DoPutResult struct holding a reader and a writer. + public Task DoPut(FlightCallOptions options, FlightDescriptor descriptor, Schema schema) + { + if (descriptor is null) + throw new ArgumentNullException(nameof(descriptor)); + + if (schema is null) + throw new ArgumentNullException(nameof(schema)); + try + { + var doPutResult = _client.StartPut(descriptor, options.Headers); + // Get the writer and reader + var writer = doPutResult.RequestStream; + var reader = doPutResult.ResponseStream; + + // TODO: After Re-Check it with Jeremy + // Create an empty RecordBatch to begin the writer with the schema + // var emptyRecordBatch = new RecordBatch(schema, new List(), 0); + // await writer.WriteAsync(emptyRecordBatch); + + // Begin the writer with the schema + return Task.FromResult(new DoPutResult(writer, reader)); + } + catch (RpcException ex) + { + // Handle gRPC exceptions + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to perform DoPut operation", ex); + } + catch (Exception ex) + { + // Handle other exceptions + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Upload data to a Flight described by the given descriptor. The caller must call Close() on the returned stream + /// once they are done writing. Uses default options. + /// + /// The descriptor of the stream. + /// The schema for the data to upload. + /// A Task representing the asynchronous operation. The task result contains a DoPutResult struct holding a reader and a writer. + public Task DoPutAsync(FlightDescriptor descriptor, Schema schema) + { + return DoPut(new FlightCallOptions(), descriptor, schema); + } + + /// + /// Request the primary keys for a table. + /// + /// RPC-layer hints for this call. + /// The table reference. + /// The FlightInfo describing where to access the dataset. + public async Task GetPrimaryKeys(FlightCallOptions options, TableRef tableRef) + { + if (tableRef == null) + throw new ArgumentNullException(nameof(tableRef)); + + try + { + var getPrimaryKeysRequest = new CommandGetPrimaryKeys + { + Catalog = tableRef.Catalog ?? string.Empty, + DbSchema = tableRef.DbSchema, + Table = tableRef.Table + }; + var action = new FlightAction("GetPrimaryKeys", getPrimaryKeysRequest.PackAndSerialize()); + var doActionResult = DoActionAsync(options, action); + + await foreach (var result in doActionResult) + { + var getPrimaryKeysResponse = + result.Body.ParseAndUnpack(); + var command = new CommandPreparedStatementQuery + { + PreparedStatementHandle = getPrimaryKeysResponse.PreparedStatementHandle + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + return flightInfo; + } + throw new InvalidOperationException("Failed to retrieve primary keys information."); + + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get primary keys", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Request a list of tables. + /// + /// RPC-layer hints for this call. + /// The catalog. + /// The schema filter pattern. + /// The table filter pattern. + /// True to include the schema upon return, false to not include the schema. + /// The table types to include. + /// The FlightInfo describing where to access the dataset. + public async Task> GetTablesAsync(FlightCallOptions options, + string? catalog = null, + string? dbSchemaFilterPattern = null, + string? tableFilterPattern = null, + bool includeSchema = false, + IEnumerable? tableTypes = null) + { + if (options == null) + throw new ArgumentNullException(nameof(options)); + + var command = new CommandGetTables + { + Catalog = catalog ?? string.Empty, + DbSchemaFilterPattern = dbSchemaFilterPattern ?? string.Empty, + TableNameFilterPattern = tableFilterPattern ?? string.Empty, + IncludeSchema = includeSchema + }; + command.TableTypes.AddRange(tableTypes); + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfoCall = GetFlightInfoAsync(options, descriptor); + var flightInfo = await flightInfoCall.ConfigureAwait(false); + var flightInfos = new List { flightInfo }; + + return flightInfos; + } + + + /// + /// Execute a bulk ingestion to the server. + /// + /// RPC-layer hints for this call. + /// The records to ingest. + /// The behavior for handling the table definition. + /// The destination table to load into. + /// The DB schema of the destination table. + /// The catalog of the destination table. + /// Use a temporary table. + /// Ingest as part of this transaction. + /// Additional, backend-specific options. + /// The number of rows ingested to the server. + public async Task ExecuteIngestAsync(FlightCallOptions options, FlightClientRecordBatchStreamReader reader, + CommandStatementIngest.Types.TableDefinitionOptions tableDefinitionOptions, string table, string? schema = null, + string? catalog = null, bool temporary = false, Transaction? transaction = null, + Dictionary? ingestOptions = null) + { + transaction ??= NoTransaction(); + + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (reader == null) + { + throw new ArgumentNullException(nameof(reader)); + } + + var ingestRequest = new CommandStatementIngest + { + Table = table, + Schema = schema ?? string.Empty, + Catalog = catalog ?? string.Empty, + Temporary = temporary, + // TransactionId = transaction?.TransactionId, + TableDefinitionOptions = tableDefinitionOptions, + }; + + if (ingestOptions != null) + { + foreach (var option in ingestOptions) + { + ingestRequest.Options.Add(option.Key, option.Value); + } + } + + var action = new FlightAction(SqlAction.CreateRequest, ingestRequest.PackAndSerialize()); + var call = _client.DoAction(action, options.Headers); + + long ingestedRows = 0; + await foreach (var result in call.ResponseStream.ReadAllAsync()) + { + var response = result.Body.ParseAndUnpack(); + } + return ingestedRows; + } + + +} + +internal static class FlightDescriptorExtensions +{ + public static byte[] PackAndSerialize(this IMessage command) + { + return Any.Pack(command).Serialize().ToByteArray(); + } + + public static T ParseAndUnpack(this ByteString body) where T : IMessage, new() + { + return Any.Parser.ParseFrom(body).Unpack(); + } +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs b/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs new file mode 100644 index 0000000000000..48dcf78328416 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs @@ -0,0 +1,16 @@ +using Apache.Arrow.Flight.Client; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql; + +public class DoPutResult +{ + public FlightClientRecordBatchStreamWriter Writer { get; } + public IAsyncStreamReader Reader { get; } + + public DoPutResult(FlightClientRecordBatchStreamWriter writer, IAsyncStreamReader reader) + { + Writer = writer; + Reader = reader; + } +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs new file mode 100644 index 0000000000000..98b436a34723f --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs @@ -0,0 +1,10 @@ +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql; + +public class FlightCallOptions +{ + // Implement any necessary options for RPC calls + public Metadata Headers { get; set; } = new(); + +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs new file mode 100644 index 0000000000000..9051d103bdbfd --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -0,0 +1,24 @@ +using System.Threading.Tasks; + +namespace Apache.Arrow.Flight.Sql; + +public class PreparedStatement +{ + public Task SetParameters(RecordBatch parameterBatch) + { + // Implement setting parameters + return Task.CompletedTask; + } + + public Task ExecuteUpdateAsync(FlightCallOptions options) + { + // Implement execution of the prepared statement + return Task.CompletedTask; + } + + public Task CloseAsync(FlightCallOptions options) + { + // Implement closing the prepared statement + return Task.CompletedTask; + } +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Savepoint.cs b/csharp/src/Apache.Arrow.Flight.Sql/Savepoint.cs new file mode 100644 index 0000000000000..212e4d73fbd61 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Savepoint.cs @@ -0,0 +1,13 @@ +namespace Apache.Arrow.Flight.Sql; + +public class Savepoint +{ + public string SavepointId { get; private set; } + + public Savepoint(string savepointId) + { + SavepointId = savepointId; + } + + public bool IsValid() => !string.IsNullOrEmpty(SavepointId); +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs b/csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs new file mode 100644 index 0000000000000..704e9fedcd44d --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs @@ -0,0 +1,8 @@ +namespace Apache.Arrow.Flight.Sql; + +public class TableRef +{ + public string? Catalog { get; set; } + public string DbSchema { get; set; } = null!; + public string Table { get; set; } = null!; +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs b/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs new file mode 100644 index 0000000000000..1916f0ec3d50a --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs @@ -0,0 +1,8 @@ +namespace Apache.Arrow.Flight.Sql; + +public class Transaction(string? transactionId) +{ + public string? TransactionId { get; private set; } = transactionId; + + public bool IsValid() => !string.IsNullOrEmpty(TransactionId); +} diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs index adc229a051227..c66c7401c7cef 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs @@ -304,7 +304,7 @@ public void Visit(DictionaryType type) // type in the DictionaryEncoding metadata in the parent field type.ValueType.Accept(this); } - + public void Visit(FixedSizeBinaryType type) { Result = FieldType.Build( From 24981b42e4d42f98b8dc50bf005187ac5fddb36b Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Tue, 6 Aug 2024 17:41:16 +0300 Subject: [PATCH 02/58] feat(sqlclient): implemented methods **GetExportedKeys** - Retrieves a description about the foreign key columns that reference the primary key columns of the given table. - **GetExportedKeysSchema** - Get the exported keys schema from the server. - **GetImportedKeys** - Retrieves the foreign key columns for the given table. - **GetImportedKeysSchema** - Get the imported keys schema from the server. - **GetCrossReference** - Retrieves a description of the foreign key columns in the given foreign key table that reference the primary key or the columns representing a unique constraint of the parent table. - **GetCrossReferenceSchema** - Get the cross reference schema from the server. - **GetTableTypes** - Request a list of table types. - **GetTableTypesSchema** - Get the table types schema from the server. - **GetXdbcTypeInfo** - Request the information about all the data types supported. - **GetXdbcTypeInfo (with data_type parameter)** - Request the information about a specific data type supported. - **GetXdbcTypeInfoSchema** - Get the type info schema from the server. - **GetSqlInfo** - Request a list of SQL information. - **GetSqlInfoSchema** --- .../Program.cs | 99 +++- .../Client/FlightSqlClient.cs | 468 +++++++++++++++++- 2 files changed, 555 insertions(+), 12 deletions(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs index 83f3021d12882..4e22fb7fc9e35 100644 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs @@ -222,6 +222,7 @@ static async Task Main(string[] args) // Call GetTablesAsync method + Console.WriteLine("GetTablesAsync:"); IEnumerable tables = await sqlClient.GetTablesAsync( new FlightCallOptions(), catalog: "", @@ -229,8 +230,6 @@ static async Task Main(string[] args) tableFilterPattern: "SYSDB", includeSchema: true, tableTypes: new List { "TABLE", "VIEW" }); - - // Process and print the results foreach (var table in tables) { Console.WriteLine($"Table URI: {table.Descriptor.Paths}"); @@ -239,12 +238,103 @@ static async Task Main(string[] args) Console.WriteLine($"Endpoint Ticket: {endpoint.Ticket}"); } } + + var tableRef = new TableRef { Catalog = "", DbSchema = "SYSDB", Table = "Info" }; + + // Get exported keys + // Console.WriteLine("GetExportedKeysAsync:"); + // var tableRef = new TableRef { Catalog = "", DbSchema = "SYSDB", Table = "Info" }; + // var flightInfoExportedKeys = await sqlClient.GetExportedKeysAsync(new FlightCallOptions(), tableRef); + // Console.WriteLine("FlightInfo obtained:"); + // Console.WriteLine($" FlightDescriptor: {flightInfoExportedKeys.Descriptor}"); + // Console.WriteLine($" Total records: {flightInfoExportedKeys.TotalRecords}"); + // Console.WriteLine($" Total bytes: {flightInfoExportedKeys.TotalBytes}"); + + // Get exported keys schema + // var schema = await sqlClient.GetExportedKeysSchemaAsync(new FlightCallOptions()); + // Console.WriteLine("Schema obtained:"); + // Console.WriteLine($" Fields: {string.Join(", ", schema.FieldsList)}"); + + // Get imported keys + // Console.WriteLine("GetImportedKeys"); + // var flightInfoGetImportedKeys = sqlClient.GetImportedKeysAsync(new FlightCallOptions(), tableRef); + // Console.WriteLine("FlightInfo obtained:"); + // Console.WriteLine($@" Location: {flightInfoGetImportedKeys.Result.Endpoints[0]}"); + + // Get imported keys schema + // Console.WriteLine("GetImportedKeysSchemaAsync:"); + // var schema = await sqlClient.GetImportedKeysSchemaAsync(new FlightCallOptions()); + // Console.WriteLine("Imported Keys Schema obtained:"); + // Console.WriteLine($"Schema Fields: {string.Join(", ", schema.FieldsList)}"); + + // Get cross reference + // Console.WriteLine("GetCrossReferenceAsync:"); + // var flightInfoGetCrossReference = await sqlClient.GetCrossReferenceAsync(new FlightCallOptions(), tableRef, new TableRef + // { + // Catalog = "catalog2", + // DbSchema = "schema2", + // Table = "table2" + // }); + // Console.WriteLine("Cross Reference Information obtained:"); + // Console.WriteLine($"Flight Descriptor: {flightInfoGetCrossReference.Descriptor}"); + // Console.WriteLine($"Endpoints: {string.Join(", ", flightInfoGetCrossReference.Endpoints)}"); + + // Get cross-reference schema + // Console.WriteLine("GetCrossReferenceSchemaAsync:"); + // var schema = await sqlClient.GetCrossReferenceSchemaAsync(new FlightCallOptions()); + // Console.WriteLine("Cross Reference Schema obtained:"); + // Console.WriteLine($"Schema: {schema}"); + + + // Get table types + // Console.WriteLine("GetTableTypesAsync:"); + // var tableTypesInfo = await sqlClient.GetTableTypesAsync(new FlightCallOptions()); + // Console.WriteLine("Table Types Info obtained:"); + // Console.WriteLine($"FlightInfo: {tableTypesInfo}"); + + // Get table types schema + // Console.WriteLine("GetTableTypesSchemaAsync:"); + // var tableTypesSchema = await sqlClient.GetTableTypesSchemaAsync(new FlightCallOptions()); + // Console.WriteLine("Table Types Schema obtained:"); + // Console.WriteLine($"Schema: {tableTypesSchema}"); + + // Get XDBC type info (with DataType) + Console.WriteLine("GetXdbcTypeInfoAsync: (With DataType)"); + var flightInfoGetXdbcTypeInfoWithoutDataType = + await sqlClient.GetXdbcTypeInfoAsync(new FlightCallOptions(), 4); + Console.WriteLine("XDBC With DataType Info obtained:"); + Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeInfoWithoutDataType}"); + + // Get XDBC type info + Console.WriteLine("GetXdbcTypeInfoAsync:"); + var flightInfoGetXdbcTypeInfo = await sqlClient.GetXdbcTypeInfoAsync(new FlightCallOptions()); + Console.WriteLine("XDBC Type Info obtained:"); + Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeInfo}"); + + // Get XDBC type info schema + // Console.WriteLine("GetXdbcTypeInfoSchemaAsync:"); + // var flightInfoGetXdbcTypeSchemaInfo = await sqlClient.GetXdbcTypeInfoSchemaAsync(new FlightCallOptions()); + // Console.WriteLine("XDBC Type Info obtained:"); + // Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeSchemaInfo}"); + + // Get SQL info + Console.WriteLine("GetSqlInfoAsync:"); + // Define SQL info list + var sqlInfo = new List { 1, 2, 3 }; + var flightInfoGetSqlInfo = sqlClient.GetSqlInfoAsync(new FlightCallOptions(), sqlInfo); + Console.WriteLine("SQL Info obtained:"); + Console.WriteLine($"FlightInfo: {flightInfoGetSqlInfo}"); + + // Get SQL info schema + Console.WriteLine("GetSqlInfoSchemaAsync:"); + var schema = await sqlClient.GetSqlInfoSchemaAsync(new FlightCallOptions()); + Console.WriteLine("SQL Info Schema obtained:"); + Console.WriteLine($"Schema: {schema}"); } catch (Exception ex) { Console.WriteLine($"Error executing query: {ex.Message}"); } - } static async Task PutExample(FlightSqlClient client, string query) @@ -269,7 +359,8 @@ static async Task PutExample(FlightSqlClient client, string query) { new("id", Int32Type.Default, nullable: false), new("name", StringType.Default, nullable: false) }; - var metadata = new List> { new("db_name", "SYSDB"), new("table_name", "Info") }; + var metadata = + new List> { new("db_name", "SYSDB"), new("table_name", "Info") }; var schema = new Schema(fields, metadata); var doPutResult = await client.DoPut(options, descriptor, schema).ConfigureAwait(false); diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 68ddc91481490..b1a8e4bd52052 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -570,9 +570,7 @@ public async Task GetPrimaryKeys(FlightCallOptions options, TableRef { var getPrimaryKeysRequest = new CommandGetPrimaryKeys { - Catalog = tableRef.Catalog ?? string.Empty, - DbSchema = tableRef.DbSchema, - Table = tableRef.Table + Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table }; var action = new FlightAction("GetPrimaryKeys", getPrimaryKeysRequest.PackAndSerialize()); var doActionResult = DoActionAsync(options, action); @@ -590,8 +588,8 @@ public async Task GetPrimaryKeys(FlightCallOptions options, TableRef var flightInfo = await GetFlightInfoAsync(options, descriptor); return flightInfo; } - throw new InvalidOperationException("Failed to retrieve primary keys information."); + throw new InvalidOperationException("Failed to retrieve primary keys information."); } catch (RpcException ex) { @@ -643,6 +641,459 @@ public async Task> GetTablesAsync(FlightCallOptions opti } + /// + /// Retrieves a description about the foreign key columns that reference the primary key columns of the given table. + /// + /// RPC-layer hints for this call. + /// The table reference. + /// The FlightInfo describing where to access the dataset. + public async Task GetExportedKeysAsync(FlightCallOptions options, TableRef tableRef) + { + if (tableRef == null) + throw new ArgumentNullException(nameof(tableRef)); + + try + { + var getExportedKeysRequest = new CommandGetExportedKeys + { + Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(getExportedKeysRequest.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + return flightInfo; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get exported keys", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Get the exported keys schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the exported keys. + public async Task GetExportedKeysSchemaAsync(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var commandGetExportedKeysSchema = new CommandGetExportedKeys(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetExportedKeysSchema.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(options, descriptor); + return schemaResult; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get exported keys schema", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Retrieves the foreign key columns for the given table. + /// + /// RPC-layer hints for this call. + /// The table reference. + /// The FlightInfo describing where to access the dataset. + public async Task GetImportedKeysAsync(FlightCallOptions options, TableRef tableRef) + { + if (tableRef == null) + throw new ArgumentNullException(nameof(tableRef)); + + try + { + var getImportedKeysRequest = new CommandGetImportedKeys + { + Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table + }; + + var action = + new FlightAction("GetImportedKeys", + getImportedKeysRequest.PackAndSerialize()); // check: whether using SqlAction.Enum + var doActionResult = DoActionAsync(options, action); + + await foreach (var result in doActionResult) + { + var getImportedKeysResponse = + result.Body.ParseAndUnpack(); + var command = new CommandPreparedStatementQuery + { + PreparedStatementHandle = getImportedKeysResponse.PreparedStatementHandle + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + return flightInfo; + } + + throw new InvalidOperationException("Failed to retrieve imported keys information."); + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get imported keys", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Get the imported keys schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the imported keys. + public async Task GetImportedKeysSchemaAsync(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var commandGetImportedKeysSchema = new CommandGetImportedKeys(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetImportedKeysSchema.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(options, descriptor); + return schemaResult; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get imported keys schema", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Retrieves a description of the foreign key columns in the given foreign key table that reference the primary key or the columns representing a unique constraint of the parent table. + /// + /// RPC-layer hints for this call. + /// The table reference that exports the key. + /// The table reference that imports the key. + /// The FlightInfo describing where to access the dataset. + public async Task GetCrossReferenceAsync(FlightCallOptions options, TableRef pkTableRef, + TableRef fkTableRef) + { + if (pkTableRef == null) + throw new ArgumentNullException(nameof(pkTableRef)); + + if (fkTableRef == null) + throw new ArgumentNullException(nameof(fkTableRef)); + + try + { + var commandGetCrossReference = new CommandGetCrossReference + { + PkCatalog = pkTableRef.Catalog ?? string.Empty, + PkDbSchema = pkTableRef.DbSchema, + PkTable = pkTableRef.Table, + FkCatalog = fkTableRef.Catalog ?? string.Empty, + FkDbSchema = fkTableRef.DbSchema, + FkTable = fkTableRef.Table + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetCrossReference.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + + return flightInfo; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get cross reference", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Get the cross-reference schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the cross-reference. + public async Task GetCrossReferenceSchemaAsync(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var commandGetCrossReferenceSchema = new CommandGetCrossReference(); + var descriptor = + FlightDescriptor.CreateCommandDescriptor(commandGetCrossReferenceSchema.PackAndSerialize()); + var schemaResultCall = GetSchemaAsync(options, descriptor); + var schemaResult = await schemaResultCall.ConfigureAwait(false); + + return schemaResult; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get cross-reference schema", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Request a list of table types. + /// + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetTableTypesAsync(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var command = new CommandGetTableTypes(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + return flightInfo; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get table types", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Get the table types schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the table types. + public async Task GetTableTypesSchemaAsync(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var command = new CommandGetTableTypes(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(options, descriptor); + return schemaResult; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get table types schema", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Request the information about all the data types supported with filtering by data type. + /// + /// RPC-layer hints for this call. + /// The data type to search for as filtering. + /// The FlightInfo describing where to access the dataset. + public async Task GetXdbcTypeInfoAsync(FlightCallOptions options, int dataType) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var command = new CommandGetXdbcTypeInfo { DataType = dataType }; + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + return flightInfo; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get XDBC type info", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Request the information about all the data types supported. + /// + /// RPC-layer hints for this call. + /// The FlightInfo describing where to access the dataset. + public async Task GetXdbcTypeInfoAsync(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var command = new CommandGetXdbcTypeInfo(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + return flightInfo; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get XDBC type info", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Get the type info schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the type info. + public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var command = new CommandGetXdbcTypeInfo(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var schemaResult = await GetSchemaAsync(options, descriptor); + return schemaResult; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get XDBC type info schema", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Request a list of SQL information. + /// + /// RPC-layer hints for this call. + /// The SQL info required. + /// The FlightInfo describing where to access the dataset. + public async Task GetSqlInfoAsync(FlightCallOptions options, List sqlInfo) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (sqlInfo == null || sqlInfo.Count == 0) + { + throw new ArgumentException("SQL info list cannot be null or empty", nameof(sqlInfo)); + } + + try + { + var command = new CommandGetSqlInfo(); + command.Info.AddRange(sqlInfo.ConvertAll(item => (uint)item)); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + return flightInfo; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get SQL info", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Get the SQL information schema from the server. + /// + /// RPC-layer hints for this call. + /// The SchemaResult describing the schema of the SQL information. + public async Task GetSqlInfoSchemaAsync(FlightCallOptions options) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + try + { + var command = new CommandGetSqlInfo(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var schemaResultCall = _client.GetSchema(descriptor, options.Headers); + var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); + + return schemaResult; + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to get SQL info schema", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + /// /// Execute a bulk ingestion to the server. /// @@ -656,8 +1107,10 @@ public async Task> GetTablesAsync(FlightCallOptions opti /// Ingest as part of this transaction. /// Additional, backend-specific options. /// The number of rows ingested to the server. - public async Task ExecuteIngestAsync(FlightCallOptions options, FlightClientRecordBatchStreamReader reader, - CommandStatementIngest.Types.TableDefinitionOptions tableDefinitionOptions, string table, string? schema = null, + public async Task ExecuteIngestAsync(FlightCallOptions options, + FlightClientRecordBatchStreamReader reader, + CommandStatementIngest.Types.TableDefinitionOptions tableDefinitionOptions, string table, + string? schema = null, string? catalog = null, bool temporary = false, Transaction? transaction = null, Dictionary? ingestOptions = null) { @@ -699,10 +1152,9 @@ public async Task ExecuteIngestAsync(FlightCallOptions options, FlightClie { var response = result.Body.ParseAndUnpack(); } + return ingestedRows; } - - } internal static class FlightDescriptorExtensions From d7a535cab646ea09a63ba9d587cc9aba66404249 Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Sun, 18 Aug 2024 19:09:06 +0300 Subject: [PATCH 03/58] feat: cancel functionality --- cpp/CMakeLists.txt | 10 +- cpp/src/arrow/flight/sql/CMakeLists.txt | 8 +- .../Program.cs | 44 ++- .../CancelFlightInfoRequest.cs | 53 ++++ .../CancelFlightInfoResult.cs | 32 ++ .../Client/FlightSqlClient.cs | 300 +++++++++++++++--- .../FlightCallOptions.cs | 22 ++ .../PreparedStatement.cs | 13 + .../Apache.Arrow.Flight.Sql/Transaction.cs | 2 +- 9 files changed, 429 insertions(+), 55 deletions(-) create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a1e3138da9e0b..e161bd308ebc1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -414,9 +414,9 @@ else() endif() endif() -if(NOT ARROW_BUILD_EXAMPLES) - set(NO_EXAMPLES 1) -endif() +#if(NOT ARROW_BUILD_EXAMPLES) +# set(NO_EXAMPLES 1) +#endif() if(ARROW_FUZZING) # Fuzzing builds enable ASAN without setting our home-grown option for it. @@ -738,10 +738,10 @@ if(ARROW_SKYHOOK) add_subdirectory(src/skyhook) endif() -if(ARROW_BUILD_EXAMPLES) +#if(ARROW_BUILD_EXAMPLES) add_custom_target(runexample ctest -L example) add_subdirectory(examples/arrow) -endif() +#endif() install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/../LICENSE.txt ${CMAKE_CURRENT_SOURCE_DIR}/../NOTICE.txt diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index b32f731496749..eed30f2c47ca0 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -94,7 +94,7 @@ endif() list(APPEND ARROW_FLIGHT_SQL_TEST_LINK_LIBS ${ARROW_FLIGHT_TEST_LINK_LIBS}) # Build test server for unit tests -if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) +#if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) find_package(SQLite3Alt REQUIRED) set(ARROW_FLIGHT_SQL_TEST_SERVER_SRCS @@ -121,14 +121,14 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_shared) endif() - if(ARROW_BUILD_EXAMPLES) + #if(ARROW_BUILD_EXAMPLES) add_executable(acero-flight-sql-server ${ARROW_FLIGHT_SQL_ACERO_SRCS} example/acero_main.cc) target_link_libraries(acero-flight-sql-server PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} ${ARROW_FLIGHT_SQL_TEST_LIBS} ${GFLAGS_LIBRARIES}) - endif() - endif() + #endif() + #endif() add_arrow_test(flight_sql_test SOURCES diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs index 4e22fb7fc9e35..11c88a93dd9b9 100644 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs @@ -186,7 +186,6 @@ static async Task Main(string[] args) // GetCatalogsSchema // Console.WriteLine("GetCatalogsSchema:"); // Schema schemaCatalogResult = await sqlClient.GetCatalogsSchema(new FlightCallOptions()); - // // Print schema details // Console.WriteLine("Catalogs Schema retrieved:"); // Console.WriteLine(schemaCatalogResult); @@ -198,7 +197,6 @@ static async Task Main(string[] args) // Console.WriteLine("Database schemas retrieved:"); // Console.WriteLine(flightInfoDbSchemas); - // GetDbSchemasSchemaAsync // Console.WriteLine("GetDbSchemasSchemaAsync:"); // Schema schema = await sqlClient.GetDbSchemasSchemaAsync(new FlightCallOptions()); @@ -326,10 +324,44 @@ static async Task Main(string[] args) Console.WriteLine($"FlightInfo: {flightInfoGetSqlInfo}"); // Get SQL info schema - Console.WriteLine("GetSqlInfoSchemaAsync:"); - var schema = await sqlClient.GetSqlInfoSchemaAsync(new FlightCallOptions()); - Console.WriteLine("SQL Info Schema obtained:"); - Console.WriteLine($"Schema: {schema}"); + // Console.WriteLine("GetSqlInfoSchemaAsync:"); + // var schema = await sqlClient.GetSqlInfoSchemaAsync(new FlightCallOptions()); + // Console.WriteLine("SQL Info Schema obtained:"); + // Console.WriteLine($"Schema: {schema}"); + + // Prepare a SQL statement + Console.WriteLine("PrepareAsync:"); + var preparedStatement = await sqlClient.PrepareAsync(new FlightCallOptions(), query); + Console.WriteLine("Prepared statement created successfully."); + + + // Cancel FlightInfo Request + Console.WriteLine("CancelFlightInfoRequest:"); + var cancelRequest = new CancelFlightInfoRequest(flightInfo); + var cancelResult = await sqlClient.CancelFlightInfoAsync(new FlightCallOptions(), cancelRequest); + Console.WriteLine($"Cancellation Status: {cancelResult.Status}"); + + // Begin Transaction + // Console.WriteLine("BeginTransaction:"); + // Transaction transaction = await sqlClient.BeginTransactionAsync(new FlightCallOptions()); + // Console.WriteLine($"Transaction started with ID: {transaction.TransactionId}"); + // FlightInfo flightInfoBeginTransaction = + // await sqlClient.ExecuteAsync(new FlightCallOptions(), query, transaction); + // Console.WriteLine("Query executed within transaction"); + // + // // Commit Transaction + // Console.WriteLine("CommitTransaction:"); + // await sqlClient.CommitAsync(new FlightCallOptions(), new Transaction("transaction-id")); + // Console.WriteLine("Transaction committed successfully."); + // + // // Rollback Transaction + // Console.WriteLine("RollbackTransaction"); + // await sqlClient.RollbackAsync(new FlightCallOptions(), new Transaction("transaction-id")); + // Console.WriteLine("Transaction rolled back successfully."); + + // Cancel Query + // var cancelResult = await sqlClient.CancelQueryAsync(new FlightCallOptions(), flightInfo); + // Console.WriteLine($"Cancellation Status: {cancelResult.Status}"); } catch (Exception ex) { diff --git a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs new file mode 100644 index 0000000000000..17d0958c2daa4 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs @@ -0,0 +1,53 @@ +using System; +using Apache.Arrow.Flight; +using Google.Protobuf; +using Google.Protobuf.Reflection; + +public sealed class CancelFlightInfoRequest : IMessage +{ + public FlightInfo FlightInfo { get; set; } + + // Overloaded constructor + public CancelFlightInfoRequest(FlightInfo flightInfo) => + FlightInfo = flightInfo ?? throw new ArgumentNullException(nameof(flightInfo)); + + + public void MergeFrom(CancelFlightInfoRequest message) + { + if (message != null) + { + + } + } + + public void MergeFrom(CodedInputStream input) + { + while (input.ReadTag() != 0) + { + // Assuming FlightInfo is serialized as a field + } + } + + + public void WriteTo(CodedOutputStream output) + { + if (FlightInfo != null) + { + output.WriteRawMessage(this); + } + } + + public int CalculateSize() + { + int size = 0; + if (FlightInfo != null) + { + size += CodedOutputStream.ComputeMessageSize(this); + } + return size; + } + + public MessageDescriptor Descriptor => null!; + public bool Equals(CancelFlightInfoRequest other) => other != null && FlightInfo.Equals(other.FlightInfo); + public CancelFlightInfoRequest Clone() => new CancelFlightInfoRequest(FlightInfo); +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs new file mode 100644 index 0000000000000..28fe476ef4405 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs @@ -0,0 +1,32 @@ +using System; +using Google.Protobuf; +using Google.Protobuf.Reflection; + +namespace Apache.Arrow.Flight.Sql; + +public enum CancelStatus +{ + Unspecified = 0, + Cancelled = 1, + Cancelling = 2, + NotCancellable = 3, + Unrecognized = -1 +} + +public sealed class CancelFlightInfoResult : IMessage +{ + public CancelStatus CancelStatus { get; } + + public void MergeFrom(CancelFlightInfoResult message) => throw new NotImplementedException(); + + public void MergeFrom(CodedInputStream input) => throw new NotImplementedException(); + + public void WriteTo(CodedOutputStream output) => throw new NotImplementedException(); + + public int CalculateSize() => throw new NotImplementedException(); + + public MessageDescriptor Descriptor { get; } + public bool Equals(CancelFlightInfoResult other) => throw new NotImplementedException(); + + public CancelFlightInfoResult Clone() => throw new NotImplementedException(); +} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index b1a8e4bd52052..fbac6cd20c093 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -3,10 +3,10 @@ using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; -using Grpc.Core; +using Arrow.Flight.Protocol.Sql; using Google.Protobuf; using Google.Protobuf.WellKnownTypes; -using Arrow.Flight.Protocol.Sql; +using Grpc.Core; namespace Apache.Arrow.Flight.Sql.Client; @@ -30,7 +30,6 @@ public FlightSqlClient(FlightClient client) /// The FlightInfo describing where to access the dataset. public async Task ExecuteAsync(FlightCallOptions options, string query, Transaction? transaction = null) { - // todo: return FlightInfo transaction ??= NoTransaction(); FlightInfo? flightInfo = null; @@ -572,6 +571,7 @@ public async Task GetPrimaryKeys(FlightCallOptions options, TableRef { Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table }; + //TODO: Refactor var action = new FlightAction("GetPrimaryKeys", getPrimaryKeysRequest.PackAndSerialize()); var doActionResult = DoActionAsync(options, action); @@ -1095,65 +1095,287 @@ public async Task GetSqlInfoSchemaAsync(FlightCallOptions options) } /// - /// Execute a bulk ingestion to the server. + /// Explicitly cancel a FlightInfo. /// /// RPC-layer hints for this call. - /// The records to ingest. - /// The behavior for handling the table definition. - /// The destination table to load into. - /// The DB schema of the destination table. - /// The catalog of the destination table. - /// Use a temporary table. - /// Ingest as part of this transaction. - /// Additional, backend-specific options. - /// The number of rows ingested to the server. - public async Task ExecuteIngestAsync(FlightCallOptions options, - FlightClientRecordBatchStreamReader reader, - CommandStatementIngest.Types.TableDefinitionOptions tableDefinitionOptions, string table, - string? schema = null, - string? catalog = null, bool temporary = false, Transaction? transaction = null, - Dictionary? ingestOptions = null) + /// The CancelFlightInfoRequest. + /// A Task representing the asynchronous operation. The task result contains the CancelFlightInfoResult describing the canceled result. + public async Task CancelFlightInfoAsync(FlightCallOptions options, + CancelFlightInfoRequest request) { - transaction ??= NoTransaction(); + // TODO: fix the CancelFlightInfoResult missing implementation of MessageDescriptor + if (options == null) throw new ArgumentNullException(nameof(options)); + if (request == null) throw new ArgumentNullException(nameof(request)); + + try + { + var action = new FlightAction("CancelFlightInfo", request.ToByteString()); + var call = _client.DoAction(action, options.Headers); + await foreach (var result in call.ResponseStream.ReadAllAsync()) + { + var cancelResult = FlightSqlUtils.ParseAndUnpack(result.Body); + return cancelResult; + } + throw new InvalidOperationException("No response received for the CancelFlightInfo request."); + } + catch (RpcException ex) + { + // Handle gRPC exceptions + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to cancel flight info", ex); + } + catch (Exception ex) + { + // Handle other exceptions + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /// + /// Begin a new transaction. + /// + /// RPC-layer hints for this call. + /// A Task representing the asynchronous operation. The task result contains the Transaction object representing the new transaction. + public async Task BeginTransactionAsync(FlightCallOptions options) + { if (options == null) { throw new ArgumentNullException(nameof(options)); } - if (reader == null) + try { - throw new ArgumentNullException(nameof(reader)); + var actionBeginTransaction = new ActionBeginTransactionRequest(); + var action = new FlightAction("BeginTransaction", actionBeginTransaction.PackAndSerialize()); + var responseStream = _client.DoAction(action, options.Headers); + await foreach (var result in responseStream.ResponseStream.ReadAllAsync()) + { + var beginTransactionResult = + ActionBeginTransactionResult.Parser.ParseFrom(result.Body.Span); + + var transaction = new Transaction(beginTransactionResult?.TransactionId.ToBase64()); + return transaction; + } + throw new InvalidOperationException("Failed to begin transaction: No response received."); } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to begin transaction", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } - var ingestRequest = new CommandStatementIngest + /// + /// Commit a transaction. + /// After this, the transaction and all associated savepoints will be invalidated. + /// + /// RPC-layer hints for this call. + /// The transaction. + /// A Task representing the asynchronous operation. + public async Task CommitAsync(FlightCallOptions options, Transaction transaction) + { + if (options == null) { - Table = table, - Schema = schema ?? string.Empty, - Catalog = catalog ?? string.Empty, - Temporary = temporary, - // TransactionId = transaction?.TransactionId, - TableDefinitionOptions = tableDefinitionOptions, - }; + throw new ArgumentNullException(nameof(options)); + } - if (ingestOptions != null) + if (transaction == null) { - foreach (var option in ingestOptions) + throw new ArgumentNullException(nameof(transaction)); + } + + try + { + var actionCommit = new FlightAction("Commit", transaction.TransactionId); + var responseStream = _client.DoAction(actionCommit, options.Headers); + await foreach (var result in responseStream.ResponseStream.ReadAllAsync()) { - ingestRequest.Options.Add(option.Key, option.Value); + Console.WriteLine("Transaction committed successfully."); } } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to commit transaction", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } - var action = new FlightAction(SqlAction.CreateRequest, ingestRequest.PackAndSerialize()); - var call = _client.DoAction(action, options.Headers); - long ingestedRows = 0; - await foreach (var result in call.ResponseStream.ReadAllAsync()) + /// + /// Rollback a transaction. + /// After this, the transaction and all associated savepoints will be invalidated. + /// + /// RPC-layer hints for this call. + /// The transaction to rollback. + /// A Task representing the asynchronous operation. + public async Task RollbackAsync(FlightCallOptions options, Transaction transaction) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (transaction == null) { - var response = result.Body.ParseAndUnpack(); + throw new ArgumentNullException(nameof(transaction)); } - return ingestedRows; + try + { + var actionRollback = new FlightAction("Rollback", transaction.TransactionId); + var responseStream = _client.DoAction(actionRollback, options.Headers); + await foreach (var result in responseStream.ResponseStream.ReadAllAsync()) + { + Console.WriteLine("Transaction rolled back successfully."); + } + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to rollback transaction", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + + /// + /// Explicitly cancel a query. + /// + /// RPC-layer hints for this call. + /// The FlightInfo of the query to cancel. + /// A Task representing the asynchronous operation. + public async Task CancelQueryAsync(FlightCallOptions options, FlightInfo info) + { + // TODO: Should reconsider more appropriate implementation + // TODO: fix the CancelFlightInfoResult missing implementation of MessageDescriptor + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (info == null) + { + throw new ArgumentNullException(nameof(info)); + } + try + { + var cancelRequest = new CancelFlightInfoRequest(info); + var actionCancelFlightInfo = new FlightAction("CancelFlightInfo", cancelRequest.PackAndSerialize()); + var call = _client.DoAction(actionCancelFlightInfo, options.Headers); + await foreach (var result in call.ResponseStream.ReadAllAsync()) + { + var cancelResult = FlightSqlUtils.ParseAndUnpack(result.Body); + return cancelResult; + } + throw new InvalidOperationException("Failed to cancel query: No response received."); + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); + throw new InvalidOperationException("Failed to cancel query", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } + } + + /*public async Task SetSessionOptionsAsync(FlightCallOptions options, + Dictionary sessionOptions) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (sessionOptions == null) + { + throw new ArgumentNullException(nameof(sessionOptions)); + } + }*/ + + /// + /// Create a prepared statement object. + /// + /// RPC-layer hints for this call. + /// The query that will be executed. + /// A transaction to associate this query with. + /// The created prepared statement. + public async Task PrepareAsync(FlightCallOptions options, string query, + Transaction? transaction = null) + { + transaction ??= NoTransaction(); + + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (string.IsNullOrEmpty(query)) + { + throw new ArgumentException("Query cannot be null or empty", nameof(query)); + } + + try + { + var preparedStatementRequest = new ActionCreatePreparedStatementRequest + { + Query = query, + // TransactionId = transaction?.TransactionId + }; + + var action = new FlightAction(SqlAction.CreateRequest, preparedStatementRequest.PackAndSerialize()); + var call = _client.DoAction(action, options.Headers); + + await foreach (var result in call.ResponseStream.ReadAllAsync()) + { + var preparedStatementResponse = + FlightSqlUtils.ParseAndUnpack(result.Body); + + var commandSqlCall = new CommandPreparedStatementQuery + { + PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle + }; + byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + await foreach (var recordBatch in DoGetAsync(options, flightInfo.Endpoints[0].Ticket)) + { + Console.WriteLine(recordBatch); + } + + return new PreparedStatement(this, flightInfo, query); + } + + throw new NullReferenceException($"{nameof(PreparedStatement)} was not able to be created"); + } + catch (RpcException ex) + { + Console.WriteLine($@"gRPC Error: {ex.Status}"); + throw new InvalidOperationException("Failed to prepare statement", ex); + } + catch (Exception ex) + { + Console.WriteLine($@"Unexpected Error: {ex.Message}"); + throw; + } } } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs index 98b436a34723f..461bee550a97d 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs @@ -1,10 +1,32 @@ +using System; +using System.Buffers; +using System.Threading; using Grpc.Core; namespace Apache.Arrow.Flight.Sql; public class FlightCallOptions { + public FlightCallOptions() + { + Timeout = TimeSpan.FromSeconds(-1); + } // Implement any necessary options for RPC calls public Metadata Headers { get; set; } = new(); + /// + /// Gets or sets a token to enable interactive user cancellation of long-running requests. + /// + public CancellationToken StopToken { get; set; } + + /// + /// Gets or sets the optional timeout for this call. + /// Negative durations mean an implementation-defined default behavior will be used instead. + /// + public TimeSpan Timeout { get; set; } + + /// + /// Gets or sets an optional memory manager to control where to allocate incoming data. + /// + public MemoryManager? MemoryManager { get; set; } } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index 9051d103bdbfd..342b23ef3613d 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -1,9 +1,22 @@ using System.Threading.Tasks; +using Apache.Arrow.Flight.Sql.Client; namespace Apache.Arrow.Flight.Sql; +// TODO: Refactor this to match C++ implementation public class PreparedStatement { + private readonly FlightSqlClient _client; + private readonly FlightInfo _flightInfo; + private readonly string _query; + + public PreparedStatement(FlightSqlClient client, FlightInfo flightInfo, string query) + { + _client = client; + _flightInfo = flightInfo; + _query = query; + } + public Task SetParameters(RecordBatch parameterBatch) { // Implement setting parameters diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs b/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs index 1916f0ec3d50a..1f725074bb489 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs @@ -2,7 +2,7 @@ namespace Apache.Arrow.Flight.Sql; public class Transaction(string? transactionId) { - public string? TransactionId { get; private set; } = transactionId; + public string? TransactionId { get; } = transactionId; public bool IsValid() => !string.IsNullOrEmpty(TransactionId); } From b6c3e9e98af67623fa05888621b8fa76ba6b61bd Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Mon, 19 Aug 2024 15:56:22 +0300 Subject: [PATCH 04/58] feat: cancel flight info --- .../Program.cs | 3 +- .../CancelFlightInfoRequest.cs | 38 +++++---- .../CancelFlightInfoResult.cs | 81 ++++++++++++++++--- .../Client/FlightSqlClient.cs | 24 +++--- 4 files changed, 111 insertions(+), 35 deletions(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs index 11c88a93dd9b9..3b41b7ef64df5 100644 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs @@ -124,6 +124,7 @@ public static RecordBatch CreateTestBatch(int start, int length) using Google.Protobuf.WellKnownTypes; using Grpc.Core; using Grpc.Net.Client; +using Google.Protobuf.Reflection; namespace Apache.Arrow.Flight.Sql.IntegrationTest; @@ -339,7 +340,7 @@ static async Task Main(string[] args) Console.WriteLine("CancelFlightInfoRequest:"); var cancelRequest = new CancelFlightInfoRequest(flightInfo); var cancelResult = await sqlClient.CancelFlightInfoAsync(new FlightCallOptions(), cancelRequest); - Console.WriteLine($"Cancellation Status: {cancelResult.Status}"); + Console.WriteLine($"Cancellation Status: {cancelResult.CancelStatus}"); // Begin Transaction // Console.WriteLine("BeginTransaction:"); diff --git a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs index 17d0958c2daa4..586f8cece913a 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs @@ -1,22 +1,27 @@ using System; -using Apache.Arrow.Flight; using Google.Protobuf; using Google.Protobuf.Reflection; +namespace Apache.Arrow.Flight.Sql; + public sealed class CancelFlightInfoRequest : IMessage { public FlightInfo FlightInfo { get; set; } // Overloaded constructor - public CancelFlightInfoRequest(FlightInfo flightInfo) => + public CancelFlightInfoRequest(FlightInfo flightInfo) + { FlightInfo = flightInfo ?? throw new ArgumentNullException(nameof(flightInfo)); + Descriptor = + DescriptorReflection.Descriptor.MessageTypes[0]; + } public void MergeFrom(CancelFlightInfoRequest message) { if (message != null) { - + FlightInfo = message.FlightInfo; } } @@ -24,30 +29,33 @@ public void MergeFrom(CodedInputStream input) { while (input.ReadTag() != 0) { - // Assuming FlightInfo is serialized as a field + switch (input.Position) + { + case 1: + input.ReadMessage(this); + break; + default: + input.SkipLastField(); + break; + } } } - public void WriteTo(CodedOutputStream output) { - if (FlightInfo != null) - { - output.WriteRawMessage(this); - } + output.WriteTag(1, WireFormat.WireType.LengthDelimited); + output.WriteMessage(FlightInfo.Descriptor.ParsedAndUnpackedMessage()); } public int CalculateSize() { int size = 0; - if (FlightInfo != null) - { - size += CodedOutputStream.ComputeMessageSize(this); - } + size += 1 + CodedOutputStream.ComputeMessageSize(FlightInfo.Descriptor.ParsedAndUnpackedMessage()); return size; } - public MessageDescriptor Descriptor => null!; + public MessageDescriptor Descriptor { get; } + public bool Equals(CancelFlightInfoRequest other) => other != null && FlightInfo.Equals(other.FlightInfo); - public CancelFlightInfoRequest Clone() => new CancelFlightInfoRequest(FlightInfo); + public CancelFlightInfoRequest Clone() => new(FlightInfo); } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs index 28fe476ef4405..5e61fc6f975f5 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs @@ -1,4 +1,3 @@ -using System; using Google.Protobuf; using Google.Protobuf.Reflection; @@ -15,18 +14,82 @@ public enum CancelStatus public sealed class CancelFlightInfoResult : IMessage { - public CancelStatus CancelStatus { get; } + public CancelStatus CancelStatus { get; private set; } - public void MergeFrom(CancelFlightInfoResult message) => throw new NotImplementedException(); + // Public parameterless constructor + public CancelFlightInfoResult() + { + CancelStatus = CancelStatus.Unspecified; + Descriptor = + DescriptorReflection.Descriptor.MessageTypes[0]; + } - public void MergeFrom(CodedInputStream input) => throw new NotImplementedException(); + public void MergeFrom(CancelFlightInfoResult message) + { + if (message != null) + { + CancelStatus = message.CancelStatus; + } + } - public void WriteTo(CodedOutputStream output) => throw new NotImplementedException(); + public void MergeFrom(CodedInputStream input) + { + while (input.ReadTag() != 0) + { + switch (input.Position) + { + case 1: + CancelStatus = (CancelStatus)input.ReadEnum(); + break; + default: + input.SkipLastField(); + break; + } + } + } - public int CalculateSize() => throw new NotImplementedException(); + public void WriteTo(CodedOutputStream output) + { + if (CancelStatus != CancelStatus.Unspecified) + { + output.WriteRawTag(8); // Field number 1, wire type 0 (varint) + output.WriteEnum((int)CancelStatus); + } + } - public MessageDescriptor Descriptor { get; } - public bool Equals(CancelFlightInfoResult other) => throw new NotImplementedException(); + public int CalculateSize() + { + int size = 0; + if (CancelStatus != CancelStatus.Unspecified) + { + size += 1 + CodedOutputStream.ComputeEnumSize((int)CancelStatus); + } - public CancelFlightInfoResult Clone() => throw new NotImplementedException(); + return size; + } + + public MessageDescriptor? Descriptor { get; } + + + public CancelFlightInfoResult Clone() => new() { CancelStatus = CancelStatus }; + + public bool Equals(CancelFlightInfoResult other) + { + if (other == null) + { + return false; + } + + return CancelStatus == other.CancelStatus; + } + + public override int GetHashCode() + { + return CancelStatus.GetHashCode(); + } + + public override string ToString() + { + return $"CancelFlightInfoResult {{ CancelStatus = {CancelStatus} }}"; + } } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index fbac6cd20c093..1598c6ea15f51 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -1104,12 +1104,14 @@ public async Task CancelFlightInfoAsync(FlightCallOption CancelFlightInfoRequest request) { // TODO: fix the CancelFlightInfoResult missing implementation of MessageDescriptor + // NOTE: I modified the CancelFlightInfoRequest/CancelFlightInfoResult with an Descriptor in the constructor, + // NOTE: Can't see any other way to solve the nullified property if (options == null) throw new ArgumentNullException(nameof(options)); if (request == null) throw new ArgumentNullException(nameof(request)); try { - var action = new FlightAction("CancelFlightInfo", request.ToByteString()); + var action = new FlightAction("CancelFlightInfo", request.PackAndSerialize()); var call = _client.DoAction(action, options.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync()) { @@ -1158,6 +1160,7 @@ public async Task BeginTransactionAsync(FlightCallOptions options) var transaction = new Transaction(beginTransactionResult?.TransactionId.ToBase64()); return transaction; } + throw new InvalidOperationException("Failed to begin transaction: No response received."); } catch (RpcException ex) @@ -1260,7 +1263,7 @@ public async Task RollbackAsync(FlightCallOptions options, Transaction transacti /// RPC-layer hints for this call. /// The FlightInfo of the query to cancel. /// A Task representing the asynchronous operation. - public async Task CancelQueryAsync(FlightCallOptions options, FlightInfo info) + public Task CancelQueryAsync(FlightCallOptions options, FlightInfo info) { // TODO: Should reconsider more appropriate implementation // TODO: fix the CancelFlightInfoResult missing implementation of MessageDescriptor @@ -1273,16 +1276,17 @@ public async Task CancelQueryAsync(FlightCallOptions opt { throw new ArgumentNullException(nameof(info)); } + try { - var cancelRequest = new CancelFlightInfoRequest(info); - var actionCancelFlightInfo = new FlightAction("CancelFlightInfo", cancelRequest.PackAndSerialize()); - var call = _client.DoAction(actionCancelFlightInfo, options.Headers); - await foreach (var result in call.ResponseStream.ReadAllAsync()) - { - var cancelResult = FlightSqlUtils.ParseAndUnpack(result.Body); - return cancelResult; - } + // var cancelRequest = new CancelFlightInfoRequest(info); + // var actionCancelFlightInfo = new FlightAction("CancelFlightInfo", cancelRequest.PackAndSerialize()); + // var call = _client.DoAction(actionCancelFlightInfo, options.Headers); + // await foreach (var result in call.ResponseStream.ReadAllAsync()) + // { + // var cancelResult = FlightSqlUtils.ParseAndUnpack(result.Body); + // return cancelResult; + // } throw new InvalidOperationException("Failed to cancel query: No response received."); } catch (RpcException ex) From d43aa7695fcf042600c9d808606fbd4c4d6fe888 Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Mon, 19 Aug 2024 17:30:04 +0300 Subject: [PATCH 05/58] feat: cancel query --- .../Program.cs | 13 ++++---- .../Client/FlightSqlClient.cs | 32 +++++++++++-------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs index 3b41b7ef64df5..7a2cd2dafe702 100644 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs @@ -337,10 +337,10 @@ static async Task Main(string[] args) // Cancel FlightInfo Request - Console.WriteLine("CancelFlightInfoRequest:"); - var cancelRequest = new CancelFlightInfoRequest(flightInfo); - var cancelResult = await sqlClient.CancelFlightInfoAsync(new FlightCallOptions(), cancelRequest); - Console.WriteLine($"Cancellation Status: {cancelResult.CancelStatus}"); + // Console.WriteLine("CancelFlightInfoRequest:"); + // var cancelRequest = new CancelFlightInfoRequest(flightInfo); + // var cancelResult = await sqlClient.CancelFlightInfoAsync(new FlightCallOptions(), cancelRequest); + // Console.WriteLine($"Cancellation Status: {cancelResult.CancelStatus}"); // Begin Transaction // Console.WriteLine("BeginTransaction:"); @@ -361,8 +361,9 @@ static async Task Main(string[] args) // Console.WriteLine("Transaction rolled back successfully."); // Cancel Query - // var cancelResult = await sqlClient.CancelQueryAsync(new FlightCallOptions(), flightInfo); - // Console.WriteLine($"Cancellation Status: {cancelResult.Status}"); + Console.WriteLine("CancelQuery:"); + var cancelResult = await sqlClient.CancelQueryAsync(new FlightCallOptions(), flightInfo); + Console.WriteLine($"Cancellation Status: {cancelResult}"); } catch (Exception ex) { diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 1598c6ea15f51..8470f1c631f08 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -1103,9 +1103,6 @@ public async Task GetSqlInfoSchemaAsync(FlightCallOptions options) public async Task CancelFlightInfoAsync(FlightCallOptions options, CancelFlightInfoRequest request) { - // TODO: fix the CancelFlightInfoResult missing implementation of MessageDescriptor - // NOTE: I modified the CancelFlightInfoRequest/CancelFlightInfoResult with an Descriptor in the constructor, - // NOTE: Can't see any other way to solve the nullified property if (options == null) throw new ArgumentNullException(nameof(options)); if (request == null) throw new ArgumentNullException(nameof(request)); @@ -1263,10 +1260,8 @@ public async Task RollbackAsync(FlightCallOptions options, Transaction transacti /// RPC-layer hints for this call. /// The FlightInfo of the query to cancel. /// A Task representing the asynchronous operation. - public Task CancelQueryAsync(FlightCallOptions options, FlightInfo info) + public async Task CancelQueryAsync(FlightCallOptions options, FlightInfo info) { - // TODO: Should reconsider more appropriate implementation - // TODO: fix the CancelFlightInfoResult missing implementation of MessageDescriptor if (options == null) { throw new ArgumentNullException(nameof(options)); @@ -1279,14 +1274,23 @@ public Task CancelQueryAsync(FlightCallOptions options, try { - // var cancelRequest = new CancelFlightInfoRequest(info); - // var actionCancelFlightInfo = new FlightAction("CancelFlightInfo", cancelRequest.PackAndSerialize()); - // var call = _client.DoAction(actionCancelFlightInfo, options.Headers); - // await foreach (var result in call.ResponseStream.ReadAllAsync()) - // { - // var cancelResult = FlightSqlUtils.ParseAndUnpack(result.Body); - // return cancelResult; - // } + var cancelRequest = new CancelFlightInfoRequest(info); + var action = new FlightAction("CancelFlightInfo", cancelRequest.ToByteString()); + var call = _client.DoAction(action, options.Headers); + await foreach (var result in call.ResponseStream.ReadAllAsync()) + { + var cancelResult = Any.Parser.ParseFrom(result.Body); + if (cancelResult.TryUnpack(out CancelFlightInfoResult cancelFlightInfoResult)) + { + return cancelFlightInfoResult.CancelStatus switch + { + CancelStatus.Cancelled => CancelStatus.Cancelled, + CancelStatus.Cancelling => CancelStatus.Cancelling, + CancelStatus.NotCancellable => CancelStatus.NotCancellable, + _ => CancelStatus.Unspecified + }; + } + } throw new InvalidOperationException("Failed to cancel query: No response received."); } catch (RpcException ex) From 728957e2e43618a107c768ca9d353a51754a371e Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Mon, 26 Aug 2024 18:15:43 +0300 Subject: [PATCH 06/58] test(FlightSqlClient): infra and initial testing --- .../Program.cs | 1 - .../Client/FlightSqlClient.cs | 22 +-- .../Client/FlightClient.cs | 57 +++++-- .../Apache.Arrow.Flight.Sql.Tests.csproj | 1 + .../FlightSqlClientTests.cs | 140 ++++++++++++++++++ .../FlightSqlTestExtensions.cs | 1 + .../TestFlightSqlSever.cs | 14 ++ 7 files changed, 200 insertions(+), 36 deletions(-) create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs index 7a2cd2dafe702..0b443250e68d7 100644 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs @@ -124,7 +124,6 @@ public static RecordBatch CreateTestBatch(int start, int length) using Google.Protobuf.WellKnownTypes; using Grpc.Core; using Grpc.Net.Client; -using Google.Protobuf.Reflection; namespace Apache.Arrow.Flight.Sql.IntegrationTest; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 8470f1c631f08..c9b4d4672df39 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -12,9 +12,9 @@ namespace Apache.Arrow.Flight.Sql.Client; public class FlightSqlClient { - private readonly FlightClient _client; + private readonly IFlightClient _client; - public FlightSqlClient(FlightClient client) + public FlightSqlClient(IFlightClient client) { _client = client ?? throw new ArgumentNullException(nameof(client)); } @@ -571,10 +571,8 @@ public async Task GetPrimaryKeys(FlightCallOptions options, TableRef { Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table }; - //TODO: Refactor var action = new FlightAction("GetPrimaryKeys", getPrimaryKeysRequest.PackAndSerialize()); var doActionResult = DoActionAsync(options, action); - await foreach (var result in doActionResult) { var getPrimaryKeysResponse = @@ -1153,7 +1151,6 @@ public async Task BeginTransactionAsync(FlightCallOptions options) { var beginTransactionResult = ActionBeginTransactionResult.Parser.ParseFrom(result.Body.Span); - var transaction = new Transaction(beginTransactionResult?.TransactionId.ToBase64()); return transaction; } @@ -1212,7 +1209,6 @@ public async Task CommitAsync(FlightCallOptions options, Transaction transaction } } - /// /// Rollback a transaction. /// After this, the transaction and all associated savepoints will be invalidated. @@ -1305,20 +1301,6 @@ public async Task CancelQueryAsync(FlightCallOptions options, Flig } } - /*public async Task SetSessionOptionsAsync(FlightCallOptions options, - Dictionary sessionOptions) - { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - - if (sessionOptions == null) - { - throw new ArgumentNullException(nameof(sessionOptions)); - } - }*/ - /// /// Create a prepared statement object. /// diff --git a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs index efb22b1948a01..358c461380cc6 100644 --- a/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs +++ b/csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs @@ -21,7 +21,20 @@ namespace Apache.Arrow.Flight.Client { - public class FlightClient + public interface IFlightClient + { + AsyncServerStreamingCall ListFlights(FlightCriteria criteria = null, Metadata headers = null); + AsyncServerStreamingCall ListActions(Metadata headers = null); + FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers = null); + AsyncUnaryCall GetInfo(FlightDescriptor flightDescriptor, Metadata headers = null); + FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers = null); + AsyncDuplexStreamingCall Handshake(Metadata headers = null); + FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescriptor, Metadata headers = null); + AsyncServerStreamingCall DoAction(FlightAction action, Metadata headers = null); + AsyncUnaryCall GetSchema(FlightDescriptor flightDescriptor, Metadata headers = null); + } + + public class FlightClient : IFlightClient { internal static readonly Empty EmptyInstance = new Empty(); @@ -34,30 +47,36 @@ public FlightClient(ChannelBase grpcChannel) public AsyncServerStreamingCall ListFlights(FlightCriteria criteria = null, Metadata headers = null) { - if(criteria == null) + if (criteria == null) { criteria = FlightCriteria.Empty; } - + var response = _client.ListFlights(criteria.ToProtocol(), headers); - var convertStream = new StreamReader(response.ResponseStream, inFlight => new FlightInfo(inFlight)); + var convertStream = + new StreamReader(response.ResponseStream, + inFlight => new FlightInfo(inFlight)); - return new AsyncServerStreamingCall(convertStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); + return new AsyncServerStreamingCall(convertStream, response.ResponseHeadersAsync, + response.GetStatus, response.GetTrailers, response.Dispose); } public AsyncServerStreamingCall ListActions(Metadata headers = null) { var response = _client.ListActions(EmptyInstance, headers); - var convertStream = new StreamReader(response.ResponseStream, actionType => new FlightActionType(actionType)); + var convertStream = new StreamReader(response.ResponseStream, + actionType => new FlightActionType(actionType)); - return new AsyncServerStreamingCall(convertStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose); + return new AsyncServerStreamingCall(convertStream, response.ResponseHeadersAsync, + response.GetStatus, response.GetTrailers, response.Dispose); } public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers = null) { - var stream = _client.DoGet(ticket.ToProtocol(), headers); + var stream = _client.DoGet(ticket.ToProtocol(), headers); var responseStream = new FlightClientRecordBatchStreamReader(stream.ResponseStream); - return new FlightRecordBatchStreamingCall(responseStream, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose); + return new FlightRecordBatchStreamingCall(responseStream, stream.ResponseHeadersAsync, stream.GetStatus, + stream.GetTrailers, stream.Dispose); } public AsyncUnaryCall GetInfo(FlightDescriptor flightDescriptor, Metadata headers = null) @@ -81,7 +100,8 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc { var channels = _client.DoPut(headers); var requestStream = new FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor); - var readStream = new StreamReader(channels.ResponseStream, putResult => new FlightPutResult(putResult)); + var readStream = new StreamReader(channels.ResponseStream, + putResult => new FlightPutResult(putResult)); return new FlightRecordBatchDuplexStreamingCall( requestStream, readStream, @@ -91,10 +111,12 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc channels.Dispose); } - public AsyncDuplexStreamingCall Handshake(Metadata headers = null) + public AsyncDuplexStreamingCall Handshake( + Metadata headers = null) { var channel = _client.Handshake(headers); - var readStream = new StreamReader(channel.ResponseStream, response => new FlightHandshakeResponse(response)); + var readStream = new StreamReader(channel.ResponseStream, + response => new FlightHandshakeResponse(response)); var writeStream = new FlightHandshakeStreamWriterAdapter(channel.RequestStream); var call = new AsyncDuplexStreamingCall( writeStream, @@ -126,8 +148,11 @@ public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescripto public AsyncServerStreamingCall DoAction(FlightAction action, Metadata headers = null) { var stream = _client.DoAction(action.ToProtocol(), headers); - var streamReader = new StreamReader(stream.ResponseStream, result => new FlightResult(result)); - return new AsyncServerStreamingCall(streamReader, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose); + var streamReader = + new StreamReader(stream.ResponseStream, + result => new FlightResult(result)); + return new AsyncServerStreamingCall(streamReader, stream.ResponseHeadersAsync, + stream.GetStatus, stream.GetTrailers, stream.Dispose); } public AsyncUnaryCall GetSchema(FlightDescriptor flightDescriptor, Metadata headers = null) @@ -136,7 +161,9 @@ public AsyncUnaryCall GetSchema(FlightDescriptor flightDescriptor, Metad var schema = schemaResult .ResponseAsync - .ContinueWith(async schema => FlightMessageSerializer.DecodeSchema((await schemaResult.ResponseAsync.ConfigureAwait(false)).Schema.Memory)) + .ContinueWith(async schema => + FlightMessageSerializer.DecodeSchema((await schemaResult.ResponseAsync.ConfigureAwait(false)).Schema + .Memory)) .Unwrap(); return new AsyncUnaryCall( diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj index dc95f9edf9f7f..09e8dd26566bd 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj @@ -14,6 +14,7 @@ + diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs new file mode 100644 index 0000000000000..80917bdbc69f4 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -0,0 +1,140 @@ +using System.Collections.Generic; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Client; +using Apache.Arrow.Flight.Sql.Client; +using Apache.Arrow.Flight.Tests; +using Apache.Arrow.Flight.TestWeb; +using Google.Protobuf; +using Grpc.Core; +using Xunit; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class FlightSqlClientTests +{ + readonly TestWebFactory _testWebFactory; + readonly FlightStore _flightStore; + + public FlightSqlClientTests() + { + _flightStore = new FlightStore(); + _testWebFactory = new TestWebFactory(_flightStore); + } + + [Fact] + public async Task CommitAsync_CommitsTransactionSuccessfully() + { + // Arrange + var mockClient = new MockFlightClient(_testWebFactory.GetChannel()); + string transactionId = "sample-transaction-id"; + var expectedResult = + new FlightResult(ByteString.CopyFromUtf8("Transaction committed successfully.")); + + mockClient.SetupActionResult("Commit", [expectedResult]); + var client = new FlightSqlClient(mockClient); + var options = new FlightCallOptions(); + var transaction = new Transaction(transactionId); + + // Act + await client.CommitAsync(options, transaction); + + // Assert + Assert.Single(mockClient.SentActions); + Assert.Equal("Commit", mockClient.SentActions.First().Type); + Assert.Equal(transactionId, mockClient.SentActions.First().Body.ToStringUtf8()); + } + + [Fact] + public async Task CommitAsync_WithNoActions_ShouldNotCommitTransaction() + { + // Arrange + var mockClient = new MockFlightClient(_testWebFactory.GetChannel(), new NoActionStrategy()); + var client = new FlightSqlClient(mockClient); + var options = new FlightCallOptions(); + var transaction = new Transaction("sample-transaction-id"); + + // Act + await client.CommitAsync(options, transaction); + + // Assert + Assert.Empty(mockClient.SentActions); + } +} + +public class MockFlightClient : FlightClient, IFlightClient +{ + private readonly Dictionary> _actionResults; + private readonly List _sentActions = new(); + private readonly IActionStrategy _actionStrategy; + + public MockFlightClient(ChannelBase channel, IActionStrategy actionStrategy = null) : base(channel) + { + _actionResults = new Dictionary>(); + _actionStrategy = actionStrategy ?? new DefaultActionStrategy(); + } + + public void SetupActionResult(string actionType, List results) + { + _actionResults[actionType] = results; + } + + public new AsyncServerStreamingCall DoAction(FlightAction action, Metadata headers = null) + { + _actionStrategy.HandleAction(this, action); + + var result = _actionResults.TryGetValue(action.Type, out List actionResult) + ? actionResult + : new List(); + + var responseStream = new MockAsyncStreamReader(result); + return new AsyncServerStreamingCall( + responseStream, + Task.FromResult(new Metadata()), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + } + + public List SentActions => _sentActions; + + private class MockAsyncStreamReader(IEnumerable items) : IAsyncStreamReader + { + private readonly IEnumerator _enumerator = items.GetEnumerator(); + + public T Current => _enumerator.Current; + + + public async Task MoveNext(CancellationToken cancellationToken) + { + return await Task.Run(() => _enumerator.MoveNext(), cancellationToken); + } + + public void Dispose() + { + _enumerator.Dispose(); + } + } +} + +public interface IActionStrategy +{ + void HandleAction(MockFlightClient mockClient, FlightAction action); +} + +internal class DefaultActionStrategy : IActionStrategy +{ + public void HandleAction(MockFlightClient mockClient, FlightAction action) + { + mockClient.SentActions.Add(action); + } +} + +internal class NoActionStrategy : IActionStrategy +{ + public void HandleAction(MockFlightClient mockClient, FlightAction action) + { + // Do nothing, or handle the action differently + } +} diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs index 031495fffdcc7..0df6460e56b0e 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs @@ -13,6 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +using Apache.Arrow.Flight.Sql.Client; using Google.Protobuf; using Google.Protobuf.WellKnownTypes; diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs index 3dca632b5b761..da86778897ad1 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs @@ -19,12 +19,25 @@ using Apache.Arrow.Flight.Server; using Apache.Arrow.Types; using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; using Grpc.Core; namespace Apache.Arrow.Flight.Sql.Tests; public class TestFlightSqlSever : FlightSqlServer { + public override async Task DoAction(FlightAction action, IAsyncStreamWriter responseStream, ServerCallContext context) + { + if (action.Type == "Commit") + { + await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("Transaction committed successfully."))); + } + else + { + await base.DoAction(action, responseStream, context); + } + } + protected override Task GetStatementQueryFlightInfo(CommandStatementQuery commandStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); protected override Task GetPreparedStatementQueryFlightInfo(CommandPreparedStatementQuery preparedStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); @@ -86,4 +99,5 @@ private RecordBatch MockRecordBatch(string name) var schema = new Schema(new List {new(name, StringType.Default, false)}, System.Array.Empty>()); return new RecordBatch(schema, new []{ new StringArray.Builder().Append(name).Build() }, 1); } + } From a3c0c09355e5ded0127420a314faf560ddd698c0 Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Wed, 28 Aug 2024 19:33:49 +0300 Subject: [PATCH 07/58] refactor: FlightSqlClient --- .../Program.cs | 126 +++++++++--------- .../Client/FlightSqlClient.cs | 13 +- .../Apache.Arrow.Flight.Sql.Tests.csproj | 6 + .../FlightSqlClientTests.cs | 113 ++-------------- .../TestFlightSqlSever.cs | 13 -- .../TestFlightServer.cs | 10 +- .../Apache.Arrow.Flight.Tests/FlightTests.cs | 2 +- 7 files changed, 93 insertions(+), 190 deletions(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs index 0b443250e68d7..9672de39fcb4a 100644 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs @@ -154,34 +154,32 @@ static async Task Main(string[] args) // ExecuteAsync Console.WriteLine("ExecuteAsync:"); var flightInfo = await sqlClient.ExecuteAsync(new FlightCallOptions(), query); - // Handle the ExecuteAsync result - Console.WriteLine($@"Query executed successfully. Records count: {flightInfo.TotalRecords}"); // ExecuteUpdate - Console.WriteLine("ExecuteUpdate:"); - string updateQuery = "UPDATE SYSDB.`Info` SET Key = 1, Val=10 WHERE Id=1"; - long affectedRows = await sqlClient.ExecuteUpdateAsync(new FlightCallOptions(), updateQuery); - // Handle the ExecuteUpdate result - Console.WriteLine($@"Number of affected d rows: {affectedRows}"); - - // GetExecuteSchema - Console.WriteLine("GetExecuteSchema:"); - var schemaResult = await sqlClient.GetExecuteSchemaAsync(new FlightCallOptions(), query); - // Process the schemaResult as needed - Console.WriteLine($"Schema retrieved successfully:{schemaResult}"); - - // ExecuteIngest - - // GetCatalogs - Console.WriteLine("GetCatalogs:"); - var catalogsInfo = await sqlClient.GetCatalogs(new FlightCallOptions()); - // Print catalog details - Console.WriteLine("Catalogs retrieved:"); - foreach (var endpoint in catalogsInfo.Endpoints) - { - var ticket = endpoint.Ticket; - Console.WriteLine($"- Ticket: {ticket}"); - } + // Console.WriteLine("ExecuteUpdate:"); + // string updateQuery = "UPDATE SYSDB.`Info` SET Key = 1, Val=10 WHERE Id=1"; + // long affectedRows = await sqlClient.ExecuteUpdateAsync(new FlightCallOptions(), updateQuery); + // // Handle the ExecuteUpdate result + // Console.WriteLine($@"Number of affected d rows: {affectedRows}"); + // + // // GetExecuteSchema + // Console.WriteLine("GetExecuteSchema:"); + // var schemaResult = await sqlClient.GetExecuteSchemaAsync(new FlightCallOptions(), query); + // // Process the schemaResult as needed + // Console.WriteLine($"Schema retrieved successfully:{schemaResult}"); + // + // // ExecuteIngest + // + // // GetCatalogs + // Console.WriteLine("GetCatalogs:"); + // var catalogsInfo = await sqlClient.GetCatalogs(new FlightCallOptions()); + // // Print catalog details + // Console.WriteLine("Catalogs retrieved:"); + // foreach (var endpoint in catalogsInfo.Endpoints) + // { + // var ticket = endpoint.Ticket; + // Console.WriteLine($"- Ticket: {ticket}"); + // } // GetCatalogsSchema // Console.WriteLine("GetCatalogsSchema:"); @@ -220,24 +218,24 @@ static async Task Main(string[] args) // Call GetTablesAsync method - Console.WriteLine("GetTablesAsync:"); - IEnumerable tables = await sqlClient.GetTablesAsync( - new FlightCallOptions(), - catalog: "", - dbSchemaFilterPattern: "public", - tableFilterPattern: "SYSDB", - includeSchema: true, - tableTypes: new List { "TABLE", "VIEW" }); - foreach (var table in tables) - { - Console.WriteLine($"Table URI: {table.Descriptor.Paths}"); - foreach (var endpoint in table.Endpoints) - { - Console.WriteLine($"Endpoint Ticket: {endpoint.Ticket}"); - } - } - - var tableRef = new TableRef { Catalog = "", DbSchema = "SYSDB", Table = "Info" }; + // Console.WriteLine("GetTablesAsync:"); + // IEnumerable tables = await sqlClient.GetTablesAsync( + // new FlightCallOptions(), + // catalog: "", + // dbSchemaFilterPattern: "public", + // tableFilterPattern: "SYSDB", + // includeSchema: true, + // tableTypes: new List { "TABLE", "VIEW" }); + // foreach (var table in tables) + // { + // Console.WriteLine($"Table URI: {table.Descriptor.Paths}"); + // foreach (var endpoint in table.Endpoints) + // { + // Console.WriteLine($"Endpoint Ticket: {endpoint.Ticket}"); + // } + // } + // + // var tableRef = new TableRef { Catalog = "", DbSchema = "SYSDB", Table = "Info" }; // Get exported keys // Console.WriteLine("GetExportedKeysAsync:"); @@ -297,17 +295,17 @@ static async Task Main(string[] args) // Console.WriteLine($"Schema: {tableTypesSchema}"); // Get XDBC type info (with DataType) - Console.WriteLine("GetXdbcTypeInfoAsync: (With DataType)"); - var flightInfoGetXdbcTypeInfoWithoutDataType = - await sqlClient.GetXdbcTypeInfoAsync(new FlightCallOptions(), 4); - Console.WriteLine("XDBC With DataType Info obtained:"); - Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeInfoWithoutDataType}"); + // Console.WriteLine("GetXdbcTypeInfoAsync: (With DataType)"); + // var flightInfoGetXdbcTypeInfoWithoutDataType = + // await sqlClient.GetXdbcTypeInfoAsync(new FlightCallOptions(), 4); + // Console.WriteLine("XDBC With DataType Info obtained:"); + // Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeInfoWithoutDataType}"); // Get XDBC type info - Console.WriteLine("GetXdbcTypeInfoAsync:"); - var flightInfoGetXdbcTypeInfo = await sqlClient.GetXdbcTypeInfoAsync(new FlightCallOptions()); - Console.WriteLine("XDBC Type Info obtained:"); - Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeInfo}"); + // Console.WriteLine("GetXdbcTypeInfoAsync:"); + // var flightInfoGetXdbcTypeInfo = await sqlClient.GetXdbcTypeInfoAsync(new FlightCallOptions()); + // Console.WriteLine("XDBC Type Info obtained:"); + // Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeInfo}"); // Get XDBC type info schema // Console.WriteLine("GetXdbcTypeInfoSchemaAsync:"); @@ -316,12 +314,12 @@ static async Task Main(string[] args) // Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeSchemaInfo}"); // Get SQL info - Console.WriteLine("GetSqlInfoAsync:"); + // Console.WriteLine("GetSqlInfoAsync:"); // Define SQL info list - var sqlInfo = new List { 1, 2, 3 }; - var flightInfoGetSqlInfo = sqlClient.GetSqlInfoAsync(new FlightCallOptions(), sqlInfo); - Console.WriteLine("SQL Info obtained:"); - Console.WriteLine($"FlightInfo: {flightInfoGetSqlInfo}"); + // var sqlInfo = new List { 1, 2, 3 }; + // var flightInfoGetSqlInfo = sqlClient.GetSqlInfoAsync(new FlightCallOptions(), sqlInfo); + // Console.WriteLine("SQL Info obtained:"); + // Console.WriteLine($"FlightInfo: {flightInfoGetSqlInfo}"); // Get SQL info schema // Console.WriteLine("GetSqlInfoSchemaAsync:"); @@ -330,9 +328,9 @@ static async Task Main(string[] args) // Console.WriteLine($"Schema: {schema}"); // Prepare a SQL statement - Console.WriteLine("PrepareAsync:"); - var preparedStatement = await sqlClient.PrepareAsync(new FlightCallOptions(), query); - Console.WriteLine("Prepared statement created successfully."); + // Console.WriteLine("PrepareAsync:"); + // var preparedStatement = await sqlClient.PrepareAsync(new FlightCallOptions(), query); + // Console.WriteLine("Prepared statement created successfully."); // Cancel FlightInfo Request @@ -360,9 +358,9 @@ static async Task Main(string[] args) // Console.WriteLine("Transaction rolled back successfully."); // Cancel Query - Console.WriteLine("CancelQuery:"); - var cancelResult = await sqlClient.CancelQueryAsync(new FlightCallOptions(), flightInfo); - Console.WriteLine($"Cancellation Status: {cancelResult}"); + // Console.WriteLine("CancelQuery:"); + // var cancelResult = await sqlClient.CancelQueryAsync(new FlightCallOptions(), flightInfo); + // Console.WriteLine($"Cancellation Status: {cancelResult}"); } catch (Exception ex) { diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index c9b4d4672df39..7a5632730d180 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -12,9 +12,9 @@ namespace Apache.Arrow.Flight.Sql.Client; public class FlightSqlClient { - private readonly IFlightClient _client; + private readonly FlightClient _client; - public FlightSqlClient(IFlightClient client) + public FlightSqlClient(FlightClient client) { _client = client ?? throw new ArgumentNullException(nameof(client)); } @@ -1176,7 +1176,7 @@ public async Task BeginTransactionAsync(FlightCallOptions options) /// RPC-layer hints for this call. /// The transaction. /// A Task representing the asynchronous operation. - public async Task CommitAsync(FlightCallOptions options, Transaction transaction) + public AsyncServerStreamingCall CommitAsync(FlightCallOptions options, Transaction transaction) { if (options == null) { @@ -1191,11 +1191,7 @@ public async Task CommitAsync(FlightCallOptions options, Transaction transaction try { var actionCommit = new FlightAction("Commit", transaction.TransactionId); - var responseStream = _client.DoAction(actionCommit, options.Headers); - await foreach (var result in responseStream.ResponseStream.ReadAllAsync()) - { - Console.WriteLine("Transaction committed successfully."); - } + return _client.DoAction(actionCommit, options.Headers); } catch (RpcException ex) { @@ -1287,6 +1283,7 @@ public async Task CancelQueryAsync(FlightCallOptions options, Flig }; } } + throw new InvalidOperationException("Failed to cancel query: No response received."); } catch (RpcException ex) diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj index 09e8dd26566bd..9afc9bc51faa1 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj @@ -17,4 +17,10 @@ + + + ..\..\..\..\..\..\Applications\Rider.app\Contents\lib\ReSharperHost\TestRunner\netcoreapp3.0\JetBrains.ReSharper.TestRunner.Merged.dll + + + diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 80917bdbc69f4..6bcca7b600053 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -1,4 +1,6 @@ +using System; using System.Collections.Generic; +using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; @@ -12,129 +14,34 @@ namespace Apache.Arrow.Flight.Sql.Tests; -public class FlightSqlClientTests +public class FlightSqlClientTests : IDisposable { readonly TestWebFactory _testWebFactory; readonly FlightStore _flightStore; + readonly FlightClient _flightClient; public FlightSqlClientTests() { _flightStore = new FlightStore(); _testWebFactory = new TestWebFactory(_flightStore); + _flightClient = new FlightClient(_testWebFactory.GetChannel()); } [Fact] - public async Task CommitAsync_CommitsTransactionSuccessfully() + public void CommitAsync_CommitsTransaction() { // Arrange - var mockClient = new MockFlightClient(_testWebFactory.GetChannel()); string transactionId = "sample-transaction-id"; - var expectedResult = - new FlightResult(ByteString.CopyFromUtf8("Transaction committed successfully.")); - - mockClient.SetupActionResult("Commit", [expectedResult]); - var client = new FlightSqlClient(mockClient); + var client = new FlightSqlClient(_flightClient); var options = new FlightCallOptions(); var transaction = new Transaction(transactionId); // Act - await client.CommitAsync(options, transaction); - - // Assert - Assert.Single(mockClient.SentActions); - Assert.Equal("Commit", mockClient.SentActions.First().Type); - Assert.Equal(transactionId, mockClient.SentActions.First().Body.ToStringUtf8()); - } - - [Fact] - public async Task CommitAsync_WithNoActions_ShouldNotCommitTransaction() - { - // Arrange - var mockClient = new MockFlightClient(_testWebFactory.GetChannel(), new NoActionStrategy()); - var client = new FlightSqlClient(mockClient); - var options = new FlightCallOptions(); - var transaction = new Transaction("sample-transaction-id"); - - // Act - await client.CommitAsync(options, transaction); + var streamCall = client.CommitAsync(options, transaction); // Assert - Assert.Empty(mockClient.SentActions); + // Assert.Contains("Transaction committed successfully.", consoleOutput.ToString()); } -} - -public class MockFlightClient : FlightClient, IFlightClient -{ - private readonly Dictionary> _actionResults; - private readonly List _sentActions = new(); - private readonly IActionStrategy _actionStrategy; - - public MockFlightClient(ChannelBase channel, IActionStrategy actionStrategy = null) : base(channel) - { - _actionResults = new Dictionary>(); - _actionStrategy = actionStrategy ?? new DefaultActionStrategy(); - } - - public void SetupActionResult(string actionType, List results) - { - _actionResults[actionType] = results; - } - - public new AsyncServerStreamingCall DoAction(FlightAction action, Metadata headers = null) - { - _actionStrategy.HandleAction(this, action); - - var result = _actionResults.TryGetValue(action.Type, out List actionResult) - ? actionResult - : new List(); - var responseStream = new MockAsyncStreamReader(result); - return new AsyncServerStreamingCall( - responseStream, - Task.FromResult(new Metadata()), - () => Status.DefaultSuccess, - () => new Metadata(), - () => { }); - } - - public List SentActions => _sentActions; - - private class MockAsyncStreamReader(IEnumerable items) : IAsyncStreamReader - { - private readonly IEnumerator _enumerator = items.GetEnumerator(); - - public T Current => _enumerator.Current; - - - public async Task MoveNext(CancellationToken cancellationToken) - { - return await Task.Run(() => _enumerator.MoveNext(), cancellationToken); - } - - public void Dispose() - { - _enumerator.Dispose(); - } - } -} - -public interface IActionStrategy -{ - void HandleAction(MockFlightClient mockClient, FlightAction action); -} - -internal class DefaultActionStrategy : IActionStrategy -{ - public void HandleAction(MockFlightClient mockClient, FlightAction action) - { - mockClient.SentActions.Add(action); - } -} - -internal class NoActionStrategy : IActionStrategy -{ - public void HandleAction(MockFlightClient mockClient, FlightAction action) - { - // Do nothing, or handle the action differently - } + public void Dispose() => _testWebFactory?.Dispose(); } diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs index da86778897ad1..d2e1ac21f9e01 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs @@ -19,25 +19,12 @@ using Apache.Arrow.Flight.Server; using Apache.Arrow.Types; using Arrow.Flight.Protocol.Sql; -using Google.Protobuf; using Grpc.Core; namespace Apache.Arrow.Flight.Sql.Tests; public class TestFlightSqlSever : FlightSqlServer { - public override async Task DoAction(FlightAction action, IAsyncStreamWriter responseStream, ServerCallContext context) - { - if (action.Type == "Commit") - { - await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("Transaction committed successfully."))); - } - else - { - await base.DoAction(action, responseStream, context); - } - } - protected override Task GetStatementQueryFlightInfo(CommandStatementQuery commandStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); protected override Task GetPreparedStatementQueryFlightInfo(CommandPreparedStatementQuery preparedStatementQuery, FlightDescriptor flightDescriptor, ServerCallContext serverCallContext) => Task.FromResult(new FlightInfo(null, FlightDescriptor.CreatePathDescriptor(MethodBase.GetCurrentMethod().Name), System.Array.Empty())); diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs index 4a72b73274f1e..cbca22c435c76 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs @@ -40,6 +40,12 @@ public override async Task DoAction(FlightAction request, IAsyncStreamWriter resp await responseStream.WriteAsync(new FlightActionType("put", "add a flight")); await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight")); await responseStream.WriteAsync(new FlightActionType("test", "test action")); + await responseStream.WriteAsync(new FlightActionType("commit", "commit a transaction")); + await responseStream.WriteAsync(new FlightActionType("rollback", "rollback a transaction")); } public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter responseStream, ServerCallContext context) diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index aac4e4209240a..67f8b6b22a6fb 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -187,7 +187,7 @@ public async Task TestGetFlightMetadata() var getStream = _flightClient.GetStream(endpoint.Ticket); - List actualMetadata = new List(); + List actualMetadata = new List(); while(await getStream.ResponseStream.MoveNext(default)) { actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata); From 4d48ab1ec4899124f98254419caed6c7e69ca08c Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Mon, 2 Sep 2024 17:40:19 +0300 Subject: [PATCH 08/58] feat(test): GetFlightInfo --- .../Client/FlightSqlClient.cs | 75 ++----------- .../FlightSqlClientTests.cs | 104 ++++++++++++++++-- .../FlightHolder.cs | 32 ++++-- .../TestFlightServer.cs | 11 +- .../FlightTestUtils.cs | 48 ++++++++ .../Apache.Arrow.Flight.Tests/FlightTests.cs | 92 ++++++---------- 6 files changed, 223 insertions(+), 139 deletions(-) create mode 100644 csharp/test/Apache.Arrow.Flight.Tests/FlightTestUtils.cs diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 7a5632730d180..f390f5e78de36 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -32,8 +32,6 @@ public async Task ExecuteAsync(FlightCallOptions options, string que { transaction ??= NoTransaction(); - FlightInfo? flightInfo = null; - if (options == null) { throw new ArgumentNullException(nameof(options)); @@ -46,12 +44,10 @@ public async Task ExecuteAsync(FlightCallOptions options, string que try { - Console.WriteLine($@"Executing query: {query}"); var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query }; var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); var call = _client.DoAction(action, options.Headers); - // Process the response await foreach (var result in call.ResponseStream.ReadAllAsync()) { var preparedStatementResponse = @@ -60,30 +56,17 @@ public async Task ExecuteAsync(FlightCallOptions options, string que { PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle }; + byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); - flightInfo = await GetFlightInfoAsync(options, descriptor); - var doGetResult = DoGetAsync(options, flightInfo.Endpoints[0].Ticket); - await foreach (var recordBatch in doGetResult) - { - Console.WriteLine(recordBatch); - } + return await GetFlightInfoAsync(options, descriptor); } - - return flightInfo!; + throw new InvalidOperationException("No results returned from the query."); } catch (RpcException ex) { - // Handle gRPC exceptions - Console.WriteLine($@"gRPC Error: {ex.Status}"); throw new InvalidOperationException("Failed to execute query", ex); } - catch (Exception ex) - { - // Handle other exceptions - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -172,16 +155,8 @@ public async Task GetFlightInfoAsync(FlightCallOptions options, Flig } catch (RpcException ex) { - // Handle gRPC exceptions - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get flight info", ex); } - catch (Exception ex) - { - // Handle other exceptions - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -1147,26 +1122,19 @@ public async Task BeginTransactionAsync(FlightCallOptions options) var actionBeginTransaction = new ActionBeginTransactionRequest(); var action = new FlightAction("BeginTransaction", actionBeginTransaction.PackAndSerialize()); var responseStream = _client.DoAction(action, options.Headers); + await foreach (var result in responseStream.ResponseStream.ReadAllAsync()) { - var beginTransactionResult = - ActionBeginTransactionResult.Parser.ParseFrom(result.Body.Span); - var transaction = new Transaction(beginTransactionResult?.TransactionId.ToBase64()); - return transaction; + string? beginTransactionResult = result.Body.ToStringUtf8(); + return new Transaction(beginTransactionResult); } throw new InvalidOperationException("Failed to begin transaction: No response received."); } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to begin transaction", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -1212,7 +1180,7 @@ public AsyncServerStreamingCall CommitAsync(FlightCallOptions opti /// RPC-layer hints for this call. /// The transaction to rollback. /// A Task representing the asynchronous operation. - public async Task RollbackAsync(FlightCallOptions options, Transaction transaction) + public AsyncServerStreamingCall RollbackAsync(FlightCallOptions options, Transaction transaction) { if (options == null) { @@ -1227,22 +1195,12 @@ public async Task RollbackAsync(FlightCallOptions options, Transaction transacti try { var actionRollback = new FlightAction("Rollback", transaction.TransactionId); - var responseStream = _client.DoAction(actionRollback, options.Headers); - await foreach (var result in responseStream.ResponseStream.ReadAllAsync()) - { - Console.WriteLine("Transaction rolled back successfully."); - } + return _client.DoAction(actionRollback, options.Headers); } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to rollback transaction", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } @@ -1325,7 +1283,9 @@ public async Task PrepareAsync(FlightCallOptions options, str var preparedStatementRequest = new ActionCreatePreparedStatementRequest { Query = query, - // TransactionId = transaction?.TransactionId + TransactionId = transaction is null + ? ByteString.CopyFromUtf8(transaction?.TransactionId) + : ByteString.Empty }; var action = new FlightAction(SqlAction.CreateRequest, preparedStatementRequest.PackAndSerialize()); @@ -1334,7 +1294,7 @@ public async Task PrepareAsync(FlightCallOptions options, str await foreach (var result in call.ResponseStream.ReadAllAsync()) { var preparedStatementResponse = - FlightSqlUtils.ParseAndUnpack(result.Body); + FlightSqlUtils.ParseAndUnpack(result.Body); var commandSqlCall = new CommandPreparedStatementQuery { @@ -1343,11 +1303,6 @@ public async Task PrepareAsync(FlightCallOptions options, str byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); var flightInfo = await GetFlightInfoAsync(options, descriptor); - await foreach (var recordBatch in DoGetAsync(options, flightInfo.Endpoints[0].Ticket)) - { - Console.WriteLine(recordBatch); - } - return new PreparedStatement(this, flightInfo, query); } @@ -1355,14 +1310,8 @@ public async Task PrepareAsync(FlightCallOptions options, str } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status}"); throw new InvalidOperationException("Failed to prepare statement", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } } diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 6bcca7b600053..ae4113aac25bb 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -1,15 +1,11 @@ using System; -using System.Collections.Generic; -using System.IO; using System.Linq; -using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.Sql.Client; using Apache.Arrow.Flight.Tests; using Apache.Arrow.Flight.TestWeb; -using Google.Protobuf; -using Grpc.Core; +using Grpc.Core.Utils; using Xunit; namespace Apache.Arrow.Flight.Sql.Tests; @@ -19,28 +15,118 @@ public class FlightSqlClientTests : IDisposable readonly TestWebFactory _testWebFactory; readonly FlightStore _flightStore; readonly FlightClient _flightClient; + private readonly FlightSqlClient _flightSqlClient; + private readonly FlightTestUtils _testUtils; public FlightSqlClientTests() { _flightStore = new FlightStore(); _testWebFactory = new TestWebFactory(_flightStore); _flightClient = new FlightClient(_testWebFactory.GetChannel()); + _flightSqlClient = new FlightSqlClient(_flightClient); + + _testUtils = new FlightTestUtils(_testWebFactory, _flightStore); } + #region Transactions [Fact] - public void CommitAsync_CommitsTransaction() + public async Task CommitAsync_Transaction() { // Arrange string transactionId = "sample-transaction-id"; - var client = new FlightSqlClient(_flightClient); var options = new FlightCallOptions(); var transaction = new Transaction(transactionId); // Act - var streamCall = client.CommitAsync(options, transaction); + var streamCall = _flightSqlClient.CommitAsync(options, transaction); + var result = await streamCall.ResponseStream.ToListAsync(); + + // Assert + Assert.NotNull(result); + Assert.Equal(transaction.TransactionId, result.FirstOrDefault()?.Body.ToStringUtf8()); + } + + [Fact] + public async Task BeginTransactionAsync_Transaction() + { + // Arrange + var options = new FlightCallOptions(); + string expectedTransactionId = "sample-transaction-id"; + + // Act + var transaction = await _flightSqlClient.BeginTransactionAsync(options); + + // Assert + Assert.NotNull(transaction); + Assert.Equal(expectedTransactionId, transaction.TransactionId); + } + + [Fact] + public async Task RollbackAsync_Transaction() + { + // Arrange + string transactionId = "sample-transaction-id"; + var options = new FlightCallOptions(); + var transaction = new Transaction(transactionId); + + // Act + var streamCall = _flightSqlClient.RollbackAsync(options, transaction); + var result = await streamCall.ResponseStream.ToListAsync(); + + // Assert + Assert.NotNull(transaction); + Assert.Equal(result.FirstOrDefault()?.Body.ToStringUtf8(), transaction.TransactionId); + } + + #endregion + + #region PreparedStatement + [Fact] + public async Task PreparedStatement() + { + // Arrange + string query = "INSERT INTO users (id, name) VALUES (1, 'John Doe')"; + var options = new FlightCallOptions(); + + // Act + var preparedStatement = await _flightSqlClient.PrepareAsync(options, query); + + // Assert + Assert.NotNull(preparedStatement); + } + #endregion + + [Fact] + public async Task Execute() + { + // Arrange + string query = "SELECT * FROM test_table"; + var options = new FlightCallOptions(); + + // Act + var flightInfo = await _flightSqlClient.ExecuteAsync(options, query); + + // Assert + Assert.NotNull(flightInfo); + Assert.Single(flightInfo.Endpoints); + } + + [Fact] + public async Task GetFlightInfo() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + + _flightStore.Flights.Add(flightDescriptor, flightHolder); + // Act + var flightInfo = await _flightSqlClient.GetFlightInfoAsync(options, flightDescriptor); // Assert - // Assert.Contains("Transaction committed successfully.", consoleOutput.ToString()); + Assert.NotNull(flightInfo); } public void Dispose() => _testWebFactory?.Dispose(); diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs index c6f7e66c6c2d8..896bc4489c472 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs @@ -30,7 +30,7 @@ public class FlightHolder //Not thread safe, but only used in tests private readonly List _recordBatches = new List(); - + public FlightHolder(FlightDescriptor flightDescriptor, Schema schema, string location) { _flightDescriptor = flightDescriptor; @@ -52,13 +52,31 @@ public IEnumerable GetRecordBatches() public FlightInfo GetFlightInfo() { int batchArrayLength = _recordBatches.Sum(rb => rb.RecordBatch.Length); - int batchBytes = _recordBatches.Sum(rb => rb.RecordBatch.Arrays.Sum(arr => arr.Data.Buffers.Sum(b=>b.Length))); - return new FlightInfo(_schema, _flightDescriptor, new List() + int batchBytes = + _recordBatches.Sum(rb => rb.RecordBatch.Arrays.Sum(arr => arr.Data.Buffers.Sum(b => b.Length))); + + if (!_flightDescriptor.Paths.Any()) + { + return GetFlightInfoWithCommand(); + } + + var flightInfo = new FlightInfo(_schema, _flightDescriptor, + new List() + { + new FlightEndpoint(new FlightTicket(_flightDescriptor.Paths.FirstOrDefault()), + new List() { new FlightLocation(_location) }) + }, batchArrayLength, batchBytes); + return flightInfo; + } + + public FlightInfo GetFlightInfoWithCommand() + { + if (!_flightDescriptor.Paths.Any()) { - new FlightEndpoint(new FlightTicket(_flightDescriptor.Paths.FirstOrDefault()), new List(){ - new FlightLocation(_location) - }) - }, batchArrayLength, batchBytes); + return new FlightInfo(_schema, _flightDescriptor, new List(), 0, 0); + } + + return null; } } } diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs index cbca22c435c76..87cc27744b7a3 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs @@ -40,11 +40,14 @@ public override async Task DoAction(FlightAction request, IAsyncStreamWriter GetFlightInfo(FlightDescriptor request, ServerC { return Task.FromResult(flightHolder.GetFlightInfo()); } + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); } + public override async Task Handshake(IAsyncStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) { while (await requestStream.MoveNext().ConfigureAwait(false)) diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTestUtils.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTestUtils.cs new file mode 100644 index 0000000000000..6a9184368e658 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTestUtils.cs @@ -0,0 +1,48 @@ +using System.Linq; +using Apache.Arrow.Flight.TestWeb; + +namespace Apache.Arrow.Flight.Tests; + +public class FlightTestUtils +{ + private readonly TestWebFactory _testWebFactory; + private readonly FlightStore _flightStore; + + public FlightTestUtils(TestWebFactory testWebFactory, FlightStore flightStore) + { + _testWebFactory = testWebFactory; + _flightStore = flightStore; + } + + public RecordBatch CreateTestBatch(int startValue, int length) + { + var batchBuilder = new RecordBatch.Builder(); + Int32Array.Builder builder = new(); + for (int i = 0; i < length; i++) + { + builder.Append(startValue + i); + } + + batchBuilder.Append("test", true, builder.Build()); + return batchBuilder.Build(); + } + + + public FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, + params RecordBatchWithMetadata[] batches) + { + var initialBatch = batches.FirstOrDefault(); + + var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, + _testWebFactory.GetAddress()); + + foreach (var batch in batches) + { + flightHolder.AddBatch(batch); + } + + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + return flightHolder.GetFlightInfo(); + } +} diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index 67f8b6b22a6fb..3ef08f3080cae 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -31,11 +31,14 @@ public class FlightTests : IDisposable readonly TestWebFactory _testWebFactory; readonly FlightClient _flightClient; readonly FlightStore _flightStore; + private readonly FlightTestUtils _testUtils; + public FlightTests() { _flightStore = new FlightStore(); _testWebFactory = new TestWebFactory(_flightStore); _flightClient = new FlightClient(_testWebFactory.GetChannel()); + _testUtils = new FlightTestUtils(_testWebFactory, _flightStore); } public void Dispose() @@ -43,18 +46,6 @@ public void Dispose() _testWebFactory.Dispose(); } - private RecordBatch CreateTestBatch(int startValue, int length) - { - var batchBuilder = new RecordBatch.Builder(); - Int32Array.Builder builder = new Int32Array.Builder(); - for (int i = 0; i < length; i++) - { - builder.Append(startValue + i); - } - batchBuilder.Append("test", true, builder.Build()); - return batchBuilder.Build(); - } - private IEnumerable GetStoreBatch(FlightDescriptor flightDescriptor) { @@ -64,27 +55,11 @@ private IEnumerable GetStoreBatch(FlightDescriptor flig return flightHolder.GetRecordBatches(); } - private FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, params RecordBatchWithMetadata[] batches) - { - var initialBatch = batches.FirstOrDefault(); - - var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress()); - - foreach(var batch in batches) - { - flightHolder.AddBatch(batch); - } - - _flightStore.Flights.Add(flightDescriptor, flightHolder); - - return flightHolder.GetFlightInfo(); - } - [Fact] public async Task TestPutSingleRecordBatch() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch = CreateTestBatch(0, 100); + var expectedBatch = _testUtils.CreateTestBatch(0, 100); var putStream = _flightClient.StartPut(flightDescriptor); await putStream.RequestStream.WriteAsync(expectedBatch); @@ -103,8 +78,8 @@ public async Task TestPutSingleRecordBatch() public async Task TestPutTwoRecordBatches() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch1 = CreateTestBatch(0, 100); - var expectedBatch2 = CreateTestBatch(0, 100); + var expectedBatch1 = _testUtils.CreateTestBatch(0, 100); + var expectedBatch2 = _testUtils.CreateTestBatch(0, 100); var putStream = _flightClient.StartPut(flightDescriptor); await putStream.RequestStream.WriteAsync(expectedBatch1); @@ -125,10 +100,10 @@ public async Task TestPutTwoRecordBatches() public async Task TestGetSingleRecordBatch() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch = CreateTestBatch(0, 100); + var expectedBatch = _testUtils.CreateTestBatch(0, 100); //Add batch to the in memory store - GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch)); + _testUtils.GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch)); //Get the flight info for the ticket var flightInfo = await _flightClient.GetInfo(flightDescriptor); @@ -147,11 +122,12 @@ public async Task TestGetSingleRecordBatch() public async Task TestGetTwoRecordBatch() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch1 = CreateTestBatch(0, 100); - var expectedBatch2 = CreateTestBatch(100, 100); + var expectedBatch1 = _testUtils.CreateTestBatch(0, 100); + var expectedBatch2 = _testUtils.CreateTestBatch(100, 100); //Add batch to the in memory store - GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1), new RecordBatchWithMetadata(expectedBatch2)); + _testUtils.GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1), + new RecordBatchWithMetadata(expectedBatch2)); //Get the flight info for the ticket var flightInfo = await _flightClient.GetInfo(flightDescriptor); @@ -171,13 +147,13 @@ public async Task TestGetTwoRecordBatch() public async Task TestGetFlightMetadata() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch1 = CreateTestBatch(0, 100); + var expectedBatch1 = _testUtils.CreateTestBatch(0, 100); var expectedMetadata = ByteString.CopyFromUtf8("test metadata"); var expectedMetadataList = new List() { expectedMetadata }; //Add batch to the in memory store - GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1, expectedMetadata)); + _testUtils.GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1, expectedMetadata)); //Get the flight info for the ticket var flightInfo = await _flightClient.GetInfo(flightDescriptor); @@ -188,7 +164,7 @@ public async Task TestGetFlightMetadata() var getStream = _flightClient.GetStream(endpoint.Ticket); List actualMetadata = new List(); - while(await getStream.ResponseStream.MoveNext(default)) + while (await getStream.ResponseStream.MoveNext(default)) { actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata); } @@ -200,7 +176,7 @@ public async Task TestGetFlightMetadata() public async Task TestPutWithMetadata() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch = CreateTestBatch(0, 100); + var expectedBatch = _testUtils.CreateTestBatch(0, 100); var expectedMetadata = ByteString.CopyFromUtf8("test metadata"); var putStream = _flightClient.StartPut(flightDescriptor); @@ -221,10 +197,10 @@ public async Task TestPutWithMetadata() public async Task TestGetSchema() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch = CreateTestBatch(0, 100); + var expectedBatch = _testUtils.CreateTestBatch(0, 100); var expectedSchema = expectedBatch.Schema; - GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch)); + _testUtils.GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch)); var actualSchema = await _flightClient.GetSchema(flightDescriptor); @@ -266,18 +242,18 @@ public async Task TestListFlights() { var flightDescriptor1 = FlightDescriptor.CreatePathDescriptor("test1"); var flightDescriptor2 = FlightDescriptor.CreatePathDescriptor("test2"); - var expectedBatch = CreateTestBatch(0, 100); + var expectedBatch = _testUtils.CreateTestBatch(0, 100); List expectedFlightInfo = new List(); - expectedFlightInfo.Add(GivenStoreBatches(flightDescriptor1, new RecordBatchWithMetadata(expectedBatch))); - expectedFlightInfo.Add(GivenStoreBatches(flightDescriptor2, new RecordBatchWithMetadata(expectedBatch))); + expectedFlightInfo.Add(_testUtils.GivenStoreBatches(flightDescriptor1, new RecordBatchWithMetadata(expectedBatch))); + expectedFlightInfo.Add(_testUtils.GivenStoreBatches(flightDescriptor2, new RecordBatchWithMetadata(expectedBatch))); var listFlightStream = _flightClient.ListFlights(); var actualFlights = await listFlightStream.ResponseStream.ToListAsync(); - for(int i = 0; i < expectedFlightInfo.Count; i++) + for (int i = 0; i < expectedFlightInfo.Count; i++) { FlightInfoComparer.Compare(expectedFlightInfo[i], actualFlights[i]); } @@ -301,7 +277,7 @@ public async Task TestSingleExchange() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("single_exchange"); var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor); - var expectedBatch = CreateTestBatch(0, 100); + var expectedBatch = _testUtils.CreateTestBatch(0, 100); await duplexStreamingCall.RequestStream.WriteAsync(expectedBatch); await duplexStreamingCall.RequestStream.CompleteAsync(); @@ -317,8 +293,8 @@ public async Task TestMultipleExchange() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("multiple_exchange"); var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor); - var expectedBatch1 = CreateTestBatch(0, 100); - var expectedBatch2 = CreateTestBatch(100, 100); + var expectedBatch1 = _testUtils.CreateTestBatch(0, 100); + var expectedBatch2 = _testUtils.CreateTestBatch(100, 100); await duplexStreamingCall.RequestStream.WriteAsync(expectedBatch1); await duplexStreamingCall.RequestStream.WriteAsync(expectedBatch2); @@ -335,7 +311,7 @@ public async Task TestExchangeWithMetadata() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("metadata_exchange"); var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor); - var expectedBatch = CreateTestBatch(0, 100); + var expectedBatch = _testUtils.CreateTestBatch(0, 100); var expectedMetadata = ByteString.CopyFromUtf8("test metadata"); await duplexStreamingCall.RequestStream.WriteAsync(expectedBatch, expectedMetadata); @@ -358,7 +334,8 @@ public async Task TestHandshakeWithSpecificMessage() { var duplexStreamingCall = _flightClient.Handshake(); - await duplexStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.CopyFromUtf8("Hello"))); + await duplexStreamingCall.RequestStream.WriteAsync( + new FlightHandshakeRequest(ByteString.CopyFromUtf8("Hello"))); await duplexStreamingCall.RequestStream.CompleteAsync(); var results = await duplexStreamingCall.ResponseStream.ToListAsync(); @@ -370,11 +347,12 @@ public async Task TestHandshakeWithSpecificMessage() public async Task TestGetBatchesWithAsyncEnumerable() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch1 = CreateTestBatch(0, 100); - var expectedBatch2 = CreateTestBatch(100, 100); + var expectedBatch1 = _testUtils.CreateTestBatch(0, 100); + var expectedBatch2 = _testUtils.CreateTestBatch(100, 100); //Add batch to the in memory store - GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1), new RecordBatchWithMetadata(expectedBatch2)); + _testUtils.GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1), + new RecordBatchWithMetadata(expectedBatch2)); //Get the flight info for the ticket var flightInfo = await _flightClient.GetInfo(flightDescriptor); @@ -386,7 +364,7 @@ public async Task TestGetBatchesWithAsyncEnumerable() List resultList = new List(); - await foreach(var recordBatch in getStream.ResponseStream) + await foreach (var recordBatch in getStream.ResponseStream) { resultList.Add(recordBatch); } @@ -400,12 +378,12 @@ public async Task TestGetBatchesWithAsyncEnumerable() public async Task EnsureTheSerializedBatchContainsTheProperTotalRecordsAndTotalBytesProperties() { var flightDescriptor1 = FlightDescriptor.CreatePathDescriptor("test1"); - var expectedBatch = CreateTestBatch(0, 100); + var expectedBatch = _testUtils.CreateTestBatch(0, 100); var expectedTotalBytes = expectedBatch.Arrays.Sum(arr => arr.Data.Buffers.Sum(b => b.Length)); List expectedFlightInfo = new List(); - expectedFlightInfo.Add(GivenStoreBatches(flightDescriptor1, new RecordBatchWithMetadata(expectedBatch))); + expectedFlightInfo.Add(_testUtils.GivenStoreBatches(flightDescriptor1, new RecordBatchWithMetadata(expectedBatch))); var listFlightStream = _flightClient.ListFlights(); From 4bfc2a9d84b784ee9cb297b30d02c29246e87bb6 Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Wed, 4 Sep 2024 12:42:22 +0300 Subject: [PATCH 09/58] refactor: Flight sql testing --- .../Program.cs | 7 +- .../FlightSqlHolder.cs | 46 +++++ .../FlightSqlStore.cs | 8 + .../RecordBatchWithMetadata.cs | 15 ++ .../Startup.cs | 39 +++++ .../TestFlightSqlServer.cs | 161 ++++++++++++++++++ .../TestSqlWebFactory.cs | 66 +++++++ csharp/Apache.Arrow.sln | 6 + .../Client/FlightSqlClient.cs | 28 +-- .../Apache.Arrow.Flight.Sql.Tests.csproj | 3 +- .../FlightSqlClientTests.cs | 42 ++++- .../FlightSqlTestUtils.cs | 48 ++++++ .../Apache.Arrow.Flight.TestWeb.csproj | 1 + .../FlightHolder.cs | 30 +--- .../TestFlightServer.cs | 13 -- 15 files changed, 443 insertions(+), 70 deletions(-) create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlStore.cs create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/RecordBatchWithMetadata.cs create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/Startup.cs create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/TestSqlWebFactory.cs create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs index 9672de39fcb4a..f8b3bd49ea992 100644 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs @@ -157,10 +157,9 @@ static async Task Main(string[] args) // ExecuteUpdate // Console.WriteLine("ExecuteUpdate:"); - // string updateQuery = "UPDATE SYSDB.`Info` SET Key = 1, Val=10 WHERE Id=1"; - // long affectedRows = await sqlClient.ExecuteUpdateAsync(new FlightCallOptions(), updateQuery); - // // Handle the ExecuteUpdate result - // Console.WriteLine($@"Number of affected d rows: {affectedRows}"); + string updateQuery = "UPDATE SYSDB.`Info` SET Key = 1, Val=10 WHERE Id=1"; + long affectedRows = await sqlClient.ExecuteUpdateAsync(new FlightCallOptions(), updateQuery); + Console.WriteLine($@"Number of affected rows: {affectedRows}"); // // // GetExecuteSchema // Console.WriteLine("GetExecuteSchema:"); diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs new file mode 100644 index 0000000000000..7b1a382cd7492 --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs @@ -0,0 +1,46 @@ +using System.Collections.Generic; +using System.Linq; + +namespace Apache.Arrow.Flight.Sql.TestWeb; + +public class FlightSqlHolder +{ + private readonly FlightDescriptor _flightDescriptor; + private readonly Schema _schema; + private readonly string _location; + + //Not thread safe, but only used in tests + private readonly List _recordBatches = new List(); + + public FlightSqlHolder(FlightDescriptor flightDescriptor, Schema schema, string location) + { + _flightDescriptor = flightDescriptor; + _schema = schema; + _location = location; + } + + public void AddBatch(RecordBatchWithMetadata recordBatchWithMetadata) + { + //Should validate schema here + _recordBatches.Add(recordBatchWithMetadata); + } + + public IEnumerable GetRecordBatches() + { + return _recordBatches.ToList(); + } + + public FlightInfo GetFlightInfo() + { + int batchArrayLength = _recordBatches.Sum(rb => rb.RecordBatch.Length); + int batchBytes = + _recordBatches.Sum(rb => rb.RecordBatch.Arrays.Sum(arr => arr.Data.Buffers.Sum(b => b.Length))); + var flightInfo = new FlightInfo(_schema, _flightDescriptor, + new List() + { + new FlightEndpoint(new FlightTicket(_flightDescriptor.Paths.FirstOrDefault() ?? "test"), + new List() { new FlightLocation(_location) }) + }, batchArrayLength, batchBytes); + return flightInfo; + } +} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlStore.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlStore.cs new file mode 100644 index 0000000000000..9ac7df457b1b1 --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlStore.cs @@ -0,0 +1,8 @@ +using System.Collections.Generic; + +namespace Apache.Arrow.Flight.Sql.TestWeb; + +public class FlightSqlStore +{ + public Dictionary Flights { get; set; } = new(); +} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/RecordBatchWithMetadata.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/RecordBatchWithMetadata.cs new file mode 100644 index 0000000000000..214d5d557b00a --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/RecordBatchWithMetadata.cs @@ -0,0 +1,15 @@ +using Google.Protobuf; + +namespace Apache.Arrow.Flight.Sql.TestWeb; + +public class RecordBatchWithMetadata +{ + public RecordBatch RecordBatch { get; } + public ByteString Metadata { get; } + + public RecordBatchWithMetadata(RecordBatch recordBatch, ByteString metadata = null) + { + RecordBatch = recordBatch; + Metadata = metadata; + } +} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Startup.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Startup.cs new file mode 100644 index 0000000000000..4019143d57747 --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Startup.cs @@ -0,0 +1,39 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace Apache.Arrow.Flight.Sql.TestWeb; + +public class Startup +{ + public void ConfigureServices(IServiceCollection services) + { + services.AddGrpc() + .AddFlightServer(); + services.AddSingleton(); + } + + public void Configure(IApplicationBuilder app, IWebHostEnvironment env) + { + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + app.UseRouting(); + + app.UseEndpoints(endpoints => + { + endpoints.MapFlightEndpoint(); + + endpoints.MapGet("/", + async context => + { + await context.Response.WriteAsync( + "Communication with gRPC endpoints must be made through a gRPC client. To learn how to create a client, visit: https://go.microsoft.com/fwlink/?linkid=2086909"); + }); + }); + } +} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs new file mode 100644 index 0000000000000..186ac344c7676 --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs @@ -0,0 +1,161 @@ +using Apache.Arrow.Flight.Server; + +namespace Apache.Arrow.Flight.Sql.TestWeb; + +public class TestFlightSqlServer : FlightServer +{ +} +/* + * public class TestFlightServer : FlightServer + { + private readonly FlightStore _flightStore; + + public TestFlightServer(FlightStore flightStore) + { + _flightStore = flightStore; + } + + public override async Task DoAction(FlightAction request, IAsyncStreamWriter responseStream, + ServerCallContext context) + { + switch (request.Type) + { + case "test": + await responseStream.WriteAsync(new FlightResult("test data")); + break; + case "BeginTransaction": + case "Commit": + case "Rollback": + await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id"))); + break; + case "CreatePreparedStatement": + case "ClosePreparedStatement": + var prepareStatementResponse = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement") + }; + + var packedResult = Any.Pack(prepareStatementResponse).Serialize().ToByteArray(); + var flightResult = new FlightResult(packedResult); + await responseStream.WriteAsync(flightResult); + break; + default: + throw new NotImplementedException(); + } + } + + public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, + ServerCallContext context) + { + // var flightDescriptor = FlightDescriptor.CreatePathDescriptor(ticket.Ticket.ToStringUtf8()); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(ticket.Ticket.ToStringUtf8()); + + if (_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) + { + var batches = flightHolder.GetRecordBatches(); + foreach (var batch in batches) + { + await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata); + } + } + } + + public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, + IAsyncStreamWriter responseStream, ServerCallContext context) + { + var flightDescriptor = await requestStream.FlightDescriptor; + + if (!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) + { + flightHolder = new FlightHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + } + + while (await requestStream.MoveNext()) + { + flightHolder.AddBatch(new RecordBatchWithMetadata(requestStream.Current, + requestStream.ApplicationMetadata.FirstOrDefault())); + await responseStream.WriteAsync(FlightPutResult.Empty); + } + } + + public override Task GetFlightInfo(FlightDescriptor request, ServerCallContext context) + { + if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) + { + return Task.FromResult(flightHolder.GetFlightInfo()); + } + + if (_flightStore.Flights.Count > 0) + { + // todo: should rethink of the way to implement dynamic Flights search + return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo()); + } + + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); + } + + + public override async Task Handshake(IAsyncStreamReader requestStream, + IAsyncStreamWriter responseStream, ServerCallContext context) + { + while (await requestStream.MoveNext().ConfigureAwait(false)) + { + if (requestStream.Current.Payload.ToStringUtf8() == "Hello") + { + await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Hello handshake"))) + .ConfigureAwait(false); + } + else + { + await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Done"))).ConfigureAwait(false); + } + } + } + + public override Task GetSchema(FlightDescriptor request, ServerCallContext context) + { + if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) + { + return Task.FromResult(flightHolder.GetFlightInfo().Schema); + } + + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); + } + + public override async Task ListActions(IAsyncStreamWriter responseStream, + ServerCallContext context) + { + await responseStream.WriteAsync(new FlightActionType("get", "get a flight")); + await responseStream.WriteAsync(new FlightActionType("put", "add a flight")); + await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight")); + await responseStream.WriteAsync(new FlightActionType("test", "test action")); + await responseStream.WriteAsync(new FlightActionType("commit", "commit a transaction")); + await responseStream.WriteAsync(new FlightActionType("rollback", "rollback a transaction")); + } + + public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter responseStream, + ServerCallContext context) + { + var flightInfos = _flightStore.Flights.Select(x => x.Value.GetFlightInfo()).ToList(); + + foreach (var flightInfo in flightInfos) + { + await responseStream.WriteAsync(flightInfo); + } + } + + public override async Task DoExchange(FlightServerRecordBatchStreamReader requestStream, + FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) + { + while (await requestStream.MoveNext().ConfigureAwait(false)) + { + await responseStream + .WriteAsync(requestStream.Current, requestStream.ApplicationMetadata.FirstOrDefault()) + .ConfigureAwait(false); + } + } + } + * + * + */ diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestSqlWebFactory.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestSqlWebFactory.cs new file mode 100644 index 0000000000000..5d98a197d65e4 --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestSqlWebFactory.cs @@ -0,0 +1,66 @@ +using System; +using System.Linq; +using Grpc.Net.Client; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Hosting.Server.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace Apache.Arrow.Flight.Sql.TestWeb; + +public class TestSqlWebFactory : IDisposable +{ + readonly IHost host; + private int _port; + + public TestSqlWebFactory(FlightSqlStore flightStore) + { + host = WebHostBuilder(flightStore).Build(); //Create the server + host.Start(); + var addressInfo = host.Services.GetRequiredService().Features.Get(); + if (addressInfo == null) + { + throw new Exception("No address info could be found for configured server"); + } + + var address = addressInfo.Addresses.First(); + var addressUri = new Uri(address); + _port = addressUri.Port; + AppContext.SetSwitch( + "System.Net.Http.SocketsHttpHandler.Http2UnencryptedSupport", true); + } + + private IHostBuilder WebHostBuilder(FlightSqlStore flightStore) + { + return Host.CreateDefaultBuilder() + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder + .ConfigureKestrel(c => { c.ListenAnyIP(0, l => l.Protocols = HttpProtocols.Http2); }) + .UseStartup() + .ConfigureServices(services => { services.AddSingleton(flightStore); }); + }); + } + + public string GetAddress() + { + return $"http://127.0.0.1:{_port}"; + } + + public GrpcChannel GetChannel() + { + return GrpcChannel.ForAddress(GetAddress()); + } + + public void Stop() + { + host.StopAsync().Wait(); + } + + public void Dispose() + { + Stop(); + } +} diff --git a/csharp/Apache.Arrow.sln b/csharp/Apache.Arrow.sln index 1de7202780060..f564971071c01 100644 --- a/csharp/Apache.Arrow.sln +++ b/csharp/Apache.Arrow.sln @@ -29,6 +29,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql", " EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.IntegrationTest", "Apache.Arrow.Flight.Sql.IntegrationTest\Apache.Arrow.Flight.Sql.IntegrationTest.csproj", "{45416D7D-F12B-4524-B641-AD0E1A33B3B0}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.TestWeb", "Apache.Arrow.Flight.Sql.TestWeb\Apache.Arrow.Flight.Sql.TestWeb.csproj", "{85A6CB32-A83B-48C4-96E8-625C8FBDB349}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -87,6 +89,10 @@ Global {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Debug|Any CPU.Build.0 = Debug|Any CPU {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Release|Any CPU.ActiveCfg = Release|Any CPU {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Release|Any CPU.Build.0 = Release|Any CPU + {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Debug|Any CPU.Build.0 = Debug|Any CPU + {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Release|Any CPU.ActiveCfg = Release|Any CPU + {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index f390f5e78de36..9522ead468a76 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -1,6 +1,6 @@ using System; using System.Collections.Generic; -using System.Threading; +using System.IO; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Arrow.Flight.Protocol.Sql; @@ -61,6 +61,7 @@ public async Task ExecuteAsync(FlightCallOptions options, string que var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); return await GetFlightInfoAsync(options, descriptor); } + throw new InvalidOperationException("No results returned from the query."); } catch (RpcException ex) @@ -92,18 +93,15 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options, string que try { - // Step 1: Create statement query - Console.WriteLine($@"Executing query: {query}"); var updateRequestCommand = new ActionCreatePreparedStatementRequest { Query = query }; byte[] serializedUpdateRequestCommand = updateRequestCommand.PackAndSerialize(); var action = new FlightAction(SqlAction.CreateRequest, serializedUpdateRequestCommand); - var call = _client.DoAction(action, options.Headers); + var call = DoActionAsync(options, action); long affectedRows = 0; - await foreach (var result in call.ResponseStream.ReadAllAsync()) - { - var preparedStatementResponse = - FlightSqlUtils.ParseAndUnpack(result.Body); + await foreach (var result in call) + { + var preparedStatementResponse = result.Body.ParseAndUnpack(); var command = new CommandPreparedStatementQuery { PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle @@ -111,12 +109,10 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options, string que var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); var flightInfo = await GetFlightInfoAsync(options, descriptor); - var doGetResult = DoGetAsync(options, flightInfo.Endpoints[0].Ticket); await foreach (var recordBatch in doGetResult) { - Console.WriteLine(recordBatch); - Interlocked.Increment(ref affectedRows); + affectedRows += recordBatch.Column(0).Length; } } @@ -124,14 +120,8 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options, string que } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to execute update query", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -1322,8 +1312,8 @@ public static byte[] PackAndSerialize(this IMessage command) return Any.Pack(command).Serialize().ToByteArray(); } - public static T ParseAndUnpack(this ByteString body) where T : IMessage, new() + public static T ParseAndUnpack(this ByteString source) where T : IMessage, new() { - return Any.Parser.ParseFrom(body).Unpack(); + return Any.Parser.ParseFrom(source).Unpack(); } } diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj index 810309c4f638f..3b8ce02fa4ae0 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj @@ -13,8 +13,9 @@ + - + diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index ae4113aac25bb..4bbcc152521f6 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -3,8 +3,7 @@ using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.Sql.Client; -using Apache.Arrow.Flight.Tests; -using Apache.Arrow.Flight.TestWeb; +using Apache.Arrow.Flight.Sql.TestWeb; using Grpc.Core.Utils; using Xunit; @@ -12,23 +11,24 @@ namespace Apache.Arrow.Flight.Sql.Tests; public class FlightSqlClientTests : IDisposable { - readonly TestWebFactory _testWebFactory; - readonly FlightStore _flightStore; + readonly TestSqlWebFactory _testWebFactory; + readonly FlightSqlStore _flightStore; readonly FlightClient _flightClient; private readonly FlightSqlClient _flightSqlClient; - private readonly FlightTestUtils _testUtils; + private readonly FlightSqlTestUtils _testUtils; public FlightSqlClientTests() { - _flightStore = new FlightStore(); - _testWebFactory = new TestWebFactory(_flightStore); + _flightStore = new FlightSqlStore(); + _testWebFactory = new TestSqlWebFactory(_flightStore); _flightClient = new FlightClient(_testWebFactory.GetChannel()); _flightSqlClient = new FlightSqlClient(_flightClient); - _testUtils = new FlightTestUtils(_testWebFactory, _flightStore); + _testUtils = new FlightSqlTestUtils(_testWebFactory, _flightStore); } #region Transactions + [Fact] public async Task CommitAsync_Transaction() { @@ -81,6 +81,7 @@ public async Task RollbackAsync_Transaction() #endregion #region PreparedStatement + [Fact] public async Task PreparedStatement() { @@ -94,8 +95,31 @@ public async Task PreparedStatement() // Assert Assert.NotNull(preparedStatement); } + #endregion + [Fact] + public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows() + { + // Arrange + string query = "UPDATE test_table SET column1 = 'value' WHERE column2 = 'condition'"; + var options = new FlightCallOptions(); + var transaction = new Transaction("sample-transaction-id"); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); + flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); + + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + long affectedRows = await _flightSqlClient.ExecuteUpdateAsync(options, query, transaction); + + // Assert + Assert.Equal(100, affectedRows); + } + [Fact] public async Task Execute() { @@ -118,7 +142,7 @@ public async Task GetFlightInfo() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs new file mode 100644 index 0000000000000..25147349628f3 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs @@ -0,0 +1,48 @@ +using System.Linq; +using Apache.Arrow.Flight.Sql.TestWeb; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class FlightSqlTestUtils +{ + private readonly TestSqlWebFactory _testWebFactory; + private readonly FlightSqlStore _flightStore; + + public FlightSqlTestUtils(TestSqlWebFactory testWebFactory, FlightSqlStore flightStore) + { + _testWebFactory = testWebFactory; + _flightStore = flightStore; + } + + public RecordBatch CreateTestBatch(int startValue, int length) + { + var batchBuilder = new RecordBatch.Builder(); + Int32Array.Builder builder = new(); + for (int i = 0; i < length; i++) + { + builder.Append(startValue + i); + } + + batchBuilder.Append("test", true, builder.Build()); + return batchBuilder.Build(); + } + + + public FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, + params RecordBatchWithMetadata[] batches) + { + var initialBatch = batches.FirstOrDefault(); + + var flightHolder = new FlightSqlHolder(flightDescriptor, initialBatch.RecordBatch.Schema, + _testWebFactory.GetAddress()); + + foreach (var batch in batches) + { + flightHolder.AddBatch(batch); + } + + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + return flightHolder.GetFlightInfo(); + } +} diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj b/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj index 14227e2c4eb6b..fcbb24033af45 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj @@ -11,6 +11,7 @@ + diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs index 896bc4489c472..c033987db3c30 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs @@ -52,31 +52,13 @@ public IEnumerable GetRecordBatches() public FlightInfo GetFlightInfo() { int batchArrayLength = _recordBatches.Sum(rb => rb.RecordBatch.Length); - int batchBytes = - _recordBatches.Sum(rb => rb.RecordBatch.Arrays.Sum(arr => arr.Data.Buffers.Sum(b => b.Length))); - - if (!_flightDescriptor.Paths.Any()) + int batchBytes = _recordBatches.Sum(rb => rb.RecordBatch.Arrays.Sum(arr => arr.Data.Buffers.Sum(b=>b.Length))); + return new FlightInfo(_schema, _flightDescriptor, new List() { - return GetFlightInfoWithCommand(); - } - - var flightInfo = new FlightInfo(_schema, _flightDescriptor, - new List() - { - new FlightEndpoint(new FlightTicket(_flightDescriptor.Paths.FirstOrDefault()), - new List() { new FlightLocation(_location) }) - }, batchArrayLength, batchBytes); - return flightInfo; - } - - public FlightInfo GetFlightInfoWithCommand() - { - if (!_flightDescriptor.Paths.Any()) - { - return new FlightInfo(_schema, _flightDescriptor, new List(), 0, 0); - } - - return null; + new FlightEndpoint(new FlightTicket(_flightDescriptor.Paths.FirstOrDefault()), new List(){ + new FlightLocation(_location) + }) + }, batchArrayLength, batchBytes); } } } diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs index 87cc27744b7a3..149fb92f9916b 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs @@ -40,15 +40,6 @@ public override async Task DoAction(FlightAction request, IAsyncStreamWriter GetFlightInfo(FlightDescriptor request, ServerC { return Task.FromResult(flightHolder.GetFlightInfo()); } - throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); } - public override async Task Handshake(IAsyncStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) { while (await requestStream.MoveNext().ConfigureAwait(false)) @@ -128,8 +117,6 @@ public override async Task ListActions(IAsyncStreamWriter resp await responseStream.WriteAsync(new FlightActionType("put", "add a flight")); await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight")); await responseStream.WriteAsync(new FlightActionType("test", "test action")); - await responseStream.WriteAsync(new FlightActionType("commit", "commit a transaction")); - await responseStream.WriteAsync(new FlightActionType("rollback", "rollback a transaction")); } public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter responseStream, ServerCallContext context) From 291670dede4dc001b752db469f9cd37af73b43be Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Wed, 4 Sep 2024 12:43:11 +0300 Subject: [PATCH 10/58] testing: flightsql --- .../Apache.Arrow.Flight.Sql.TestWeb.csproj | 18 +++++++++ .../Program.cs | 29 ++++++++++++++ .../Properties/launchSettings.json | 38 +++++++++++++++++++ .../appsettings.Development.json | 8 ++++ .../appsettings.json | 9 +++++ 5 files changed, 102 insertions(+) create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/Program.cs create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/Properties/launchSettings.json create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.Development.json create mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.json diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj new file mode 100644 index 0000000000000..650089c643f6f --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj @@ -0,0 +1,18 @@ + + + + net8.0 + + + + + + + + + + + + + + diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Program.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Program.cs new file mode 100644 index 0000000000000..9a56cb0c998e6 --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Program.cs @@ -0,0 +1,29 @@ +using System.Net; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Server.Kestrel.Core; +using Microsoft.Extensions.Hosting; + +namespace Apache.Arrow.Flight.Sql.TestWeb; + +public class Program +{ + public static void Main(string[] args) + { + CreateHostBuilder(args).Build().Run(); + } + + private static IHostBuilder CreateHostBuilder(string[] args) => + Host.CreateDefaultBuilder(args) + .ConfigureWebHostDefaults(webBuilder => + { + webBuilder + .ConfigureKestrel((context, options) => + { + if (context.HostingEnvironment.IsDevelopment()) + { + options.Listen(IPEndPoint.Parse("0.0.0.0:5001"), l => l.Protocols = HttpProtocols.Http2); + } + }) + .UseStartup(); + }); +} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Properties/launchSettings.json b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Properties/launchSettings.json new file mode 100644 index 0000000000000..08f737dc9415e --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Properties/launchSettings.json @@ -0,0 +1,38 @@ +{ + "$schema": "http://json.schemastore.org/launchsettings.json", + "iisSettings": { + "windowsAuthentication": false, + "anonymousAuthentication": true, + "iisExpress": { + "applicationUrl": "http://localhost:64484", + "sslPort": 44321 + } + }, + "profiles": { + "http": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": true, + "applicationUrl": "http://localhost:5285", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + }, + "https": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": true, + "applicationUrl": "https://localhost:7276;http://localhost:5285", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + }, + "IIS Express": { + "commandName": "IISExpress", + "launchBrowser": true, + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.Development.json b/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.Development.json new file mode 100644 index 0000000000000..0c208ae9181e5 --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.Development.json @@ -0,0 +1,8 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + } +} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.json b/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.json new file mode 100644 index 0000000000000..10f68b8c8b4f7 --- /dev/null +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.json @@ -0,0 +1,9 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + }, + "AllowedHosts": "*" +} From 2df0baaa3266b6d34268526ef5becd77045ebc4f Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Wed, 4 Sep 2024 12:50:20 +0300 Subject: [PATCH 11/58] rollback: FlightTests --- .../FlightTestUtils.cs | 48 ---------- .../Apache.Arrow.Flight.Tests/FlightTests.cs | 92 ++++++++++++------- 2 files changed, 57 insertions(+), 83 deletions(-) delete mode 100644 csharp/test/Apache.Arrow.Flight.Tests/FlightTestUtils.cs diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTestUtils.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTestUtils.cs deleted file mode 100644 index 6a9184368e658..0000000000000 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTestUtils.cs +++ /dev/null @@ -1,48 +0,0 @@ -using System.Linq; -using Apache.Arrow.Flight.TestWeb; - -namespace Apache.Arrow.Flight.Tests; - -public class FlightTestUtils -{ - private readonly TestWebFactory _testWebFactory; - private readonly FlightStore _flightStore; - - public FlightTestUtils(TestWebFactory testWebFactory, FlightStore flightStore) - { - _testWebFactory = testWebFactory; - _flightStore = flightStore; - } - - public RecordBatch CreateTestBatch(int startValue, int length) - { - var batchBuilder = new RecordBatch.Builder(); - Int32Array.Builder builder = new(); - for (int i = 0; i < length; i++) - { - builder.Append(startValue + i); - } - - batchBuilder.Append("test", true, builder.Build()); - return batchBuilder.Build(); - } - - - public FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, - params RecordBatchWithMetadata[] batches) - { - var initialBatch = batches.FirstOrDefault(); - - var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, - _testWebFactory.GetAddress()); - - foreach (var batch in batches) - { - flightHolder.AddBatch(batch); - } - - _flightStore.Flights.Add(flightDescriptor, flightHolder); - - return flightHolder.GetFlightInfo(); - } -} diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index 3ef08f3080cae..67f8b6b22a6fb 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -31,14 +31,11 @@ public class FlightTests : IDisposable readonly TestWebFactory _testWebFactory; readonly FlightClient _flightClient; readonly FlightStore _flightStore; - private readonly FlightTestUtils _testUtils; - public FlightTests() { _flightStore = new FlightStore(); _testWebFactory = new TestWebFactory(_flightStore); _flightClient = new FlightClient(_testWebFactory.GetChannel()); - _testUtils = new FlightTestUtils(_testWebFactory, _flightStore); } public void Dispose() @@ -46,6 +43,18 @@ public void Dispose() _testWebFactory.Dispose(); } + private RecordBatch CreateTestBatch(int startValue, int length) + { + var batchBuilder = new RecordBatch.Builder(); + Int32Array.Builder builder = new Int32Array.Builder(); + for (int i = 0; i < length; i++) + { + builder.Append(startValue + i); + } + batchBuilder.Append("test", true, builder.Build()); + return batchBuilder.Build(); + } + private IEnumerable GetStoreBatch(FlightDescriptor flightDescriptor) { @@ -55,11 +64,27 @@ private IEnumerable GetStoreBatch(FlightDescriptor flig return flightHolder.GetRecordBatches(); } + private FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, params RecordBatchWithMetadata[] batches) + { + var initialBatch = batches.FirstOrDefault(); + + var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress()); + + foreach(var batch in batches) + { + flightHolder.AddBatch(batch); + } + + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + return flightHolder.GetFlightInfo(); + } + [Fact] public async Task TestPutSingleRecordBatch() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch = _testUtils.CreateTestBatch(0, 100); + var expectedBatch = CreateTestBatch(0, 100); var putStream = _flightClient.StartPut(flightDescriptor); await putStream.RequestStream.WriteAsync(expectedBatch); @@ -78,8 +103,8 @@ public async Task TestPutSingleRecordBatch() public async Task TestPutTwoRecordBatches() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch1 = _testUtils.CreateTestBatch(0, 100); - var expectedBatch2 = _testUtils.CreateTestBatch(0, 100); + var expectedBatch1 = CreateTestBatch(0, 100); + var expectedBatch2 = CreateTestBatch(0, 100); var putStream = _flightClient.StartPut(flightDescriptor); await putStream.RequestStream.WriteAsync(expectedBatch1); @@ -100,10 +125,10 @@ public async Task TestPutTwoRecordBatches() public async Task TestGetSingleRecordBatch() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch = _testUtils.CreateTestBatch(0, 100); + var expectedBatch = CreateTestBatch(0, 100); //Add batch to the in memory store - _testUtils.GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch)); + GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch)); //Get the flight info for the ticket var flightInfo = await _flightClient.GetInfo(flightDescriptor); @@ -122,12 +147,11 @@ public async Task TestGetSingleRecordBatch() public async Task TestGetTwoRecordBatch() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch1 = _testUtils.CreateTestBatch(0, 100); - var expectedBatch2 = _testUtils.CreateTestBatch(100, 100); + var expectedBatch1 = CreateTestBatch(0, 100); + var expectedBatch2 = CreateTestBatch(100, 100); //Add batch to the in memory store - _testUtils.GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1), - new RecordBatchWithMetadata(expectedBatch2)); + GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1), new RecordBatchWithMetadata(expectedBatch2)); //Get the flight info for the ticket var flightInfo = await _flightClient.GetInfo(flightDescriptor); @@ -147,13 +171,13 @@ public async Task TestGetTwoRecordBatch() public async Task TestGetFlightMetadata() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch1 = _testUtils.CreateTestBatch(0, 100); + var expectedBatch1 = CreateTestBatch(0, 100); var expectedMetadata = ByteString.CopyFromUtf8("test metadata"); var expectedMetadataList = new List() { expectedMetadata }; //Add batch to the in memory store - _testUtils.GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1, expectedMetadata)); + GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1, expectedMetadata)); //Get the flight info for the ticket var flightInfo = await _flightClient.GetInfo(flightDescriptor); @@ -164,7 +188,7 @@ public async Task TestGetFlightMetadata() var getStream = _flightClient.GetStream(endpoint.Ticket); List actualMetadata = new List(); - while (await getStream.ResponseStream.MoveNext(default)) + while(await getStream.ResponseStream.MoveNext(default)) { actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata); } @@ -176,7 +200,7 @@ public async Task TestGetFlightMetadata() public async Task TestPutWithMetadata() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch = _testUtils.CreateTestBatch(0, 100); + var expectedBatch = CreateTestBatch(0, 100); var expectedMetadata = ByteString.CopyFromUtf8("test metadata"); var putStream = _flightClient.StartPut(flightDescriptor); @@ -197,10 +221,10 @@ public async Task TestPutWithMetadata() public async Task TestGetSchema() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch = _testUtils.CreateTestBatch(0, 100); + var expectedBatch = CreateTestBatch(0, 100); var expectedSchema = expectedBatch.Schema; - _testUtils.GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch)); + GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch)); var actualSchema = await _flightClient.GetSchema(flightDescriptor); @@ -242,18 +266,18 @@ public async Task TestListFlights() { var flightDescriptor1 = FlightDescriptor.CreatePathDescriptor("test1"); var flightDescriptor2 = FlightDescriptor.CreatePathDescriptor("test2"); - var expectedBatch = _testUtils.CreateTestBatch(0, 100); + var expectedBatch = CreateTestBatch(0, 100); List expectedFlightInfo = new List(); - expectedFlightInfo.Add(_testUtils.GivenStoreBatches(flightDescriptor1, new RecordBatchWithMetadata(expectedBatch))); - expectedFlightInfo.Add(_testUtils.GivenStoreBatches(flightDescriptor2, new RecordBatchWithMetadata(expectedBatch))); + expectedFlightInfo.Add(GivenStoreBatches(flightDescriptor1, new RecordBatchWithMetadata(expectedBatch))); + expectedFlightInfo.Add(GivenStoreBatches(flightDescriptor2, new RecordBatchWithMetadata(expectedBatch))); var listFlightStream = _flightClient.ListFlights(); var actualFlights = await listFlightStream.ResponseStream.ToListAsync(); - for (int i = 0; i < expectedFlightInfo.Count; i++) + for(int i = 0; i < expectedFlightInfo.Count; i++) { FlightInfoComparer.Compare(expectedFlightInfo[i], actualFlights[i]); } @@ -277,7 +301,7 @@ public async Task TestSingleExchange() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("single_exchange"); var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor); - var expectedBatch = _testUtils.CreateTestBatch(0, 100); + var expectedBatch = CreateTestBatch(0, 100); await duplexStreamingCall.RequestStream.WriteAsync(expectedBatch); await duplexStreamingCall.RequestStream.CompleteAsync(); @@ -293,8 +317,8 @@ public async Task TestMultipleExchange() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("multiple_exchange"); var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor); - var expectedBatch1 = _testUtils.CreateTestBatch(0, 100); - var expectedBatch2 = _testUtils.CreateTestBatch(100, 100); + var expectedBatch1 = CreateTestBatch(0, 100); + var expectedBatch2 = CreateTestBatch(100, 100); await duplexStreamingCall.RequestStream.WriteAsync(expectedBatch1); await duplexStreamingCall.RequestStream.WriteAsync(expectedBatch2); @@ -311,7 +335,7 @@ public async Task TestExchangeWithMetadata() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("metadata_exchange"); var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor); - var expectedBatch = _testUtils.CreateTestBatch(0, 100); + var expectedBatch = CreateTestBatch(0, 100); var expectedMetadata = ByteString.CopyFromUtf8("test metadata"); await duplexStreamingCall.RequestStream.WriteAsync(expectedBatch, expectedMetadata); @@ -334,8 +358,7 @@ public async Task TestHandshakeWithSpecificMessage() { var duplexStreamingCall = _flightClient.Handshake(); - await duplexStreamingCall.RequestStream.WriteAsync( - new FlightHandshakeRequest(ByteString.CopyFromUtf8("Hello"))); + await duplexStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.CopyFromUtf8("Hello"))); await duplexStreamingCall.RequestStream.CompleteAsync(); var results = await duplexStreamingCall.ResponseStream.ToListAsync(); @@ -347,12 +370,11 @@ await duplexStreamingCall.RequestStream.WriteAsync( public async Task TestGetBatchesWithAsyncEnumerable() { var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); - var expectedBatch1 = _testUtils.CreateTestBatch(0, 100); - var expectedBatch2 = _testUtils.CreateTestBatch(100, 100); + var expectedBatch1 = CreateTestBatch(0, 100); + var expectedBatch2 = CreateTestBatch(100, 100); //Add batch to the in memory store - _testUtils.GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1), - new RecordBatchWithMetadata(expectedBatch2)); + GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(expectedBatch1), new RecordBatchWithMetadata(expectedBatch2)); //Get the flight info for the ticket var flightInfo = await _flightClient.GetInfo(flightDescriptor); @@ -364,7 +386,7 @@ public async Task TestGetBatchesWithAsyncEnumerable() List resultList = new List(); - await foreach (var recordBatch in getStream.ResponseStream) + await foreach(var recordBatch in getStream.ResponseStream) { resultList.Add(recordBatch); } @@ -378,12 +400,12 @@ public async Task TestGetBatchesWithAsyncEnumerable() public async Task EnsureTheSerializedBatchContainsTheProperTotalRecordsAndTotalBytesProperties() { var flightDescriptor1 = FlightDescriptor.CreatePathDescriptor("test1"); - var expectedBatch = _testUtils.CreateTestBatch(0, 100); + var expectedBatch = CreateTestBatch(0, 100); var expectedTotalBytes = expectedBatch.Arrays.Sum(arr => arr.Data.Buffers.Sum(b => b.Length)); List expectedFlightInfo = new List(); - expectedFlightInfo.Add(_testUtils.GivenStoreBatches(flightDescriptor1, new RecordBatchWithMetadata(expectedBatch))); + expectedFlightInfo.Add(GivenStoreBatches(flightDescriptor1, new RecordBatchWithMetadata(expectedBatch))); var listFlightStream = _flightClient.ListFlights(); From 6c6cbf7cd935a89ddd581b27da3ca7162851e7c0 Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Wed, 4 Sep 2024 13:07:39 +0300 Subject: [PATCH 12/58] test: flight client --- csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index 19dc8bf906b98..9746700d81d10 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -16,6 +16,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.TestWeb; From 89b6661ebb40d10fdcfa337edecd8a32275cd8fb Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Mon, 9 Sep 2024 19:21:55 +0300 Subject: [PATCH 13/58] testing: FlightSqlClient --- .../Program.cs | 2 +- .../Apache.Arrow.Flight.Sql.TestWeb.csproj | 2 +- .../FlightSqlHolder.cs | 15 +- .../TestFlightSqlServer.cs | 329 +++++----- .../CancelFlightInfoRequest.cs | 51 +- .../CancelFlightInfoResult.cs | 7 +- .../Client/FlightSqlClient.cs | 333 +++------- .../FlightSqlClientTests.cs | 599 +++++++++++++++++- .../FlightSqlTestExtensions.cs | 1 + .../TestFlightServer.cs | 9 + 10 files changed, 911 insertions(+), 437 deletions(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs index f8b3bd49ea992..b782ea6f92acb 100644 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs @@ -392,7 +392,7 @@ static async Task PutExample(FlightSqlClient client, string query) var metadata = new List> { new("db_name", "SYSDB"), new("table_name", "Info") }; var schema = new Schema(fields, metadata); - var doPutResult = await client.DoPut(options, descriptor, schema).ConfigureAwait(false); + var doPutResult = await client.DoPutAsync(options, descriptor, schema).ConfigureAwait(false); // Example data to write var col1 = new Int32Array.Builder().AppendRange(new[] { 8, 9, 10, 11 }).Build(); diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj index 650089c643f6f..4985573674ae1 100644 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj @@ -6,7 +6,7 @@ - + diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs index 7b1a382cd7492..88e6e545c2b99 100644 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs @@ -38,9 +38,22 @@ public FlightInfo GetFlightInfo() var flightInfo = new FlightInfo(_schema, _flightDescriptor, new List() { - new FlightEndpoint(new FlightTicket(_flightDescriptor.Paths.FirstOrDefault() ?? "test"), + new FlightEndpoint(new FlightTicket( + CustomTicketStrategy(_flightDescriptor) + ), new List() { new FlightLocation(_location) }) }, batchArrayLength, batchBytes); return flightInfo; } + + private string CustomTicketStrategy(FlightDescriptor descriptor) + { + if (descriptor.Command.Length > 0) + { + return $"{descriptor.Command.ToStringUtf8()}"; + } + + // Fallback in case there is no command in the descriptor + return "default_custom_ticket"; + } } diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs index 186ac344c7676..73abdd5c89363 100644 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs @@ -1,161 +1,182 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; using Apache.Arrow.Flight.Server; +using Apache.Arrow.Types; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; namespace Apache.Arrow.Flight.Sql.TestWeb; public class TestFlightSqlServer : FlightServer { + private readonly FlightSqlStore _flightStore; + + public TestFlightSqlServer(FlightSqlStore flightStore) + { + _flightStore = flightStore; + } + + public override async Task DoAction(FlightAction request, IAsyncStreamWriter responseStream, + ServerCallContext context) + { + switch (request.Type) + { + case "test": + await responseStream.WriteAsync(new FlightResult("test data")); + break; + case "GetPrimaryKeys": + await responseStream.WriteAsync(new FlightResult("test data")); + break; + case "CancelFlightInfo": + var schema = new Schema + .Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var flightInfo = new FlightInfo(schema, flightDescriptor, new List(), 0, 0); + var cancelRequest = new CancelFlightInfoRequest(flightInfo); + await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray())); + break; + case "BeginTransaction": + case "Commit": + case "Rollback": + await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id"))); + break; + case "CreatePreparedStatement": + case "ClosePreparedStatement": + var prepareStatementResponse = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement") + }; + + var packedResult = Any.Pack(prepareStatementResponse).Serialize().ToByteArray(); + var flightResult = new FlightResult(packedResult); + await responseStream.WriteAsync(flightResult); + break; + default: + throw new NotImplementedException(); + } + } + + public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, + ServerCallContext context) + { + // var flightDescriptor = FlightDescriptor.CreatePathDescriptor(ticket.Ticket.ToStringUtf8()); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(ticket.Ticket.ToStringUtf8()); + + if (_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) + { + var batches = flightHolder.GetRecordBatches(); + foreach (var batch in batches) + { + await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata); + } + } + } + + public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, + IAsyncStreamWriter responseStream, ServerCallContext context) + { + var flightDescriptor = await requestStream.FlightDescriptor; + + if (!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) + { + flightHolder = new FlightSqlHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + } + + while (await requestStream.MoveNext()) + { + flightHolder.AddBatch(new RecordBatchWithMetadata(requestStream.Current, + requestStream.ApplicationMetadata.FirstOrDefault())); + await responseStream.WriteAsync(FlightPutResult.Empty); + } + } + + public override Task GetFlightInfo(FlightDescriptor request, ServerCallContext context) + { + if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) + { + return Task.FromResult(flightHolder.GetFlightInfo()); + } + + if (_flightStore.Flights.Count > 0) + { + // todo: should rethink of the way to implement dynamic Flights search + return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo()); + } + + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); + } + + + public override async Task Handshake(IAsyncStreamReader requestStream, + IAsyncStreamWriter responseStream, ServerCallContext context) + { + while (await requestStream.MoveNext().ConfigureAwait(false)) + { + if (requestStream.Current.Payload.ToStringUtf8() == "Hello") + { + await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Hello handshake"))) + .ConfigureAwait(false); + } + else + { + await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Done"))).ConfigureAwait(false); + } + } + } + + public override Task GetSchema(FlightDescriptor request, ServerCallContext context) + { + if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) + { + return Task.FromResult(flightHolder.GetFlightInfo().Schema); + } + + if (_flightStore.Flights.Count > 0) + { + // todo: should rethink of the way to implement dynamic Flights search + return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo().Schema); + } + + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); + } + + public override async Task ListActions(IAsyncStreamWriter responseStream, + ServerCallContext context) + { + await responseStream.WriteAsync(new FlightActionType("get", "get a flight")); + await responseStream.WriteAsync(new FlightActionType("put", "add a flight")); + await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight")); + await responseStream.WriteAsync(new FlightActionType("test", "test action")); + await responseStream.WriteAsync(new FlightActionType("commit", "commit a transaction")); + await responseStream.WriteAsync(new FlightActionType("rollback", "rollback a transaction")); + } + + public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter responseStream, + ServerCallContext context) + { + var flightInfos = _flightStore.Flights.Select(x => x.Value.GetFlightInfo()).ToList(); + + foreach (var flightInfo in flightInfos) + { + await responseStream.WriteAsync(flightInfo); + } + } + + public override async Task DoExchange(FlightServerRecordBatchStreamReader requestStream, + FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) + { + while (await requestStream.MoveNext().ConfigureAwait(false)) + { + await responseStream + .WriteAsync(requestStream.Current, requestStream.ApplicationMetadata.FirstOrDefault()) + .ConfigureAwait(false); + } + } } -/* - * public class TestFlightServer : FlightServer - { - private readonly FlightStore _flightStore; - - public TestFlightServer(FlightStore flightStore) - { - _flightStore = flightStore; - } - - public override async Task DoAction(FlightAction request, IAsyncStreamWriter responseStream, - ServerCallContext context) - { - switch (request.Type) - { - case "test": - await responseStream.WriteAsync(new FlightResult("test data")); - break; - case "BeginTransaction": - case "Commit": - case "Rollback": - await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id"))); - break; - case "CreatePreparedStatement": - case "ClosePreparedStatement": - var prepareStatementResponse = new ActionCreatePreparedStatementResult - { - PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement") - }; - - var packedResult = Any.Pack(prepareStatementResponse).Serialize().ToByteArray(); - var flightResult = new FlightResult(packedResult); - await responseStream.WriteAsync(flightResult); - break; - default: - throw new NotImplementedException(); - } - } - - public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, - ServerCallContext context) - { - // var flightDescriptor = FlightDescriptor.CreatePathDescriptor(ticket.Ticket.ToStringUtf8()); - var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(ticket.Ticket.ToStringUtf8()); - - if (_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) - { - var batches = flightHolder.GetRecordBatches(); - foreach (var batch in batches) - { - await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata); - } - } - } - - public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, - IAsyncStreamWriter responseStream, ServerCallContext context) - { - var flightDescriptor = await requestStream.FlightDescriptor; - - if (!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) - { - flightHolder = new FlightHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}"); - _flightStore.Flights.Add(flightDescriptor, flightHolder); - } - - while (await requestStream.MoveNext()) - { - flightHolder.AddBatch(new RecordBatchWithMetadata(requestStream.Current, - requestStream.ApplicationMetadata.FirstOrDefault())); - await responseStream.WriteAsync(FlightPutResult.Empty); - } - } - - public override Task GetFlightInfo(FlightDescriptor request, ServerCallContext context) - { - if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) - { - return Task.FromResult(flightHolder.GetFlightInfo()); - } - - if (_flightStore.Flights.Count > 0) - { - // todo: should rethink of the way to implement dynamic Flights search - return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo()); - } - - throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); - } - - - public override async Task Handshake(IAsyncStreamReader requestStream, - IAsyncStreamWriter responseStream, ServerCallContext context) - { - while (await requestStream.MoveNext().ConfigureAwait(false)) - { - if (requestStream.Current.Payload.ToStringUtf8() == "Hello") - { - await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Hello handshake"))) - .ConfigureAwait(false); - } - else - { - await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Done"))).ConfigureAwait(false); - } - } - } - - public override Task GetSchema(FlightDescriptor request, ServerCallContext context) - { - if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) - { - return Task.FromResult(flightHolder.GetFlightInfo().Schema); - } - - throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); - } - - public override async Task ListActions(IAsyncStreamWriter responseStream, - ServerCallContext context) - { - await responseStream.WriteAsync(new FlightActionType("get", "get a flight")); - await responseStream.WriteAsync(new FlightActionType("put", "add a flight")); - await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight")); - await responseStream.WriteAsync(new FlightActionType("test", "test action")); - await responseStream.WriteAsync(new FlightActionType("commit", "commit a transaction")); - await responseStream.WriteAsync(new FlightActionType("rollback", "rollback a transaction")); - } - - public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter responseStream, - ServerCallContext context) - { - var flightInfos = _flightStore.Flights.Select(x => x.Value.GetFlightInfo()).ToList(); - - foreach (var flightInfo in flightInfos) - { - await responseStream.WriteAsync(flightInfo); - } - } - - public override async Task DoExchange(FlightServerRecordBatchStreamReader requestStream, - FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) - { - while (await requestStream.MoveNext().ConfigureAwait(false)) - { - await responseStream - .WriteAsync(requestStream.Current, requestStream.ApplicationMetadata.FirstOrDefault()) - .ConfigureAwait(false); - } - } - } - * - * - */ diff --git a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs index 586f8cece913a..9afddf656aa93 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs @@ -12,10 +12,10 @@ public sealed class CancelFlightInfoRequest : IMessage public CancelFlightInfoRequest(FlightInfo flightInfo) { FlightInfo = flightInfo ?? throw new ArgumentNullException(nameof(flightInfo)); - Descriptor = - DescriptorReflection.Descriptor.MessageTypes[0]; } + public MessageDescriptor Descriptor => + DescriptorReflection.Descriptor.MessageTypes[0]; public void MergeFrom(CancelFlightInfoRequest message) { @@ -43,19 +43,56 @@ public void MergeFrom(CodedInputStream input) public void WriteTo(CodedOutputStream output) { - output.WriteTag(1, WireFormat.WireType.LengthDelimited); - output.WriteMessage(FlightInfo.Descriptor.ParsedAndUnpackedMessage()); + if (FlightInfo != null) + { + output.WriteTag(1, WireFormat.WireType.LengthDelimited); + output.WriteString(FlightInfo.Descriptor.Command.ToStringUtf8()); + foreach (var path in FlightInfo.Descriptor.Paths) + { + output.WriteString(path); + } + output.WriteInt64(FlightInfo.TotalRecords); + output.WriteInt64(FlightInfo.TotalBytes); + } } public int CalculateSize() { int size = 0; - size += 1 + CodedOutputStream.ComputeMessageSize(FlightInfo.Descriptor.ParsedAndUnpackedMessage()); + + if (FlightInfo != null) + { + // Manually compute the size of FlightInfo + size += 1 + ComputeFlightInfoSize(FlightInfo); + } + return size; } - public MessageDescriptor Descriptor { get; } - public bool Equals(CancelFlightInfoRequest other) => other != null && FlightInfo.Equals(other.FlightInfo); public CancelFlightInfoRequest Clone() => new(FlightInfo); + + private int ComputeFlightInfoSize(FlightInfo flightInfo) + { + int size = 0; + + if (flightInfo.Descriptor != null) + { + size += CodedOutputStream.ComputeStringSize(flightInfo.Descriptor.Command.ToStringUtf8()); + } + + if (flightInfo.Descriptor?.Paths != null) + { + foreach (string? path in flightInfo.Descriptor.Paths) + { + size += CodedOutputStream.ComputeStringSize(path); + } + } + + // Compute size for other fields within FlightInfo + size += CodedOutputStream.ComputeInt64Size(flightInfo.TotalRecords); + size += CodedOutputStream.ComputeInt64Size(flightInfo.TotalBytes); + + return size; + } } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs index 5e61fc6f975f5..d98384e38aa53 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs @@ -34,11 +34,12 @@ public void MergeFrom(CancelFlightInfoResult message) public void MergeFrom(CodedInputStream input) { - while (input.ReadTag() != 0) + uint tag; + while ((tag = input.ReadTag()) != 0) { - switch (input.Position) + switch (tag) { - case 1: + case 10: CancelStatus = (CancelStatus)input.ReadEnum(); break; default: diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 9522ead468a76..1f9e5463c7100 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -3,6 +3,7 @@ using System.IO; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; +using Apache.Arrow.Flight.Server; using Arrow.Flight.Protocol.Sql; using Google.Protobuf; using Google.Protobuf.WellKnownTypes; @@ -220,7 +221,6 @@ public async Task GetExecuteSchemaAsync(FlightCallOptions options, strin var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); var call = _client.DoAction(action, options.Headers); - // Process the response await foreach (var result in call.ResponseStream.ReadAllAsync()) { var preparedStatementResponse = @@ -238,16 +238,8 @@ public async Task GetExecuteSchemaAsync(FlightCallOptions options, strin } catch (RpcException ex) { - // Handle gRPC exceptions - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get execute schema", ex); } - catch (Exception ex) - { - // Handle other exceptions - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -255,7 +247,7 @@ public async Task GetExecuteSchemaAsync(FlightCallOptions options, strin /// /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. - public async Task GetCatalogs(FlightCallOptions options) + public async Task GetCatalogsAsync(FlightCallOptions options) { if (options == null) { @@ -271,14 +263,8 @@ public async Task GetCatalogs(FlightCallOptions options) } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status}"); throw new InvalidOperationException("Failed to get catalogs", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -286,7 +272,7 @@ public async Task GetCatalogs(FlightCallOptions options) /// /// RPC-layer hints for this call. /// The SchemaResult describing the schema of the catalogs. - public async Task GetCatalogsSchema(FlightCallOptions options) + public async Task GetCatalogsSchemaAsync(FlightCallOptions options) { if (options == null) { @@ -302,14 +288,8 @@ public async Task GetCatalogsSchema(FlightCallOptions options) } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status}"); throw new InvalidOperationException("Failed to get catalogs schema", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -333,16 +313,8 @@ public async Task GetSchemaAsync(FlightCallOptions options, FlightDescri } catch (RpcException ex) { - // Handle gRPC exceptions - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get schema", ex); } - catch (Exception ex) - { - // Handle other exceptions - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -394,14 +366,8 @@ public async Task GetDbSchemasAsync(FlightCallOptions options, strin } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get database schemas", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -421,21 +387,13 @@ public async Task GetDbSchemasSchemaAsync(FlightCallOptions options) var command = new CommandGetDbSchemas(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var schemaResultCall = _client.GetSchema(descriptor, options.Headers); - var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); - + var schemaResult = await GetSchemaAsync(options, descriptor); return schemaResult; } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get database schemas schema", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -471,7 +429,7 @@ public async IAsyncEnumerable DoGetAsync(FlightCallOptions options, /// The descriptor of the stream. /// The schema for the data to upload. /// A Task representing the asynchronous operation. The task result contains a DoPutResult struct holding a reader and a writer. - public Task DoPut(FlightCallOptions options, FlightDescriptor descriptor, Schema schema) + public async Task DoPutAsync(FlightCallOptions options, FlightDescriptor descriptor, Schema schema) { if (descriptor is null) throw new ArgumentNullException(nameof(descriptor)); @@ -481,30 +439,16 @@ public Task DoPut(FlightCallOptions options, FlightDescriptor descr try { var doPutResult = _client.StartPut(descriptor, options.Headers); - // Get the writer and reader var writer = doPutResult.RequestStream; var reader = doPutResult.ResponseStream; - // TODO: After Re-Check it with Jeremy - // Create an empty RecordBatch to begin the writer with the schema - // var emptyRecordBatch = new RecordBatch(schema, new List(), 0); - // await writer.WriteAsync(emptyRecordBatch); - - // Begin the writer with the schema - return Task.FromResult(new DoPutResult(writer, reader)); + await writer.WriteAsync(new RecordBatch(schema, new List(), 0)).ConfigureAwait(false); + return new DoPutResult(writer, reader); } catch (RpcException ex) { - // Handle gRPC exceptions - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to perform DoPut operation", ex); } - catch (Exception ex) - { - // Handle other exceptions - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -516,7 +460,7 @@ public Task DoPut(FlightCallOptions options, FlightDescriptor descr /// A Task representing the asynchronous operation. The task result contains a DoPutResult struct holding a reader and a writer. public Task DoPutAsync(FlightDescriptor descriptor, Schema schema) { - return DoPut(new FlightCallOptions(), descriptor, schema); + return DoPutAsync(new FlightCallOptions(), descriptor, schema); } /// @@ -525,7 +469,7 @@ public Task DoPutAsync(FlightDescriptor descriptor, Schema schema) /// RPC-layer hints for this call. /// The table reference. /// The FlightInfo describing where to access the dataset. - public async Task GetPrimaryKeys(FlightCallOptions options, TableRef tableRef) + public async Task GetPrimaryKeysAsync(FlightCallOptions options, TableRef tableRef) { if (tableRef == null) throw new ArgumentNullException(nameof(tableRef)); @@ -536,34 +480,16 @@ public async Task GetPrimaryKeys(FlightCallOptions options, TableRef { Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table }; - var action = new FlightAction("GetPrimaryKeys", getPrimaryKeysRequest.PackAndSerialize()); - var doActionResult = DoActionAsync(options, action); - await foreach (var result in doActionResult) - { - var getPrimaryKeysResponse = - result.Body.ParseAndUnpack(); - var command = new CommandPreparedStatementQuery - { - PreparedStatementHandle = getPrimaryKeysResponse.PreparedStatementHandle - }; - - var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor); - return flightInfo; - } + byte[] packedRequest = getPrimaryKeysRequest.PackAndSerialize(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(packedRequest); + var flightInfo = await GetFlightInfoAsync(options, descriptor); - throw new InvalidOperationException("Failed to retrieve primary keys information."); + return flightInfo; } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get primary keys", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -593,7 +519,11 @@ public async Task> GetTablesAsync(FlightCallOptions opti TableNameFilterPattern = tableFilterPattern ?? string.Empty, IncludeSchema = includeSchema }; - command.TableTypes.AddRange(tableTypes); + + if (tableTypes != null) + { + command.TableTypes.AddRange(tableTypes); + } var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); var flightInfoCall = GetFlightInfoAsync(options, descriptor); @@ -628,14 +558,8 @@ public async Task GetExportedKeysAsync(FlightCallOptions options, Ta } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get exported keys", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -659,14 +583,8 @@ public async Task GetExportedKeysSchemaAsync(FlightCallOptions options) } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get exported keys schema", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -686,38 +604,15 @@ public async Task GetImportedKeysAsync(FlightCallOptions options, Ta { Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table }; + var descriptor = FlightDescriptor.CreateCommandDescriptor(getImportedKeysRequest.PackAndSerialize()); - var action = - new FlightAction("GetImportedKeys", - getImportedKeysRequest.PackAndSerialize()); // check: whether using SqlAction.Enum - var doActionResult = DoActionAsync(options, action); - - await foreach (var result in doActionResult) - { - var getImportedKeysResponse = - result.Body.ParseAndUnpack(); - var command = new CommandPreparedStatementQuery - { - PreparedStatementHandle = getImportedKeysResponse.PreparedStatementHandle - }; - - var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor); - return flightInfo; - } - - throw new InvalidOperationException("Failed to retrieve imported keys information."); + var flightInfo = await GetFlightInfoAsync(options, descriptor); + return flightInfo; } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get imported keys", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -741,14 +636,8 @@ public async Task GetImportedKeysSchemaAsync(FlightCallOptions options) } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get imported keys schema", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -786,14 +675,8 @@ public async Task GetCrossReferenceAsync(FlightCallOptions options, } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get cross reference", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -820,14 +703,8 @@ public async Task GetCrossReferenceSchemaAsync(FlightCallOptions options } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get cross-reference schema", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -851,14 +728,8 @@ public async Task GetTableTypesAsync(FlightCallOptions options) } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get table types", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -882,14 +753,8 @@ public async Task GetTableTypesSchemaAsync(FlightCallOptions options) } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get table types schema", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -914,14 +779,8 @@ public async Task GetXdbcTypeInfoAsync(FlightCallOptions options, in } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get XDBC type info", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -945,14 +804,8 @@ public async Task GetXdbcTypeInfoAsync(FlightCallOptions options) } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get XDBC type info", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -976,14 +829,8 @@ public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions options) } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get XDBC type info schema", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -1014,14 +861,8 @@ public async Task GetSqlInfoAsync(FlightCallOptions options, List @@ -1047,14 +888,8 @@ public async Task GetSqlInfoSchemaAsync(FlightCallOptions options) } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to get SQL info schema", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -1075,23 +910,64 @@ public async Task CancelFlightInfoAsync(FlightCallOption var call = _client.DoAction(action, options.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync()) { - var cancelResult = FlightSqlUtils.ParseAndUnpack(result.Body); - return cancelResult; + var cancelResult = Any.Parser.ParseFrom(result.Body); + if (cancelResult.TryUnpack(out CancelFlightInfoResult cancelFlightInfoResult)) + { + return cancelFlightInfoResult; + } } throw new InvalidOperationException("No response received for the CancelFlightInfo request."); } catch (RpcException ex) { - // Handle gRPC exceptions - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to cancel flight info", ex); } - catch (Exception ex) + } + + /// + /// Explicitly cancel a query. + /// + /// RPC-layer hints for this call. + /// The FlightInfo of the query to cancel. + /// A Task representing the asynchronous operation. + public async Task CancelQueryAsync(FlightCallOptions options, FlightInfo info) + { + if (options == null) + { + throw new ArgumentNullException(nameof(options)); + } + + if (info == null) + { + throw new ArgumentNullException(nameof(info)); + } + + try + { + var cancelRequest = new CancelFlightInfoRequest(info); + var action = new FlightAction("CancelFlightInfo", cancelRequest.ToByteString()); + var call = _client.DoAction(action, options.Headers); + await foreach (var result in call.ResponseStream.ReadAllAsync()) + { + var cancelResult = Any.Parser.ParseFrom(result.Body); + if (cancelResult.TryUnpack(out CancelFlightInfoResult cancelFlightInfoResult)) + { + return cancelFlightInfoResult.CancelStatus switch + { + CancelStatus.Cancelled => CancelStatus.Cancelled, + CancelStatus.Cancelling => CancelStatus.Cancelling, + CancelStatus.NotCancellable => CancelStatus.NotCancellable, + _ => CancelStatus.Unspecified + }; + } + } + + throw new InvalidOperationException("Failed to cancel query: No response received."); + } + catch (RpcException ex) { - // Handle other exceptions - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; + throw new InvalidOperationException("Failed to cancel query", ex); } } @@ -1112,13 +988,11 @@ public async Task BeginTransactionAsync(FlightCallOptions options) var actionBeginTransaction = new ActionBeginTransactionRequest(); var action = new FlightAction("BeginTransaction", actionBeginTransaction.PackAndSerialize()); var responseStream = _client.DoAction(action, options.Headers); - await foreach (var result in responseStream.ResponseStream.ReadAllAsync()) { string? beginTransactionResult = result.Body.ToStringUtf8(); return new Transaction(beginTransactionResult); } - throw new InvalidOperationException("Failed to begin transaction: No response received."); } catch (RpcException ex) @@ -1153,14 +1027,8 @@ public AsyncServerStreamingCall CommitAsync(FlightCallOptions opti } catch (RpcException ex) { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); throw new InvalidOperationException("Failed to commit transaction", ex); } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } } /// @@ -1193,59 +1061,6 @@ public AsyncServerStreamingCall RollbackAsync(FlightCallOptions op } } - - /// - /// Explicitly cancel a query. - /// - /// RPC-layer hints for this call. - /// The FlightInfo of the query to cancel. - /// A Task representing the asynchronous operation. - public async Task CancelQueryAsync(FlightCallOptions options, FlightInfo info) - { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - - if (info == null) - { - throw new ArgumentNullException(nameof(info)); - } - - try - { - var cancelRequest = new CancelFlightInfoRequest(info); - var action = new FlightAction("CancelFlightInfo", cancelRequest.ToByteString()); - var call = _client.DoAction(action, options.Headers); - await foreach (var result in call.ResponseStream.ReadAllAsync()) - { - var cancelResult = Any.Parser.ParseFrom(result.Body); - if (cancelResult.TryUnpack(out CancelFlightInfoResult cancelFlightInfoResult)) - { - return cancelFlightInfoResult.CancelStatus switch - { - CancelStatus.Cancelled => CancelStatus.Cancelled, - CancelStatus.Cancelling => CancelStatus.Cancelling, - CancelStatus.NotCancellable => CancelStatus.NotCancellable, - _ => CancelStatus.Unspecified - }; - } - } - - throw new InvalidOperationException("Failed to cancel query: No response received."); - } - catch (RpcException ex) - { - Console.WriteLine($@"gRPC Error: {ex.Status.Detail}"); - throw new InvalidOperationException("Failed to cancel query", ex); - } - catch (Exception ex) - { - Console.WriteLine($@"Unexpected Error: {ex.Message}"); - throw; - } - } - /// /// Create a prepared statement object. /// @@ -1273,9 +1088,9 @@ public async Task PrepareAsync(FlightCallOptions options, str var preparedStatementRequest = new ActionCreatePreparedStatementRequest { Query = query, - TransactionId = transaction is null - ? ByteString.CopyFromUtf8(transaction?.TransactionId) - : ByteString.Empty + TransactionId = transaction.IsValid() + ? ByteString.CopyFromUtf8(transaction.TransactionId) + : ByteString.Empty, }; var action = new FlightAction(SqlAction.CreateRequest, preparedStatementRequest.PackAndSerialize()); @@ -1284,7 +1099,7 @@ public async Task PrepareAsync(FlightCallOptions options, str await foreach (var result in call.ResponseStream.ReadAllAsync()) { var preparedStatementResponse = - FlightSqlUtils.ParseAndUnpack(result.Body); + FlightSqlUtils.ParseAndUnpack(result.Body); var commandSqlCall = new CommandPreparedStatementQuery { diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 4bbcc152521f6..c6f9888b724c5 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -1,11 +1,16 @@ using System; +using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.Sql.Client; using Apache.Arrow.Flight.Sql.TestWeb; +using Apache.Arrow.Types; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf.WellKnownTypes; using Grpc.Core.Utils; using Xunit; +using RecordBatchWithMetadata = Apache.Arrow.Flight.Sql.TestWeb.RecordBatchWithMetadata; namespace Apache.Arrow.Flight.Sql.Tests; @@ -30,7 +35,7 @@ public FlightSqlClientTests() #region Transactions [Fact] - public async Task CommitAsync_Transaction() + public async Task CommitTransactionAsync() { // Arrange string transactionId = "sample-transaction-id"; @@ -47,7 +52,7 @@ public async Task CommitAsync_Transaction() } [Fact] - public async Task BeginTransactionAsync_Transaction() + public async Task BeginTransactionAsync() { // Arrange var options = new FlightCallOptions(); @@ -62,7 +67,7 @@ public async Task BeginTransactionAsync_Transaction() } [Fact] - public async Task RollbackAsync_Transaction() + public async Task RollbackTransactionAsync() { // Arrange string transactionId = "sample-transaction-id"; @@ -83,14 +88,21 @@ public async Task RollbackAsync_Transaction() #region PreparedStatement [Fact] - public async Task PreparedStatement() + public async Task PreparedStatementAsync() { // Arrange string query = "INSERT INTO users (id, name) VALUES (1, 'John Doe')"; var options = new FlightCallOptions(); + var transaction = new Transaction("sample-transaction-id"); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); + flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); + + _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var preparedStatement = await _flightSqlClient.PrepareAsync(options, query); + var preparedStatement = await _flightSqlClient.PrepareAsync(options, query, transaction); // Assert Assert.NotNull(preparedStatement); @@ -99,7 +111,7 @@ public async Task PreparedStatement() #endregion [Fact] - public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows() + public async Task ExecuteUpdateAsync() { // Arrange string query = "UPDATE test_table SET column1 = 'value' WHERE column2 = 'condition'"; @@ -110,7 +122,6 @@ public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows() var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); - _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act @@ -121,14 +132,22 @@ public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows() } [Fact] - public async Task Execute() + public async Task ExecuteAsync() { // Arrange string query = "SELECT * FROM test_table"; var options = new FlightCallOptions(); + var transaction = new Transaction("sample-transaction-id"); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); + flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); + + _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var flightInfo = await _flightSqlClient.ExecuteAsync(options, query); + var flightInfo = await _flightSqlClient.ExecuteAsync(options, query, transaction); // Assert Assert.NotNull(flightInfo); @@ -136,7 +155,7 @@ public async Task Execute() } [Fact] - public async Task GetFlightInfo() + public async Task GetFlightInfoAsync() { // Arrange var options = new FlightCallOptions(); @@ -144,7 +163,6 @@ public async Task GetFlightInfo() var recordBatch = _testUtils.CreateTestBatch(0, 100); var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); - _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act var flightInfo = await _flightSqlClient.GetFlightInfoAsync(options, flightDescriptor); @@ -153,5 +171,564 @@ public async Task GetFlightInfo() Assert.NotNull(flightInfo); } + [Fact] + public async Task GetExecuteSchemaAsync() + { + // Arrange + string query = "SELECT * FROM test_table"; + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + Schema resultSchema = + await _flightSqlClient.GetExecuteSchemaAsync(options, query, new Transaction("sample-transaction-id")); + + // Assert + Assert.NotNull(resultSchema); + Assert.Equal(recordBatch.Schema.FieldsList.Count, resultSchema.FieldsList.Count); + CompareSchemas(resultSchema, recordBatch.Schema); + } + + [Fact] + public async Task GetCatalogsAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var result = await _flightSqlClient.GetCatalogsAsync(options); + + // Assert + Assert.NotNull(result); + Assert.Equal(flightHolder.GetFlightInfo().Endpoints.Count, result.Endpoints.Count); + Assert.Equal(flightDescriptor, result.Descriptor); + } + + [Fact] + public async Task GetSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var result = await _flightSqlClient.GetSchemaAsync(options, flightDescriptor); + + // Assert + Assert.NotNull(result); + Assert.Equal(recordBatch.Schema.FieldsList.Count, result.FieldsList.Count); + CompareSchemas(result, recordBatch.Schema); + } + + [Fact] + public async Task GetDbSchemasAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + string catalog = "test-catalog"; + string dbSchemaFilterPattern = "test-schema-pattern"; + + // Act + var result = await _flightSqlClient.GetDbSchemasAsync(options, catalog, dbSchemaFilterPattern); + + // Assert + Assert.NotNull(result); + var expectedFlightInfo = flightHolder.GetFlightInfo(); + Assert.Equal(recordBatch.Schema.FieldsList.Count, result.Schema.FieldsList.Count); + Assert.Equal(expectedFlightInfo.Descriptor.Command, result.Descriptor.Command); + Assert.Equal(expectedFlightInfo.Descriptor.Type, result.Descriptor.Type); + Assert.Equal(expectedFlightInfo.Schema.FieldsList.Count, result.Schema.FieldsList.Count); + Assert.Equal(expectedFlightInfo.Endpoints.Count, result.Endpoints.Count); + + for (int i = 0; i < expectedFlightInfo.Schema.FieldsList.Count; i++) + { + var expectedField = expectedFlightInfo.Schema.FieldsList[i]; + var actualField = result.Schema.FieldsList[i]; + + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + Assert.Equal(expectedField.Metadata?.Count ?? 0, actualField.Metadata?.Count ?? 0); + } + + for (int i = 0; i < expectedFlightInfo.Endpoints.Count; i++) + { + var expectedEndpoint = expectedFlightInfo.Endpoints[i]; + var actualEndpoint = result.Endpoints[i]; + + Assert.Equal(expectedEndpoint.Ticket, actualEndpoint.Ticket); + Assert.Equal(expectedEndpoint.Locations.Count(), actualEndpoint.Locations.Count()); + } + } + + [Fact] + public async Task GetPrimaryKeysAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var tableRef = new TableRef { Catalog = "test-catalog", Table = "test-table", DbSchema = "test-schema" }; + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var result = await _flightSqlClient.GetPrimaryKeysAsync(options, tableRef); + + // Assert + Assert.NotNull(result); + var expectedFlightInfo = flightHolder.GetFlightInfo(); + Assert.Equal(expectedFlightInfo.Descriptor.Command, result.Descriptor.Command); + Assert.Equal(expectedFlightInfo.Descriptor.Type, result.Descriptor.Type); + Assert.Equal(expectedFlightInfo.Schema.FieldsList.Count, result.Schema.FieldsList.Count); + + for (int i = 0; i < expectedFlightInfo.Schema.FieldsList.Count; i++) + { + var expectedField = expectedFlightInfo.Schema.FieldsList[i]; + var actualField = result.Schema.FieldsList[i]; + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + Assert.Equal(expectedField.Metadata?.Count ?? 0, actualField.Metadata?.Count ?? 0); + } + + Assert.Equal(expectedFlightInfo.Endpoints.Count, result.Endpoints.Count); + + for (int i = 0; i < expectedFlightInfo.Endpoints.Count; i++) + { + var expectedEndpoint = expectedFlightInfo.Endpoints[i]; + var actualEndpoint = result.Endpoints[i]; + + Assert.Equal(expectedEndpoint.Ticket, actualEndpoint.Ticket); + Assert.Equal(expectedEndpoint.Locations.Count(), actualEndpoint.Locations.Count()); + } + } + + [Fact] + public async Task GetTablesAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + string catalog = "sample_catalog"; + string dbSchemaFilterPattern = "sample_schema"; + string tableFilterPattern = "sample_table"; + bool includeSchema = true; + var tableTypes = new List { "BASE TABLE" }; + + // Act + var result = await _flightSqlClient.GetTablesAsync(options, catalog, dbSchemaFilterPattern, tableFilterPattern, + includeSchema, tableTypes); + + // Assert + Assert.NotNull(result); + Assert.Single(result); + + var expectedFlightInfo = flightHolder.GetFlightInfo(); + var flightInfo = result.First(); + Assert.Equal(expectedFlightInfo.Descriptor.Command, flightInfo.Descriptor.Command); + Assert.Equal(expectedFlightInfo.Descriptor.Type, flightInfo.Descriptor.Type); + Assert.Equal(expectedFlightInfo.Schema.FieldsList.Count, flightInfo.Schema.FieldsList.Count); + for (int i = 0; i < expectedFlightInfo.Schema.FieldsList.Count; i++) + { + var expectedField = expectedFlightInfo.Schema.FieldsList[i]; + var actualField = flightInfo.Schema.FieldsList[i]; + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + } + + Assert.Equal(expectedFlightInfo.Endpoints.Count, flightInfo.Endpoints.Count); + } + + + [Fact] + public async Task GetCatalogsSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetCatalogsSchemaAsync(options); + + // Assert + Assert.NotNull(schema); + var expectedFlightInfo = flightHolder.GetFlightInfo(); + for (int i = 0; i < expectedFlightInfo.Schema.FieldsList.Count; i++) + { + var expectedField = expectedFlightInfo.Schema.FieldsList[i]; + var actualField = schema.FieldsList[i]; + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + } + } + + [Fact] + public async Task GetDbSchemasSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetDbSchemasSchemaAsync(options); + + // Assert + Assert.NotNull(schema); + for (int i = 0; i < schema.FieldsList.Count; i++) + { + var expectedField = schema.FieldsList[i]; + var actualField = schema.FieldsList[i]; + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + } + } + + [Fact] + public async Task DoPutAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var result = await _flightSqlClient.DoPutAsync(options, flightDescriptor, recordBatch.Schema); + + // Assert + Assert.NotNull(result); + } + + [Fact] + public async Task GetExportedKeysAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var tableRef = new TableRef { Catalog = "test-catalog", Table = "test-table", DbSchema = "test-schema" }; + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.GetExportedKeysAsync(options, tableRef); + + // Assert + Assert.NotNull(flightInfo); + Assert.Equal("test", flightInfo.Descriptor.Command.ToStringUtf8()); + } + + [Fact] + public async Task GetExportedKeysSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetExportedKeysSchemaAsync(options); + + // Assert + Assert.NotNull(schema); + Assert.True(schema.FieldsList.Count > 0, "Schema should contain fields."); + Assert.Equal("test", schema.FieldsList.First().Name); + } + + [Fact] + public async Task GetImportedKeysAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.GetImportedKeysAsync(options, + new TableRef { Catalog = "test-catalog", Table = "test-table", DbSchema = "test-schema" }); + + // Assert + Assert.NotNull(flightInfo); + for (int i = 0; i < recordBatch.Schema.FieldsList.Count; i++) + { + var expectedField = recordBatch.Schema.FieldsList[i]; + var actualField = flightInfo.Schema.FieldsList[i]; + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + } + } + + [Fact] + public async Task GetImportedKeysSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetImportedKeysSchemaAsync(options); + + // Assert + var expectedSchema = recordBatch.Schema; + Assert.NotNull(schema); + Assert.Equal(expectedSchema.FieldsList.Count, schema.FieldsList.Count); + CompareSchemas(expectedSchema, schema); + } + + [Fact] + public async Task GetCrossReferenceAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + var pkTableRef = new TableRef { Catalog = "PKCatalog", DbSchema = "PKSchema", Table = "PKTable" }; + var fkTableRef = new TableRef { Catalog = "FKCatalog", DbSchema = "FKSchema", Table = "FKTable" }; + + // Act + var flightInfo = await _flightSqlClient.GetCrossReferenceAsync(options, pkTableRef, fkTableRef); + + // Assert + Assert.NotNull(flightInfo); + Assert.Equal(flightDescriptor, flightInfo.Descriptor); + Assert.Single(flightInfo.Schema.FieldsList); + } + + [Fact] + public async Task GetCrossReferenceSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetCrossReferenceSchemaAsync(options); + + // Assert + var expectedSchema = recordBatch.Schema; + Assert.NotNull(schema); + Assert.Equal(expectedSchema.FieldsList.Count, schema.FieldsList.Count); + } + + [Fact] + public async Task GetTableTypesAsync() + { + // Arrange + var options = new FlightCallOptions(); + var expectedSchema = new Schema + .Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var commandGetTableTypes = new CommandGetTableTypes(); + byte[] packedCommand = commandGetTableTypes.PackAndSerialize().ToByteArray(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); + var flightHolder = new FlightSqlHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.GetTableTypesAsync(options); + var actualSchema = flightInfo.Schema; + + // Assert + Assert.NotNull(flightInfo); + CompareSchemas(expectedSchema, actualSchema); + } + + [Fact] + public async Task GetTableTypesSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var expectedSchema = new Schema + .Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var commandGetTableTypesSchema = new CommandGetTableTypes(); + byte[] packedCommand = commandGetTableTypesSchema.PackAndSerialize().ToByteArray(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); + + var flightHolder = new FlightSqlHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schemaResult = await _flightSqlClient.GetTableTypesSchemaAsync(options); + + // Assert + Assert.NotNull(schemaResult); + CompareSchemas(expectedSchema, schemaResult); + } + + [Fact] + public async Task GetXdbcTypeInfoAsync() + { + // Arrange + var options = new FlightCallOptions(); + var expectedSchema = new Schema + .Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("LITERAL_PREFIX").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("COLUMN_SIZE").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var commandGetXdbcTypeInfo = new CommandGetXdbcTypeInfo(); + byte[] packedCommand = commandGetXdbcTypeInfo.PackAndSerialize().ToByteArray(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); + + // Creating a flight holder with the expected schema and adding it to the flight store + var flightHolder = new FlightSqlHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.GetXdbcTypeInfoAsync(options); + + // Assert + Assert.NotNull(flightInfo); + CompareSchemas(expectedSchema, flightInfo.Schema); + } + + [Fact] + public async Task GetXdbcTypeInfoSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var expectedSchema = new Schema + .Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("LITERAL_PREFIX").DataType(StringType.Default).Nullable(false)) + .Build(); + + var commandGetXdbcTypeInfo = new CommandGetXdbcTypeInfo(); + byte[] packedCommand = commandGetXdbcTypeInfo.PackAndSerialize().ToByteArray(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); + + var flightHolder = new FlightSqlHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetXdbcTypeInfoSchemaAsync(options); + + // Assert + Assert.NotNull(schema); + CompareSchemas(expectedSchema, schema); + } + + [Fact] + public async Task GetSqlInfoSchemaAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("sqlInfo"); + var expectedSchema = new Schema + .Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var flightHolder = new FlightSqlHolder(flightDescriptor, expectedSchema, _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var schema = await _flightSqlClient.GetSqlInfoSchemaAsync(options); + + // Assert + Assert.NotNull(schema); + CompareSchemas(expectedSchema, schema); + } + + [Fact] + public async Task CancelFlightInfoAsync() + { + // Arrange + var options = new FlightCallOptions(); + var schema = new Schema + .Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var flightInfo = new FlightInfo(schema, flightDescriptor, new List(), 0, 0); + var cancelRequest = new CancelFlightInfoRequest(flightInfo); + + // Act + var cancelResult = await _flightSqlClient.CancelFlightInfoAsync(options, cancelRequest); + + // Assert + Assert.NotNull(cancelResult); + Assert.True(cancelResult.CancelStatus == CancelStatus.Cancelled); + } + public void Dispose() => _testWebFactory?.Dispose(); + + private void CompareSchemas(Schema expectedSchema, Schema actualSchema) + { + Assert.Equal(expectedSchema.FieldsList.Count, actualSchema.FieldsList.Count); + + for (int i = 0; i < expectedSchema.FieldsList.Count; i++) + { + var expectedField = expectedSchema.FieldsList[i]; + var actualField = actualSchema.FieldsList[i]; + + Assert.Equal(expectedField.Name, actualField.Name); + Assert.Equal(expectedField.DataType, actualField.DataType); + Assert.Equal(expectedField.IsNullable, actualField.IsNullable); + Assert.Equal(expectedField.Metadata, actualField.Metadata); + } + } } diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs index 0df6460e56b0e..db6804f4e3c04 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs @@ -25,4 +25,5 @@ public static ByteString PackAndSerialize(this IMessage command) { return Any.Pack(command).Serialize(); } + } diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs index 149fb92f9916b..722b49d9063a9 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs @@ -40,6 +40,15 @@ public override async Task DoAction(FlightAction request, IAsyncStreamWriter Date: Mon, 9 Sep 2024 19:32:00 +0300 Subject: [PATCH 14/58] testing: CancelQuery --- .../FlightSqlClientTests.cs | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index c6f9888b724c5..6a6e7e3836e7f 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -714,6 +714,29 @@ public async Task CancelFlightInfoAsync() Assert.True(cancelResult.CancelStatus == CancelStatus.Cancelled); } + [Fact] + public async Task CancelQueryAsync() + { + // Arrange + var options = new FlightCallOptions(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test-query"); + var schema = new Schema + .Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var flightInfo = new FlightInfo(schema, flightDescriptor, new List(), 0, 0); + + // Adding the flight info to the flight store for testing + _flightStore.Flights.Add(flightDescriptor, new FlightSqlHolder(flightDescriptor, schema, _testWebFactory.GetAddress())); + + // Act + var cancelStatus = await _flightSqlClient.CancelQueryAsync(options, flightInfo); + + // Assert + Assert.Equal(CancelStatus.Cancelled, cancelStatus); + } + + public void Dispose() => _testWebFactory?.Dispose(); private void CompareSchemas(Schema expectedSchema, Schema actualSchema) From 0263c44078332ead4a895a4dc2b974ea23123583 Mon Sep 17 00:00:00 2001 From: Genady Shmunik Date: Wed, 11 Sep 2024 18:06:48 +0300 Subject: [PATCH 15/58] test: DoPut --- .../Program.cs | 8 +- .../Client/FlightSqlClient.cs | 79 ++++++++++++++++++- .../FlightSqlClientTests.cs | 44 +++++++++-- 3 files changed, 118 insertions(+), 13 deletions(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs index b782ea6f92acb..4683603d67037 100644 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs +++ b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs @@ -395,10 +395,10 @@ static async Task PutExample(FlightSqlClient client, string query) var doPutResult = await client.DoPutAsync(options, descriptor, schema).ConfigureAwait(false); // Example data to write - var col1 = new Int32Array.Builder().AppendRange(new[] { 8, 9, 10, 11 }).Build(); - var col2 = new StringArray.Builder().AppendRange(new[] { "a", "b", "c", "d" }).Build(); - var col3 = new StringArray.Builder().AppendRange(new[] { "x", "y", "z", "q" }).Build(); - var batch = new RecordBatch(schema, new IArrowArray[] { col1, col2, col3 }, 4); + var col1 = new Int32Array.Builder().AppendRange([8, 9, 10, 11]).Build(); + var col2 = new StringArray.Builder().AppendRange(["a", "b", "c", "d"]).Build(); + var col3 = new StringArray.Builder().AppendRange(["x", "y", "z", "q"]).Build(); + var batch = new RecordBatch(schema, [col1, col2, col3], 4); await doPutResult.Writer.WriteAsync(batch); await doPutResult.Writer.CompleteAsync(); diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 1f9e5463c7100..9d2de3e62c401 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -4,6 +4,7 @@ using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.Server; +using Apache.Arrow.Types; using Arrow.Flight.Protocol.Sql; using Google.Protobuf; using Google.Protobuf.WellKnownTypes; @@ -442,7 +443,10 @@ public async Task DoPutAsync(FlightCallOptions options, FlightDescr var writer = doPutResult.RequestStream; var reader = doPutResult.ResponseStream; - await writer.WriteAsync(new RecordBatch(schema, new List(), 0)).ConfigureAwait(false); + var recordBatch = new RecordBatch(schema, BuildArrowArraysFromSchema(schema, schema.FieldsList.Count), 0); + await writer.WriteAsync(recordBatch).ConfigureAwait(false); + await writer.CompleteAsync().ConfigureAwait(false); + return new DoPutResult(writer, reader); } catch (RpcException ex) @@ -451,6 +455,78 @@ public async Task DoPutAsync(FlightCallOptions options, FlightDescr } } + public List BuildArrowArraysFromSchema(Schema schema, int rowCount) + { + var arrays = new List(); + + foreach (var field in schema.FieldsList) + { + switch (field.DataType) + { + case Int32Type _: + // Create an Int32 array + var intArrayBuilder = new Int32Array.Builder(); + for (int i = 0; i < rowCount; i++) + { + intArrayBuilder.Append(i); // Just filling with sample data + } + + arrays.Add(intArrayBuilder.Build()); + break; + + case StringType: + // Create a String array + var stringArrayBuilder = new StringArray.Builder(); + for (int i = 0; i < rowCount; i++) + { + stringArrayBuilder.Append($"Value-{i}"); // Sample string values + } + + arrays.Add(stringArrayBuilder.Build()); + break; + + case Int64Type: + // Create an Int64 array + var longArrayBuilder = new Int64Array.Builder(); + for (int i = 0; i < rowCount; i++) + { + longArrayBuilder.Append((long)i * 100); // Sample data + } + + arrays.Add(longArrayBuilder.Build()); + break; + + case FloatType: + // Create a Float array + var floatArrayBuilder = new FloatArray.Builder(); + for (int i = 0; i < rowCount; i++) + { + floatArrayBuilder.Append((float)(i * 1.1)); // Sample data + } + + arrays.Add(floatArrayBuilder.Build()); + break; + + case BooleanType: + // Create a Boolean array + var boolArrayBuilder = new BooleanArray.Builder(); + for (int i = 0; i < rowCount; i++) + { + boolArrayBuilder.Append(i % 2 == 0); // Alternate between true and false + } + + arrays.Add(boolArrayBuilder.Build()); + break; + + default: + throw new NotSupportedException($"Data type {field.DataType} not supported yet."); + } + } + + return arrays; + } + + /// /// Upload data to a Flight described by the given descriptor. The caller must call Close() on the returned stream /// once they are done writing. Uses default options. @@ -993,6 +1069,7 @@ public async Task BeginTransactionAsync(FlightCallOptions options) string? beginTransactionResult = result.Body.ToStringUtf8(); return new Transaction(beginTransactionResult); } + throw new InvalidOperationException("Failed to begin transaction: No response received."); } catch (RpcException ex) diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 6a6e7e3836e7f..06ea264b572e3 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -7,7 +7,7 @@ using Apache.Arrow.Flight.Sql.TestWeb; using Apache.Arrow.Types; using Arrow.Flight.Protocol.Sql; -using Google.Protobuf.WellKnownTypes; +using Google.Protobuf; using Grpc.Core.Utils; using Xunit; using RecordBatchWithMetadata = Apache.Arrow.Flight.Sql.TestWeb.RecordBatchWithMetadata; @@ -424,14 +424,41 @@ public async Task DoPutAsync() { // Arrange var options = new FlightCallOptions(); - var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); - var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, - _testWebFactory.GetAddress()); - _flightStore.Flights.Add(flightDescriptor, flightHolder); + // var schema = new Schema + // .Builder() + // .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + // .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) + // .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) + // .Field(f => f.Name("LITERAL_PREFIX").DataType(StringType.Default).Nullable(false)) + // .Field(f => f.Name("COLUMN_SIZE").DataType(Int32Type.Default).Nullable(false)) + // .Build(); + // var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + // + // int[] dataTypeIds = { 1, 2, 3 }; + // string[] typeNames = ["INTEGER", "VARCHAR", "BOOLEAN"]; + // int[] precisions = { 32, 255, 1 }; // For PRECISION + // string[] literalPrefixes = ["N'", "'", "b'"]; + // int[] columnSizes = [10, 255, 1]; + // + // var recordBatch = new RecordBatch(schema, + // [ + // new Int32Array.Builder().AppendRange(dataTypeIds).Build(), + // new StringArray.Builder().AppendRange(typeNames).Build(), + // new Int32Array.Builder().AppendRange(precisions).Build(), + // new StringArray.Builder().AppendRange(literalPrefixes).Build(), + // new Int32Array.Builder().AppendRange(columnSizes).Build() + // ], 5); + // Assert.NotNull(recordBatch); + // Assert.Equal(5, recordBatch.Length); + // var flightHolder = new FlightSqlHolder(flightDescriptor, schema, _testWebFactory.GetAddress()); + // flightHolder.AddBatch(new RecordBatchWithMetadata(_testUtils.CreateTestBatch(0, 100))); + // _flightStore.Flights.Add(flightDescriptor, flightHolder); + + var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); + var expectedBatch = _testUtils.CreateTestBatch(0, 100); // Act - var result = await _flightSqlClient.DoPutAsync(options, flightDescriptor, recordBatch.Schema); + var result = await _flightSqlClient.DoPutAsync(options, flightDescriptor, expectedBatch.Schema); // Assert Assert.NotNull(result); @@ -727,7 +754,8 @@ public async Task CancelQueryAsync() var flightInfo = new FlightInfo(schema, flightDescriptor, new List(), 0, 0); // Adding the flight info to the flight store for testing - _flightStore.Flights.Add(flightDescriptor, new FlightSqlHolder(flightDescriptor, schema, _testWebFactory.GetAddress())); + _flightStore.Flights.Add(flightDescriptor, + new FlightSqlHolder(flightDescriptor, schema, _testWebFactory.GetAddress())); // Act var cancelStatus = await _flightSqlClient.CancelQueryAsync(options, flightInfo); From c5eb5d6c63584d35b6437101d79e2e59fe9145fa Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 18 Sep 2024 09:48:39 +0300 Subject: [PATCH 16/58] chore: protobuf version fix --- .../Apache.Arrow.Flight.Sql.TestWeb.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj index 4985573674ae1..0fd8f47c393f5 100644 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj @@ -6,7 +6,7 @@ - + From c4abd00c10c09fb7487c05febe48541680986703 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 18 Sep 2024 09:57:20 +0300 Subject: [PATCH 17/58] chore: remove commented out commands in CMakeLists.txt --- cpp/CMakeLists.txt | 6 +++--- cpp/src/arrow/flight/sql/CMakeLists.txt | 8 ++++---- csharp/Apache.Arrow.sln | 6 ------ 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 31a0be1ac5cf0..08c6a9b846488 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -415,9 +415,9 @@ else() endif() endif() -#if(NOT ARROW_BUILD_EXAMPLES) -# set(NO_EXAMPLES 1) -#endif() +if(NOT ARROW_BUILD_EXAMPLES) + set(NO_EXAMPLES 1) +endif() if(ARROW_FUZZING) # Fuzzing builds enable ASAN without setting our home-grown option for it. diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index eed30f2c47ca0..b32f731496749 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -94,7 +94,7 @@ endif() list(APPEND ARROW_FLIGHT_SQL_TEST_LINK_LIBS ${ARROW_FLIGHT_TEST_LINK_LIBS}) # Build test server for unit tests -#if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) +if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) find_package(SQLite3Alt REQUIRED) set(ARROW_FLIGHT_SQL_TEST_SERVER_SRCS @@ -121,14 +121,14 @@ list(APPEND ARROW_FLIGHT_SQL_TEST_LINK_LIBS ${ARROW_FLIGHT_TEST_LINK_LIBS}) list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_shared) endif() - #if(ARROW_BUILD_EXAMPLES) + if(ARROW_BUILD_EXAMPLES) add_executable(acero-flight-sql-server ${ARROW_FLIGHT_SQL_ACERO_SRCS} example/acero_main.cc) target_link_libraries(acero-flight-sql-server PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} ${ARROW_FLIGHT_SQL_TEST_LIBS} ${GFLAGS_LIBRARIES}) - #endif() - #endif() + endif() + endif() add_arrow_test(flight_sql_test SOURCES diff --git a/csharp/Apache.Arrow.sln b/csharp/Apache.Arrow.sln index f564971071c01..524fcf3f56b81 100644 --- a/csharp/Apache.Arrow.sln +++ b/csharp/Apache.Arrow.sln @@ -27,8 +27,6 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.Tes EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql", "src\Apache.Arrow.Flight.Sql\Apache.Arrow.Flight.Sql.csproj", "{2ADE087A-B424-4895-8CC5-10170D10BA62}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.IntegrationTest", "Apache.Arrow.Flight.Sql.IntegrationTest\Apache.Arrow.Flight.Sql.IntegrationTest.csproj", "{45416D7D-F12B-4524-B641-AD0E1A33B3B0}" -EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.TestWeb", "Apache.Arrow.Flight.Sql.TestWeb\Apache.Arrow.Flight.Sql.TestWeb.csproj", "{85A6CB32-A83B-48C4-96E8-625C8FBDB349}" EndProject Global @@ -85,10 +83,6 @@ Global {2ADE087A-B424-4895-8CC5-10170D10BA62}.Debug|Any CPU.Build.0 = Debug|Any CPU {2ADE087A-B424-4895-8CC5-10170D10BA62}.Release|Any CPU.ActiveCfg = Release|Any CPU {2ADE087A-B424-4895-8CC5-10170D10BA62}.Release|Any CPU.Build.0 = Release|Any CPU - {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Debug|Any CPU.Build.0 = Debug|Any CPU - {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Release|Any CPU.ActiveCfg = Release|Any CPU - {45416D7D-F12B-4524-B641-AD0E1A33B3B0}.Release|Any CPU.Build.0 = Release|Any CPU {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Debug|Any CPU.Build.0 = Debug|Any CPU {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Release|Any CPU.ActiveCfg = Release|Any CPU From 2ce6f4c80c093429f51cde8de6c58151f2eb8473 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 18 Sep 2024 11:57:21 +0300 Subject: [PATCH 18/58] chore: adding ConfigureAwait(false) to every async method to TestFlightSqlServer.cs csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs var batches = flightHolder.GetRecordBatches(); foreach (var batch in batches) { await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata); Add ConfigureAwait(false) --- .../TestFlightSqlServer.cs | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs index 73abdd5c89363..0dac59fa2bfe6 100644 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs @@ -26,10 +26,10 @@ public override async Task DoAction(FlightAction request, IAsyncStreamWriter(), 0, 0); var cancelRequest = new CancelFlightInfoRequest(flightInfo); - await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray())); + await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray())).ConfigureAwait(false); break; case "BeginTransaction": case "Commit": case "Rollback": - await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id"))); + await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id"))).ConfigureAwait(false); break; case "CreatePreparedStatement": case "ClosePreparedStatement": @@ -52,10 +52,9 @@ public override async Task DoAction(FlightAction request, IAsyncStreamWriter GetSchema(FlightDescriptor request, ServerCallConte public override async Task ListActions(IAsyncStreamWriter responseStream, ServerCallContext context) { - await responseStream.WriteAsync(new FlightActionType("get", "get a flight")); - await responseStream.WriteAsync(new FlightActionType("put", "add a flight")); - await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight")); - await responseStream.WriteAsync(new FlightActionType("test", "test action")); - await responseStream.WriteAsync(new FlightActionType("commit", "commit a transaction")); - await responseStream.WriteAsync(new FlightActionType("rollback", "rollback a transaction")); + await responseStream.WriteAsync(new FlightActionType("get", "get a flight")).ConfigureAwait(false); + await responseStream.WriteAsync(new FlightActionType("put", "add a flight")).ConfigureAwait(false); + await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight")).ConfigureAwait(false); + await responseStream.WriteAsync(new FlightActionType("test", "test action")).ConfigureAwait(false); + await responseStream.WriteAsync(new FlightActionType("commit", "commit a transaction")).ConfigureAwait(false); + await responseStream.WriteAsync(new FlightActionType("rollback", "rollback a transaction")).ConfigureAwait(false); } public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter responseStream, @@ -165,7 +163,7 @@ public override async Task ListFlights(FlightCriteria request, IAsyncStreamWrite foreach (var flightInfo in flightInfos) { - await responseStream.WriteAsync(flightInfo); + await responseStream.WriteAsync(flightInfo).ConfigureAwait(false); } } From 8907913903384bdd36771175a92cfae96fe166e7 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 18 Sep 2024 12:55:39 +0300 Subject: [PATCH 19/58] chore: attending issues in github --- .../Client/FlightSqlClient.cs | 54 +++++++++---------- .../src/Apache.Arrow.Flight.Sql/Savepoint.cs | 13 ----- .../FlightSqlTestExtensions.cs | 1 - .../TestFlightSqlSever.cs | 1 - 4 files changed, 26 insertions(+), 43 deletions(-) delete mode 100644 csharp/src/Apache.Arrow.Flight.Sql/Savepoint.cs diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 9d2de3e62c401..519f689fe1090 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -50,7 +50,7 @@ public async Task ExecuteAsync(FlightCallOptions options, string que var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); var call = _client.DoAction(action, options.Headers); - await foreach (var result in call.ResponseStream.ReadAllAsync()) + await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { var preparedStatementResponse = FlightSqlUtils.ParseAndUnpack(result.Body); @@ -61,7 +61,7 @@ public async Task ExecuteAsync(FlightCallOptions options, string que byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); - return await GetFlightInfoAsync(options, descriptor); + return await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); } throw new InvalidOperationException("No results returned from the query."); @@ -101,7 +101,7 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options, string que var call = DoActionAsync(options, action); long affectedRows = 0; - await foreach (var result in call) + await foreach (var result in call.ConfigureAwait(false)) { var preparedStatementResponse = result.Body.ParseAndUnpack(); var command = new CommandPreparedStatementQuery @@ -112,7 +112,7 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options, string que var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); var flightInfo = await GetFlightInfoAsync(options, descriptor); var doGetResult = DoGetAsync(options, flightInfo.Endpoints[0].Ticket); - await foreach (var recordBatch in doGetResult) + await foreach (var recordBatch in doGetResult.ConfigureAwait(false)) { affectedRows += recordBatch.Column(0).Length; } @@ -178,7 +178,7 @@ public async IAsyncEnumerable DoActionAsync(FlightCallOptions opti var call = _client.DoAction(action, options.Headers); - await foreach (var result in call.ResponseStream.ReadAllAsync()) + await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { yield return result; } @@ -222,7 +222,7 @@ public async Task GetExecuteSchemaAsync(FlightCallOptions options, strin var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); var call = _client.DoAction(action, options.Headers); - await foreach (var result in call.ResponseStream.ReadAllAsync()) + await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { var preparedStatementResponse = FlightSqlUtils.ParseAndUnpack(result.Body); @@ -259,7 +259,7 @@ public async Task GetCatalogsAsync(FlightCallOptions options) { var command = new CommandGetCatalogs(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var catalogsInfo = await GetFlightInfoAsync(options, descriptor); + var catalogsInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); return catalogsInfo; } catch (RpcException ex) @@ -284,7 +284,7 @@ public async Task GetCatalogsSchemaAsync(FlightCallOptions options) { var commandGetCatalogsSchema = new CommandGetCatalogs(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetCatalogsSchema.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor); + var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -387,8 +387,7 @@ public async Task GetDbSchemasSchemaAsync(FlightCallOptions options) { var command = new CommandGetDbSchemas(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - - var schemaResult = await GetSchemaAsync(options, descriptor); + var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -416,7 +415,7 @@ public async IAsyncEnumerable DoGetAsync(FlightCallOptions options, } var call = _client.GetStream(ticket, options.Headers); - await foreach (var recordBatch in call.ResponseStream.ReadAllAsync()) + await foreach (var recordBatch in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { yield return recordBatch; } @@ -558,7 +557,7 @@ public async Task GetPrimaryKeysAsync(FlightCallOptions options, Tab }; byte[] packedRequest = getPrimaryKeysRequest.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(packedRequest); - var flightInfo = await GetFlightInfoAsync(options, descriptor); + var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); return flightInfo; } @@ -629,7 +628,7 @@ public async Task GetExportedKeysAsync(FlightCallOptions options, Ta }; var descriptor = FlightDescriptor.CreateCommandDescriptor(getExportedKeysRequest.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor); + var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -654,7 +653,7 @@ public async Task GetExportedKeysSchemaAsync(FlightCallOptions options) { var commandGetExportedKeysSchema = new CommandGetExportedKeys(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetExportedKeysSchema.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor); + var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -681,8 +680,7 @@ public async Task GetImportedKeysAsync(FlightCallOptions options, Ta Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table }; var descriptor = FlightDescriptor.CreateCommandDescriptor(getImportedKeysRequest.PackAndSerialize()); - - var flightInfo = await GetFlightInfoAsync(options, descriptor); + var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -707,7 +705,7 @@ public async Task GetImportedKeysSchemaAsync(FlightCallOptions options) { var commandGetImportedKeysSchema = new CommandGetImportedKeys(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetImportedKeysSchema.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor); + var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -745,7 +743,7 @@ public async Task GetCrossReferenceAsync(FlightCallOptions options, }; var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetCrossReference.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor); + var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); return flightInfo; } @@ -799,7 +797,7 @@ public async Task GetTableTypesAsync(FlightCallOptions options) { var command = new CommandGetTableTypes(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor); + var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -824,7 +822,7 @@ public async Task GetTableTypesSchemaAsync(FlightCallOptions options) { var command = new CommandGetTableTypes(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor); + var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -850,7 +848,7 @@ public async Task GetXdbcTypeInfoAsync(FlightCallOptions options, in { var command = new CommandGetXdbcTypeInfo { DataType = dataType }; var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor); + var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -875,7 +873,7 @@ public async Task GetXdbcTypeInfoAsync(FlightCallOptions options) { var command = new CommandGetXdbcTypeInfo(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor); + var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -900,7 +898,7 @@ public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions options) { var command = new CommandGetXdbcTypeInfo(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor); + var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -932,7 +930,7 @@ public async Task GetSqlInfoAsync(FlightCallOptions options, List (uint)item)); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor); + var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -984,7 +982,7 @@ public async Task CancelFlightInfoAsync(FlightCallOption { var action = new FlightAction("CancelFlightInfo", request.PackAndSerialize()); var call = _client.DoAction(action, options.Headers); - await foreach (var result in call.ResponseStream.ReadAllAsync()) + await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { var cancelResult = Any.Parser.ParseFrom(result.Body); if (cancelResult.TryUnpack(out CancelFlightInfoResult cancelFlightInfoResult)) @@ -1024,7 +1022,7 @@ public async Task CancelQueryAsync(FlightCallOptions options, Flig var cancelRequest = new CancelFlightInfoRequest(info); var action = new FlightAction("CancelFlightInfo", cancelRequest.ToByteString()); var call = _client.DoAction(action, options.Headers); - await foreach (var result in call.ResponseStream.ReadAllAsync()) + await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { var cancelResult = Any.Parser.ParseFrom(result.Body); if (cancelResult.TryUnpack(out CancelFlightInfoResult cancelFlightInfoResult)) @@ -1064,7 +1062,7 @@ public async Task BeginTransactionAsync(FlightCallOptions options) var actionBeginTransaction = new ActionBeginTransactionRequest(); var action = new FlightAction("BeginTransaction", actionBeginTransaction.PackAndSerialize()); var responseStream = _client.DoAction(action, options.Headers); - await foreach (var result in responseStream.ResponseStream.ReadAllAsync()) + await foreach (var result in responseStream.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { string? beginTransactionResult = result.Body.ToStringUtf8(); return new Transaction(beginTransactionResult); @@ -1184,7 +1182,7 @@ public async Task PrepareAsync(FlightCallOptions options, str }; byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); - var flightInfo = await GetFlightInfoAsync(options, descriptor); + var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); return new PreparedStatement(this, flightInfo, query); } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Savepoint.cs b/csharp/src/Apache.Arrow.Flight.Sql/Savepoint.cs deleted file mode 100644 index 212e4d73fbd61..0000000000000 --- a/csharp/src/Apache.Arrow.Flight.Sql/Savepoint.cs +++ /dev/null @@ -1,13 +0,0 @@ -namespace Apache.Arrow.Flight.Sql; - -public class Savepoint -{ - public string SavepointId { get; private set; } - - public Savepoint(string savepointId) - { - SavepointId = savepointId; - } - - public bool IsValid() => !string.IsNullOrEmpty(SavepointId); -} diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs index db6804f4e3c04..cd2ebc725bcc7 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs @@ -13,7 +13,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -using Apache.Arrow.Flight.Sql.Client; using Google.Protobuf; using Google.Protobuf.WellKnownTypes; diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs index d2e1ac21f9e01..3dca632b5b761 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlSever.cs @@ -86,5 +86,4 @@ private RecordBatch MockRecordBatch(string name) var schema = new Schema(new List {new(name, StringType.Default, false)}, System.Array.Empty>()); return new RecordBatch(schema, new []{ new StringArray.Builder().Append(name).Build() }, 1); } - } From 861df24002d0b1b379e374db7a7f928ff5cf566d Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 22 Sep 2024 16:30:05 +0300 Subject: [PATCH 20/58] chore: adding features - PreparedStatement implementation - SqlActions adding command requests -- Commit -- Rollback -- BeginTransaction -- CancelFlightInfo --- .../Client/FlightSqlClient.cs | 34 +++---- .../PreparedStatement.cs | 92 +++++++++++++++++-- .../src/Apache.Arrow.Flight.Sql/SqlActions.cs | 4 + .../Apache.Arrow.Flight.Sql/Transaction.cs | 31 ++++++- .../FlightSqlClientTests.cs | 11 +-- 5 files changed, 133 insertions(+), 39 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 519f689fe1090..d3acc4323a036 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -1,9 +1,7 @@ using System; using System.Collections.Generic; -using System.IO; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; -using Apache.Arrow.Flight.Server; using Apache.Arrow.Types; using Arrow.Flight.Protocol.Sql; using Google.Protobuf; @@ -21,8 +19,6 @@ public FlightSqlClient(FlightClient client) _client = client ?? throw new ArgumentNullException(nameof(client)); } - public static Transaction NoTransaction() => new(null); - /// /// Execute a SQL query on the server. /// @@ -32,7 +28,7 @@ public FlightSqlClient(FlightClient client) /// The FlightInfo describing where to access the dataset. public async Task ExecuteAsync(FlightCallOptions options, string query, Transaction? transaction = null) { - transaction ??= NoTransaction(); + transaction ??= Transaction.NoTransaction; if (options == null) { @@ -46,7 +42,8 @@ public async Task ExecuteAsync(FlightCallOptions options, string que try { - var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query }; + var prepareStatementRequest = + new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); var call = _client.DoAction(action, options.Headers); @@ -81,7 +78,7 @@ public async Task ExecuteAsync(FlightCallOptions options, string que /// The number of rows affected by the operation. public async Task ExecuteUpdateAsync(FlightCallOptions options, string query, Transaction? transaction = null) { - transaction ??= NoTransaction(); + transaction ??= Transaction.NoTransaction; if (options == null) { @@ -95,7 +92,7 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options, string que try { - var updateRequestCommand = new ActionCreatePreparedStatementRequest { Query = query }; + var updateRequestCommand = new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; byte[] serializedUpdateRequestCommand = updateRequestCommand.PackAndSerialize(); var action = new FlightAction(SqlAction.CreateRequest, serializedUpdateRequestCommand); var call = DoActionAsync(options, action); @@ -207,7 +204,7 @@ public async IAsyncEnumerable DoActionAsync(FlightAction action) public async Task GetExecuteSchemaAsync(FlightCallOptions options, string query, Transaction? transaction = null) { - transaction ??= NoTransaction(); + transaction ??= Transaction.NoTransaction; if (options is null) throw new ArgumentNullException(nameof(options)); @@ -218,7 +215,7 @@ public async Task GetExecuteSchemaAsync(FlightCallOptions options, strin FlightInfo schemaResult = null!; try { - var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query }; + var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); var call = _client.DoAction(action, options.Headers); @@ -980,7 +977,7 @@ public async Task CancelFlightInfoAsync(FlightCallOption try { - var action = new FlightAction("CancelFlightInfo", request.PackAndSerialize()); + var action = new FlightAction(SqlAction.CancelFlightInfoRequest, request.PackAndSerialize()); var call = _client.DoAction(action, options.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -1020,7 +1017,7 @@ public async Task CancelQueryAsync(FlightCallOptions options, Flig try { var cancelRequest = new CancelFlightInfoRequest(info); - var action = new FlightAction("CancelFlightInfo", cancelRequest.ToByteString()); + var action = new FlightAction(SqlAction.CancelFlightInfoRequest, cancelRequest.ToByteString()); var call = _client.DoAction(action, options.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -1060,7 +1057,7 @@ public async Task BeginTransactionAsync(FlightCallOptions options) try { var actionBeginTransaction = new ActionBeginTransactionRequest(); - var action = new FlightAction("BeginTransaction", actionBeginTransaction.PackAndSerialize()); + var action = new FlightAction(SqlAction.BeginTransactionRequest, actionBeginTransaction.PackAndSerialize()); var responseStream = _client.DoAction(action, options.Headers); await foreach (var result in responseStream.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -1097,7 +1094,7 @@ public AsyncServerStreamingCall CommitAsync(FlightCallOptions opti try { - var actionCommit = new FlightAction("Commit", transaction.TransactionId); + var actionCommit = new FlightAction(SqlAction.CommitRequest, transaction.TransactionId); return _client.DoAction(actionCommit, options.Headers); } catch (RpcException ex) @@ -1127,7 +1124,7 @@ public AsyncServerStreamingCall RollbackAsync(FlightCallOptions op try { - var actionRollback = new FlightAction("Rollback", transaction.TransactionId); + var actionRollback = new FlightAction(SqlAction.RollbackRequest, transaction.TransactionId); return _client.DoAction(actionRollback, options.Headers); } catch (RpcException ex) @@ -1146,7 +1143,7 @@ public AsyncServerStreamingCall RollbackAsync(FlightCallOptions op public async Task PrepareAsync(FlightCallOptions options, string query, Transaction? transaction = null) { - transaction ??= NoTransaction(); + transaction ??= Transaction.NoTransaction; if (options == null) { @@ -1162,10 +1159,7 @@ public async Task PrepareAsync(FlightCallOptions options, str { var preparedStatementRequest = new ActionCreatePreparedStatementRequest { - Query = query, - TransactionId = transaction.IsValid() - ? ByteString.CopyFromUtf8(transaction.TransactionId) - : ByteString.Empty, + Query = query, TransactionId = transaction.TransactionId }; var action = new FlightAction(SqlAction.CreateRequest, preparedStatementRequest.PackAndSerialize()); diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index 342b23ef3613d..2cf4e0f559939 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -1,5 +1,9 @@ +using System; +using System.Linq; using System.Threading.Tasks; using Apache.Arrow.Flight.Sql.Client; +using Arrow.Flight.Protocol.Sql; +using Grpc.Core; namespace Apache.Arrow.Flight.Sql; @@ -8,30 +12,98 @@ public class PreparedStatement { private readonly FlightSqlClient _client; private readonly FlightInfo _flightInfo; + private RecordBatch? _parameterBatch; private readonly string _query; + private bool _isClosed; public PreparedStatement(FlightSqlClient client, FlightInfo flightInfo, string query) { - _client = client; - _flightInfo = flightInfo; - _query = query; + _client = client ?? throw new ArgumentNullException(nameof(client)); + _flightInfo = flightInfo ?? throw new ArgumentNullException(nameof(flightInfo)); + _query = query ?? throw new ArgumentNullException(nameof(query)); + _isClosed = false; } + /// + /// Set parameters for the prepared statement + /// + /// The batch of parameters to bind public Task SetParameters(RecordBatch parameterBatch) { - // Implement setting parameters + if (_isClosed) + { + throw new InvalidOperationException("Cannot set parameters on a closed statement."); + } + + _parameterBatch = parameterBatch ?? throw new ArgumentNullException(nameof(parameterBatch)); return Task.CompletedTask; } - public Task ExecuteUpdateAsync(FlightCallOptions options) + /// + /// Execute the prepared statement, returning the number of affected rows + /// + /// The FlightCallOptions for the execution + /// Task representing the asynchronous operation + public async Task ExecuteUpdateAsync(FlightCallOptions options) { - // Implement execution of the prepared statement - return Task.CompletedTask; + if (_isClosed) + { + throw new InvalidOperationException("Cannot execute a closed statement."); + } + + if (_parameterBatch == null) + { + throw new InvalidOperationException("No parameters set for the prepared statement."); + } + + var commandSqlCall = new CommandPreparedStatementQuery + { + PreparedStatementHandle = _flightInfo.Endpoints.First().Ticket.Ticket + }; + byte[] packedCommand = commandSqlCall.PackAndSerialize(); + var descriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); + + var flightInfo = await _client.GetFlightInfoAsync(options, descriptor); + return await ExecuteAndGetAffectedRowsAsync(options, flightInfo); } - public Task CloseAsync(FlightCallOptions options) + /// + /// Closes the prepared statement + /// + public async Task CloseAsync(FlightCallOptions options) { - // Implement closing the prepared statement - return Task.CompletedTask; + if (_isClosed) + { + throw new InvalidOperationException("Statement already closed."); + } + + try + { + var actionClose = new FlightAction(SqlAction.CloseRequest, _flightInfo.Descriptor.Command); + await foreach (var result in _client.DoActionAsync(options, actionClose).ConfigureAwait(false)) + { + // Process any result if necessary (e.g., logging). + } + _isClosed = true; + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to close the prepared statement", ex); + } + } + + /// + /// Helper method to execute the statement and get affected rows + /// + private async Task ExecuteAndGetAffectedRowsAsync(FlightCallOptions options, FlightInfo flightInfo) + { + long affectedRows = 0; + var doGetResult = _client.DoGetAsync(options, flightInfo.Endpoints.First().Ticket); + await foreach (var recordBatch in doGetResult.ConfigureAwait(false)) + { + affectedRows += recordBatch.Length; + } + + return affectedRows; } } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs b/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs index f3f3bef1e1d00..8e935f7614e33 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs @@ -19,4 +19,8 @@ public static class SqlAction { public const string CreateRequest = "CreatePreparedStatement"; public const string CloseRequest = "ClosePreparedStatement"; + public const string CancelFlightInfoRequest = "CancelFlightInfo"; + public const string BeginTransactionRequest = "BeginTransaction"; + public const string CommitRequest = "Commit"; + public const string RollbackRequest = "Rollback"; } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs b/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs index 1f725074bb489..32c882b94c0bc 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs @@ -1,8 +1,33 @@ +using Google.Protobuf; + namespace Apache.Arrow.Flight.Sql; -public class Transaction(string? transactionId) +public class Transaction { - public string? TransactionId { get; } = transactionId; + private static readonly ByteString TransactionIdDefaultValue = ByteString.Empty; + private ByteString? _transactionId; + + public ByteString TransactionId + { + get => _transactionId ?? TransactionIdDefaultValue; + set => _transactionId = ProtoPreconditions.CheckNotNull(value, nameof(value)); + } + + public static readonly Transaction NoTransaction = new(TransactionIdDefaultValue); + + public Transaction(ByteString transactionId) + { + TransactionId = transactionId; + } + + public Transaction(string transactionId) + { + _transactionId = ByteString.CopyFromUtf8(transactionId); + } - public bool IsValid() => !string.IsNullOrEmpty(TransactionId); + public bool IsValid() => TransactionId.Length > 0; + public void ResetTransaction() + { + _transactionId = TransactionIdDefaultValue; + } } diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 06ea264b572e3..2a71ac4b19aaa 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -18,7 +18,6 @@ public class FlightSqlClientTests : IDisposable { readonly TestSqlWebFactory _testWebFactory; readonly FlightSqlStore _flightStore; - readonly FlightClient _flightClient; private readonly FlightSqlClient _flightSqlClient; private readonly FlightSqlTestUtils _testUtils; @@ -26,8 +25,8 @@ public FlightSqlClientTests() { _flightStore = new FlightSqlStore(); _testWebFactory = new TestSqlWebFactory(_flightStore); - _flightClient = new FlightClient(_testWebFactory.GetChannel()); - _flightSqlClient = new FlightSqlClient(_flightClient); + FlightClient flightClient = new(_testWebFactory.GetChannel()); + _flightSqlClient = new FlightSqlClient(flightClient); _testUtils = new FlightSqlTestUtils(_testWebFactory, _flightStore); } @@ -48,7 +47,7 @@ public async Task CommitTransactionAsync() // Assert Assert.NotNull(result); - Assert.Equal(transaction.TransactionId, result.FirstOrDefault()?.Body.ToStringUtf8()); + Assert.Equal(transaction.TransactionId, result.FirstOrDefault()?.Body); } [Fact] @@ -63,7 +62,7 @@ public async Task BeginTransactionAsync() // Assert Assert.NotNull(transaction); - Assert.Equal(expectedTransactionId, transaction.TransactionId); + Assert.Equal(ByteString.CopyFromUtf8(expectedTransactionId), transaction.TransactionId); } [Fact] @@ -80,7 +79,7 @@ public async Task RollbackTransactionAsync() // Assert Assert.NotNull(transaction); - Assert.Equal(result.FirstOrDefault()?.Body.ToStringUtf8(), transaction.TransactionId); + Assert.Equal(result.FirstOrDefault()?.Body, transaction.TransactionId); } #endregion From 8c201836b9fe127692e809b7f24ddf5aae151729 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 23 Sep 2024 14:28:41 +0300 Subject: [PATCH 21/58] feat: PrepareStatement implementation: PreparedStatement - tests --- .../PreparedStatement.cs | 56 ++++--- .../src/Apache.Arrow.Flight.Sql/SqlActions.cs | 1 + .../FlightSqlClientTests.cs | 62 ++++---- .../FlightSqlPreparedStatementTests.cs | 143 ++++++++++++++++++ 4 files changed, 206 insertions(+), 56 deletions(-) create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index 2cf4e0f559939..599854f8f0a45 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -7,8 +7,7 @@ namespace Apache.Arrow.Flight.Sql; -// TODO: Refactor this to match C++ implementation -public class PreparedStatement +public class PreparedStatement : IDisposable { private readonly FlightSqlClient _client; private readonly FlightInfo _flightInfo; @@ -30,11 +29,7 @@ public PreparedStatement(FlightSqlClient client, FlightInfo flightInfo, string q /// The batch of parameters to bind public Task SetParameters(RecordBatch parameterBatch) { - if (_isClosed) - { - throw new InvalidOperationException("Cannot set parameters on a closed statement."); - } - + EnsureStatementIsNotClosed(); _parameterBatch = parameterBatch ?? throw new ArgumentNullException(nameof(parameterBatch)); return Task.CompletedTask; } @@ -46,23 +41,14 @@ public Task SetParameters(RecordBatch parameterBatch) /// Task representing the asynchronous operation public async Task ExecuteUpdateAsync(FlightCallOptions options) { - if (_isClosed) - { - throw new InvalidOperationException("Cannot execute a closed statement."); - } - - if (_parameterBatch == null) - { - throw new InvalidOperationException("No parameters set for the prepared statement."); - } - + EnsureStatementIsNotClosed(); + EnsureParametersAreSet(); var commandSqlCall = new CommandPreparedStatementQuery { PreparedStatementHandle = _flightInfo.Endpoints.First().Ticket.Ticket }; byte[] packedCommand = commandSqlCall.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); - var flightInfo = await _client.GetFlightInfoAsync(options, descriptor); return await ExecuteAndGetAffectedRowsAsync(options, flightInfo); } @@ -72,17 +58,12 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options) /// public async Task CloseAsync(FlightCallOptions options) { - if (_isClosed) - { - throw new InvalidOperationException("Statement already closed."); - } - + EnsureStatementIsNotClosed(); try { var actionClose = new FlightAction(SqlAction.CloseRequest, _flightInfo.Descriptor.Command); await foreach (var result in _client.DoActionAsync(options, actionClose).ConfigureAwait(false)) { - // Process any result if necessary (e.g., logging). } _isClosed = true; } @@ -106,4 +87,31 @@ private async Task ExecuteAndGetAffectedRowsAsync(FlightCallOptions option return affectedRows; } + + /// + /// Helper method to ensure the statement is not closed. + /// + private void EnsureStatementIsNotClosed() + { + if (_isClosed) + throw new InvalidOperationException("Cannot execute a closed statement."); + } + + private void EnsureParametersAreSet() + { + if (_parameterBatch == null || _parameterBatch.Length == 0) + { + throw new InvalidOperationException("Prepared statement parameters have not been set."); + } + } + + public void Dispose() + { + _parameterBatch?.Dispose(); + + if (!_isClosed) + { + _isClosed = true; + } + } } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs b/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs index 8e935f7614e33..aea5a3522783f 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/SqlActions.cs @@ -23,4 +23,5 @@ public static class SqlAction public const string BeginTransactionRequest = "BeginTransaction"; public const string CommitRequest = "Commit"; public const string RollbackRequest = "Rollback"; + public const string GetPrimaryKeysRequest = "GetPrimaryKeys"; } diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 2a71ac4b19aaa..4dc1c7c9d71e9 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -423,37 +423,36 @@ public async Task DoPutAsync() { // Arrange var options = new FlightCallOptions(); - // var schema = new Schema - // .Builder() - // .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) - // .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) - // .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) - // .Field(f => f.Name("LITERAL_PREFIX").DataType(StringType.Default).Nullable(false)) - // .Field(f => f.Name("COLUMN_SIZE").DataType(Int32Type.Default).Nullable(false)) - // .Build(); - // var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); - // - // int[] dataTypeIds = { 1, 2, 3 }; - // string[] typeNames = ["INTEGER", "VARCHAR", "BOOLEAN"]; - // int[] precisions = { 32, 255, 1 }; // For PRECISION - // string[] literalPrefixes = ["N'", "'", "b'"]; - // int[] columnSizes = [10, 255, 1]; - // - // var recordBatch = new RecordBatch(schema, - // [ - // new Int32Array.Builder().AppendRange(dataTypeIds).Build(), - // new StringArray.Builder().AppendRange(typeNames).Build(), - // new Int32Array.Builder().AppendRange(precisions).Build(), - // new StringArray.Builder().AppendRange(literalPrefixes).Build(), - // new Int32Array.Builder().AppendRange(columnSizes).Build() - // ], 5); - // Assert.NotNull(recordBatch); - // Assert.Equal(5, recordBatch.Length); - // var flightHolder = new FlightSqlHolder(flightDescriptor, schema, _testWebFactory.GetAddress()); - // flightHolder.AddBatch(new RecordBatchWithMetadata(_testUtils.CreateTestBatch(0, 100))); - // _flightStore.Flights.Add(flightDescriptor, flightHolder); - - var flightDescriptor = FlightDescriptor.CreatePathDescriptor("test"); + var schema = new Schema + .Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("LITERAL_PREFIX").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("COLUMN_SIZE").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + + int[] dataTypeIds = { 1, 2, 3 }; + string[] typeNames = ["INTEGER", "VARCHAR", "BOOLEAN"]; + int[] precisions = { 32, 255, 1 }; + string[] literalPrefixes = ["N'", "'", "b'"]; + int[] columnSizes = [10, 255, 1]; + + var recordBatch = new RecordBatch(schema, + [ + new Int32Array.Builder().AppendRange(dataTypeIds).Build(), + new StringArray.Builder().AppendRange(typeNames).Build(), + new Int32Array.Builder().AppendRange(precisions).Build(), + new StringArray.Builder().AppendRange(literalPrefixes).Build(), + new Int32Array.Builder().AppendRange(columnSizes).Build() + ], 5); + Assert.NotNull(recordBatch); + Assert.Equal(5, recordBatch.Length); + var flightHolder = new FlightSqlHolder(flightDescriptor, schema, _testWebFactory.GetAddress()); + flightHolder.AddBatch(new RecordBatchWithMetadata(_testUtils.CreateTestBatch(0, 100))); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + var expectedBatch = _testUtils.CreateTestBatch(0, 100); // Act @@ -763,7 +762,6 @@ public async Task CancelQueryAsync() Assert.Equal(CancelStatus.Cancelled, cancelStatus); } - public void Dispose() => _testWebFactory?.Dispose(); private void CompareSchemas(Schema expectedSchema, Schema actualSchema) diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs new file mode 100644 index 0000000000000..dbdcabb03c95a --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs @@ -0,0 +1,143 @@ +using System; +using System.Threading.Tasks; +using Apache.Arrow.Flight.Client; +using Apache.Arrow.Flight.Sql.Client; +using Apache.Arrow.Flight.Sql.TestWeb; +using Apache.Arrow.Types; +using Xunit; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class FlightSqlPreparedStatementTests +{ + readonly TestSqlWebFactory _testWebFactory; + readonly FlightSqlStore _flightStore; + private readonly PreparedStatement _preparedStatement; + private readonly Schema _schema; + private readonly RecordBatch _parameterBatch; + private readonly FlightDescriptor _flightDescriptor; + + public FlightSqlPreparedStatementTests() + { + _flightStore = new FlightSqlStore(); + _testWebFactory = new TestSqlWebFactory(_flightStore); + FlightClient flightClient = new(_testWebFactory.GetChannel()); + FlightSqlClient flightSqlClient = new(flightClient); + + // Setup mock + _flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test-query"); + _schema = new Schema + .Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); + + int[] dataTypeIds = [1, 2, 3]; + string[] typeNames = ["INTEGER", "VARCHAR", "BOOLEAN"]; + int[] precisions = [32, 255, 1]; + string[] literalPrefixes = ["N'", "'", "b'"]; + int[] columnSizes = [10, 255, 1]; + _parameterBatch = new RecordBatch(_schema, + [ + new Int32Array.Builder().AppendRange(dataTypeIds).Build(), + new StringArray.Builder().AppendRange(typeNames).Build(), + new Int32Array.Builder().AppendRange(precisions).Build(), + new StringArray.Builder().AppendRange(literalPrefixes).Build(), + new Int32Array.Builder().AppendRange(columnSizes).Build() + ], 5); + + var flightHolder = new FlightSqlHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); + _preparedStatement = new PreparedStatement(flightSqlClient, flightHolder.GetFlightInfo(), "SELECT * FROM test"); + } + + // PreparedStatement + [Fact] + public async Task SetParameters_ShouldSetParameters_WhenStatementIsOpen() + { + await _preparedStatement.SetParameters(_parameterBatch); + Assert.NotNull(_parameterBatch); + } + + [Fact] + public async Task SetParameters_ShouldThrowException_WhenStatementIsClosed() + { + // Arrange + await _preparedStatement.CloseAsync(new FlightCallOptions()); + + // Act & Assert + await Assert.ThrowsAsync( + () => _preparedStatement.SetParameters(_parameterBatch) + ); + } + + [Fact] + public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows_WhenParametersAreSet() + { + // Arrange + var options = new FlightCallOptions(); + var flightHolder = new FlightSqlHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); + flightHolder.AddBatch(new RecordBatchWithMetadata(_parameterBatch)); + _flightStore.Flights.Add(_flightDescriptor, flightHolder); + + await _preparedStatement.SetParameters(_parameterBatch); + + + // Act + long affectedRows = await _preparedStatement.ExecuteUpdateAsync(options); + + // Assert + Assert.True(affectedRows > 0); // Verifies that the statement executed successfully. + } + + [Fact] + public async Task ExecuteUpdateAsync_ShouldThrowException_WhenNoParametersSet() + { + // Arrange + var options = new FlightCallOptions(); + + // Act & Assert + await Assert.ThrowsAsync( + () => _preparedStatement.ExecuteUpdateAsync(options) + ); + } + + [Fact] + public async Task ExecuteUpdateAsync_ShouldThrowException_WhenStatementIsClosed() + { + // Arrange + var options = new FlightCallOptions(); + await _preparedStatement.CloseAsync(options); + + // Act & Assert + await Assert.ThrowsAsync( + () => _preparedStatement.ExecuteUpdateAsync(options) + ); + } + + [Fact] + public async Task CloseAsync_ShouldCloseStatement_WhenCalled() + { + // Arrange + var options = new FlightCallOptions(); + + // Act + await _preparedStatement.CloseAsync(options); + + // Assert + await Assert.ThrowsAsync( + () => _preparedStatement.CloseAsync(options) + ); + } + + [Fact] + public async Task CloseAsync_ShouldThrowException_WhenStatementAlreadyClosed() + { + // Arrange + var options = new FlightCallOptions(); + await _preparedStatement.CloseAsync(options); + + // Act & Assert + await Assert.ThrowsAsync( + () => _preparedStatement.CloseAsync(options) + ); + } +} From 75d9e0910fb015e8ffb9d2f56fbbed2aedf603b2 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 23 Sep 2024 15:56:20 +0300 Subject: [PATCH 22/58] fix: update PreparedStatement fixed update to receive sql query and update batch parameters --- .../Apache.Arrow.Flight.Sql/PreparedStatement.cs | 14 +++++++------- .../FlightSqlPreparedStatementTests.cs | 2 -- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index 599854f8f0a45..982177d4fc06d 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -43,14 +43,14 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options) { EnsureStatementIsNotClosed(); EnsureParametersAreSet(); - var commandSqlCall = new CommandPreparedStatementQuery + try { - PreparedStatementHandle = _flightInfo.Endpoints.First().Ticket.Ticket - }; - byte[] packedCommand = commandSqlCall.PackAndSerialize(); - var descriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); - var flightInfo = await _client.GetFlightInfoAsync(options, descriptor); - return await ExecuteAndGetAffectedRowsAsync(options, flightInfo); + return await _client.ExecuteUpdateAsync(options, _query); + } + catch (RpcException ex) + { + throw new InvalidOperationException("Failed to execute update query", ex); + } } /// diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs index dbdcabb03c95a..0f35482d2a803 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs @@ -77,10 +77,8 @@ public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows_WhenParametersAreS var flightHolder = new FlightSqlHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); flightHolder.AddBatch(new RecordBatchWithMetadata(_parameterBatch)); _flightStore.Flights.Add(_flightDescriptor, flightHolder); - await _preparedStatement.SetParameters(_parameterBatch); - // Act long affectedRows = await _preparedStatement.ExecuteUpdateAsync(options); From 6d19d062944c9d3a8ae94266d6433d1577df5dfa Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 25 Sep 2024 16:19:14 +0300 Subject: [PATCH 23/58] feat: CancelFlightInfo & CancelFlightQuery --- .../TestFlightSqlServer.cs | 2 +- .../CancelFlightInfoRequest.cs | 98 ------------------- .../CancelFlightInfoResult.cs | 96 ------------------ .../Client/FlightSqlClient.cs | 45 ++++----- .../FlightInfoCancelRequest.cs | 38 +++++++ .../FlightInfoCancelResult.cs | 34 +++++++ .../FlightSqlClientTests.cs | 10 +- 7 files changed, 96 insertions(+), 227 deletions(-) delete mode 100644 csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs delete mode 100644 csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs create mode 100644 csharp/src/Apache.Arrow.Flight/FlightInfoCancelRequest.cs create mode 100644 csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs index 0dac59fa2bfe6..481b0c9c0e29c 100644 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs @@ -38,7 +38,7 @@ public override async Task DoAction(FlightAction request, IAsyncStreamWriter(), 0, 0); - var cancelRequest = new CancelFlightInfoRequest(flightInfo); + var cancelRequest = new FlightInfoCancelRequest(flightInfo); await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray())).ConfigureAwait(false); break; case "BeginTransaction": diff --git a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs deleted file mode 100644 index 9afddf656aa93..0000000000000 --- a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoRequest.cs +++ /dev/null @@ -1,98 +0,0 @@ -using System; -using Google.Protobuf; -using Google.Protobuf.Reflection; - -namespace Apache.Arrow.Flight.Sql; - -public sealed class CancelFlightInfoRequest : IMessage -{ - public FlightInfo FlightInfo { get; set; } - - // Overloaded constructor - public CancelFlightInfoRequest(FlightInfo flightInfo) - { - FlightInfo = flightInfo ?? throw new ArgumentNullException(nameof(flightInfo)); - } - - public MessageDescriptor Descriptor => - DescriptorReflection.Descriptor.MessageTypes[0]; - - public void MergeFrom(CancelFlightInfoRequest message) - { - if (message != null) - { - FlightInfo = message.FlightInfo; - } - } - - public void MergeFrom(CodedInputStream input) - { - while (input.ReadTag() != 0) - { - switch (input.Position) - { - case 1: - input.ReadMessage(this); - break; - default: - input.SkipLastField(); - break; - } - } - } - - public void WriteTo(CodedOutputStream output) - { - if (FlightInfo != null) - { - output.WriteTag(1, WireFormat.WireType.LengthDelimited); - output.WriteString(FlightInfo.Descriptor.Command.ToStringUtf8()); - foreach (var path in FlightInfo.Descriptor.Paths) - { - output.WriteString(path); - } - output.WriteInt64(FlightInfo.TotalRecords); - output.WriteInt64(FlightInfo.TotalBytes); - } - } - - public int CalculateSize() - { - int size = 0; - - if (FlightInfo != null) - { - // Manually compute the size of FlightInfo - size += 1 + ComputeFlightInfoSize(FlightInfo); - } - - return size; - } - - public bool Equals(CancelFlightInfoRequest other) => other != null && FlightInfo.Equals(other.FlightInfo); - public CancelFlightInfoRequest Clone() => new(FlightInfo); - - private int ComputeFlightInfoSize(FlightInfo flightInfo) - { - int size = 0; - - if (flightInfo.Descriptor != null) - { - size += CodedOutputStream.ComputeStringSize(flightInfo.Descriptor.Command.ToStringUtf8()); - } - - if (flightInfo.Descriptor?.Paths != null) - { - foreach (string? path in flightInfo.Descriptor.Paths) - { - size += CodedOutputStream.ComputeStringSize(path); - } - } - - // Compute size for other fields within FlightInfo - size += CodedOutputStream.ComputeInt64Size(flightInfo.TotalRecords); - size += CodedOutputStream.ComputeInt64Size(flightInfo.TotalBytes); - - return size; - } -} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs b/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs deleted file mode 100644 index d98384e38aa53..0000000000000 --- a/csharp/src/Apache.Arrow.Flight.Sql/CancelFlightInfoResult.cs +++ /dev/null @@ -1,96 +0,0 @@ -using Google.Protobuf; -using Google.Protobuf.Reflection; - -namespace Apache.Arrow.Flight.Sql; - -public enum CancelStatus -{ - Unspecified = 0, - Cancelled = 1, - Cancelling = 2, - NotCancellable = 3, - Unrecognized = -1 -} - -public sealed class CancelFlightInfoResult : IMessage -{ - public CancelStatus CancelStatus { get; private set; } - - // Public parameterless constructor - public CancelFlightInfoResult() - { - CancelStatus = CancelStatus.Unspecified; - Descriptor = - DescriptorReflection.Descriptor.MessageTypes[0]; - } - - public void MergeFrom(CancelFlightInfoResult message) - { - if (message != null) - { - CancelStatus = message.CancelStatus; - } - } - - public void MergeFrom(CodedInputStream input) - { - uint tag; - while ((tag = input.ReadTag()) != 0) - { - switch (tag) - { - case 10: - CancelStatus = (CancelStatus)input.ReadEnum(); - break; - default: - input.SkipLastField(); - break; - } - } - } - - public void WriteTo(CodedOutputStream output) - { - if (CancelStatus != CancelStatus.Unspecified) - { - output.WriteRawTag(8); // Field number 1, wire type 0 (varint) - output.WriteEnum((int)CancelStatus); - } - } - - public int CalculateSize() - { - int size = 0; - if (CancelStatus != CancelStatus.Unspecified) - { - size += 1 + CodedOutputStream.ComputeEnumSize((int)CancelStatus); - } - - return size; - } - - public MessageDescriptor? Descriptor { get; } - - - public CancelFlightInfoResult Clone() => new() { CancelStatus = CancelStatus }; - - public bool Equals(CancelFlightInfoResult other) - { - if (other == null) - { - return false; - } - - return CancelStatus == other.CancelStatus; - } - - public override int GetHashCode() - { - return CancelStatus.GetHashCode(); - } - - public override string ToString() - { - return $"CancelFlightInfoResult {{ CancelStatus = {CancelStatus} }}"; - } -} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index d3acc4323a036..62ac3e5a48a7e 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -92,7 +92,8 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options, string que try { - var updateRequestCommand = new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; + var updateRequestCommand = + new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; byte[] serializedUpdateRequestCommand = updateRequestCommand.PackAndSerialize(); var action = new FlightAction(SqlAction.CreateRequest, serializedUpdateRequestCommand); var call = DoActionAsync(options, action); @@ -215,7 +216,8 @@ public async Task GetExecuteSchemaAsync(FlightCallOptions options, strin FlightInfo schemaResult = null!; try { - var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; + var prepareStatementRequest = + new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); var call = _client.DoAction(action, options.Headers); @@ -969,8 +971,8 @@ public async Task GetSqlInfoSchemaAsync(FlightCallOptions options) /// RPC-layer hints for this call. /// The CancelFlightInfoRequest. /// A Task representing the asynchronous operation. The task result contains the CancelFlightInfoResult describing the canceled result. - public async Task CancelFlightInfoAsync(FlightCallOptions options, - CancelFlightInfoRequest request) + public async Task CancelFlightInfoAsync(FlightCallOptions options, + FlightInfoCancelRequest request) { if (options == null) throw new ArgumentNullException(nameof(options)); if (request == null) throw new ArgumentNullException(nameof(request)); @@ -981,10 +983,10 @@ public async Task CancelFlightInfoAsync(FlightCallOption var call = _client.DoAction(action, options.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { - var cancelResult = Any.Parser.ParseFrom(result.Body); - if (cancelResult.TryUnpack(out CancelFlightInfoResult cancelFlightInfoResult)) + if (Any.Parser.ParseFrom(result.Body) is Any anyResult && + anyResult.TryUnpack(out FlightInfoCancelResult cancelResult)) { - return cancelFlightInfoResult; + return cancelResult; } } @@ -1002,38 +1004,29 @@ public async Task CancelFlightInfoAsync(FlightCallOption /// RPC-layer hints for this call. /// The FlightInfo of the query to cancel. /// A Task representing the asynchronous operation. - public async Task CancelQueryAsync(FlightCallOptions options, FlightInfo info) + public async Task CancelQueryAsync(FlightCallOptions options, FlightInfo info) { if (options == null) - { throw new ArgumentNullException(nameof(options)); - } if (info == null) - { throw new ArgumentNullException(nameof(info)); - } try { - var cancelRequest = new CancelFlightInfoRequest(info); - var action = new FlightAction(SqlAction.CancelFlightInfoRequest, cancelRequest.ToByteString()); - var call = _client.DoAction(action, options.Headers); - await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + var cancelQueryRequest = new FlightInfoCancelRequest(info); + var cancelQueryAction = + new FlightAction(SqlAction.CancelFlightInfoRequest, cancelQueryRequest.PackAndSerialize()); + var cancelQueryCall = _client.DoAction(cancelQueryAction, options.Headers); + + await foreach (var result in cancelQueryCall.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { - var cancelResult = Any.Parser.ParseFrom(result.Body); - if (cancelResult.TryUnpack(out CancelFlightInfoResult cancelFlightInfoResult)) + if (Any.Parser.ParseFrom(result.Body) is Any anyResult && + anyResult.TryUnpack(out FlightInfoCancelResult cancelResult)) { - return cancelFlightInfoResult.CancelStatus switch - { - CancelStatus.Cancelled => CancelStatus.Cancelled, - CancelStatus.Cancelling => CancelStatus.Cancelling, - CancelStatus.NotCancellable => CancelStatus.NotCancellable, - _ => CancelStatus.Unspecified - }; + return cancelResult; } } - throw new InvalidOperationException("Failed to cancel query: No response received."); } catch (RpcException ex) diff --git a/csharp/src/Apache.Arrow.Flight/FlightInfoCancelRequest.cs b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelRequest.cs new file mode 100644 index 0000000000000..a4cbe9cb75b4c --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelRequest.cs @@ -0,0 +1,38 @@ +using System; +using Apache.Arrow.Flight.Protocol; +using Google.Protobuf; +using Google.Protobuf.Reflection; + +namespace Apache.Arrow.Flight; + +public class FlightInfoCancelRequest : IMessage +{ + private readonly CancelFlightInfoRequest _cancelFlightInfoRequest; + public FlightInfo FlightInfo { get; private set; } + + public FlightInfoCancelRequest(FlightInfo flightInfo) + { + FlightInfo = flightInfo ?? throw new ArgumentNullException(nameof(flightInfo)); + _cancelFlightInfoRequest = new CancelFlightInfoRequest(); + } + + public FlightInfoCancelRequest() + { + _cancelFlightInfoRequest = new CancelFlightInfoRequest(); + } + + public void MergeFrom(CodedInputStream input) + { + _cancelFlightInfoRequest.MergeFrom(input); + } + + public void WriteTo(CodedOutputStream output) + { + _cancelFlightInfoRequest.WriteTo(output); + } + + public int CalculateSize() => _cancelFlightInfoRequest.CalculateSize(); + + public MessageDescriptor Descriptor => + DescriptorReflection.Descriptor.MessageTypes[0]; +} diff --git a/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs new file mode 100644 index 0000000000000..a10c5977fa601 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs @@ -0,0 +1,34 @@ +using System; +using Apache.Arrow.Flight.Protocol; +using Google.Protobuf; +using Google.Protobuf.Reflection; + +namespace Apache.Arrow.Flight; + +public class FlightInfoCancelResult : IMessage +{ + private readonly CancelFlightInfoResult _flightInfoCancelResult; + + public FlightInfoCancelResult() + { + _flightInfoCancelResult = new CancelFlightInfoResult(); + Descriptor = + DescriptorReflection.Descriptor.MessageTypes[0]; + } + + public void MergeFrom(CodedInputStream input) => _flightInfoCancelResult.MergeFrom(input); + + public void WriteTo(CodedOutputStream output) => _flightInfoCancelResult.WriteTo(output); + + public int CalculateSize() + { + return _flightInfoCancelResult.CalculateSize(); + } + + public MessageDescriptor Descriptor { get; } + + public int GetCancelStatus() + { + return (int)_flightInfoCancelResult.Status; + } +} diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 4dc1c7c9d71e9..2c487cde8d17c 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -10,7 +10,6 @@ using Google.Protobuf; using Grpc.Core.Utils; using Xunit; -using RecordBatchWithMetadata = Apache.Arrow.Flight.Sql.TestWeb.RecordBatchWithMetadata; namespace Apache.Arrow.Flight.Sql.Tests; @@ -729,14 +728,13 @@ public async Task CancelFlightInfoAsync() .Build(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var flightInfo = new FlightInfo(schema, flightDescriptor, new List(), 0, 0); - var cancelRequest = new CancelFlightInfoRequest(flightInfo); + var cancelRequest = new FlightInfoCancelRequest(flightInfo); // Act var cancelResult = await _flightSqlClient.CancelFlightInfoAsync(options, cancelRequest); // Assert - Assert.NotNull(cancelResult); - Assert.True(cancelResult.CancelStatus == CancelStatus.Cancelled); + Assert.Equal(0, cancelResult.GetCancelStatus()); } [Fact] @@ -744,7 +742,7 @@ public async Task CancelQueryAsync() { // Arrange var options = new FlightCallOptions(); - var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test-query"); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var schema = new Schema .Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) @@ -759,7 +757,7 @@ public async Task CancelQueryAsync() var cancelStatus = await _flightSqlClient.CancelQueryAsync(options, flightInfo); // Assert - Assert.Equal(CancelStatus.Cancelled, cancelStatus); + Assert.Equal(0, cancelStatus.GetCancelStatus()); } public void Dispose() => _testWebFactory?.Dispose(); From f36616b9b092c379589aa20d9b45aa41489007d7 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 25 Sep 2024 17:13:13 +0300 Subject: [PATCH 24/58] test: allow update status from FlightInfoCancelResult --- .../TestFlightSqlServer.cs | 9 ++------- csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs | 5 +++++ .../FlightSqlClientTests.cs | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs index 481b0c9c0e29c..3ff0291dffd53 100644 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs @@ -32,13 +32,8 @@ public override async Task DoAction(FlightAction request, IAsyncStreamWriter f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) - .Build(); - var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); - var flightInfo = new FlightInfo(schema, flightDescriptor, new List(), 0, 0); - var cancelRequest = new FlightInfoCancelRequest(flightInfo); + var cancelRequest = new FlightInfoCancelResult(); + cancelRequest.SetStatus(1); await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray())).ConfigureAwait(false); break; case "BeginTransaction": diff --git a/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs index a10c5977fa601..b41c54c6f3455 100644 --- a/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs +++ b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs @@ -31,4 +31,9 @@ public int GetCancelStatus() { return (int)_flightInfoCancelResult.Status; } + + public void SetStatus(int status) + { + _flightInfoCancelResult.Status = (CancelStatus)status; + } } diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 2c487cde8d17c..0f99caf30e58c 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -734,7 +734,7 @@ public async Task CancelFlightInfoAsync() var cancelResult = await _flightSqlClient.CancelFlightInfoAsync(options, cancelRequest); // Assert - Assert.Equal(0, cancelResult.GetCancelStatus()); + Assert.Equal(1, cancelResult.GetCancelStatus()); } [Fact] @@ -757,7 +757,7 @@ public async Task CancelQueryAsync() var cancelStatus = await _flightSqlClient.CancelQueryAsync(options, flightInfo); // Assert - Assert.Equal(0, cancelStatus.GetCancelStatus()); + Assert.Equal(1, cancelStatus.GetCancelStatus()); } public void Dispose() => _testWebFactory?.Dispose(); From be327a2efcd4525272cb7d039d239f08c1a8fa2f Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 29 Sep 2024 16:27:46 +0300 Subject: [PATCH 25/58] version: update 3.28.1 to 3.28.2 --- .../Apache.Arrow.Flight.Sql.TestWeb.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj index 0fd8f47c393f5..0ce968ee7cd1e 100644 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj @@ -6,7 +6,7 @@ - + From 1940c862d8996acd27bad41fe8f936e0a9aa4b20 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 29 Sep 2024 16:51:56 +0300 Subject: [PATCH 26/58] chore: remove commented out code - ref: https://github.com/HackPoint/arrow/pull/1#pullrequestreview-2331978166 --- cpp/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 3e152fb59d6d8..e333892c8f2cb 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -740,10 +740,10 @@ if(ARROW_SKYHOOK) add_subdirectory(src/skyhook) endif() -#if(ARROW_BUILD_EXAMPLES) +if(ARROW_BUILD_EXAMPLES) add_custom_target(runexample ctest -L example) add_subdirectory(examples/arrow) -#endif() +endif() install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/../LICENSE.txt ${CMAKE_CURRENT_SOURCE_DIR}/../NOTICE.txt @@ -759,4 +759,4 @@ validate_config() config_summary_message() if(${ARROW_BUILD_CONFIG_SUMMARY_JSON}) config_summary_json() -endif() +endif() \ No newline at end of file From 3c6a6dd1ca7c4b10a8dbbe6254d7a21ebe31c888 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 29 Sep 2024 16:56:50 +0300 Subject: [PATCH 27/58] chore: removed testing project --- ...he.Arrow.Flight.Sql.IntegrationTest.csproj | 15 - .../Program.cs | 425 ------------------ csharp/Apache.Arrow.sln | 2 - 3 files changed, 442 deletions(-) delete mode 100644 csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Apache.Arrow.Flight.Sql.IntegrationTest.csproj delete mode 100644 csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Apache.Arrow.Flight.Sql.IntegrationTest.csproj b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Apache.Arrow.Flight.Sql.IntegrationTest.csproj deleted file mode 100644 index 0c5b923d72880..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Apache.Arrow.Flight.Sql.IntegrationTest.csproj +++ /dev/null @@ -1,15 +0,0 @@ - - - - Exe - net6.0 - enable - enable - - - - - - - - diff --git a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs b/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs deleted file mode 100644 index 4683603d67037..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.IntegrationTest/Program.cs +++ /dev/null @@ -1,425 +0,0 @@ -/*using Apache.Arrow; -using Apache.Arrow.Flight; -using Apache.Arrow.Flight.Client; -using Grpc.Core; -using Grpc.Net.Client; - -namespace FlightClientExample -{ - public class Program - { - public static async Task Main(string[] args) - { - string host = args.Length > 0 ? args[0] : "localhost"; - string port = args.Length > 1 ? args[1] : "5000"; - - // Create client - // (In production systems, you should use https not http) - var address = $"http://{host}:{port}"; - Console.WriteLine($"Connecting to: {address}"); - var channel = GrpcChannel.ForAddress(address); - var client = new FlightClient(channel); - - var recordBatches = new[] { CreateTestBatch(0, 2000), CreateTestBatch(50, 9000) }; - - // Particular flights are identified by a descriptor. This might be a name, - // a SQL query, or a path. Here, just using the name "test". - var descriptor = FlightDescriptor.CreatePathDescriptor("//SYSDB/Info"); //.CreateCommandDescriptor("SELECT * FROM SYSDB.`Info` "); - - // Upload data with StartPut - // var batchStreamingCall = client.StartPut(descriptor); - // foreach (var batch in recordBatches) - // { - // await batchStreamingCall.RequestStream.WriteAsync(batch); - // } - // - // // Signal we are done sending record batches - // await batchStreamingCall.RequestStream.CompleteAsync(); - // // Retrieve final response - // await batchStreamingCall.ResponseStream.MoveNext(); - // Console.WriteLine(batchStreamingCall.ResponseStream.Current.ApplicationMetadata.ToStringUtf8()); - // Console.WriteLine($"Wrote {recordBatches.Length} batches to server."); - - // Request information: - //var schema = await client.GetSchema(descriptor).ResponseAsync; - //Console.WriteLine($"Schema saved as: \n {schema}"); - - var info = await client.GetInfo(descriptor).ResponseAsync; - Console.WriteLine($"Info provided: \n {info.TotalRecords}"); - - Console.WriteLine($"Available flights:"); - // var flights_call = client.ListFlights(); - // - // while (await flights_call.ResponseStream.MoveNext()) - // { - // Console.WriteLine(" " + flights_call.ResponseStream.Current); - // } - - // // Download data - // await foreach (var batch in StreamRecordBatches(info)) - // { - // Console.WriteLine($"Read batch from flight server: \n {batch}"); - // } - - // See available commands on this server - // var action_stream = client.ListActions(); - // Console.WriteLine("Actions:"); - // while (await action_stream.ResponseStream.MoveNext()) - // { - // var action = action_stream.ResponseStream.Current; - // Console.WriteLine($" {action.Type}: {action.Description}"); - // } - // - // // Send clear command to drop all data from the server. - // var clear_result = client.DoAction(new FlightAction("clear")); - // await clear_result.ResponseStream.MoveNext(default); - } - - public static async IAsyncEnumerable StreamRecordBatches( - FlightInfo info - ) - { - // There might be multiple endpoints hosting part of the data. In simple services, - // the only endpoint might be the same server we initially queried. - foreach (var endpoint in info.Endpoints) - { - // We may have multiple locations to choose from. Here we choose the first. - var download_channel = GrpcChannel.ForAddress(endpoint.Locations.First().Uri); - var download_client = new FlightClient(download_channel); - - var stream = download_client.GetStream(endpoint.Ticket); - - while (await stream.ResponseStream.MoveNext()) - { - yield return stream.ResponseStream.Current; - } - } - } - - public static RecordBatch CreateTestBatch(int start, int length) - { - return new RecordBatch.Builder() - .Append("Column A", false, - col => col.Int32(array => array.AppendRange(Enumerable.Range(start, start + length)))) - .Append("Column B", false, - col => col.Float(array => - array.AppendRange(Enumerable.Range(start, start + length) - .Select(x => Convert.ToSingle(x * 2))))) - .Append("Column C", false, - col => col.String(array => - array.AppendRange(Enumerable.Range(start, start + length).Select(x => $"Item {x + 1}")))) - .Append("Column D", false, - col => col.Boolean(array => - array.AppendRange(Enumerable.Range(start, start + length).Select(x => x % 2 == 0)))) - .Build(); - } - } -}*/ - -using Apache.Arrow.Flight.Client; -using Apache.Arrow.Flight.Sql.Client; -using Apache.Arrow.Types; -using Arrow.Flight.Protocol.Sql; -using Google.Protobuf; -using Google.Protobuf.WellKnownTypes; -using Grpc.Core; -using Grpc.Net.Client; - -namespace Apache.Arrow.Flight.Sql.IntegrationTest; - -class Program -{ - static async Task Main(string[] args) - { - var httpHandler = new SocketsHttpHandler - { - PooledConnectionIdleTimeout = TimeSpan.FromMinutes(1), - KeepAlivePingDelay = TimeSpan.FromSeconds(60), - KeepAlivePingTimeout = TimeSpan.FromSeconds(30), - EnableMultipleHttp2Connections = true - }; - // Initialize the gRPC channel to connect to the Flight server - using var channel = GrpcChannel.ForAddress("http://localhost:5000", - new GrpcChannelOptions { HttpHandler = httpHandler, Credentials = ChannelCredentials.Insecure }); - - // Initialize the Flight client - var flightClient = new FlightClient(channel); - var sqlClient = new FlightSqlClient(flightClient); - - // Define the SQL query - string query = "SELECT * FROM SYSDB.`Info`"; - - try - { - // ExecuteAsync - Console.WriteLine("ExecuteAsync:"); - var flightInfo = await sqlClient.ExecuteAsync(new FlightCallOptions(), query); - - // ExecuteUpdate - // Console.WriteLine("ExecuteUpdate:"); - string updateQuery = "UPDATE SYSDB.`Info` SET Key = 1, Val=10 WHERE Id=1"; - long affectedRows = await sqlClient.ExecuteUpdateAsync(new FlightCallOptions(), updateQuery); - Console.WriteLine($@"Number of affected rows: {affectedRows}"); - // - // // GetExecuteSchema - // Console.WriteLine("GetExecuteSchema:"); - // var schemaResult = await sqlClient.GetExecuteSchemaAsync(new FlightCallOptions(), query); - // // Process the schemaResult as needed - // Console.WriteLine($"Schema retrieved successfully:{schemaResult}"); - // - // // ExecuteIngest - // - // // GetCatalogs - // Console.WriteLine("GetCatalogs:"); - // var catalogsInfo = await sqlClient.GetCatalogs(new FlightCallOptions()); - // // Print catalog details - // Console.WriteLine("Catalogs retrieved:"); - // foreach (var endpoint in catalogsInfo.Endpoints) - // { - // var ticket = endpoint.Ticket; - // Console.WriteLine($"- Ticket: {ticket}"); - // } - - // GetCatalogsSchema - // Console.WriteLine("GetCatalogsSchema:"); - // Schema schemaCatalogResult = await sqlClient.GetCatalogsSchema(new FlightCallOptions()); - // Console.WriteLine("Catalogs Schema retrieved:"); - // Console.WriteLine(schemaCatalogResult); - - // GetDbSchemasAsync - // Console.WriteLine("GetDbSchemasAsync:"); - // FlightInfo flightInfoDbSchemas = - // await sqlClient.GetDbSchemasAsync(new FlightCallOptions(), "default_catalog", "public"); - // // Process the FlightInfoDbSchemas - // Console.WriteLine("Database schemas retrieved:"); - // Console.WriteLine(flightInfoDbSchemas); - - // GetDbSchemasSchemaAsync - // Console.WriteLine("GetDbSchemasSchemaAsync:"); - // Schema schema = await sqlClient.GetDbSchemasSchemaAsync(new FlightCallOptions()); - // // Process the Schema - // Console.WriteLine("Database schemas schema retrieved:"); - // Console.WriteLine(schema); - - // DoPut - // Console.WriteLine("DoPut:"); - // await PutExample(sqlClient, query); - - // GetPrimaryKeys - // Console.WriteLine("GetPrimaryKeys:"); - // var tableRef = new TableRef - // { - // DbSchema = "SYSDB", - // Table = "Info" - // }; - // var getPrimaryKeysInfo = await sqlClient.GetPrimaryKeys(new FlightCallOptions(), tableRef); - // Console.WriteLine("Primary keys information retrieved successfully."); - - - // Call GetTablesAsync method - // Console.WriteLine("GetTablesAsync:"); - // IEnumerable tables = await sqlClient.GetTablesAsync( - // new FlightCallOptions(), - // catalog: "", - // dbSchemaFilterPattern: "public", - // tableFilterPattern: "SYSDB", - // includeSchema: true, - // tableTypes: new List { "TABLE", "VIEW" }); - // foreach (var table in tables) - // { - // Console.WriteLine($"Table URI: {table.Descriptor.Paths}"); - // foreach (var endpoint in table.Endpoints) - // { - // Console.WriteLine($"Endpoint Ticket: {endpoint.Ticket}"); - // } - // } - // - // var tableRef = new TableRef { Catalog = "", DbSchema = "SYSDB", Table = "Info" }; - - // Get exported keys - // Console.WriteLine("GetExportedKeysAsync:"); - // var tableRef = new TableRef { Catalog = "", DbSchema = "SYSDB", Table = "Info" }; - // var flightInfoExportedKeys = await sqlClient.GetExportedKeysAsync(new FlightCallOptions(), tableRef); - // Console.WriteLine("FlightInfo obtained:"); - // Console.WriteLine($" FlightDescriptor: {flightInfoExportedKeys.Descriptor}"); - // Console.WriteLine($" Total records: {flightInfoExportedKeys.TotalRecords}"); - // Console.WriteLine($" Total bytes: {flightInfoExportedKeys.TotalBytes}"); - - // Get exported keys schema - // var schema = await sqlClient.GetExportedKeysSchemaAsync(new FlightCallOptions()); - // Console.WriteLine("Schema obtained:"); - // Console.WriteLine($" Fields: {string.Join(", ", schema.FieldsList)}"); - - // Get imported keys - // Console.WriteLine("GetImportedKeys"); - // var flightInfoGetImportedKeys = sqlClient.GetImportedKeysAsync(new FlightCallOptions(), tableRef); - // Console.WriteLine("FlightInfo obtained:"); - // Console.WriteLine($@" Location: {flightInfoGetImportedKeys.Result.Endpoints[0]}"); - - // Get imported keys schema - // Console.WriteLine("GetImportedKeysSchemaAsync:"); - // var schema = await sqlClient.GetImportedKeysSchemaAsync(new FlightCallOptions()); - // Console.WriteLine("Imported Keys Schema obtained:"); - // Console.WriteLine($"Schema Fields: {string.Join(", ", schema.FieldsList)}"); - - // Get cross reference - // Console.WriteLine("GetCrossReferenceAsync:"); - // var flightInfoGetCrossReference = await sqlClient.GetCrossReferenceAsync(new FlightCallOptions(), tableRef, new TableRef - // { - // Catalog = "catalog2", - // DbSchema = "schema2", - // Table = "table2" - // }); - // Console.WriteLine("Cross Reference Information obtained:"); - // Console.WriteLine($"Flight Descriptor: {flightInfoGetCrossReference.Descriptor}"); - // Console.WriteLine($"Endpoints: {string.Join(", ", flightInfoGetCrossReference.Endpoints)}"); - - // Get cross-reference schema - // Console.WriteLine("GetCrossReferenceSchemaAsync:"); - // var schema = await sqlClient.GetCrossReferenceSchemaAsync(new FlightCallOptions()); - // Console.WriteLine("Cross Reference Schema obtained:"); - // Console.WriteLine($"Schema: {schema}"); - - - // Get table types - // Console.WriteLine("GetTableTypesAsync:"); - // var tableTypesInfo = await sqlClient.GetTableTypesAsync(new FlightCallOptions()); - // Console.WriteLine("Table Types Info obtained:"); - // Console.WriteLine($"FlightInfo: {tableTypesInfo}"); - - // Get table types schema - // Console.WriteLine("GetTableTypesSchemaAsync:"); - // var tableTypesSchema = await sqlClient.GetTableTypesSchemaAsync(new FlightCallOptions()); - // Console.WriteLine("Table Types Schema obtained:"); - // Console.WriteLine($"Schema: {tableTypesSchema}"); - - // Get XDBC type info (with DataType) - // Console.WriteLine("GetXdbcTypeInfoAsync: (With DataType)"); - // var flightInfoGetXdbcTypeInfoWithoutDataType = - // await sqlClient.GetXdbcTypeInfoAsync(new FlightCallOptions(), 4); - // Console.WriteLine("XDBC With DataType Info obtained:"); - // Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeInfoWithoutDataType}"); - - // Get XDBC type info - // Console.WriteLine("GetXdbcTypeInfoAsync:"); - // var flightInfoGetXdbcTypeInfo = await sqlClient.GetXdbcTypeInfoAsync(new FlightCallOptions()); - // Console.WriteLine("XDBC Type Info obtained:"); - // Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeInfo}"); - - // Get XDBC type info schema - // Console.WriteLine("GetXdbcTypeInfoSchemaAsync:"); - // var flightInfoGetXdbcTypeSchemaInfo = await sqlClient.GetXdbcTypeInfoSchemaAsync(new FlightCallOptions()); - // Console.WriteLine("XDBC Type Info obtained:"); - // Console.WriteLine($"FlightInfo: {flightInfoGetXdbcTypeSchemaInfo}"); - - // Get SQL info - // Console.WriteLine("GetSqlInfoAsync:"); - // Define SQL info list - // var sqlInfo = new List { 1, 2, 3 }; - // var flightInfoGetSqlInfo = sqlClient.GetSqlInfoAsync(new FlightCallOptions(), sqlInfo); - // Console.WriteLine("SQL Info obtained:"); - // Console.WriteLine($"FlightInfo: {flightInfoGetSqlInfo}"); - - // Get SQL info schema - // Console.WriteLine("GetSqlInfoSchemaAsync:"); - // var schema = await sqlClient.GetSqlInfoSchemaAsync(new FlightCallOptions()); - // Console.WriteLine("SQL Info Schema obtained:"); - // Console.WriteLine($"Schema: {schema}"); - - // Prepare a SQL statement - // Console.WriteLine("PrepareAsync:"); - // var preparedStatement = await sqlClient.PrepareAsync(new FlightCallOptions(), query); - // Console.WriteLine("Prepared statement created successfully."); - - - // Cancel FlightInfo Request - // Console.WriteLine("CancelFlightInfoRequest:"); - // var cancelRequest = new CancelFlightInfoRequest(flightInfo); - // var cancelResult = await sqlClient.CancelFlightInfoAsync(new FlightCallOptions(), cancelRequest); - // Console.WriteLine($"Cancellation Status: {cancelResult.CancelStatus}"); - - // Begin Transaction - // Console.WriteLine("BeginTransaction:"); - // Transaction transaction = await sqlClient.BeginTransactionAsync(new FlightCallOptions()); - // Console.WriteLine($"Transaction started with ID: {transaction.TransactionId}"); - // FlightInfo flightInfoBeginTransaction = - // await sqlClient.ExecuteAsync(new FlightCallOptions(), query, transaction); - // Console.WriteLine("Query executed within transaction"); - // - // // Commit Transaction - // Console.WriteLine("CommitTransaction:"); - // await sqlClient.CommitAsync(new FlightCallOptions(), new Transaction("transaction-id")); - // Console.WriteLine("Transaction committed successfully."); - // - // // Rollback Transaction - // Console.WriteLine("RollbackTransaction"); - // await sqlClient.RollbackAsync(new FlightCallOptions(), new Transaction("transaction-id")); - // Console.WriteLine("Transaction rolled back successfully."); - - // Cancel Query - // Console.WriteLine("CancelQuery:"); - // var cancelResult = await sqlClient.CancelQueryAsync(new FlightCallOptions(), flightInfo); - // Console.WriteLine($"Cancellation Status: {cancelResult}"); - } - catch (Exception ex) - { - Console.WriteLine($"Error executing query: {ex.Message}"); - } - } - - static async Task PutExample(FlightSqlClient client, string query) - { - // TODO: Talk with Jeremy about the implementation: DoPut - seems that needed to resolve missing part - var options = new FlightCallOptions(); - var body = new ActionCreatePreparedStatementRequest { Query = query }.PackAndSerialize(); - var action = new FlightAction(SqlAction.CreateRequest, body); - await foreach (FlightResult flightResult in client.DoActionAsync(action)) - { - var preparedStatementResponse = - FlightSqlUtils.ParseAndUnpack(flightResult.Body); - - var command = new CommandPreparedStatementUpdate - { - PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle - }; - - var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - // Define schema - var fields = new List - { - new("id", Int32Type.Default, nullable: false), new("name", StringType.Default, nullable: false) - }; - var metadata = - new List> { new("db_name", "SYSDB"), new("table_name", "Info") }; - var schema = new Schema(fields, metadata); - var doPutResult = await client.DoPutAsync(options, descriptor, schema).ConfigureAwait(false); - - // Example data to write - var col1 = new Int32Array.Builder().AppendRange([8, 9, 10, 11]).Build(); - var col2 = new StringArray.Builder().AppendRange(["a", "b", "c", "d"]).Build(); - var col3 = new StringArray.Builder().AppendRange(["x", "y", "z", "q"]).Build(); - var batch = new RecordBatch(schema, [col1, col2, col3], 4); - - await doPutResult.Writer.WriteAsync(batch); - await doPutResult.Writer.CompleteAsync(); - - // Handle metadata response (if any) - while (await doPutResult.Reader.MoveNext()) - { - var receivedMetadata = doPutResult.Reader.Current.ApplicationMetadata; - if (receivedMetadata != null) - { - Console.WriteLine("Received metadata: " + receivedMetadata.ToStringUtf8()); - } - } - } - } -} - -internal static class FlightDescriptorExtensions -{ - public static byte[] PackAndSerialize(this IMessage command) - { - return Any.Pack(command).Serialize().ToByteArray(); - } -} diff --git a/csharp/Apache.Arrow.sln b/csharp/Apache.Arrow.sln index 524fcf3f56b81..071561101e1e9 100644 --- a/csharp/Apache.Arrow.sln +++ b/csharp/Apache.Arrow.sln @@ -17,8 +17,6 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Flight", "src\ EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Flight.AspNetCore", "src\Apache.Arrow.Flight.AspNetCore\Apache.Arrow.Flight.AspNetCore.csproj", "{E4F74938-E8FF-4AC1-A495-FEE95FC1EFDF}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.IntegrationTest", "test\Apache.Arrow.IntegrationTest\Apache.Arrow.IntegrationTest.csproj", "{E8264B7F-B680-4A55-939B-85DB628164BB}" -EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Compression", "src\Apache.Arrow.Compression\Apache.Arrow.Compression.csproj", "{B62E77D2-D0B0-4C0C-BA78-1C117DE4C299}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.Compression.Tests", "test\Apache.Arrow.Compression.Tests\Apache.Arrow.Compression.Tests.csproj", "{5D7FF380-B7DF-4752-B415-7C08C70C9F06}" From a48a85e8afe92ce6c1b52d2a012a0004c4bf3376 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 29 Sep 2024 17:58:50 +0300 Subject: [PATCH 28/58] chore: revert ident --- csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs index c033987db3c30..14a4d491c13a6 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs @@ -61,4 +61,4 @@ public FlightInfo GetFlightInfo() }, batchArrayLength, batchBytes); } } -} +} \ No newline at end of file From 55449df49754b0e1c43cf1bb05ed2c7b40b1205b Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 29 Sep 2024 18:01:16 +0300 Subject: [PATCH 29/58] chore: revert ident --- csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs index 667182097fbc4..0e82673d02240 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs @@ -489,7 +489,7 @@ public async Task EnsureCallRaisesRequestCancelled() { var cts = new CancellationTokenSource(); cts.CancelAfter(1); - + var batch = CreateTestBatch(0, 100); var metadata = new Metadata(); var flightDescriptor = FlightDescriptor.CreatePathDescriptor("raise_cancelled"); From 5a77bfa3e6580216607f490e14083cbfba392af6 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 29 Sep 2024 18:02:40 +0300 Subject: [PATCH 30/58] chore: revert ident --- .../Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs index cd2ebc725bcc7..031495fffdcc7 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs @@ -24,5 +24,4 @@ public static ByteString PackAndSerialize(this IMessage command) { return Any.Pack(command).Serialize(); } - } From 3cb317150be6409e094cbca455da67cb14d6d662 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 29 Sep 2024 18:04:04 +0300 Subject: [PATCH 31/58] chore: fix ident --- csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs index c66c7401c7cef..adc229a051227 100644 --- a/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs +++ b/csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs @@ -304,7 +304,7 @@ public void Visit(DictionaryType type) // type in the DictionaryEncoding metadata in the parent field type.ValueType.Accept(this); } - + public void Visit(FixedSizeBinaryType type) { Result = FieldType.Build( From 157530bb797eccf41840a3b91e271854448355d1 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 30 Sep 2024 12:41:17 +0300 Subject: [PATCH 32/58] chore: remove line wrap --- csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs index b41c54c6f3455..0afcb193f8967 100644 --- a/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs +++ b/csharp/src/Apache.Arrow.Flight/FlightInfoCancelResult.cs @@ -12,8 +12,7 @@ public class FlightInfoCancelResult : IMessage public FlightInfoCancelResult() { _flightInfoCancelResult = new CancelFlightInfoResult(); - Descriptor = - DescriptorReflection.Descriptor.MessageTypes[0]; + Descriptor = DescriptorReflection.Descriptor.MessageTypes[0]; } public void MergeFrom(CodedInputStream input) => _flightInfoCancelResult.MergeFrom(input); From e42336039dec7bfc951b4d894062f41ce9eba9b6 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 30 Sep 2024 12:45:26 +0300 Subject: [PATCH 33/58] chore: remove as this is specific to JetBrains --- .../Apache.Arrow.Flight.Sql.Tests.csproj | 7 ------- 1 file changed, 7 deletions(-) diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj index 9a6fb1f95ad4f..3a978f9b411d8 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj @@ -17,11 +17,4 @@ - - - - ..\..\..\..\..\..\Applications\Rider.app\Contents\lib\ReSharperHost\TestRunner\netcoreapp3.0\JetBrains.ReSharper.TestRunner.Merged.dll - - - From 5d93c15f456944a82a43a96a4284317759346e64 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 30 Sep 2024 13:13:26 +0300 Subject: [PATCH 34/58] chore: use the predefined constants instead of the hardcoded strings --- .../TestFlightSqlServer.cs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs index 3ff0291dffd53..cbebf02a5157e 100644 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs +++ b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs @@ -28,21 +28,21 @@ public override async Task DoAction(FlightAction request, IAsyncStreamWriter Date: Sun, 6 Oct 2024 16:19:13 +0300 Subject: [PATCH 35/58] chore: move the runner of the test project and testing env to inner project --- .../Apache.Arrow.Flight.Sql.TestWeb.csproj | 18 - .../FlightSqlHolder.cs | 59 ---- .../FlightSqlStore.cs | 8 - .../Program.cs | 29 -- .../Properties/launchSettings.json | 38 -- .../RecordBatchWithMetadata.cs | 15 - .../Startup.cs | 39 --- .../TestFlightSqlServer.cs | 175 ---------- .../appsettings.Development.json | 8 - .../appsettings.json | 9 - csharp/Apache.Arrow.sln | 6 - .../Apache.Arrow.Flight.Sql.Tests.csproj | 2 +- .../FlightSqlClientTests.cs | 61 ++-- .../FlightSqlPreparedStatementTests.cs | 19 +- .../FlightSqlTestUtils.cs | 13 +- .../Apache.Arrow.Flight.Sql.Tests/Startup.cs | 42 +++ .../TestFlightSqlWebFactory.cs} | 15 +- .../FlightHolder.cs | 17 +- .../TestFlightServer.cs | 9 - .../TestFlightSqlServer.cs | 326 ++++++++++++++++++ 20 files changed, 441 insertions(+), 467 deletions(-) delete mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj delete mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs delete mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlStore.cs delete mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/Program.cs delete mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/Properties/launchSettings.json delete mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/RecordBatchWithMetadata.cs delete mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/Startup.cs delete mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs delete mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.Development.json delete mode 100644 csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.json create mode 100644 csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs rename csharp/{Apache.Arrow.Flight.Sql.TestWeb/TestSqlWebFactory.cs => test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs} (81%) create mode 100644 csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj deleted file mode 100644 index 0ce968ee7cd1e..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Apache.Arrow.Flight.Sql.TestWeb.csproj +++ /dev/null @@ -1,18 +0,0 @@ - - - - net8.0 - - - - - - - - - - - - - - diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs deleted file mode 100644 index 88e6e545c2b99..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlHolder.cs +++ /dev/null @@ -1,59 +0,0 @@ -using System.Collections.Generic; -using System.Linq; - -namespace Apache.Arrow.Flight.Sql.TestWeb; - -public class FlightSqlHolder -{ - private readonly FlightDescriptor _flightDescriptor; - private readonly Schema _schema; - private readonly string _location; - - //Not thread safe, but only used in tests - private readonly List _recordBatches = new List(); - - public FlightSqlHolder(FlightDescriptor flightDescriptor, Schema schema, string location) - { - _flightDescriptor = flightDescriptor; - _schema = schema; - _location = location; - } - - public void AddBatch(RecordBatchWithMetadata recordBatchWithMetadata) - { - //Should validate schema here - _recordBatches.Add(recordBatchWithMetadata); - } - - public IEnumerable GetRecordBatches() - { - return _recordBatches.ToList(); - } - - public FlightInfo GetFlightInfo() - { - int batchArrayLength = _recordBatches.Sum(rb => rb.RecordBatch.Length); - int batchBytes = - _recordBatches.Sum(rb => rb.RecordBatch.Arrays.Sum(arr => arr.Data.Buffers.Sum(b => b.Length))); - var flightInfo = new FlightInfo(_schema, _flightDescriptor, - new List() - { - new FlightEndpoint(new FlightTicket( - CustomTicketStrategy(_flightDescriptor) - ), - new List() { new FlightLocation(_location) }) - }, batchArrayLength, batchBytes); - return flightInfo; - } - - private string CustomTicketStrategy(FlightDescriptor descriptor) - { - if (descriptor.Command.Length > 0) - { - return $"{descriptor.Command.ToStringUtf8()}"; - } - - // Fallback in case there is no command in the descriptor - return "default_custom_ticket"; - } -} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlStore.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlStore.cs deleted file mode 100644 index 9ac7df457b1b1..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/FlightSqlStore.cs +++ /dev/null @@ -1,8 +0,0 @@ -using System.Collections.Generic; - -namespace Apache.Arrow.Flight.Sql.TestWeb; - -public class FlightSqlStore -{ - public Dictionary Flights { get; set; } = new(); -} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Program.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Program.cs deleted file mode 100644 index 9a56cb0c998e6..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Program.cs +++ /dev/null @@ -1,29 +0,0 @@ -using System.Net; -using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Server.Kestrel.Core; -using Microsoft.Extensions.Hosting; - -namespace Apache.Arrow.Flight.Sql.TestWeb; - -public class Program -{ - public static void Main(string[] args) - { - CreateHostBuilder(args).Build().Run(); - } - - private static IHostBuilder CreateHostBuilder(string[] args) => - Host.CreateDefaultBuilder(args) - .ConfigureWebHostDefaults(webBuilder => - { - webBuilder - .ConfigureKestrel((context, options) => - { - if (context.HostingEnvironment.IsDevelopment()) - { - options.Listen(IPEndPoint.Parse("0.0.0.0:5001"), l => l.Protocols = HttpProtocols.Http2); - } - }) - .UseStartup(); - }); -} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Properties/launchSettings.json b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Properties/launchSettings.json deleted file mode 100644 index 08f737dc9415e..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Properties/launchSettings.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "$schema": "http://json.schemastore.org/launchsettings.json", - "iisSettings": { - "windowsAuthentication": false, - "anonymousAuthentication": true, - "iisExpress": { - "applicationUrl": "http://localhost:64484", - "sslPort": 44321 - } - }, - "profiles": { - "http": { - "commandName": "Project", - "dotnetRunMessages": true, - "launchBrowser": true, - "applicationUrl": "http://localhost:5285", - "environmentVariables": { - "ASPNETCORE_ENVIRONMENT": "Development" - } - }, - "https": { - "commandName": "Project", - "dotnetRunMessages": true, - "launchBrowser": true, - "applicationUrl": "https://localhost:7276;http://localhost:5285", - "environmentVariables": { - "ASPNETCORE_ENVIRONMENT": "Development" - } - }, - "IIS Express": { - "commandName": "IISExpress", - "launchBrowser": true, - "environmentVariables": { - "ASPNETCORE_ENVIRONMENT": "Development" - } - } - } -} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/RecordBatchWithMetadata.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/RecordBatchWithMetadata.cs deleted file mode 100644 index 214d5d557b00a..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/RecordBatchWithMetadata.cs +++ /dev/null @@ -1,15 +0,0 @@ -using Google.Protobuf; - -namespace Apache.Arrow.Flight.Sql.TestWeb; - -public class RecordBatchWithMetadata -{ - public RecordBatch RecordBatch { get; } - public ByteString Metadata { get; } - - public RecordBatchWithMetadata(RecordBatch recordBatch, ByteString metadata = null) - { - RecordBatch = recordBatch; - Metadata = metadata; - } -} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Startup.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/Startup.cs deleted file mode 100644 index 4019143d57747..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/Startup.cs +++ /dev/null @@ -1,39 +0,0 @@ -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Hosting; - -namespace Apache.Arrow.Flight.Sql.TestWeb; - -public class Startup -{ - public void ConfigureServices(IServiceCollection services) - { - services.AddGrpc() - .AddFlightServer(); - services.AddSingleton(); - } - - public void Configure(IApplicationBuilder app, IWebHostEnvironment env) - { - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } - - app.UseRouting(); - - app.UseEndpoints(endpoints => - { - endpoints.MapFlightEndpoint(); - - endpoints.MapGet("/", - async context => - { - await context.Response.WriteAsync( - "Communication with gRPC endpoints must be made through a gRPC client. To learn how to create a client, visit: https://go.microsoft.com/fwlink/?linkid=2086909"); - }); - }); - } -} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs b/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs deleted file mode 100644 index cbebf02a5157e..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestFlightSqlServer.cs +++ /dev/null @@ -1,175 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using Apache.Arrow.Flight.Server; -using Apache.Arrow.Types; -using Arrow.Flight.Protocol.Sql; -using Google.Protobuf; -using Google.Protobuf.WellKnownTypes; -using Grpc.Core; - -namespace Apache.Arrow.Flight.Sql.TestWeb; - -public class TestFlightSqlServer : FlightServer -{ - private readonly FlightSqlStore _flightStore; - - public TestFlightSqlServer(FlightSqlStore flightStore) - { - _flightStore = flightStore; - } - - public override async Task DoAction(FlightAction request, IAsyncStreamWriter responseStream, - ServerCallContext context) - { - switch (request.Type) - { - case "test": - await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false); - break; - case SqlAction.GetPrimaryKeysRequest: - await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false); - break; - case SqlAction.CancelFlightInfoRequest: - var cancelRequest = new FlightInfoCancelResult(); - cancelRequest.SetStatus(1); - await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray())).ConfigureAwait(false); - break; - case SqlAction.BeginTransactionRequest: - case SqlAction.CommitRequest: - case SqlAction.RollbackRequest: - await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id"))).ConfigureAwait(false); - break; - case SqlAction.CreateRequest: - case SqlAction.CloseRequest: - var prepareStatementResponse = new ActionCreatePreparedStatementResult - { - PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement") - }; - byte[] packedResult = Any.Pack(prepareStatementResponse).Serialize().ToByteArray(); - var flightResult = new FlightResult(packedResult); - await responseStream.WriteAsync(flightResult).ConfigureAwait(false); - break; - default: - throw new NotImplementedException(); - } - } - - public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, - ServerCallContext context) - { - var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(ticket.Ticket.ToStringUtf8()); - - if (_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) - { - var batches = flightHolder.GetRecordBatches(); - foreach (var batch in batches) - { - await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata).ConfigureAwait(false); - } - } - } - - public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, - IAsyncStreamWriter responseStream, ServerCallContext context) - { - var flightDescriptor = await requestStream.FlightDescriptor; - - if (!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) - { - flightHolder = new FlightSqlHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}"); - _flightStore.Flights.Add(flightDescriptor, flightHolder); - } - - while (await requestStream.MoveNext()) - { - flightHolder.AddBatch(new RecordBatchWithMetadata(requestStream.Current, - requestStream.ApplicationMetadata.FirstOrDefault())); - await responseStream.WriteAsync(FlightPutResult.Empty).ConfigureAwait(false); - } - } - - public override Task GetFlightInfo(FlightDescriptor request, ServerCallContext context) - { - if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) - { - return Task.FromResult(flightHolder.GetFlightInfo()); - } - - if (_flightStore.Flights.Count > 0) - { - // todo: should rethink of the way to implement dynamic Flights search - return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo()); - } - - throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); - } - - - public override async Task Handshake(IAsyncStreamReader requestStream, - IAsyncStreamWriter responseStream, ServerCallContext context) - { - while (await requestStream.MoveNext().ConfigureAwait(false)) - { - if (requestStream.Current.Payload.ToStringUtf8() == "Hello") - { - await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Hello handshake"))) - .ConfigureAwait(false); - } - else - { - await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Done"))).ConfigureAwait(false); - } - } - } - - public override Task GetSchema(FlightDescriptor request, ServerCallContext context) - { - if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) - { - return Task.FromResult(flightHolder.GetFlightInfo().Schema); - } - - if (_flightStore.Flights.Count > 0) - { - // todo: should rethink of the way to implement dynamic Flights search - return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo().Schema); - } - - throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); - } - - public override async Task ListActions(IAsyncStreamWriter responseStream, - ServerCallContext context) - { - await responseStream.WriteAsync(new FlightActionType("get", "get a flight")).ConfigureAwait(false); - await responseStream.WriteAsync(new FlightActionType("put", "add a flight")).ConfigureAwait(false); - await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight")).ConfigureAwait(false); - await responseStream.WriteAsync(new FlightActionType("test", "test action")).ConfigureAwait(false); - await responseStream.WriteAsync(new FlightActionType("commit", "commit a transaction")).ConfigureAwait(false); - await responseStream.WriteAsync(new FlightActionType("rollback", "rollback a transaction")).ConfigureAwait(false); - } - - public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter responseStream, - ServerCallContext context) - { - var flightInfos = _flightStore.Flights.Select(x => x.Value.GetFlightInfo()).ToList(); - - foreach (var flightInfo in flightInfos) - { - await responseStream.WriteAsync(flightInfo).ConfigureAwait(false); - } - } - - public override async Task DoExchange(FlightServerRecordBatchStreamReader requestStream, - FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) - { - while (await requestStream.MoveNext().ConfigureAwait(false)) - { - await responseStream - .WriteAsync(requestStream.Current, requestStream.ApplicationMetadata.FirstOrDefault()) - .ConfigureAwait(false); - } - } -} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.Development.json b/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.Development.json deleted file mode 100644 index 0c208ae9181e5..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.Development.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "Logging": { - "LogLevel": { - "Default": "Information", - "Microsoft.AspNetCore": "Warning" - } - } -} diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.json b/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.json deleted file mode 100644 index 10f68b8c8b4f7..0000000000000 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/appsettings.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "Logging": { - "LogLevel": { - "Default": "Information", - "Microsoft.AspNetCore": "Warning" - } - }, - "AllowedHosts": "*" -} diff --git a/csharp/Apache.Arrow.sln b/csharp/Apache.Arrow.sln index 071561101e1e9..b140eed87b226 100644 --- a/csharp/Apache.Arrow.sln +++ b/csharp/Apache.Arrow.sln @@ -25,8 +25,6 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.Tes EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql", "src\Apache.Arrow.Flight.Sql\Apache.Arrow.Flight.Sql.csproj", "{2ADE087A-B424-4895-8CC5-10170D10BA62}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql.TestWeb", "Apache.Arrow.Flight.Sql.TestWeb\Apache.Arrow.Flight.Sql.TestWeb.csproj", "{85A6CB32-A83B-48C4-96E8-625C8FBDB349}" -EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -81,10 +79,6 @@ Global {2ADE087A-B424-4895-8CC5-10170D10BA62}.Debug|Any CPU.Build.0 = Debug|Any CPU {2ADE087A-B424-4895-8CC5-10170D10BA62}.Release|Any CPU.ActiveCfg = Release|Any CPU {2ADE087A-B424-4895-8CC5-10170D10BA62}.Release|Any CPU.Build.0 = Release|Any CPU - {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Debug|Any CPU.Build.0 = Debug|Any CPU - {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Release|Any CPU.ActiveCfg = Release|Any CPU - {85A6CB32-A83B-48C4-96E8-625C8FBDB349}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj index 3a978f9b411d8..9319f7629fac5 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj @@ -13,8 +13,8 @@ - + diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 0f99caf30e58c..0deb4b0676f59 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -4,7 +4,8 @@ using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.Sql.Client; -using Apache.Arrow.Flight.Sql.TestWeb; +using Apache.Arrow.Flight.Tests; +using Apache.Arrow.Flight.TestWeb; using Apache.Arrow.Types; using Arrow.Flight.Protocol.Sql; using Google.Protobuf; @@ -15,15 +16,15 @@ namespace Apache.Arrow.Flight.Sql.Tests; public class FlightSqlClientTests : IDisposable { - readonly TestSqlWebFactory _testWebFactory; - readonly FlightSqlStore _flightStore; + readonly TestFlightSqlWebFactory _testWebFactory; + readonly FlightStore _flightStore; private readonly FlightSqlClient _flightSqlClient; private readonly FlightSqlTestUtils _testUtils; public FlightSqlClientTests() { - _flightStore = new FlightSqlStore(); - _testWebFactory = new TestSqlWebFactory(_flightStore); + _flightStore = new FlightStore(); + _testWebFactory = new TestFlightSqlWebFactory(_flightStore); FlightClient flightClient = new(_testWebFactory.GetChannel()); _flightSqlClient = new FlightSqlClient(flightClient); @@ -94,7 +95,7 @@ public async Task PreparedStatementAsync() var transaction = new Transaction("sample-transaction-id"); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -118,7 +119,7 @@ public async Task ExecuteUpdateAsync() var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -139,7 +140,7 @@ public async Task ExecuteAsync() var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -159,7 +160,7 @@ public async Task GetFlightInfoAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act @@ -177,7 +178,7 @@ public async Task GetExecuteSchemaAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -198,7 +199,7 @@ public async Task GetCatalogsAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -218,7 +219,7 @@ public async Task GetSchemaAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -238,7 +239,7 @@ public async Task GetDbSchemasAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); string catalog = "test-catalog"; @@ -285,7 +286,7 @@ public async Task GetPrimaryKeysAsync() var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var tableRef = new TableRef { Catalog = "test-catalog", Table = "test-table", DbSchema = "test-schema" }; - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -328,7 +329,7 @@ public async Task GetTablesAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -371,7 +372,7 @@ public async Task GetCatalogsSchemaAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -398,7 +399,7 @@ public async Task GetDbSchemasSchemaAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -448,7 +449,7 @@ public async Task DoPutAsync() ], 5); Assert.NotNull(recordBatch); Assert.Equal(5, recordBatch.Length); - var flightHolder = new FlightSqlHolder(flightDescriptor, schema, _testWebFactory.GetAddress()); + var flightHolder = new FlightHolder(flightDescriptor, schema, _testWebFactory.GetAddress()); flightHolder.AddBatch(new RecordBatchWithMetadata(_testUtils.CreateTestBatch(0, 100))); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -469,7 +470,7 @@ public async Task GetExportedKeysAsync() var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var tableRef = new TableRef { Catalog = "test-catalog", Table = "test-table", DbSchema = "test-schema" }; - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -488,7 +489,7 @@ public async Task GetExportedKeysSchemaAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -508,7 +509,7 @@ public async Task GetImportedKeysAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -535,7 +536,7 @@ public async Task GetImportedKeysSchemaAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -556,7 +557,7 @@ public async Task GetCrossReferenceAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); var pkTableRef = new TableRef { Catalog = "PKCatalog", DbSchema = "PKSchema", Table = "PKTable" }; @@ -578,7 +579,7 @@ public async Task GetCrossReferenceSchemaAsync() var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightSqlHolder(flightDescriptor, recordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); @@ -603,7 +604,7 @@ public async Task GetTableTypesAsync() var commandGetTableTypes = new CommandGetTableTypes(); byte[] packedCommand = commandGetTableTypes.PackAndSerialize().ToByteArray(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); - var flightHolder = new FlightSqlHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act @@ -628,7 +629,7 @@ public async Task GetTableTypesSchemaAsync() byte[] packedCommand = commandGetTableTypesSchema.PackAndSerialize().ToByteArray(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); - var flightHolder = new FlightSqlHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act @@ -657,7 +658,7 @@ public async Task GetXdbcTypeInfoAsync() var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); // Creating a flight holder with the expected schema and adding it to the flight store - var flightHolder = new FlightSqlHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act @@ -685,7 +686,7 @@ public async Task GetXdbcTypeInfoSchemaAsync() byte[] packedCommand = commandGetXdbcTypeInfo.PackAndSerialize().ToByteArray(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor(packedCommand); - var flightHolder = new FlightSqlHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); + var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, "http://localhost:5000"); _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act @@ -706,7 +707,7 @@ public async Task GetSqlInfoSchemaAsync() .Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) .Build(); - var flightHolder = new FlightSqlHolder(flightDescriptor, expectedSchema, _testWebFactory.GetAddress()); + var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act @@ -751,7 +752,7 @@ public async Task CancelQueryAsync() // Adding the flight info to the flight store for testing _flightStore.Flights.Add(flightDescriptor, - new FlightSqlHolder(flightDescriptor, schema, _testWebFactory.GetAddress())); + new FlightHolder(flightDescriptor, schema, _testWebFactory.GetAddress())); // Act var cancelStatus = await _flightSqlClient.CancelQueryAsync(options, flightInfo); diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs index 0f35482d2a803..444c80f0919ff 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs @@ -2,7 +2,8 @@ using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.Sql.Client; -using Apache.Arrow.Flight.Sql.TestWeb; +using Apache.Arrow.Flight.Tests; +using Apache.Arrow.Flight.TestWeb; using Apache.Arrow.Types; using Xunit; @@ -10,8 +11,8 @@ namespace Apache.Arrow.Flight.Sql.Tests; public class FlightSqlPreparedStatementTests { - readonly TestSqlWebFactory _testWebFactory; - readonly FlightSqlStore _flightStore; + readonly TestWebFactory _testWebFactory; + readonly FlightStore _flightStore; private readonly PreparedStatement _preparedStatement; private readonly Schema _schema; private readonly RecordBatch _parameterBatch; @@ -19,8 +20,8 @@ public class FlightSqlPreparedStatementTests public FlightSqlPreparedStatementTests() { - _flightStore = new FlightSqlStore(); - _testWebFactory = new TestSqlWebFactory(_flightStore); + _flightStore = new FlightStore(); + _testWebFactory = new TestWebFactory(_flightStore); FlightClient flightClient = new(_testWebFactory.GetChannel()); FlightSqlClient flightSqlClient = new(flightClient); @@ -45,7 +46,7 @@ public FlightSqlPreparedStatementTests() new Int32Array.Builder().AppendRange(columnSizes).Build() ], 5); - var flightHolder = new FlightSqlHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); + var flightHolder = new FlightHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); _preparedStatement = new PreparedStatement(flightSqlClient, flightHolder.GetFlightInfo(), "SELECT * FROM test"); } @@ -74,7 +75,7 @@ public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows_WhenParametersAreS { // Arrange var options = new FlightCallOptions(); - var flightHolder = new FlightSqlHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); + var flightHolder = new FlightHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); flightHolder.AddBatch(new RecordBatchWithMetadata(_parameterBatch)); _flightStore.Flights.Add(_flightDescriptor, flightHolder); await _preparedStatement.SetParameters(_parameterBatch); @@ -83,7 +84,7 @@ public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows_WhenParametersAreS long affectedRows = await _preparedStatement.ExecuteUpdateAsync(options); // Assert - Assert.True(affectedRows > 0); // Verifies that the statement executed successfully. + Assert.True(affectedRows > 0); } [Fact] @@ -138,4 +139,4 @@ await Assert.ThrowsAsync( () => _preparedStatement.CloseAsync(options) ); } -} +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs index 25147349628f3..8094a0b491bd3 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs @@ -1,20 +1,21 @@ using System.Linq; -using Apache.Arrow.Flight.Sql.TestWeb; +using Apache.Arrow.Flight.Tests; +using Apache.Arrow.Flight.TestWeb; namespace Apache.Arrow.Flight.Sql.Tests; public class FlightSqlTestUtils { - private readonly TestSqlWebFactory _testWebFactory; - private readonly FlightSqlStore _flightStore; + private readonly TestFlightSqlWebFactory _testWebFactory; + private readonly FlightStore _flightStore; - public FlightSqlTestUtils(TestSqlWebFactory testWebFactory, FlightSqlStore flightStore) + public FlightSqlTestUtils(TestFlightSqlWebFactory testWebFactory, FlightStore flightStore) { _testWebFactory = testWebFactory; _flightStore = flightStore; } - public RecordBatch CreateTestBatch(int startValue, int length) + public RecordBatch CreateTestBatch(int startValue, int length) { var batchBuilder = new RecordBatch.Builder(); Int32Array.Builder builder = new(); @@ -33,7 +34,7 @@ public FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, { var initialBatch = batches.FirstOrDefault(); - var flightHolder = new FlightSqlHolder(flightDescriptor, initialBatch.RecordBatch.Schema, + var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress()); foreach (var batch in batches) diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs new file mode 100644 index 0000000000000..fedb7de114498 --- /dev/null +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs @@ -0,0 +1,42 @@ +using Apache.Arrow.Flight.TestWeb; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace Apache.Arrow.Flight.Sql.Tests; + +public class StartupFlightSql +{ + // This method gets called by the runtime. Use this method to add services to the container. + // For more information on how to configure your application, visit https://go.microsoft.com/fwlink/?LinkID=398940 + public void ConfigureServices(IServiceCollection services) + { + services.AddGrpc() + .AddFlightServer(); + + services.AddSingleton(new FlightStore()); + } + + // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. + public void Configure(IApplicationBuilder app, IWebHostEnvironment env) + { + if (env.IsDevelopment()) + { + app.UseDeveloperExceptionPage(); + } + + app.UseRouting(); + + app.UseEndpoints(endpoints => + { + endpoints.MapFlightEndpoint(); + + endpoints.MapGet("/", async context => + { + await context.Response.WriteAsync("Communication with gRPC endpoints must be made through a gRPC client. To learn how to create a client, visit: https://go.microsoft.com/fwlink/?linkid=2086909"); + }); + }); + } +} \ No newline at end of file diff --git a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestSqlWebFactory.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs similarity index 81% rename from csharp/Apache.Arrow.Flight.Sql.TestWeb/TestSqlWebFactory.cs rename to csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs index 5d98a197d65e4..714d801f59fde 100644 --- a/csharp/Apache.Arrow.Flight.Sql.TestWeb/TestSqlWebFactory.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using Apache.Arrow.Flight.TestWeb; using Grpc.Net.Client; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting.Server; @@ -8,16 +9,16 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; -namespace Apache.Arrow.Flight.Sql.TestWeb; +namespace Apache.Arrow.Flight.Sql.Tests; -public class TestSqlWebFactory : IDisposable +public class TestFlightSqlWebFactory : IDisposable { readonly IHost host; private int _port; - public TestSqlWebFactory(FlightSqlStore flightStore) + public TestFlightSqlWebFactory(FlightStore flightStore) { - host = WebHostBuilder(flightStore).Build(); //Create the server + host = WebHostBuilder(flightStore).Build(); host.Start(); var addressInfo = host.Services.GetRequiredService().Features.Get(); if (addressInfo == null) @@ -32,14 +33,14 @@ public TestSqlWebFactory(FlightSqlStore flightStore) "System.Net.Http.SocketsHttpHandler.Http2UnencryptedSupport", true); } - private IHostBuilder WebHostBuilder(FlightSqlStore flightStore) + private IHostBuilder WebHostBuilder(FlightStore flightStore) { return Host.CreateDefaultBuilder() .ConfigureWebHostDefaults(webBuilder => { webBuilder .ConfigureKestrel(c => { c.ListenAnyIP(0, l => l.Protocols = HttpProtocols.Http2); }) - .UseStartup() + .UseStartup() .ConfigureServices(services => { services.AddSingleton(flightStore); }); }); } @@ -63,4 +64,4 @@ public void Dispose() { Stop(); } -} +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs index 14a4d491c13a6..43308ac4f8edd 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs @@ -55,10 +55,25 @@ public FlightInfo GetFlightInfo() int batchBytes = _recordBatches.Sum(rb => rb.RecordBatch.Arrays.Sum(arr => arr.Data.Buffers.Sum(b=>b.Length))); return new FlightInfo(_schema, _flightDescriptor, new List() { - new FlightEndpoint(new FlightTicket(_flightDescriptor.Paths.FirstOrDefault()), new List(){ + new FlightEndpoint(new FlightTicket(GetTicket(_flightDescriptor)), new List(){ new FlightLocation(_location) }) }, batchArrayLength, batchBytes); } + + private string GetTicket(FlightDescriptor descriptor) + { + if (descriptor.Paths.FirstOrDefault() != null) + { + return descriptor.Paths.FirstOrDefault(); + } + + if (descriptor.Command.Length > 0) + { + return $"{descriptor.Command.ToStringUtf8()}"; + } + + return "default_custom_ticket"; + } } } \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs index 722b49d9063a9..149fb92f9916b 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs @@ -40,15 +40,6 @@ public override async Task DoAction(FlightAction request, IAsyncStreamWriter responseStream, + ServerCallContext context) + { + switch (request.Type) + { + case "test": + await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false); + break; + case SqlAction.GetPrimaryKeysRequest: + await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false); + break; + case SqlAction.CancelFlightInfoRequest: + var cancelRequest = new FlightInfoCancelResult(); + cancelRequest.SetStatus(1); + await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray())) + .ConfigureAwait(false); + break; + case SqlAction.BeginTransactionRequest: + case SqlAction.CommitRequest: + case SqlAction.RollbackRequest: + await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id"))) + .ConfigureAwait(false); + break; + case SqlAction.CreateRequest: + case SqlAction.CloseRequest: + var prepareStatementResponse = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement") + }; + byte[] packedResult = Any.Pack(prepareStatementResponse).Serialize().ToByteArray(); + var flightResult = new FlightResult(packedResult); + await responseStream.WriteAsync(flightResult).ConfigureAwait(false); + break; + default: + throw new NotImplementedException(); + } + } + + public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, + ServerCallContext context) + { + FlightDescriptor flightDescriptor = FlightDescriptor.CreateCommandDescriptor(ticket.Ticket.ToStringUtf8()); + + if (_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) + { + var batches = flightHolder.GetRecordBatches(); + + foreach (var batch in batches) + { + await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata); + } + } + } + + public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, + IAsyncStreamWriter responseStream, ServerCallContext context) + { + var flightDescriptor = await requestStream.FlightDescriptor; + + if (!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) + { + flightHolder = new FlightHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + } + + while (await requestStream.MoveNext()) + { + flightHolder.AddBatch(new RecordBatchWithMetadata(requestStream.Current, + requestStream.ApplicationMetadata.FirstOrDefault())); + await responseStream.WriteAsync(FlightPutResult.Empty); + } + } + + public override Task GetFlightInfo(FlightDescriptor request, ServerCallContext context) + { + if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) + { + return Task.FromResult(flightHolder.GetFlightInfo()); + } + + if (_flightStore.Flights.Count > 0) + { + return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo()); + } + + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); + } + + public override Task GetSchema(FlightDescriptor request, ServerCallContext context) + { + if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) + { + return Task.FromResult(flightHolder.GetFlightInfo().Schema); + } + + if (_flightStore.Flights.Count > 0) + { + return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo().Schema); + } + + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); + } +} + + +/* + * + * + + using System; + using System.Collections.Generic; + using System.Linq; + using System.Threading.Tasks; + using Apache.Arrow.Flight.Server; + using Apache.Arrow.Flight.Sql; + using Arrow.Flight.Protocol.Sql; + using Google.Protobuf; + using Google.Protobuf.WellKnownTypes; + using Grpc.Core; + using Grpc.Core.Utils; + + namespace Apache.Arrow.Flight.TestWeb + { + public class TestFlightServer : FlightServer + { + private readonly FlightStore _flightStore; + + public TestFlightServer(FlightStore flightStore) + { + _flightStore = flightStore; + } + + public override async Task DoAction(FlightAction request, IAsyncStreamWriter responseStream, + ServerCallContext context) + { + switch (request.Type) + { + case "test": + await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false); + break; + case SqlAction.GetPrimaryKeysRequest: + await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false); + break; + case SqlAction.CancelFlightInfoRequest: + var cancelRequest = new FlightInfoCancelResult(); + cancelRequest.SetStatus(1); + await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray())) + .ConfigureAwait(false); + break; + case SqlAction.BeginTransactionRequest: + case SqlAction.CommitRequest: + case SqlAction.RollbackRequest: + await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id"))) + .ConfigureAwait(false); + break; + case SqlAction.CreateRequest: + case SqlAction.CloseRequest: + var prepareStatementResponse = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement") + }; + byte[] packedResult = Any.Pack(prepareStatementResponse).Serialize().ToByteArray(); + var flightResult = new FlightResult(packedResult); + await responseStream.WriteAsync(flightResult).ConfigureAwait(false); + break; + default: + throw new NotImplementedException(); + } + } + + public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, + ServerCallContext context) + { + FlightDescriptor flightDescriptor = null; + flightDescriptor = flightDescriptor is not null && flightDescriptor.Paths.Any() + ? FlightDescriptor.CreatePathDescriptor(ticket.Ticket.ToStringUtf8()) + : FlightDescriptor.CreateCommandDescriptor(ticket.Ticket.ToStringUtf8()); + + if (_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) + { + var batches = flightHolder.GetRecordBatches(); + + + foreach (var batch in batches) + { + await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata); + } + } + } + + public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, + IAsyncStreamWriter responseStream, ServerCallContext context) + { + var flightDescriptor = await requestStream.FlightDescriptor; + + if (!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) + { + flightHolder = new FlightHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}"); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + } + + while (await requestStream.MoveNext()) + { + flightHolder.AddBatch(new RecordBatchWithMetadata(requestStream.Current, + requestStream.ApplicationMetadata.FirstOrDefault())); + await responseStream.WriteAsync(FlightPutResult.Empty); + } + } + + public override Task GetFlightInfo(FlightDescriptor request, ServerCallContext context) + { + if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) + { + return Task.FromResult(flightHolder.GetFlightInfo()); + } + + if (_flightStore.Flights.Count > 0) + { + return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo()); + } + + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); + } + + public override async Task Handshake(IAsyncStreamReader requestStream, + IAsyncStreamWriter responseStream, ServerCallContext context) + { + while (await requestStream.MoveNext().ConfigureAwait(false)) + { + if (requestStream.Current.Payload.ToStringUtf8() == "Hello") + { + await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Hello handshake"))) + .ConfigureAwait(false); + } + else + { + await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Done"))).ConfigureAwait(false); + } + } + } + + public override Task GetSchema(FlightDescriptor request, ServerCallContext context) + { + if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) + { + return Task.FromResult(flightHolder.GetFlightInfo().Schema); + } + + if (_flightStore.Flights.Count > 0) + { + return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo().Schema); + } + + throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); + } + + public override async Task ListActions(IAsyncStreamWriter responseStream, + ServerCallContext context) + { + await responseStream.WriteAsync(new FlightActionType("get", "get a flight")).ConfigureAwait(false); + await responseStream.WriteAsync(new FlightActionType("put", "add a flight")).ConfigureAwait(false); + await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight")).ConfigureAwait(false); + await responseStream.WriteAsync(new FlightActionType("test", "test action")).ConfigureAwait(false); + } + + public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter responseStream, + ServerCallContext context) + { + var flightInfos = _flightStore.Flights.Select(x => x.Value.GetFlightInfo()).ToList(); + + foreach (var flightInfo in flightInfos) + { + await responseStream.WriteAsync(flightInfo).ConfigureAwait(false); + } + } + + public override async Task DoExchange(FlightServerRecordBatchStreamReader requestStream, + FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) + { + while (await requestStream.MoveNext().ConfigureAwait(false)) + { + await responseStream + .WriteAsync(requestStream.Current, requestStream.ApplicationMetadata.FirstOrDefault()) + .ConfigureAwait(false); + } + } + } + } + * + */ \ No newline at end of file From 6c229bce1c02f7e74db14d060edb370214d44171 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 13 Oct 2024 19:51:15 +0300 Subject: [PATCH 36/58] feat: PreparedStatement desc: Implemented testing and PreparedStatement functionality based on C++ implementation --- .../Apache.Arrow.Flight.Sql.csproj | 1 + .../Client/FlightSqlClient.cs | 21 +- .../Apache.Arrow.Flight.Sql/DoPutResult.cs | 23 + .../PreparedStatement.cs | 668 ++++++++++++++++-- .../FlightSqlPreparedStatementTests.cs | 417 +++++++++-- 5 files changed, 1029 insertions(+), 101 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj b/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj index ec438fde843f4..bec7dbdf54f7f 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj +++ b/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj @@ -6,6 +6,7 @@ + diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 62ac3e5a48a7e..07c4a20a54d07 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -16,7 +16,7 @@ public class FlightSqlClient public FlightSqlClient(FlightClient client) { - _client = client ?? throw new ArgumentNullException(nameof(client)); + _client = client; } /// @@ -298,7 +298,7 @@ public async Task GetCatalogsSchemaAsync(FlightCallOptions options) /// RPC-layer hints for this call. /// The descriptor of the dataset request, whether a named dataset or a command. /// A task that represents the asynchronous operation. The task result contains the SchemaResult describing the dataset schema. - public async Task GetSchemaAsync(FlightCallOptions options, FlightDescriptor descriptor) + public virtual async Task GetSchemaAsync(FlightCallOptions options, FlightDescriptor descriptor) { if (descriptor is null) { @@ -466,53 +466,48 @@ public List BuildArrowArraysFromSchema(Schema schema, int rowCount) var intArrayBuilder = new Int32Array.Builder(); for (int i = 0; i < rowCount; i++) { - intArrayBuilder.Append(i); // Just filling with sample data + intArrayBuilder.Append(i); } arrays.Add(intArrayBuilder.Build()); break; case StringType: - // Create a String array var stringArrayBuilder = new StringArray.Builder(); for (int i = 0; i < rowCount; i++) { - stringArrayBuilder.Append($"Value-{i}"); // Sample string values + stringArrayBuilder.Append($"Value-{i}"); } arrays.Add(stringArrayBuilder.Build()); break; case Int64Type: - // Create an Int64 array var longArrayBuilder = new Int64Array.Builder(); for (int i = 0; i < rowCount; i++) { - longArrayBuilder.Append((long)i * 100); // Sample data + longArrayBuilder.Append((long)i * 100); } arrays.Add(longArrayBuilder.Build()); break; case FloatType: - // Create a Float array var floatArrayBuilder = new FloatArray.Builder(); for (int i = 0; i < rowCount; i++) { - floatArrayBuilder.Append((float)(i * 1.1)); // Sample data + floatArrayBuilder.Append((float)(i * 1.1)); } arrays.Add(floatArrayBuilder.Build()); break; case BooleanType: - // Create a Boolean array var boolArrayBuilder = new BooleanArray.Builder(); for (int i = 0; i < rowCount; i++) { - boolArrayBuilder.Append(i % 2 == 0); // Alternate between true and false + boolArrayBuilder.Append(i % 2 == 0); } - arrays.Add(boolArrayBuilder.Build()); break; @@ -1170,7 +1165,7 @@ public async Task PrepareAsync(FlightCallOptions options, str byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); - return new PreparedStatement(this, flightInfo, query); + return new PreparedStatement(this, transaction.TransactionId.ToStringUtf8(), flightInfo.Schema, flightInfo.Schema); } throw new NullReferenceException($"{nameof(PreparedStatement)} was not able to be created"); diff --git a/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs b/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs index 48dcf78328416..9c5bdc0c099a9 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs @@ -1,3 +1,4 @@ +using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Grpc.Core; @@ -13,4 +14,26 @@ public DoPutResult(FlightClientRecordBatchStreamWriter writer, IAsyncStreamReade Writer = writer; Reader = reader; } + + /// + /// Reads the metadata asynchronously from the reader. + /// + /// A ByteString containing the metadata read from the reader. + public async Task ReadMetadataAsync() + { + if (await Reader.MoveNext().ConfigureAwait(false)) + { + return Reader.Current.ApplicationMetadata; + } + throw new RpcException(new Status(StatusCode.Internal, "No metadata available in the response stream.")); + } + + /// + /// Completes the writer by signaling the end of the writing process. + /// + /// A Task representing the completion of the writer. + public async Task CompleteAsync() + { + await Writer.CompleteAsync().ConfigureAwait(false); + } } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index 982177d4fc06d..c1d245cb363ab 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -1,70 +1,92 @@ using System; +using System.Collections.Generic; +using System.IO; using System.Linq; +using System.Text; +using System.Threading; using System.Threading.Tasks; +using Apache.Arrow.Flight.Server; using Apache.Arrow.Flight.Sql.Client; +using Apache.Arrow.Ipc; using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; using Grpc.Core; +using System.Threading.Channels; namespace Apache.Arrow.Flight.Sql; public class PreparedStatement : IDisposable { private readonly FlightSqlClient _client; - private readonly FlightInfo _flightInfo; - private RecordBatch? _parameterBatch; - private readonly string _query; + private readonly string _handle; + private Schema _datasetSchema; + private Schema _parameterSchema; private bool _isClosed; + public bool IsClosed => _isClosed; + public string Handle => _handle; + private FlightServerRecordBatchStreamReader? _parameterReader; + public FlightServerRecordBatchStreamReader? ParameterReader => _parameterReader; - public PreparedStatement(FlightSqlClient client, FlightInfo flightInfo, string query) + public PreparedStatement(FlightSqlClient client, string handle, Schema datasetSchema, Schema parameterSchema) { _client = client ?? throw new ArgumentNullException(nameof(client)); - _flightInfo = flightInfo ?? throw new ArgumentNullException(nameof(flightInfo)); - _query = query ?? throw new ArgumentNullException(nameof(query)); + _handle = handle ?? throw new ArgumentNullException(nameof(handle)); + _datasetSchema = datasetSchema ?? throw new ArgumentNullException(nameof(datasetSchema)); + _parameterSchema = parameterSchema ?? throw new ArgumentNullException(nameof(parameterSchema)); _isClosed = false; } /// - /// Set parameters for the prepared statement + /// Retrieves the schema associated with the prepared statement asynchronously. /// - /// The batch of parameters to bind - public Task SetParameters(RecordBatch parameterBatch) + /// The FlightCallOptions for the operation. + /// A Task representing the asynchronous operation. The task result contains the SchemaResult object. + public async Task GetSchemaAsync(FlightCallOptions options) { EnsureStatementIsNotClosed(); - _parameterBatch = parameterBatch ?? throw new ArgumentNullException(nameof(parameterBatch)); - return Task.CompletedTask; - } - /// - /// Execute the prepared statement, returning the number of affected rows - /// - /// The FlightCallOptions for the execution - /// Task representing the asynchronous operation - public async Task ExecuteUpdateAsync(FlightCallOptions options) - { - EnsureStatementIsNotClosed(); - EnsureParametersAreSet(); try { - return await _client.ExecuteUpdateAsync(options, _query); + var command = new CommandPreparedStatementQuery + { + PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8) + }; + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); + var schema = await _client.GetSchemaAsync(options, descriptor).ConfigureAwait(false); + if (schema == null || !schema.FieldsList.Any()) + { + throw new InvalidOperationException("Schema is empty or invalid."); + } + return schema; } catch (RpcException ex) { - throw new InvalidOperationException("Failed to execute update query", ex); + throw new InvalidOperationException("Failed to retrieve the schema for the prepared statement", ex); } } + /// - /// Closes the prepared statement + /// Closes the prepared statement asynchronously. /// + /// The FlightCallOptions for the operation. + /// A Task representing the asynchronous operation. public async Task CloseAsync(FlightCallOptions options) { EnsureStatementIsNotClosed(); try { - var actionClose = new FlightAction(SqlAction.CloseRequest, _flightInfo.Descriptor.Command); - await foreach (var result in _client.DoActionAsync(options, actionClose).ConfigureAwait(false)) + var closeRequest = new ActionClosePreparedStatementRequest { + PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8) + }; + + var action = new FlightAction(SqlAction.CloseRequest, closeRequest.ToByteArray()); + await foreach (var result in _client.DoActionAsync(options, action).ConfigureAwait(false)) + { + // Just drain the results to complete the operation } + _isClosed = true; } catch (RpcException ex) @@ -74,22 +96,205 @@ public async Task CloseAsync(FlightCallOptions options) } /// - /// Helper method to execute the statement and get affected rows + /// Reads the result from an asynchronous stream of FlightData and populates the provided Protobuf message. /// - private async Task ExecuteAndGetAffectedRowsAsync(FlightCallOptions options, FlightInfo flightInfo) + /// The async enumerable stream of FlightData results. + /// The Protobuf message to populate from the results. + /// A task that represents the asynchronous read operation. + public async Task ReadResultAsync(IAsyncEnumerable results, IMessage message) { - long affectedRows = 0; - var doGetResult = _client.DoGetAsync(options, flightInfo.Endpoints.First().Ticket); - await foreach (var recordBatch in doGetResult.ConfigureAwait(false)) + if (results == null) throw new ArgumentNullException(nameof(results)); + if (message == null) throw new ArgumentNullException(nameof(message)); + + await foreach (var flightData in results.ConfigureAwait(false)) + { + // Ensure that the data received is valid and non-empty. + if (flightData.DataBody == null || flightData.DataBody.Length == 0) + throw new InvalidOperationException("Received empty or invalid FlightData."); + + try + { + // Merge the flight data's body into the provided message. + message.MergeFrom(message.PackAndSerialize()); + } + catch (InvalidProtocolBufferException ex) + { + throw new InvalidOperationException( + "Failed to parse the received FlightData into the specified message.", ex); + } + } + } + + public async Task ParseResponseAsync(FlightSqlClient client, IAsyncEnumerable results) + { + if (client == null) + throw new ArgumentNullException(nameof(client)); + + if (results == null) + throw new ArgumentNullException(nameof(results)); + + var preparedStatementResult = new ActionCreatePreparedStatementResult(); + await foreach (var flightData in results.ConfigureAwait(false)) + { + if (flightData.DataBody == null || flightData.DataBody.Length == 0) + { + continue; + } + + try + { + preparedStatementResult.MergeFrom(flightData.DataBody.ToByteArray()); + } + catch (InvalidProtocolBufferException ex) + { + throw new InvalidOperationException( + "Failed to parse FlightData into ActionCreatePreparedStatementResult.", ex); + } + } + + // If the response is empty or incomplete + if (preparedStatementResult.PreparedStatementHandle.Length == 0) + { + throw new InvalidOperationException("Received an empty or invalid PreparedStatementHandle."); + } + + // Parse dataset and parameter schemas from the response + Schema datasetSchema = null!; + Schema parameterSchema = null!; + + if (preparedStatementResult.DatasetSchema.Length > 0) + { + datasetSchema = SchemaExtensions.DeserializeSchema(preparedStatementResult.DatasetSchema.ToByteArray()); + } + + if (preparedStatementResult.ParameterSchema.Length > 0) { - affectedRows += recordBatch.Length; + parameterSchema = SchemaExtensions.DeserializeSchema(preparedStatementResult.ParameterSchema.ToByteArray()); } - return affectedRows; + // Create and return the PreparedStatement object + return new PreparedStatement(client, preparedStatementResult.PreparedStatementHandle.ToStringUtf8(), + datasetSchema, parameterSchema); } + public Status SetParameters(RecordBatch parameterBatch, CancellationToken cancellationToken = default) + { + EnsureStatementIsNotClosed(); + + if (parameterBatch == null) + throw new ArgumentNullException(nameof(parameterBatch)); + + var channel = Channel.CreateUnbounded(); + var task = Task.Run(async () => + { + try + { + using (var memoryStream = new MemoryStream()) + { + var writer = new ArrowStreamWriter(memoryStream, parameterBatch.Schema); + + cancellationToken.ThrowIfCancellationRequested(); + await writer.WriteRecordBatchAsync(parameterBatch, cancellationToken).ConfigureAwait(false); + await writer.WriteEndAsync(cancellationToken).ConfigureAwait(false); + + memoryStream.Position = 0; + + cancellationToken.ThrowIfCancellationRequested(); + + var flightData = new FlightData( + FlightDescriptor.CreateCommandDescriptor(_handle), + ByteString.CopyFrom(memoryStream.ToArray()), + ByteString.Empty, + ByteString.Empty + ); + await channel.Writer.WriteAsync(flightData, cancellationToken).ConfigureAwait(false); + } + + channel.Writer.Complete(); + } + catch (OperationCanceledException) + { + channel.Writer.TryComplete(new OperationCanceledException("Task was canceled")); + } + catch (Exception ex) + { + channel.Writer.TryComplete(ex); + } + }, cancellationToken); + + _parameterReader = new FlightServerRecordBatchStreamReader(new ChannelReaderStreamAdapter(channel.Reader)); + if (task.IsCanceled || cancellationToken.IsCancellationRequested) + { + return Status.DefaultCancelled; + } + + return Status.DefaultSuccess; + } + + + /// - /// Helper method to ensure the statement is not closed. + /// Executes the prepared statement asynchronously and retrieves the query results as . + /// + /// The for the operation, which may include timeouts, headers, and other options for the call. + /// Optional containing parameters to bind before executing the statement. + /// Optional to observe while waiting for the task to complete. The task will be canceled if the token is canceled. + /// A representing the asynchronous operation. The task result contains the describing the executed query results. + /// Thrown if the prepared statement is closed or if there is an error during execution. + /// Thrown if the operation is canceled by the . + public async Task ExecuteAsync(FlightCallOptions options, RecordBatch parameterBatch, CancellationToken cancellationToken = default) + { + EnsureStatementIsNotClosed(); + + var descriptor = FlightDescriptor.CreateCommandDescriptor(_handle); + cancellationToken.ThrowIfCancellationRequested(); + + if (parameterBatch != null) + { + var boundParametersAsync = await BindParametersAsync(options, descriptor, parameterBatch, cancellationToken).ConfigureAwait(false); + } + cancellationToken.ThrowIfCancellationRequested(); + return await _client.GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + } + + + /// + /// Binds parameters to the prepared statement by streaming the given RecordBatch to the server asynchronously. + /// + /// The for the operation, which may include timeouts, headers, and other options for the call. + /// The that identifies the statement or command being executed. + /// The containing the parameters to bind to the prepared statement. + /// Optional to observe while waiting for the task to complete. The task will be canceled if the token is canceled. + /// A that represents the asynchronous operation. The task result contains the metadata from the server after binding the parameters. + /// Thrown when is null. + /// Thrown if the operation is canceled or if there is an error during the DoPut operation. + public async Task BindParametersAsync(FlightCallOptions options, FlightDescriptor descriptor, RecordBatch parameterBatch, CancellationToken cancellationToken = default) + { + if (parameterBatch == null) + { + throw new ArgumentNullException(nameof(parameterBatch), "Parameter batch cannot be null."); + } + + var putResult = await _client.DoPutAsync(options, descriptor, parameterBatch.Schema).ConfigureAwait(false); + + try + { + var metadata = await putResult.ReadMetadataAsync().ConfigureAwait(false); + await putResult.CompleteAsync().ConfigureAwait(false); + return metadata; + } + catch (OperationCanceledException) + { + throw new InvalidOperationException("Parameter binding was canceled."); + } + catch (Exception ex) + { + throw new InvalidOperationException("Failed to bind parameters to the prepared statement.", ex); + } + } + + /// + /// Ensures that the statement is not already closed. /// private void EnsureStatementIsNotClosed() { @@ -97,21 +302,404 @@ private void EnsureStatementIsNotClosed() throw new InvalidOperationException("Cannot execute a closed statement."); } - private void EnsureParametersAreSet() + /// + /// Protected implementation of the dispose pattern. + /// + /// True if called from Dispose, false if called from the finalizer. + protected virtual void Dispose(bool disposing) { - if (_parameterBatch == null || _parameterBatch.Length == 0) + if (_isClosed) return; + + if (disposing) { - throw new InvalidOperationException("Prepared statement parameters have not been set."); + // Close the statement if it's not already closed. + CloseAsync(new FlightCallOptions()).GetAwaiter().GetResult(); } + + _isClosed = true; } public void Dispose() { - _parameterBatch?.Dispose(); + Dispose(true); + GC.SuppressFinalize(this); + } +} - if (!_isClosed) +public static class SchemaExtensions +{ + /// + /// Deserializes a schema from a byte array. + /// + /// The byte array representing the serialized schema. + /// The deserialized Schema object. + public static Schema DeserializeSchema(byte[] serializedSchema) + { + if (serializedSchema == null || serializedSchema.Length == 0) { - _isClosed = true; + throw new ArgumentException("Invalid serialized schema"); + } + + using var stream = new MemoryStream(serializedSchema); + var reader = new ArrowStreamReader(stream); + return reader.Schema; + } +} + +// +// public class PreparedStatement : IDisposable +// { +// private readonly FlightSqlClient _client; +// private readonly FlightInfo _flightInfo; +// private readonly string _query; +// private bool _isClosed; +// private readonly FlightDescriptor _descriptor; +// private RecordBatch? _parameterBatch; +// +// public PreparedStatement(FlightSqlClient client, FlightInfo flightInfo, string query) +// { +// _client = client ?? throw new ArgumentNullException(nameof(client)); +// _flightInfo = flightInfo ?? throw new ArgumentNullException(nameof(flightInfo)); +// _query = query ?? throw new ArgumentNullException(nameof(query)); +// _descriptor = flightInfo.Descriptor ?? throw new ArgumentNullException(nameof(flightInfo.Descriptor)); +// _isClosed = false; +// } +// +// /// +// /// Set parameters for the prepared statement +// /// +// /// The batch of parameters to bind +// public Task SetParameters(RecordBatch parameterBatch) +// { +// EnsureStatementIsNotClosed(); +// if (parameterBatch == null) +// { +// throw new ArgumentNullException(nameof(parameterBatch)); +// } +// +// _parameterBatch = parameterBatch ?? throw new ArgumentNullException(nameof(parameterBatch)); +// return Task.CompletedTask; +// } +// +// /// +// /// Execute the prepared statement, returning the number of affected rows +// /// +// /// The FlightCallOptions for the execution +// /// Task representing the asynchronous operation +// public async Task ExecuteUpdateAsync(FlightCallOptions options) +// { +// EnsureStatementIsNotClosed(); +// EnsureParametersAreSet(); +// try +// { +// var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = _query }; +// var command = new CommandPreparedStatementQuery +// { +// PreparedStatementHandle = prepareStatementRequest.ToByteString() +// }; +// var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); +// var metadata = await BindParametersAsync(options, descriptor).ConfigureAwait(false); +// await _client.ExecuteUpdateAsync(options, _query); +// +// return ParseUpdateResult(metadata); +// } +// catch (RpcException ex) +// { +// throw new InvalidOperationException("Failed to execute update query", ex); +// } +// } +// +// /// +// /// Binds parameters to the server using DoPut and retrieves metadata. +// /// +// /// The FlightCallOptions for the execution. +// /// The FlightDescriptor for the command. +// /// A ByteString containing metadata from the server response. +// public async Task BindParametersAsync(FlightCallOptions options, FlightDescriptor descriptor) +// { +// if (_parameterBatch == null) +// throw new InvalidOperationException("Parameters have not been set."); +// +// // Start the DoPut operation +// var doPutResult = await _client.DoPutAsync(options, descriptor, _parameterBatch.Schema); +// var writer = doPutResult.Writer; +// +// // Write the record batch to the stream +// await writer.WriteAsync(_parameterBatch).ConfigureAwait(false); +// await writer.CompleteAsync().ConfigureAwait(false); +// +// // Read metadata from response +// var metadata = await doPutResult.ReadMetadataAsync().ConfigureAwait(false); +// +// // Close the writer and reader streams +// await writer.CompleteAsync().ConfigureAwait(false); +// await doPutResult.CompleteAsync().ConfigureAwait(false); +// return metadata; +// } +// +// /// +// /// Closes the prepared statement +// /// +// public async Task CloseAsync(FlightCallOptions options) +// { +// EnsureStatementIsNotClosed(); +// try +// { +// var actionClose = new FlightAction(SqlAction.CloseRequest, _flightInfo.Descriptor.Command); +// await foreach (var result in _client.DoActionAsync(options, actionClose).ConfigureAwait(false)) +// { +// } +// +// _isClosed = true; +// } +// catch (RpcException ex) +// { +// throw new InvalidOperationException("Failed to close the prepared statement", ex); +// } +// } +// +// /// +// /// Parses the metadata returned from the server to get the number of affected rows. +// /// +// private long ParseUpdateResult(ByteString metadata) +// { +// var updateResult = new DoPutUpdateResult(); +// updateResult.MergeFrom(metadata); +// return updateResult.RecordCount; +// } +// +// /// +// /// Helper method to ensure the statement is not closed. +// /// +// private void EnsureStatementIsNotClosed() +// { +// if (_isClosed) +// throw new InvalidOperationException("Cannot execute a closed statement."); +// } +// +// private void EnsureParametersAreSet() +// { +// if (_parameterBatch == null || _parameterBatch.Length == 0) +// { +// throw new InvalidOperationException("Prepared statement parameters have not been set."); +// } +// } +// +// public void Dispose() +// { +// _parameterBatch?.Dispose(); +// +// if (!_isClosed) +// { +// _isClosed = true; +// } +// } +// } +// +public static class RecordBatchExtensions +{ + /// + /// Converts a RecordBatch into an asynchronous stream of FlightData. + /// + /// The RecordBatch to convert. + /// The FlightDescriptor describing the Flight data. + /// An asynchronous stream of FlightData objects. + public static async IAsyncEnumerable ToFlightDataStreamAsync(this RecordBatch recordBatch, + FlightDescriptor flightDescriptor) + { + if (recordBatch == null) + { + throw new ArgumentNullException(nameof(recordBatch)); + } + + // Use a memory stream to write the Arrow RecordBatch into FlightData format + using var memoryStream = new MemoryStream(); + var writer = new ArrowStreamWriter(memoryStream, recordBatch.Schema); + + // Write the RecordBatch to the stream + await writer.WriteRecordBatchAsync(recordBatch).ConfigureAwait(false); + await writer.WriteEndAsync().ConfigureAwait(false); + + // Reset the memory stream position + memoryStream.Position = 0; + + // Read back the data to create FlightData + var flightData = new FlightData(flightDescriptor, ByteString.CopyFrom(memoryStream.ToArray()), + ByteString.CopyFrom(memoryStream.ToArray())); + yield return flightData; + } + + /// + /// Converts a RecordBatch into an IAsyncStreamReader. + /// + /// The RecordBatch to convert. + /// The FlightDescriptor describing the Flight data. + /// An IAsyncStreamReader of FlightData. + public static IAsyncStreamReader ToFlightDataStream(this RecordBatch recordBatch, FlightDescriptor flightDescriptor) + { + if (recordBatch == null) throw new ArgumentNullException(nameof(recordBatch)); + if (flightDescriptor == null) throw new ArgumentNullException(nameof(flightDescriptor)); + + var channel = Channel.CreateUnbounded(); + + try + { + if (recordBatch.Schema == null || !recordBatch.Schema.FieldsList.Any()) + { + throw new InvalidOperationException("The record batch has an invalid or empty schema."); + } + + using var memoryStream = new MemoryStream(); + using var writer = new ArrowStreamWriter(memoryStream, recordBatch.Schema); + writer.WriteRecordBatch(recordBatch); + writer.WriteEnd(); + memoryStream.Position = 0; + var flightData = new FlightData(flightDescriptor, ByteString.CopyFrom(memoryStream.ToArray()), ByteString.Empty, ByteString.Empty); + if (flightData.DataBody.IsEmpty) + { + throw new InvalidOperationException( + "The generated FlightData is empty. Check the RecordBatch content."); + } + + channel.Writer.TryWrite(flightData); + } + finally + { + // Mark the channel as complete once done + channel.Writer.Complete(); + } + return new ChannelFlightDataReader(channel.Reader); + } + + /*public static IAsyncStreamReader ToFlightDataStream(this RecordBatch recordBatch, + FlightDescriptor flightDescriptor) + { + if (recordBatch == null) throw new ArgumentNullException(nameof(recordBatch)); + if (flightDescriptor == null) throw new ArgumentNullException(nameof(flightDescriptor)); + + // Create a channel to act as the data stream. + var channel = Channel.CreateUnbounded(); + + // Start a background task to generate the FlightData asynchronously. + _ = Task.Run(async () => + { + try + { + // Check if the schema is set and there are fields in the RecordBatch + if (recordBatch.Schema == null || !recordBatch.Schema.FieldsList.Any()) + { + throw new InvalidOperationException("The record batch has an invalid or empty schema."); + } + + // Use a memory stream to convert the RecordBatch to FlightData + await using var memoryStream = new MemoryStream(); + var writer = new ArrowStreamWriter(memoryStream, recordBatch.Schema); + + // Write the RecordBatch to the memory stream + await writer.WriteRecordBatchAsync(recordBatch).ConfigureAwait(false); + await writer.WriteEndAsync().ConfigureAwait(false); + + // Reset the memory stream position to read from it + memoryStream.Position = 0; + + // Read back the data from the stream and create FlightData + var flightData = new FlightData( + flightDescriptor, + ByteString.CopyFrom(memoryStream.ToArray()), // Use the data from memory stream + ByteString.Empty // Empty application metadata for now + ); + + // Check if flightData has valid data before writing + if (flightData.DataBody.IsEmpty) + { + throw new InvalidOperationException( + "The generated FlightData is empty. Check the RecordBatch content."); + } + + // Write the FlightData to the channel + await channel.Writer.WriteAsync(flightData).ConfigureAwait(false); + } + catch (Exception ex) + { + // Log any exceptions for debugging purposes + Console.WriteLine($"Error generating FlightData: {ex.Message}"); + } + finally + { + // Mark the channel as complete once done + channel.Writer.Complete(); + } + }); + + // Return a custom IAsyncStreamReader implementation. + return new ChannelFlightDataReader(channel.Reader); + }*/ + + + /// + /// Custom IAsyncStreamReader implementation to read from a ChannelReader. + /// + private class ChannelFlightDataReader : IAsyncStreamReader + { + private readonly ChannelReader _channelReader; + + public ChannelFlightDataReader(ChannelReader channelReader) + { + _channelReader = channelReader ?? throw new ArgumentNullException(nameof(channelReader)); + Current = default!; + } + + public FlightData Current { get; private set; } + + public async Task MoveNext(CancellationToken cancellationToken) + { + if (await _channelReader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + if (_channelReader.TryRead(out var flightData)) + { + Current = flightData; + return true; + } + } + + return false; + } + + public void Dispose() + { + // No additional cleanup is required here since we're not managing external resources. } } } + +public class ChannelReaderStreamAdapter : IAsyncStreamReader +{ + private readonly ChannelReader _channelReader; + + public ChannelReaderStreamAdapter(ChannelReader channelReader) + { + _channelReader = channelReader ?? throw new ArgumentNullException(nameof(channelReader)); + Current = default!; + } + + public T Current { get; private set; } + + public async Task MoveNext(CancellationToken cancellationToken) + { + if (await _channelReader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + if (_channelReader.TryRead(out var item)) + { + Current = item; + return true; + } + } + + return false; + } + + public void Dispose() + { + // No additional cleanup is required here since we are using a channel + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs index 444c80f0919ff..20d365e8478b7 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs @@ -1,29 +1,37 @@ using System; +using System.Collections.Generic; +using System.Text; +using System.Threading; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.Sql.Client; -using Apache.Arrow.Flight.Tests; using Apache.Arrow.Flight.TestWeb; +using Apache.Arrow.Tests; using Apache.Arrow.Types; +using Arrow.Flight.Protocol.Sql; +using Google.Protobuf; +using Grpc.Core; using Xunit; namespace Apache.Arrow.Flight.Sql.Tests; public class FlightSqlPreparedStatementTests { - readonly TestWebFactory _testWebFactory; + readonly TestFlightSqlWebFactory _testWebFactory; readonly FlightStore _flightStore; + readonly FlightSqlClient _flightSqlClient; private readonly PreparedStatement _preparedStatement; private readonly Schema _schema; - private readonly RecordBatch _parameterBatch; private readonly FlightDescriptor _flightDescriptor; + private readonly FlightHolder _flightHolder; + private readonly RecordBatch _parameterBatch; public FlightSqlPreparedStatementTests() { _flightStore = new FlightStore(); - _testWebFactory = new TestWebFactory(_flightStore); + _testWebFactory = new TestFlightSqlWebFactory(_flightStore); FlightClient flightClient = new(_testWebFactory.GetChannel()); - FlightSqlClient flightSqlClient = new(flightClient); + _flightSqlClient = new(flightClient); // Setup mock _flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test-query"); @@ -46,97 +54,410 @@ public FlightSqlPreparedStatementTests() new Int32Array.Builder().AppendRange(columnSizes).Build() ], 5); - var flightHolder = new FlightHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); - _preparedStatement = new PreparedStatement(flightSqlClient, flightHolder.GetFlightInfo(), "SELECT * FROM test"); + _flightHolder = new FlightHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); + _flightStore.Flights.Add(_flightDescriptor, _flightHolder); + _preparedStatement = new PreparedStatement(_flightSqlClient, handle: "test-handle-guid", datasetSchema: _schema, + parameterSchema: _schema); } - // PreparedStatement [Fact] - public async Task SetParameters_ShouldSetParameters_WhenStatementIsOpen() + public async Task GetSchemaAsync_ShouldThrowInvalidOperationException_WhenStatementIsClosed() { - await _preparedStatement.SetParameters(_parameterBatch); - Assert.NotNull(_parameterBatch); + // Arrange: + await _preparedStatement.CloseAsync(new FlightCallOptions()); + + // Act & Assert: Ensure that calling GetSchemaAsync on a closed statement throws an exception. + await Assert.ThrowsAsync(() => _preparedStatement.GetSchemaAsync(new FlightCallOptions())); } [Fact] - public async Task SetParameters_ShouldThrowException_WhenStatementIsClosed() + public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenValidInputsAreProvided() { // Arrange - await _preparedStatement.CloseAsync(new FlightCallOptions()); + var validSchema = new Schema.Builder() + .Field(f => f.Name("field1").DataType(Int32Type.Default)) + .Build(); + string handle = "TestHandle"; + var preparedStatement = new PreparedStatement(_flightSqlClient, handle, validSchema, validSchema); + var validRecordBatch = CreateRecordBatch(validSchema, [1, 2, 3]); - // Act & Assert - await Assert.ThrowsAsync( - () => _preparedStatement.SetParameters(_parameterBatch) - ); + // Act + var result = preparedStatement.SetParameters(validRecordBatch); + var flightInfo = await preparedStatement.ExecuteAsync(new FlightCallOptions(), validRecordBatch); + + // Assert + Assert.NotNull(flightInfo); + Assert.IsType(flightInfo); + Assert.Equal(Status.DefaultSuccess, result); } + + [Fact] + public async Task BindParametersAsync_ShouldReturnMetadata_WhenValidInputsAreProvided() + { + // Arrange + var validSchema = new Schema.Builder() + .Field(f => f.Name("field1").DataType(Int32Type.Default)) + .Build(); + + var validRecordBatch = CreateRecordBatch(validSchema, new[] { 1, 2, 3 }); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("TestCommand"); + + var preparedStatement = new PreparedStatement(_flightSqlClient, "TestHandle", validSchema, validSchema); + + // Act + var metadata = await preparedStatement.BindParametersAsync(new FlightCallOptions(), flightDescriptor, validRecordBatch); + + // Assert + Assert.NotNull(metadata); + // Check if metadata has valid content + // Some systems might return empty metadata, so we validate it's non-null and proceed accordingly + if (metadata.Length == 0) + { + // Optionally, check if the server returned empty metadata but still succeeded + Assert.Equal(0, metadata.Length); + } + else + { + // If metadata is present, validate its contents + Assert.True(metadata.Length > 0, "Metadata should have a length greater than 0 when valid."); + } + } + + [Theory] + [MemberData(nameof(GetTestData))] + public async Task TestSetParameters(RecordBatch parameterBatch, Schema parameterSchema, Type expectedException) + { + // Arrange + var validSchema = new Schema.Builder() + .Field(f => f.Name("field1").DataType(Int32Type.Default)) + .Build(); + string handle = "TestHandle"; + + var preparedStatement = new PreparedStatement(_flightSqlClient, handle, validSchema, parameterSchema); + + if (expectedException != null) + { + // Act and Assert (Expected to throw exception) + var exception = await Record.ExceptionAsync(() => Task.Run(() => preparedStatement.SetParameters(parameterBatch))); + Assert.NotNull(exception); + Assert.IsType(expectedException, exception); // Ensure correct exception type + } + else + { + // Act + var result = await Task.Run(() => preparedStatement.SetParameters(parameterBatch)); + + // Assert + Assert.NotNull(preparedStatement.ParameterReader); + Assert.Equal(Status.DefaultSuccess, result); // Ensure Status is success + } + } + [Fact] - public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows_WhenParametersAreSet() + public async Task TestSetParameters_Cancelled() + { + // Arrange + var validSchema = new Schema.Builder() + .Field(f => f.Name("field1").DataType(Int32Type.Default)) + .Build(); + + string handle = "TestHandle"; + + var preparedStatement = new PreparedStatement(_flightSqlClient, handle, validSchema, validSchema); + var validRecordBatch = CreateRecordBatch(validSchema, [1, 2, 3]); + + // Create a CancellationTokenSource + var cts = new CancellationTokenSource(); + + // Act: Simulate cancellation before setting parameters + await cts.CancelAsync(); + var result = preparedStatement.SetParameters(validRecordBatch, cts.Token); + + // Assert: Ensure the status is DefaultCancelled + Assert.Equal(Status.DefaultCancelled, result); + } + + [Fact] + public async Task TestCloseAsync() { // Arrange var options = new FlightCallOptions(); - var flightHolder = new FlightHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); - flightHolder.AddBatch(new RecordBatchWithMetadata(_parameterBatch)); - _flightStore.Flights.Add(_flightDescriptor, flightHolder); - await _preparedStatement.SetParameters(_parameterBatch); // Act - long affectedRows = await _preparedStatement.ExecuteUpdateAsync(options); + await _preparedStatement.CloseAsync(options); + + // Assert + Assert.True(_preparedStatement.IsClosed, + "PreparedStatement should be marked as closed after calling CloseAsync."); + } + + [Fact] + public async Task ReadResultAsync_ShouldPopulateMessage_WhenValidFlightData() + { + // Arrange + var message = new ActionCreatePreparedStatementResult(); + var flightData = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("test-data")); + + var results = GetAsyncEnumerable(new List { flightData }); + + // Act + await _preparedStatement.ReadResultAsync(results, message); // Assert - Assert.True(affectedRows > 0); + Assert.NotEmpty(message.PreparedStatementHandle.ToStringUtf8()); } [Fact] - public async Task ExecuteUpdateAsync_ShouldThrowException_WhenNoParametersSet() + public async Task ReadResultAsync_ShouldNotThrow_WhenFlightDataBodyIsNullOrEmpty() { // Arrange - var options = new FlightCallOptions(); + var message = new ActionCreatePreparedStatementResult(); + var flightData1 = new FlightData(_flightDescriptor, ByteString.Empty); + var flightData2 = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("")); - // Act & Assert - await Assert.ThrowsAsync( - () => _preparedStatement.ExecuteUpdateAsync(options) - ); + var results = GetAsyncEnumerable(new List { flightData1, flightData2 }); + + // Act + await _preparedStatement.ReadResultAsync(results, message); + + // Assert + Assert.Empty(message.PreparedStatementHandle.ToStringUtf8()); } [Fact] - public async Task ExecuteUpdateAsync_ShouldThrowException_WhenStatementIsClosed() + public async Task ReadResultAsync_ShouldThrowInvalidOperationException_WhenFlightDataIsInvalid() { // Arrange - var options = new FlightCallOptions(); - await _preparedStatement.CloseAsync(options); + var invalidMessage = new ActionCreatePreparedStatementResult(); + var invalidFlightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(new byte[] { })); + + // Act + var results = GetAsyncEnumerable(new List { invalidFlightData }); // Act & Assert - await Assert.ThrowsAsync( - () => _preparedStatement.ExecuteUpdateAsync(options) - ); + await Assert.ThrowsAsync(() => _preparedStatement.ReadResultAsync(results, invalidMessage)); } [Fact] - public async Task CloseAsync_ShouldCloseStatement_WhenCalled() + public async Task ReadResultAsync_ShouldProcessMultipleFlightDataEntries() { // Arrange - var options = new FlightCallOptions(); + var message = new ActionCreatePreparedStatementResult(); + var flightData1 = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("data1")); + var flightData2 = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("data2")); + + var results = GetAsyncEnumerable(new List { flightData1, flightData2 }); // Act - await _preparedStatement.CloseAsync(options); + await _preparedStatement.ReadResultAsync(results, message); // Assert - await Assert.ThrowsAsync( - () => _preparedStatement.CloseAsync(options) - ); + Assert.NotEmpty(message.PreparedStatementHandle.ToStringUtf8()); } + [Fact] - public async Task CloseAsync_ShouldThrowException_WhenStatementAlreadyClosed() + public async Task ParseResponseAsync_ShouldReturnPreparedStatement_WhenValidData() { // Arrange - var options = new FlightCallOptions(); - await _preparedStatement.CloseAsync(options); + var preparedStatementHandle = "test-handle"; + var actionResult = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFrom(preparedStatementHandle, Encoding.UTF8), + DatasetSchema = _schema.ToByteString(), + ParameterSchema = _schema.ToByteString() + }; + var flightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(actionResult.ToByteArray())); + var results = GetAsyncEnumerable(new List { flightData }); + + // Act + var preparedStatement = await _preparedStatement.ParseResponseAsync(_flightSqlClient, results); + + // Assert + Assert.NotNull(preparedStatement); + Assert.Equal(preparedStatementHandle, preparedStatement.Handle); + } + + [Theory] + [InlineData(null)] + [InlineData("")] + public async Task ParseResponseAsync_ShouldThrowException_WhenPreparedStatementHandleIsNullOrEmpty(string handle) + { + // Arrange + ActionCreatePreparedStatementResult actionResult; + + // Check if handle is null or empty and handle accordingly + if (string.IsNullOrEmpty(handle)) + { + actionResult = new ActionCreatePreparedStatementResult(); + } + else + { + actionResult = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFrom(handle, Encoding.UTF8) + }; + } + + var flightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(actionResult.ToByteArray())); + var results = GetAsyncEnumerable(new List { flightData }); // Act & Assert - await Assert.ThrowsAsync( - () => _preparedStatement.CloseAsync(options) - ); + await Assert.ThrowsAsync(() => + _preparedStatement.ParseResponseAsync(_flightSqlClient, results)); + } + + [Fact] + public async Task GetSchemaAsync_ShouldReturnSchemaResult_WhenValidInput() + { + // Arrange: Create a ExpectedSchemaResult for the test scenario. + var sqlClient = new TestFlightSqlClient(); + var datasetSchema = new Schema.Builder() + .Field(f => f.Name("Column1").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var parameterSchema = new Schema.Builder() + .Field(f => f.Name("Parameter1").DataType(Int32Type.Default).Nullable(false)) + .Build(); + var preparedStatement = new PreparedStatement(sqlClient, "test-handle", datasetSchema, parameterSchema); + var expectedSchemaResult = new Schema.Builder() + .Field(f => f.Name("Column1").DataType(Int32Type.Default).Nullable(false)) + .Build(); + + // Act: + var result = await preparedStatement.GetSchemaAsync(new FlightCallOptions()); + + // Assert: + Assert.NotNull(result); + SchemaComparer.Compare(expectedSchemaResult, result); + } + + [Fact] + public async Task GetSchemaAsync_ShouldThrowException_WhenSchemaIsEmpty() + { + var sqlClient = new TestFlightSqlClient { ReturnEmptySchema = true }; + var emptySchema = new Schema.Builder().Build(); // Create an empty schema + var preparedStatement = new PreparedStatement(sqlClient, "test-handle", emptySchema, emptySchema); + // Act & Assert: Ensure that calling GetSchemaAsync with an empty schema throws an exception. + await Assert.ThrowsAsync(() => preparedStatement.GetSchemaAsync(new FlightCallOptions())); + } + + [Fact] + public void Dispose_ShouldSetIsClosedToTrue() + { + // Act + _preparedStatement.Dispose(); + + // Assert + Assert.True(_preparedStatement.IsClosed, "The PreparedStatement should be closed after Dispose is called."); + } + + [Fact] + public void Dispose_MultipleTimes_ShouldNotThrowException() + { + // Act + _preparedStatement.Dispose(); + var exception = Record.Exception(() => _preparedStatement.Dispose()); + + // Assert + Assert.Null(exception); + } + + [Fact] + public async Task ToFlightDataStream_ShouldConvertRecordBatchToFlightDataStream() + { + // Arrange + var schema = new Schema.Builder() + .Field(f => f.Name("Name").DataType(StringType.Default).Nullable(false)) + .Field(f => f.Name("Age").DataType(Int32Type.Default).Nullable(false)) + .Build(); + + var names = new StringArray.Builder().Append("Hello").Append("World").Build(); + var ages = new Int32Array.Builder().Append(30).Append(40).Build(); + var recordBatch = new RecordBatch(schema, [names, ages], 2); + + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test-command"); + + // Act + var flightDataStream = recordBatch.ToFlightDataStream(flightDescriptor); + var flightDataList = new List(); + + await foreach (var flightData in flightDataStream.ReadAllAsync()) + { + flightDataList.Add(flightData); + } + + // Assert + Assert.Single(flightDataList); + Assert.NotNull(flightDataList[0].DataBody); + } + + private async IAsyncEnumerable GetAsyncEnumerable(IEnumerable enumerable) + { + foreach (var item in enumerable) + { + yield return item; + await Task.Yield(); + } + } + + /// + /// Test client implementation that simulates the behavior of FlightSqlClient for testing purposes. + /// + private class TestFlightSqlClient : FlightSqlClient + { + public bool ReturnEmptySchema { get; set; } = false; + + public TestFlightSqlClient() : base(null) + { + } + + public override Task GetSchemaAsync(FlightCallOptions options, FlightDescriptor descriptor) + { + if (ReturnEmptySchema) + { + // Return an empty schema to simulate an edge case. + return Task.FromResult(new Schema.Builder().Build()); + } + + // Return a valid SchemaResult for the test. + var schemaResult = new Schema.Builder() + .Field(f => f.Name("Column1").DataType(Int32Type.Default).Nullable(false)) + .Build(); + return Task.FromResult(schemaResult); + } + } + + public static IEnumerable GetTestData() + { + // Define schema + var schema = new Schema.Builder() + .Field(f => f.Name("field1").DataType(Int32Type.Default)) + .Build(); + int[] validValues = { 1, 2, 3 }; + int[] invalidValues = { 4, 5, 6 }; + var validRecordBatch = CreateRecordBatch(schema, validValues); + var invalidSchema = new Schema.Builder() + .Field(f => f.Name("invalid_field").DataType(Int32Type.Default)) + .Build(); + + var invalidRecordBatch = CreateRecordBatch(invalidSchema, invalidValues); + return new List + { + // Valid RecordBatch and schema - no exception expected + new object[] { validRecordBatch, new Schema.Builder().Field(f => f.Name("field1").DataType(Int32Type.Default)).Build(), null }, + + // Null RecordBatch - expect ArgumentNullException + new object[] { null, new Schema.Builder().Field(f => f.Name("field1").DataType(Int32Type.Default)).Build(), typeof(ArgumentNullException) } + }; + } + + public static RecordBatch CreateRecordBatch(Schema schema, int[] values) + { + var int32Array = new Int32Array.Builder().AppendRange(values).Build(); + var recordBatch = new RecordBatch.Builder() + .Append("field1", true, int32Array) + .Build(); + return recordBatch; } } \ No newline at end of file From 300490fd9e8dbfd85d0f775c94ebf8b614d09082 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Tue, 15 Oct 2024 18:30:06 +0300 Subject: [PATCH 37/58] chore: PreparedStatement implemented and tested --- .../ChannelReaderStreamAdapter.cs | 39 ++ .../PreparedStatement.cs | 519 ++++------------ .../RecordBatchExtensions.cs | 123 ++++ .../SchemaExtensions.cs | 25 + .../FlightSqlPreparedStatementTests.cs | 562 +++++------------- .../TestFlightSqlServer.cs | 212 +------ 6 files changed, 477 insertions(+), 1003 deletions(-) create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/RecordBatchExtensions.cs create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs diff --git a/csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs b/csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs new file mode 100644 index 0000000000000..bbcbe17d32685 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs @@ -0,0 +1,39 @@ +using System; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql; + +internal class ChannelReaderStreamAdapter : IAsyncStreamReader +{ + private readonly ChannelReader _channelReader; + + public ChannelReaderStreamAdapter(ChannelReader channelReader) + { + _channelReader = channelReader ?? throw new ArgumentNullException(nameof(channelReader)); + Current = default!; + } + + public T Current { get; private set; } + + public async Task MoveNext(CancellationToken cancellationToken) + { + if (await _channelReader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + if (_channelReader.TryRead(out var item)) + { + Current = item; + return true; + } + } + + return false; + } + + public void Dispose() + { + // No additional cleanup is required here since we are using a channel + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index c1d245cb363ab..13c7d9b543fd0 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -27,6 +27,13 @@ public class PreparedStatement : IDisposable private FlightServerRecordBatchStreamReader? _parameterReader; public FlightServerRecordBatchStreamReader? ParameterReader => _parameterReader; + /// + /// Initializes a new instance of the class. + /// + /// The Flight SQL client used for executing SQL operations. + /// The handle representing the prepared statement. + /// The schema of the result dataset. + /// The schema of the parameters for this prepared statement. public PreparedStatement(FlightSqlClient client, string handle, Schema datasetSchema, Schema parameterSchema) { _client = client ?? throw new ArgumentNullException(nameof(client)); @@ -39,8 +46,9 @@ public PreparedStatement(FlightSqlClient client, string handle, Schema datasetSc /// /// Retrieves the schema associated with the prepared statement asynchronously. /// - /// The FlightCallOptions for the operation. - /// A Task representing the asynchronous operation. The task result contains the SchemaResult object. + /// The options used to configure the Flight call. + /// A task representing the asynchronous operation, which returns the schema of the result set. + /// Thrown when the schema is empty or invalid. public async Task GetSchemaAsync(FlightCallOptions options) { EnsureStatementIsNotClosed(); @@ -69,8 +77,9 @@ public async Task GetSchemaAsync(FlightCallOptions options) /// /// Closes the prepared statement asynchronously. /// - /// The FlightCallOptions for the operation. - /// A Task representing the asynchronous operation. + /// The options used to configure the Flight call. + /// A task representing the asynchronous operation. + /// Thrown if closing the prepared statement fails. public async Task CloseAsync(FlightCallOptions options) { EnsureStatementIsNotClosed(); @@ -98,9 +107,11 @@ public async Task CloseAsync(FlightCallOptions options) /// /// Reads the result from an asynchronous stream of FlightData and populates the provided Protobuf message. /// - /// The async enumerable stream of FlightData results. - /// The Protobuf message to populate from the results. + /// The asynchronous stream of objects. + /// The Protobuf message to populate with the data from the stream. /// A task that represents the asynchronous read operation. + /// Thrown if or is null. + /// Thrown if parsing the data fails. public async Task ReadResultAsync(IAsyncEnumerable results, IMessage message) { if (results == null) throw new ArgumentNullException(nameof(results)); @@ -108,23 +119,28 @@ public async Task ReadResultAsync(IAsyncEnumerable results, IMessage await foreach (var flightData in results.ConfigureAwait(false)) { - // Ensure that the data received is valid and non-empty. if (flightData.DataBody == null || flightData.DataBody.Length == 0) - throw new InvalidOperationException("Received empty or invalid FlightData."); + continue; try { - // Merge the flight data's body into the provided message. message.MergeFrom(message.PackAndSerialize()); } catch (InvalidProtocolBufferException ex) { - throw new InvalidOperationException( - "Failed to parse the received FlightData into the specified message.", ex); + throw new InvalidOperationException("Failed to parse the received FlightData into the specified message.", ex); } } } - + + /// + /// Parses the response of a prepared statement execution from the FlightData stream. + /// + /// The Flight SQL client. + /// The asynchronous stream of objects. + /// A task representing the asynchronous operation, which returns the populated . + /// Thrown if or is null. + /// Thrown if the prepared statement handle or data is invalid. public async Task ParseResponseAsync(FlightSqlClient client, IAsyncEnumerable results) { if (client == null) @@ -147,18 +163,15 @@ public async Task ParseResponseAsync(FlightSqlClient client, } catch (InvalidProtocolBufferException ex) { - throw new InvalidOperationException( - "Failed to parse FlightData into ActionCreatePreparedStatementResult.", ex); + throw new InvalidOperationException("Failed to parse FlightData into ActionCreatePreparedStatementResult.", ex); } } - // If the response is empty or incomplete if (preparedStatementResult.PreparedStatementHandle.Length == 0) { throw new InvalidOperationException("Received an empty or invalid PreparedStatementHandle."); } - // Parse dataset and parameter schemas from the response Schema datasetSchema = null!; Schema parameterSchema = null!; @@ -177,6 +190,13 @@ public async Task ParseResponseAsync(FlightSqlClient client, datasetSchema, parameterSchema); } + /// + /// Binds the specified parameter batch to the prepared statement and returns the status. + /// + /// The containing parameters to bind to the statement. + /// A cancellation token for the binding operation. + /// A indicating success or failure. + /// Thrown if is null. public Status SetParameters(RecordBatch parameterBatch, CancellationToken cancellationToken = default) { EnsureStatementIsNotClosed(); @@ -251,24 +271,91 @@ public async Task ExecuteAsync(FlightCallOptions options, RecordBatc if (parameterBatch != null) { - var boundParametersAsync = await BindParametersAsync(options, descriptor, parameterBatch, cancellationToken).ConfigureAwait(false); + await BindParametersAsync(options, descriptor, parameterBatch).ConfigureAwait(false); } cancellationToken.ThrowIfCancellationRequested(); return await _client.GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); } - + /// + /// Executes a prepared update statement asynchronously with the provided parameter batch. + /// + /// + /// This method executes an update operation using a prepared statement. The provided + /// is bound to the statement, and the operation is sent to the server. The server processes the update and returns + /// metadata indicating the number of affected rows. + /// + /// This operation is asynchronous and can be canceled via the provided . + /// + /// The for this execution, containing headers and other options. + /// + /// A containing the parameters to be bound to the update statement. + /// This batch should match the schema expected by the prepared statement. + /// + /// + /// A representing the asynchronous operation. + /// The task result contains the number of rows affected by the update. + /// + /// + /// Thrown if is null, as a valid parameter batch is required for execution. + /// + /// + /// Thrown if the update operation fails for any reason, including when the server returns invalid or empty metadata, + /// or if the operation is canceled via the . + /// + /// + /// The following example demonstrates how to use the method to execute an update operation: + /// + /// var parameterBatch = CreateParameterBatch(); + /// var affectedRows = await preparedStatement.ExecuteUpdateAsync(new FlightCallOptions(), parameterBatch); + /// Console.WriteLine($"Rows affected: {affectedRows}"); + /// + /// + public async Task ExecuteUpdateAsync(FlightCallOptions options, RecordBatch parameterBatch) + { + if (parameterBatch == null) + { + throw new ArgumentNullException(nameof(parameterBatch), "Parameter batch cannot be null."); + } + var descriptor = FlightDescriptor.CreateCommandDescriptor(_handle); + var metadata = await BindParametersAsync(options, descriptor, parameterBatch).ConfigureAwait(false); + + try + { + return ParseAffectedRows(metadata); + } + catch (OperationCanceledException) + { + throw new InvalidOperationException("Update operation was canceled."); + } + catch (Exception ex) + { + throw new InvalidOperationException("Failed to execute the prepared update statement.", ex); + } + } + + private long ParseAffectedRows(ByteString metadata) + { + if (metadata == null || metadata.Length == 0) + { + throw new InvalidOperationException("Server returned empty metadata, unable to determine affected row count."); + } + + var updateResult = new DoPutUpdateResult(); + updateResult.MergeFrom(metadata); + return updateResult.RecordCount; + } + /// /// Binds parameters to the prepared statement by streaming the given RecordBatch to the server asynchronously. /// /// The for the operation, which may include timeouts, headers, and other options for the call. /// The that identifies the statement or command being executed. /// The containing the parameters to bind to the prepared statement. - /// Optional to observe while waiting for the task to complete. The task will be canceled if the token is canceled. /// A that represents the asynchronous operation. The task result contains the metadata from the server after binding the parameters. /// Thrown when is null. /// Thrown if the operation is canceled or if there is an error during the DoPut operation. - public async Task BindParametersAsync(FlightCallOptions options, FlightDescriptor descriptor, RecordBatch parameterBatch, CancellationToken cancellationToken = default) + public async Task BindParametersAsync(FlightCallOptions options, FlightDescriptor descriptor, RecordBatch parameterBatch) { if (parameterBatch == null) { @@ -303,403 +390,27 @@ private void EnsureStatementIsNotClosed() } /// - /// Protected implementation of the dispose pattern. + /// Disposes of the resources used by the prepared statement. /// - /// True if called from Dispose, false if called from the finalizer. - protected virtual void Dispose(bool disposing) - { - if (_isClosed) return; - - if (disposing) - { - // Close the statement if it's not already closed. - CloseAsync(new FlightCallOptions()).GetAwaiter().GetResult(); - } - - _isClosed = true; - } - public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } -} -public static class SchemaExtensions -{ /// - /// Deserializes a schema from a byte array. + /// Disposes of the resources used by the prepared statement. /// - /// The byte array representing the serialized schema. - /// The deserialized Schema object. - public static Schema DeserializeSchema(byte[] serializedSchema) - { - if (serializedSchema == null || serializedSchema.Length == 0) - { - throw new ArgumentException("Invalid serialized schema"); - } - - using var stream = new MemoryStream(serializedSchema); - var reader = new ArrowStreamReader(stream); - return reader.Schema; - } -} - -// -// public class PreparedStatement : IDisposable -// { -// private readonly FlightSqlClient _client; -// private readonly FlightInfo _flightInfo; -// private readonly string _query; -// private bool _isClosed; -// private readonly FlightDescriptor _descriptor; -// private RecordBatch? _parameterBatch; -// -// public PreparedStatement(FlightSqlClient client, FlightInfo flightInfo, string query) -// { -// _client = client ?? throw new ArgumentNullException(nameof(client)); -// _flightInfo = flightInfo ?? throw new ArgumentNullException(nameof(flightInfo)); -// _query = query ?? throw new ArgumentNullException(nameof(query)); -// _descriptor = flightInfo.Descriptor ?? throw new ArgumentNullException(nameof(flightInfo.Descriptor)); -// _isClosed = false; -// } -// -// /// -// /// Set parameters for the prepared statement -// /// -// /// The batch of parameters to bind -// public Task SetParameters(RecordBatch parameterBatch) -// { -// EnsureStatementIsNotClosed(); -// if (parameterBatch == null) -// { -// throw new ArgumentNullException(nameof(parameterBatch)); -// } -// -// _parameterBatch = parameterBatch ?? throw new ArgumentNullException(nameof(parameterBatch)); -// return Task.CompletedTask; -// } -// -// /// -// /// Execute the prepared statement, returning the number of affected rows -// /// -// /// The FlightCallOptions for the execution -// /// Task representing the asynchronous operation -// public async Task ExecuteUpdateAsync(FlightCallOptions options) -// { -// EnsureStatementIsNotClosed(); -// EnsureParametersAreSet(); -// try -// { -// var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = _query }; -// var command = new CommandPreparedStatementQuery -// { -// PreparedStatementHandle = prepareStatementRequest.ToByteString() -// }; -// var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); -// var metadata = await BindParametersAsync(options, descriptor).ConfigureAwait(false); -// await _client.ExecuteUpdateAsync(options, _query); -// -// return ParseUpdateResult(metadata); -// } -// catch (RpcException ex) -// { -// throw new InvalidOperationException("Failed to execute update query", ex); -// } -// } -// -// /// -// /// Binds parameters to the server using DoPut and retrieves metadata. -// /// -// /// The FlightCallOptions for the execution. -// /// The FlightDescriptor for the command. -// /// A ByteString containing metadata from the server response. -// public async Task BindParametersAsync(FlightCallOptions options, FlightDescriptor descriptor) -// { -// if (_parameterBatch == null) -// throw new InvalidOperationException("Parameters have not been set."); -// -// // Start the DoPut operation -// var doPutResult = await _client.DoPutAsync(options, descriptor, _parameterBatch.Schema); -// var writer = doPutResult.Writer; -// -// // Write the record batch to the stream -// await writer.WriteAsync(_parameterBatch).ConfigureAwait(false); -// await writer.CompleteAsync().ConfigureAwait(false); -// -// // Read metadata from response -// var metadata = await doPutResult.ReadMetadataAsync().ConfigureAwait(false); -// -// // Close the writer and reader streams -// await writer.CompleteAsync().ConfigureAwait(false); -// await doPutResult.CompleteAsync().ConfigureAwait(false); -// return metadata; -// } -// -// /// -// /// Closes the prepared statement -// /// -// public async Task CloseAsync(FlightCallOptions options) -// { -// EnsureStatementIsNotClosed(); -// try -// { -// var actionClose = new FlightAction(SqlAction.CloseRequest, _flightInfo.Descriptor.Command); -// await foreach (var result in _client.DoActionAsync(options, actionClose).ConfigureAwait(false)) -// { -// } -// -// _isClosed = true; -// } -// catch (RpcException ex) -// { -// throw new InvalidOperationException("Failed to close the prepared statement", ex); -// } -// } -// -// /// -// /// Parses the metadata returned from the server to get the number of affected rows. -// /// -// private long ParseUpdateResult(ByteString metadata) -// { -// var updateResult = new DoPutUpdateResult(); -// updateResult.MergeFrom(metadata); -// return updateResult.RecordCount; -// } -// -// /// -// /// Helper method to ensure the statement is not closed. -// /// -// private void EnsureStatementIsNotClosed() -// { -// if (_isClosed) -// throw new InvalidOperationException("Cannot execute a closed statement."); -// } -// -// private void EnsureParametersAreSet() -// { -// if (_parameterBatch == null || _parameterBatch.Length == 0) -// { -// throw new InvalidOperationException("Prepared statement parameters have not been set."); -// } -// } -// -// public void Dispose() -// { -// _parameterBatch?.Dispose(); -// -// if (!_isClosed) -// { -// _isClosed = true; -// } -// } -// } -// -public static class RecordBatchExtensions -{ - /// - /// Converts a RecordBatch into an asynchronous stream of FlightData. - /// - /// The RecordBatch to convert. - /// The FlightDescriptor describing the Flight data. - /// An asynchronous stream of FlightData objects. - public static async IAsyncEnumerable ToFlightDataStreamAsync(this RecordBatch recordBatch, - FlightDescriptor flightDescriptor) - { - if (recordBatch == null) - { - throw new ArgumentNullException(nameof(recordBatch)); - } - - // Use a memory stream to write the Arrow RecordBatch into FlightData format - using var memoryStream = new MemoryStream(); - var writer = new ArrowStreamWriter(memoryStream, recordBatch.Schema); - - // Write the RecordBatch to the stream - await writer.WriteRecordBatchAsync(recordBatch).ConfigureAwait(false); - await writer.WriteEndAsync().ConfigureAwait(false); - - // Reset the memory stream position - memoryStream.Position = 0; - - // Read back the data to create FlightData - var flightData = new FlightData(flightDescriptor, ByteString.CopyFrom(memoryStream.ToArray()), - ByteString.CopyFrom(memoryStream.ToArray())); - yield return flightData; - } - - /// - /// Converts a RecordBatch into an IAsyncStreamReader. - /// - /// The RecordBatch to convert. - /// The FlightDescriptor describing the Flight data. - /// An IAsyncStreamReader of FlightData. - public static IAsyncStreamReader ToFlightDataStream(this RecordBatch recordBatch, FlightDescriptor flightDescriptor) - { - if (recordBatch == null) throw new ArgumentNullException(nameof(recordBatch)); - if (flightDescriptor == null) throw new ArgumentNullException(nameof(flightDescriptor)); - - var channel = Channel.CreateUnbounded(); - - try - { - if (recordBatch.Schema == null || !recordBatch.Schema.FieldsList.Any()) - { - throw new InvalidOperationException("The record batch has an invalid or empty schema."); - } - - using var memoryStream = new MemoryStream(); - using var writer = new ArrowStreamWriter(memoryStream, recordBatch.Schema); - writer.WriteRecordBatch(recordBatch); - writer.WriteEnd(); - memoryStream.Position = 0; - var flightData = new FlightData(flightDescriptor, ByteString.CopyFrom(memoryStream.ToArray()), ByteString.Empty, ByteString.Empty); - if (flightData.DataBody.IsEmpty) - { - throw new InvalidOperationException( - "The generated FlightData is empty. Check the RecordBatch content."); - } - - channel.Writer.TryWrite(flightData); - } - finally - { - // Mark the channel as complete once done - channel.Writer.Complete(); - } - return new ChannelFlightDataReader(channel.Reader); - } - - /*public static IAsyncStreamReader ToFlightDataStream(this RecordBatch recordBatch, - FlightDescriptor flightDescriptor) - { - if (recordBatch == null) throw new ArgumentNullException(nameof(recordBatch)); - if (flightDescriptor == null) throw new ArgumentNullException(nameof(flightDescriptor)); - - // Create a channel to act as the data stream. - var channel = Channel.CreateUnbounded(); - - // Start a background task to generate the FlightData asynchronously. - _ = Task.Run(async () => - { - try - { - // Check if the schema is set and there are fields in the RecordBatch - if (recordBatch.Schema == null || !recordBatch.Schema.FieldsList.Any()) - { - throw new InvalidOperationException("The record batch has an invalid or empty schema."); - } - - // Use a memory stream to convert the RecordBatch to FlightData - await using var memoryStream = new MemoryStream(); - var writer = new ArrowStreamWriter(memoryStream, recordBatch.Schema); - - // Write the RecordBatch to the memory stream - await writer.WriteRecordBatchAsync(recordBatch).ConfigureAwait(false); - await writer.WriteEndAsync().ConfigureAwait(false); - - // Reset the memory stream position to read from it - memoryStream.Position = 0; - - // Read back the data from the stream and create FlightData - var flightData = new FlightData( - flightDescriptor, - ByteString.CopyFrom(memoryStream.ToArray()), // Use the data from memory stream - ByteString.Empty // Empty application metadata for now - ); - - // Check if flightData has valid data before writing - if (flightData.DataBody.IsEmpty) - { - throw new InvalidOperationException( - "The generated FlightData is empty. Check the RecordBatch content."); - } - - // Write the FlightData to the channel - await channel.Writer.WriteAsync(flightData).ConfigureAwait(false); - } - catch (Exception ex) - { - // Log any exceptions for debugging purposes - Console.WriteLine($"Error generating FlightData: {ex.Message}"); - } - finally - { - // Mark the channel as complete once done - channel.Writer.Complete(); - } - }); - - // Return a custom IAsyncStreamReader implementation. - return new ChannelFlightDataReader(channel.Reader); - }*/ - - - /// - /// Custom IAsyncStreamReader implementation to read from a ChannelReader. - /// - private class ChannelFlightDataReader : IAsyncStreamReader - { - private readonly ChannelReader _channelReader; - - public ChannelFlightDataReader(ChannelReader channelReader) - { - _channelReader = channelReader ?? throw new ArgumentNullException(nameof(channelReader)); - Current = default!; - } - - public FlightData Current { get; private set; } - - public async Task MoveNext(CancellationToken cancellationToken) - { - if (await _channelReader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) - { - if (_channelReader.TryRead(out var flightData)) - { - Current = flightData; - return true; - } - } - - return false; - } - - public void Dispose() - { - // No additional cleanup is required here since we're not managing external resources. - } - } -} - -public class ChannelReaderStreamAdapter : IAsyncStreamReader -{ - private readonly ChannelReader _channelReader; - - public ChannelReaderStreamAdapter(ChannelReader channelReader) + /// Whether the method is called from . + protected virtual void Dispose(bool disposing) { - _channelReader = channelReader ?? throw new ArgumentNullException(nameof(channelReader)); - Current = default!; - } - - public T Current { get; private set; } + if (_isClosed) return; - public async Task MoveNext(CancellationToken cancellationToken) - { - if (await _channelReader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + if (disposing) { - if (_channelReader.TryRead(out var item)) - { - Current = item; - return true; - } + CloseAsync(new FlightCallOptions()).GetAwaiter().GetResult(); } - return false; - } - - public void Dispose() - { - // No additional cleanup is required here since we are using a channel + _isClosed = true; } } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/RecordBatchExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/RecordBatchExtensions.cs new file mode 100644 index 0000000000000..67f889565f318 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/RecordBatchExtensions.cs @@ -0,0 +1,123 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Apache.Arrow.Ipc; +using Google.Protobuf; +using Grpc.Core; + +namespace Apache.Arrow.Flight.Sql; + +public static class RecordBatchExtensions +{ + /// + /// Converts a RecordBatch into an asynchronous stream of FlightData. + /// + /// The RecordBatch to convert. + /// The FlightDescriptor describing the Flight data. + /// An asynchronous stream of FlightData objects. + public static async IAsyncEnumerable ToFlightDataStreamAsync(this RecordBatch recordBatch, + FlightDescriptor flightDescriptor) + { + if (recordBatch == null) + { + throw new ArgumentNullException(nameof(recordBatch)); + } + + // Use a memory stream to write the Arrow RecordBatch into FlightData format + using var memoryStream = new MemoryStream(); + var writer = new ArrowStreamWriter(memoryStream, recordBatch.Schema); + + // Write the RecordBatch to the stream + await writer.WriteRecordBatchAsync(recordBatch).ConfigureAwait(false); + await writer.WriteEndAsync().ConfigureAwait(false); + + // Reset the memory stream position + memoryStream.Position = 0; + + // Read back the data to create FlightData + var flightData = new FlightData(flightDescriptor, ByteString.CopyFrom(memoryStream.ToArray()), + ByteString.CopyFrom(memoryStream.ToArray())); + yield return flightData; + } + + /// + /// Converts a RecordBatch into an IAsyncStreamReader. + /// + /// The RecordBatch to convert. + /// The FlightDescriptor describing the Flight data. + /// An IAsyncStreamReader of FlightData. + public static IAsyncStreamReader ToFlightDataStream(this RecordBatch recordBatch, FlightDescriptor flightDescriptor) + { + if (recordBatch == null) throw new ArgumentNullException(nameof(recordBatch)); + if (flightDescriptor == null) throw new ArgumentNullException(nameof(flightDescriptor)); + + var channel = Channel.CreateUnbounded(); + + try + { + if (recordBatch.Schema == null || !recordBatch.Schema.FieldsList.Any()) + { + throw new InvalidOperationException("The record batch has an invalid or empty schema."); + } + + using var memoryStream = new MemoryStream(); + using var writer = new ArrowStreamWriter(memoryStream, recordBatch.Schema); + writer.WriteRecordBatch(recordBatch); + writer.WriteEnd(); + memoryStream.Position = 0; + var flightData = new FlightData(flightDescriptor, ByteString.CopyFrom(memoryStream.ToArray()), ByteString.Empty, ByteString.Empty); + if (flightData.DataBody.IsEmpty) + { + throw new InvalidOperationException( + "The generated FlightData is empty. Check the RecordBatch content."); + } + + channel.Writer.TryWrite(flightData); + } + finally + { + // Mark the channel as complete once done + channel.Writer.Complete(); + } + return new ChannelFlightDataReader(channel.Reader); + } + + /// + /// Custom IAsyncStreamReader implementation to read from a ChannelReader. + /// + private class ChannelFlightDataReader : IAsyncStreamReader + { + private readonly ChannelReader _channelReader; + + public ChannelFlightDataReader(ChannelReader channelReader) + { + _channelReader = channelReader ?? throw new ArgumentNullException(nameof(channelReader)); + Current = default!; + } + + public FlightData Current { get; private set; } + + public async Task MoveNext(CancellationToken cancellationToken) + { + if (await _channelReader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) + { + if (_channelReader.TryRead(out var flightData)) + { + Current = flightData; + return true; + } + } + + return false; + } + + public void Dispose() + { + // No additional cleanup is required here since we're not managing external resources. + } + } +} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs new file mode 100644 index 0000000000000..4df34ee8e1009 --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs @@ -0,0 +1,25 @@ +using System; +using System.IO; +using Apache.Arrow.Ipc; + +namespace Apache.Arrow.Flight.Sql; + +public static class SchemaExtensions +{ + /// + /// Deserializes a schema from a byte array. + /// + /// The byte array representing the serialized schema. + /// The deserialized Schema object. + public static Schema DeserializeSchema(byte[] serializedSchema) + { + if (serializedSchema == null || serializedSchema.Length == 0) + { + throw new ArgumentException("Invalid serialized schema"); + } + + using var stream = new MemoryStream(serializedSchema); + var reader = new ArrowStreamReader(stream); + return reader.Schema; + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs index 20d365e8478b7..6df48886e1721 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs @@ -6,458 +6,208 @@ using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.Sql.Client; using Apache.Arrow.Flight.TestWeb; -using Apache.Arrow.Tests; using Apache.Arrow.Types; using Arrow.Flight.Protocol.Sql; using Google.Protobuf; using Grpc.Core; using Xunit; -namespace Apache.Arrow.Flight.Sql.Tests; - -public class FlightSqlPreparedStatementTests +namespace Apache.Arrow.Flight.Sql.Tests { - readonly TestFlightSqlWebFactory _testWebFactory; - readonly FlightStore _flightStore; - readonly FlightSqlClient _flightSqlClient; - private readonly PreparedStatement _preparedStatement; - private readonly Schema _schema; - private readonly FlightDescriptor _flightDescriptor; - private readonly FlightHolder _flightHolder; - private readonly RecordBatch _parameterBatch; - - public FlightSqlPreparedStatementTests() - { - _flightStore = new FlightStore(); - _testWebFactory = new TestFlightSqlWebFactory(_flightStore); - FlightClient flightClient = new(_testWebFactory.GetChannel()); - _flightSqlClient = new(flightClient); - - // Setup mock - _flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test-query"); - _schema = new Schema - .Builder() - .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) - .Build(); - - int[] dataTypeIds = [1, 2, 3]; - string[] typeNames = ["INTEGER", "VARCHAR", "BOOLEAN"]; - int[] precisions = [32, 255, 1]; - string[] literalPrefixes = ["N'", "'", "b'"]; - int[] columnSizes = [10, 255, 1]; - _parameterBatch = new RecordBatch(_schema, - [ - new Int32Array.Builder().AppendRange(dataTypeIds).Build(), - new StringArray.Builder().AppendRange(typeNames).Build(), - new Int32Array.Builder().AppendRange(precisions).Build(), - new StringArray.Builder().AppendRange(literalPrefixes).Build(), - new Int32Array.Builder().AppendRange(columnSizes).Build() - ], 5); - - _flightHolder = new FlightHolder(_flightDescriptor, _schema, _testWebFactory.GetAddress()); - _flightStore.Flights.Add(_flightDescriptor, _flightHolder); - _preparedStatement = new PreparedStatement(_flightSqlClient, handle: "test-handle-guid", datasetSchema: _schema, - parameterSchema: _schema); - } - - [Fact] - public async Task GetSchemaAsync_ShouldThrowInvalidOperationException_WhenStatementIsClosed() - { - // Arrange: - await _preparedStatement.CloseAsync(new FlightCallOptions()); - - // Act & Assert: Ensure that calling GetSchemaAsync on a closed statement throws an exception. - await Assert.ThrowsAsync(() => _preparedStatement.GetSchemaAsync(new FlightCallOptions())); - } - - [Fact] - public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenValidInputsAreProvided() - { - // Arrange - var validSchema = new Schema.Builder() - .Field(f => f.Name("field1").DataType(Int32Type.Default)) - .Build(); - string handle = "TestHandle"; - var preparedStatement = new PreparedStatement(_flightSqlClient, handle, validSchema, validSchema); - var validRecordBatch = CreateRecordBatch(validSchema, [1, 2, 3]); - - // Act - var result = preparedStatement.SetParameters(validRecordBatch); - var flightInfo = await preparedStatement.ExecuteAsync(new FlightCallOptions(), validRecordBatch); - - // Assert - Assert.NotNull(flightInfo); - Assert.IsType(flightInfo); - Assert.Equal(Status.DefaultSuccess, result); - } - - [Fact] - public async Task BindParametersAsync_ShouldReturnMetadata_WhenValidInputsAreProvided() + public class FlightSqlPreparedStatementTests { - // Arrange - var validSchema = new Schema.Builder() - .Field(f => f.Name("field1").DataType(Int32Type.Default)) - .Build(); - - var validRecordBatch = CreateRecordBatch(validSchema, new[] { 1, 2, 3 }); - var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("TestCommand"); - - var preparedStatement = new PreparedStatement(_flightSqlClient, "TestHandle", validSchema, validSchema); - - // Act - var metadata = await preparedStatement.BindParametersAsync(new FlightCallOptions(), flightDescriptor, validRecordBatch); - - // Assert - Assert.NotNull(metadata); + readonly TestFlightSqlWebFactory _testWebFactory; + readonly FlightStore _flightStore; + readonly FlightSqlClient _flightSqlClient; + private readonly PreparedStatement _preparedStatement; + private readonly Schema _schema; + private readonly FlightDescriptor _flightDescriptor; + private readonly RecordBatch _parameterBatch; - // Check if metadata has valid content - // Some systems might return empty metadata, so we validate it's non-null and proceed accordingly - if (metadata.Length == 0) + public FlightSqlPreparedStatementTests() { - // Optionally, check if the server returned empty metadata but still succeeded - Assert.Equal(0, metadata.Length); + _flightStore = new FlightStore(); + _testWebFactory = new TestFlightSqlWebFactory(_flightStore); + _flightSqlClient = new FlightSqlClient(new FlightClient(_testWebFactory.GetChannel())); + + _flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test-query"); + _schema = CreateSchema(); + _parameterBatch = CreateParameterBatch(); + _preparedStatement = new PreparedStatement(_flightSqlClient, "test-handle-guid", _schema, _schema); } - else + + private static Schema CreateSchema() { - // If metadata is present, validate its contents - Assert.True(metadata.Length > 0, "Metadata should have a length greater than 0 when valid."); + return new Schema.Builder() + .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) + .Build(); } - } - - [Theory] - [MemberData(nameof(GetTestData))] - public async Task TestSetParameters(RecordBatch parameterBatch, Schema parameterSchema, Type expectedException) - { - // Arrange - var validSchema = new Schema.Builder() - .Field(f => f.Name("field1").DataType(Int32Type.Default)) - .Build(); - string handle = "TestHandle"; - - var preparedStatement = new PreparedStatement(_flightSqlClient, handle, validSchema, parameterSchema); - - if (expectedException != null) + + private RecordBatch CreateParameterBatch() { - // Act and Assert (Expected to throw exception) - var exception = await Record.ExceptionAsync(() => Task.Run(() => preparedStatement.SetParameters(parameterBatch))); - Assert.NotNull(exception); - Assert.IsType(expectedException, exception); // Ensure correct exception type + return new RecordBatch(_schema, + new IArrowArray[] + { + new Int32Array.Builder().AppendRange(new[] { 1, 2, 3 }).Build(), + new StringArray.Builder().AppendRange(new[] { "INTEGER", "VARCHAR", "BOOLEAN" }).Build(), + new Int32Array.Builder().AppendRange(new[] { 32, 255, 1 }).Build() + }, 3); } - else - { - // Act - var result = await Task.Run(() => preparedStatement.SetParameters(parameterBatch)); - // Assert - Assert.NotNull(preparedStatement.ParameterReader); - Assert.Equal(Status.DefaultSuccess, result); // Ensure Status is success + [Fact] + public async Task GetSchemaAsync_ShouldThrowInvalidOperationException_WhenStatementIsClosed() + { + await _preparedStatement.CloseAsync(new FlightCallOptions()); + await Assert.ThrowsAsync(() => _preparedStatement.GetSchemaAsync(new FlightCallOptions())); } - } - - [Fact] - public async Task TestSetParameters_Cancelled() - { - // Arrange - var validSchema = new Schema.Builder() - .Field(f => f.Name("field1").DataType(Int32Type.Default)) - .Build(); - - string handle = "TestHandle"; - - var preparedStatement = new PreparedStatement(_flightSqlClient, handle, validSchema, validSchema); - var validRecordBatch = CreateRecordBatch(validSchema, [1, 2, 3]); - - // Create a CancellationTokenSource - var cts = new CancellationTokenSource(); - - // Act: Simulate cancellation before setting parameters - await cts.CancelAsync(); - var result = preparedStatement.SetParameters(validRecordBatch, cts.Token); - - // Assert: Ensure the status is DefaultCancelled - Assert.Equal(Status.DefaultCancelled, result); - } - - [Fact] - public async Task TestCloseAsync() - { - // Arrange - var options = new FlightCallOptions(); - - // Act - await _preparedStatement.CloseAsync(options); - - // Assert - Assert.True(_preparedStatement.IsClosed, - "PreparedStatement should be marked as closed after calling CloseAsync."); - } - - [Fact] - public async Task ReadResultAsync_ShouldPopulateMessage_WhenValidFlightData() - { - // Arrange - var message = new ActionCreatePreparedStatementResult(); - var flightData = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("test-data")); - - var results = GetAsyncEnumerable(new List { flightData }); - - // Act - await _preparedStatement.ReadResultAsync(results, message); - - // Assert - Assert.NotEmpty(message.PreparedStatementHandle.ToStringUtf8()); - } - - [Fact] - public async Task ReadResultAsync_ShouldNotThrow_WhenFlightDataBodyIsNullOrEmpty() - { - // Arrange - var message = new ActionCreatePreparedStatementResult(); - var flightData1 = new FlightData(_flightDescriptor, ByteString.Empty); - var flightData2 = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("")); - - var results = GetAsyncEnumerable(new List { flightData1, flightData2 }); - - // Act - await _preparedStatement.ReadResultAsync(results, message); - - // Assert - Assert.Empty(message.PreparedStatementHandle.ToStringUtf8()); - } - - [Fact] - public async Task ReadResultAsync_ShouldThrowInvalidOperationException_WhenFlightDataIsInvalid() - { - // Arrange - var invalidMessage = new ActionCreatePreparedStatementResult(); - var invalidFlightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(new byte[] { })); - - // Act - var results = GetAsyncEnumerable(new List { invalidFlightData }); - - // Act & Assert - await Assert.ThrowsAsync(() => _preparedStatement.ReadResultAsync(results, invalidMessage)); - } - - [Fact] - public async Task ReadResultAsync_ShouldProcessMultipleFlightDataEntries() - { - // Arrange - var message = new ActionCreatePreparedStatementResult(); - var flightData1 = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("data1")); - var flightData2 = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("data2")); - - var results = GetAsyncEnumerable(new List { flightData1, flightData2 }); - // Act - await _preparedStatement.ReadResultAsync(results, message); - - // Assert - Assert.NotEmpty(message.PreparedStatementHandle.ToStringUtf8()); - } - - - [Fact] - public async Task ParseResponseAsync_ShouldReturnPreparedStatement_WhenValidData() - { - // Arrange - var preparedStatementHandle = "test-handle"; - var actionResult = new ActionCreatePreparedStatementResult + [Fact] + public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenValidInputsAreProvided() { - PreparedStatementHandle = ByteString.CopyFrom(preparedStatementHandle, Encoding.UTF8), - DatasetSchema = _schema.ToByteString(), - ParameterSchema = _schema.ToByteString() - }; - var flightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(actionResult.ToByteArray())); - var results = GetAsyncEnumerable(new List { flightData }); - - // Act - var preparedStatement = await _preparedStatement.ParseResponseAsync(_flightSqlClient, results); + var validRecordBatch = CreateRecordBatch(_schema, new[] { 1, 2, 3 }); + var result = _preparedStatement.SetParameters(validRecordBatch); + var flightInfo = await _preparedStatement.ExecuteAsync(new FlightCallOptions(), validRecordBatch); - // Assert - Assert.NotNull(preparedStatement); - Assert.Equal(preparedStatementHandle, preparedStatement.Handle); - } + Assert.NotNull(flightInfo); + Assert.IsType(flightInfo); + Assert.Equal(Status.DefaultSuccess, result); + } - [Theory] - [InlineData(null)] - [InlineData("")] - public async Task ParseResponseAsync_ShouldThrowException_WhenPreparedStatementHandleIsNullOrEmpty(string handle) - { - // Arrange - ActionCreatePreparedStatementResult actionResult; + [Fact] + public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows_WhenParametersAreSet() + { + var affectedRows = await _preparedStatement.ExecuteUpdateAsync(new FlightCallOptions(), _parameterBatch); + Assert.True(affectedRows > 0, "Expected affected rows to be greater than 0."); + } - // Check if handle is null or empty and handle accordingly - if (string.IsNullOrEmpty(handle)) + [Fact] + public async Task BindParametersAsync_ShouldReturnMetadata_WhenValidInputsAreProvided() { - actionResult = new ActionCreatePreparedStatementResult(); + var metadata = await _preparedStatement.BindParametersAsync(new FlightCallOptions(), _flightDescriptor, _parameterBatch); + Assert.NotNull(metadata); + Assert.True(metadata.Length > 0, "Metadata should have a length greater than 0 when valid."); } - else + + [Theory] + [MemberData(nameof(GetTestData))] + public async Task TestSetParameters(RecordBatch parameterBatch, Schema parameterSchema, Type expectedException) { - actionResult = new ActionCreatePreparedStatementResult + var preparedStatement = new PreparedStatement(_flightSqlClient, "TestHandle", _schema, parameterSchema); + if (expectedException != null) { - PreparedStatementHandle = ByteString.CopyFrom(handle, Encoding.UTF8) - }; + var exception = await Record.ExceptionAsync(() => Task.Run(() => preparedStatement.SetParameters(parameterBatch))); + Assert.NotNull(exception); + Assert.IsType(expectedException, exception); + } + else + { + var result = await Task.Run(() => preparedStatement.SetParameters(parameterBatch)); + Assert.NotNull(preparedStatement.ParameterReader); + Assert.Equal(Status.DefaultSuccess, result); + } } - var flightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(actionResult.ToByteArray())); - var results = GetAsyncEnumerable(new List { flightData }); - - // Act & Assert - await Assert.ThrowsAsync(() => - _preparedStatement.ParseResponseAsync(_flightSqlClient, results)); - } - - [Fact] - public async Task GetSchemaAsync_ShouldReturnSchemaResult_WhenValidInput() - { - // Arrange: Create a ExpectedSchemaResult for the test scenario. - var sqlClient = new TestFlightSqlClient(); - var datasetSchema = new Schema.Builder() - .Field(f => f.Name("Column1").DataType(Int32Type.Default).Nullable(false)) - .Build(); - var parameterSchema = new Schema.Builder() - .Field(f => f.Name("Parameter1").DataType(Int32Type.Default).Nullable(false)) - .Build(); - var preparedStatement = new PreparedStatement(sqlClient, "test-handle", datasetSchema, parameterSchema); - var expectedSchemaResult = new Schema.Builder() - .Field(f => f.Name("Column1").DataType(Int32Type.Default).Nullable(false)) - .Build(); - - // Act: - var result = await preparedStatement.GetSchemaAsync(new FlightCallOptions()); - - // Assert: - Assert.NotNull(result); - SchemaComparer.Compare(expectedSchemaResult, result); - } - - [Fact] - public async Task GetSchemaAsync_ShouldThrowException_WhenSchemaIsEmpty() - { - var sqlClient = new TestFlightSqlClient { ReturnEmptySchema = true }; - var emptySchema = new Schema.Builder().Build(); // Create an empty schema - var preparedStatement = new PreparedStatement(sqlClient, "test-handle", emptySchema, emptySchema); - // Act & Assert: Ensure that calling GetSchemaAsync with an empty schema throws an exception. - await Assert.ThrowsAsync(() => preparedStatement.GetSchemaAsync(new FlightCallOptions())); - } - - [Fact] - public void Dispose_ShouldSetIsClosedToTrue() - { - // Act - _preparedStatement.Dispose(); - - // Assert - Assert.True(_preparedStatement.IsClosed, "The PreparedStatement should be closed after Dispose is called."); - } - - [Fact] - public void Dispose_MultipleTimes_ShouldNotThrowException() - { - // Act - _preparedStatement.Dispose(); - var exception = Record.Exception(() => _preparedStatement.Dispose()); + [Fact] + public async Task TestSetParameters_Cancelled() + { + var validRecordBatch = CreateRecordBatch(_schema, new[] { 1, 2, 3 }); + var cts = new CancellationTokenSource(); + await cts.CancelAsync(); + var result = _preparedStatement.SetParameters(validRecordBatch, cts.Token); + Assert.Equal(Status.DefaultCancelled, result); + } - // Assert - Assert.Null(exception); - } + [Fact] + public async Task TestCloseAsync() + { + await _preparedStatement.CloseAsync(new FlightCallOptions()); + Assert.True(_preparedStatement.IsClosed, "PreparedStatement should be marked as closed after calling CloseAsync."); + } - [Fact] - public async Task ToFlightDataStream_ShouldConvertRecordBatchToFlightDataStream() - { - // Arrange - var schema = new Schema.Builder() - .Field(f => f.Name("Name").DataType(StringType.Default).Nullable(false)) - .Field(f => f.Name("Age").DataType(Int32Type.Default).Nullable(false)) - .Build(); + [Fact] + public async Task ReadResultAsync_ShouldPopulateMessage_WhenValidFlightData() + { + var message = new ActionCreatePreparedStatementResult(); + var flightData = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("test-data")); + var results = GetAsyncEnumerable(new List { flightData }); - var names = new StringArray.Builder().Append("Hello").Append("World").Build(); - var ages = new Int32Array.Builder().Append(30).Append(40).Build(); - var recordBatch = new RecordBatch(schema, [names, ages], 2); + await _preparedStatement.ReadResultAsync(results, message); + Assert.NotEmpty(message.PreparedStatementHandle.ToStringUtf8()); + } - var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test-command"); + [Fact] + public async Task ReadResultAsync_ShouldNotThrow_WhenFlightDataBodyIsNullOrEmpty() + { + var message = new ActionCreatePreparedStatementResult(); + var flightData1 = new FlightData(_flightDescriptor, ByteString.Empty); + var flightData2 = new FlightData(_flightDescriptor, ByteString.CopyFromUtf8("")); + var results = GetAsyncEnumerable(new List { flightData1, flightData2 }); - // Act - var flightDataStream = recordBatch.ToFlightDataStream(flightDescriptor); - var flightDataList = new List(); + await _preparedStatement.ReadResultAsync(results, message); + Assert.Empty(message.PreparedStatementHandle.ToStringUtf8()); + } - await foreach (var flightData in flightDataStream.ReadAllAsync()) + [Fact] + public async Task ParseResponseAsync_ShouldReturnPreparedStatement_WhenValidData() { - flightDataList.Add(flightData); - } + var preparedStatementHandle = "test-handle"; + var actionResult = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFrom(preparedStatementHandle, Encoding.UTF8), + DatasetSchema = _schema.ToByteString(), + ParameterSchema = _schema.ToByteString() + }; + var flightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(actionResult.ToByteArray())); + var results = GetAsyncEnumerable(new List { flightData }); - // Assert - Assert.Single(flightDataList); - Assert.NotNull(flightDataList[0].DataBody); - } + var preparedStatement = await _preparedStatement.ParseResponseAsync(_flightSqlClient, results); + Assert.NotNull(preparedStatement); + Assert.Equal(preparedStatementHandle, preparedStatement.Handle); + } - private async IAsyncEnumerable GetAsyncEnumerable(IEnumerable enumerable) - { - foreach (var item in enumerable) + [Theory] + [InlineData(null)] + [InlineData("")] + public async Task ParseResponseAsync_ShouldThrowException_WhenPreparedStatementHandleIsNullOrEmpty(string handle) { - yield return item; - await Task.Yield(); - } - } + ActionCreatePreparedStatementResult actionResult = string.IsNullOrEmpty(handle) + ? new ActionCreatePreparedStatementResult() + : new ActionCreatePreparedStatementResult { PreparedStatementHandle = ByteString.CopyFrom(handle, Encoding.UTF8) }; - /// - /// Test client implementation that simulates the behavior of FlightSqlClient for testing purposes. - /// - private class TestFlightSqlClient : FlightSqlClient - { - public bool ReturnEmptySchema { get; set; } = false; + var flightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(actionResult.ToByteArray())); + var results = GetAsyncEnumerable(new List { flightData }); - public TestFlightSqlClient() : base(null) - { + await Assert.ThrowsAsync(() => + _preparedStatement.ParseResponseAsync(_flightSqlClient, results)); } - public override Task GetSchemaAsync(FlightCallOptions options, FlightDescriptor descriptor) + private async IAsyncEnumerable GetAsyncEnumerable(IEnumerable enumerable) { - if (ReturnEmptySchema) + foreach (var item in enumerable) { - // Return an empty schema to simulate an edge case. - return Task.FromResult(new Schema.Builder().Build()); + yield return item; + await Task.Yield(); } - - // Return a valid SchemaResult for the test. - var schemaResult = new Schema.Builder() - .Field(f => f.Name("Column1").DataType(Int32Type.Default).Nullable(false)) - .Build(); - return Task.FromResult(schemaResult); } - } - - public static IEnumerable GetTestData() - { - // Define schema - var schema = new Schema.Builder() - .Field(f => f.Name("field1").DataType(Int32Type.Default)) - .Build(); - int[] validValues = { 1, 2, 3 }; - int[] invalidValues = { 4, 5, 6 }; - var validRecordBatch = CreateRecordBatch(schema, validValues); - var invalidSchema = new Schema.Builder() - .Field(f => f.Name("invalid_field").DataType(Int32Type.Default)) - .Build(); - - var invalidRecordBatch = CreateRecordBatch(invalidSchema, invalidValues); - return new List + + public static IEnumerable GetTestData() { - // Valid RecordBatch and schema - no exception expected - new object[] { validRecordBatch, new Schema.Builder().Field(f => f.Name("field1").DataType(Int32Type.Default)).Build(), null }, + var schema = new Schema.Builder().Field(f => f.Name("field1").DataType(Int32Type.Default)).Build(); + var validRecordBatch = CreateRecordBatch(schema, new[] { 1, 2, 3 }); + var invalidSchema = new Schema.Builder().Field(f => f.Name("invalid_field").DataType(Int32Type.Default)).Build(); + var invalidRecordBatch = CreateRecordBatch(invalidSchema, new[] { 4, 5, 6 }); - // Null RecordBatch - expect ArgumentNullException - new object[] { null, new Schema.Builder().Field(f => f.Name("field1").DataType(Int32Type.Default)).Build(), typeof(ArgumentNullException) } - }; - } - - public static RecordBatch CreateRecordBatch(Schema schema, int[] values) - { - var int32Array = new Int32Array.Builder().AppendRange(values).Build(); - var recordBatch = new RecordBatch.Builder() - .Append("field1", true, int32Array) - .Build(); - return recordBatch; + return new List + { + new object[] { validRecordBatch, schema, null }, + new object[] { null, schema, typeof(ArgumentNullException) } + }; + } + + public static RecordBatch CreateRecordBatch(Schema schema, int[] values) + { + var int32Array = new Int32Array.Builder().AppendRange(values).Build(); + return new RecordBatch.Builder().Append("field1", true, int32Array).Build(); + } } -} \ No newline at end of file +} diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs index 3d1d086a3c1cc..0463ee79379f3 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs @@ -88,8 +88,7 @@ public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStr } } - public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, - IAsyncStreamWriter responseStream, ServerCallContext context) + public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, IAsyncStreamWriter responseStream, ServerCallContext context) { var flightDescriptor = await requestStream.FlightDescriptor; @@ -98,13 +97,27 @@ public override async Task DoPut(FlightServerRecordBatchStreamReader requestStre flightHolder = new FlightHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}"); _flightStore.Flights.Add(flightDescriptor, flightHolder); } - + + int affectedRows = 0; while (await requestStream.MoveNext()) { - flightHolder.AddBatch(new RecordBatchWithMetadata(requestStream.Current, - requestStream.ApplicationMetadata.FirstOrDefault())); - await responseStream.WriteAsync(FlightPutResult.Empty); + // Process the record batch (if needed) here + // Increment the affected row count for demonstration purposes + affectedRows += requestStream.Current.Column(0).Length; // Example of counting rows in the first column } + + // Create a DoPutUpdateResult with the affected row count + var updateResult = new DoPutUpdateResult + { + RecordCount = affectedRows // Set the actual affected row count + }; + + // Serialize the DoPutUpdateResult into a ByteString + var metadata = updateResult.ToByteString(); + + // Send the metadata back as part of the FlightPutResult + var flightPutResult = new FlightPutResult(metadata); + await responseStream.WriteAsync(flightPutResult); } public override Task GetFlightInfo(FlightDescriptor request, ServerCallContext context) @@ -137,190 +150,3 @@ public override Task GetSchema(FlightDescriptor request, ServerCallConte throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); } } - - -/* - * - * - - using System; - using System.Collections.Generic; - using System.Linq; - using System.Threading.Tasks; - using Apache.Arrow.Flight.Server; - using Apache.Arrow.Flight.Sql; - using Arrow.Flight.Protocol.Sql; - using Google.Protobuf; - using Google.Protobuf.WellKnownTypes; - using Grpc.Core; - using Grpc.Core.Utils; - - namespace Apache.Arrow.Flight.TestWeb - { - public class TestFlightServer : FlightServer - { - private readonly FlightStore _flightStore; - - public TestFlightServer(FlightStore flightStore) - { - _flightStore = flightStore; - } - - public override async Task DoAction(FlightAction request, IAsyncStreamWriter responseStream, - ServerCallContext context) - { - switch (request.Type) - { - case "test": - await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false); - break; - case SqlAction.GetPrimaryKeysRequest: - await responseStream.WriteAsync(new FlightResult("test data")).ConfigureAwait(false); - break; - case SqlAction.CancelFlightInfoRequest: - var cancelRequest = new FlightInfoCancelResult(); - cancelRequest.SetStatus(1); - await responseStream.WriteAsync(new FlightResult(Any.Pack(cancelRequest).Serialize().ToByteArray())) - .ConfigureAwait(false); - break; - case SqlAction.BeginTransactionRequest: - case SqlAction.CommitRequest: - case SqlAction.RollbackRequest: - await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample-transaction-id"))) - .ConfigureAwait(false); - break; - case SqlAction.CreateRequest: - case SqlAction.CloseRequest: - var prepareStatementResponse = new ActionCreatePreparedStatementResult - { - PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement") - }; - byte[] packedResult = Any.Pack(prepareStatementResponse).Serialize().ToByteArray(); - var flightResult = new FlightResult(packedResult); - await responseStream.WriteAsync(flightResult).ConfigureAwait(false); - break; - default: - throw new NotImplementedException(); - } - } - - public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStreamWriter responseStream, - ServerCallContext context) - { - FlightDescriptor flightDescriptor = null; - flightDescriptor = flightDescriptor is not null && flightDescriptor.Paths.Any() - ? FlightDescriptor.CreatePathDescriptor(ticket.Ticket.ToStringUtf8()) - : FlightDescriptor.CreateCommandDescriptor(ticket.Ticket.ToStringUtf8()); - - if (_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) - { - var batches = flightHolder.GetRecordBatches(); - - - foreach (var batch in batches) - { - await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata); - } - } - } - - public override async Task DoPut(FlightServerRecordBatchStreamReader requestStream, - IAsyncStreamWriter responseStream, ServerCallContext context) - { - var flightDescriptor = await requestStream.FlightDescriptor; - - if (!_flightStore.Flights.TryGetValue(flightDescriptor, out var flightHolder)) - { - flightHolder = new FlightHolder(flightDescriptor, await requestStream.Schema, $"http://{context.Host}"); - _flightStore.Flights.Add(flightDescriptor, flightHolder); - } - - while (await requestStream.MoveNext()) - { - flightHolder.AddBatch(new RecordBatchWithMetadata(requestStream.Current, - requestStream.ApplicationMetadata.FirstOrDefault())); - await responseStream.WriteAsync(FlightPutResult.Empty); - } - } - - public override Task GetFlightInfo(FlightDescriptor request, ServerCallContext context) - { - if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) - { - return Task.FromResult(flightHolder.GetFlightInfo()); - } - - if (_flightStore.Flights.Count > 0) - { - return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo()); - } - - throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); - } - - public override async Task Handshake(IAsyncStreamReader requestStream, - IAsyncStreamWriter responseStream, ServerCallContext context) - { - while (await requestStream.MoveNext().ConfigureAwait(false)) - { - if (requestStream.Current.Payload.ToStringUtf8() == "Hello") - { - await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Hello handshake"))) - .ConfigureAwait(false); - } - else - { - await responseStream.WriteAsync(new(ByteString.CopyFromUtf8("Done"))).ConfigureAwait(false); - } - } - } - - public override Task GetSchema(FlightDescriptor request, ServerCallContext context) - { - if (_flightStore.Flights.TryGetValue(request, out var flightHolder)) - { - return Task.FromResult(flightHolder.GetFlightInfo().Schema); - } - - if (_flightStore.Flights.Count > 0) - { - return Task.FromResult(_flightStore.Flights.First().Value.GetFlightInfo().Schema); - } - - throw new RpcException(new Status(StatusCode.NotFound, "Flight not found")); - } - - public override async Task ListActions(IAsyncStreamWriter responseStream, - ServerCallContext context) - { - await responseStream.WriteAsync(new FlightActionType("get", "get a flight")).ConfigureAwait(false); - await responseStream.WriteAsync(new FlightActionType("put", "add a flight")).ConfigureAwait(false); - await responseStream.WriteAsync(new FlightActionType("delete", "delete a flight")).ConfigureAwait(false); - await responseStream.WriteAsync(new FlightActionType("test", "test action")).ConfigureAwait(false); - } - - public override async Task ListFlights(FlightCriteria request, IAsyncStreamWriter responseStream, - ServerCallContext context) - { - var flightInfos = _flightStore.Flights.Select(x => x.Value.GetFlightInfo()).ToList(); - - foreach (var flightInfo in flightInfos) - { - await responseStream.WriteAsync(flightInfo).ConfigureAwait(false); - } - } - - public override async Task DoExchange(FlightServerRecordBatchStreamReader requestStream, - FlightServerRecordBatchStreamWriter responseStream, ServerCallContext context) - { - while (await requestStream.MoveNext().ConfigureAwait(false)) - { - await responseStream - .WriteAsync(requestStream.Current, requestStream.ApplicationMetadata.FirstOrDefault()) - .ConfigureAwait(false); - } - } - } - } - * - */ \ No newline at end of file From 34e0113e6ff29e98d3dda26b47fcdc61fcc9e53f Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 28 Oct 2024 18:39:07 +0200 Subject: [PATCH 38/58] refactor: FlightCallOptions --- .../Client/FlightSqlClient.cs | 258 +++++------------- .../PreparedStatement.cs | 24 +- .../Ipc/ICompressionCodecFactory.cs | 2 +- .../FlightSqlClientTests.cs | 64 ++--- .../FlightSqlPreparedStatementTests.cs | 6 +- 5 files changed, 107 insertions(+), 247 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 07c4a20a54d07..6ef9fb324d5bf 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -26,15 +26,10 @@ public FlightSqlClient(FlightClient client) /// The UTF8-encoded SQL query to be executed. /// A transaction to associate this query with. /// The FlightInfo describing where to access the dataset. - public async Task ExecuteAsync(FlightCallOptions options, string query, Transaction? transaction = null) + public async Task ExecuteAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { transaction ??= Transaction.NoTransaction; - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - if (string.IsNullOrEmpty(query)) { throw new ArgumentException($"Query cannot be null or empty: {nameof(query)}"); @@ -45,7 +40,7 @@ public async Task ExecuteAsync(FlightCallOptions options, string que var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); - var call = _client.DoAction(action, options.Headers); + var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -58,7 +53,7 @@ public async Task ExecuteAsync(FlightCallOptions options, string que byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); - return await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + return await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); } throw new InvalidOperationException("No results returned from the query."); @@ -76,15 +71,10 @@ public async Task ExecuteAsync(FlightCallOptions options, string que /// The UTF8-encoded SQL query to be executed. /// A transaction to associate this query with. Defaults to no transaction if not provided. /// The number of rows affected by the operation. - public async Task ExecuteUpdateAsync(FlightCallOptions options, string query, Transaction? transaction = null) + public async Task ExecuteUpdateAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { transaction ??= Transaction.NoTransaction; - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - if (string.IsNullOrEmpty(query)) { throw new ArgumentException("Query cannot be null or empty", nameof(query)); @@ -96,7 +86,7 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options, string que new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; byte[] serializedUpdateRequestCommand = updateRequestCommand.PackAndSerialize(); var action = new FlightAction(SqlAction.CreateRequest, serializedUpdateRequestCommand); - var call = DoActionAsync(options, action); + var call = DoActionAsync(action, options); long affectedRows = 0; await foreach (var result in call.ConfigureAwait(false)) @@ -108,8 +98,8 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options, string que }; var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor); - var doGetResult = DoGetAsync(options, flightInfo.Endpoints[0].Ticket); + var flightInfo = await GetFlightInfoAsync(descriptor, options); + var doGetResult = DoGetAsync(flightInfo.Endpoints[0].Ticket, options); await foreach (var recordBatch in doGetResult.ConfigureAwait(false)) { affectedRows += recordBatch.Column(0).Length; @@ -130,7 +120,7 @@ public async Task ExecuteUpdateAsync(FlightCallOptions options, string que /// RPC-layer hints for this call. /// The descriptor of the dataset request, whether a named dataset or a command. /// A task that represents the asynchronous operation. The task result contains the FlightInfo describing where to access the dataset. - public async Task GetFlightInfoAsync(FlightCallOptions options, FlightDescriptor descriptor) + public async Task GetFlightInfoAsync(FlightDescriptor descriptor, FlightCallOptions? options = default) { if (descriptor is null) { @@ -139,7 +129,7 @@ public async Task GetFlightInfoAsync(FlightCallOptions options, Flig try { - var flightInfoCall = _client.GetInfo(descriptor, options.Headers); + var flightInfoCall = _client.GetInfo(descriptor, options?.Headers); var flightInfo = await flightInfoCall.ResponseAsync.ConfigureAwait(false); return flightInfo; } @@ -149,32 +139,18 @@ public async Task GetFlightInfoAsync(FlightCallOptions options, Flig } } - /// - /// Asynchronously retrieves flight information for a given flight descriptor. - /// - /// The descriptor of the dataset request, whether a named dataset or a command. - /// A task that represents the asynchronous operation. The task result contains the FlightInfo describing where to access the dataset. - public Task GetFlightInfoAsync(FlightDescriptor descriptor) - { - var options = new FlightCallOptions(); - return GetFlightInfoAsync(options, descriptor); - } - /// /// Perform the indicated action, returning an iterator to the stream of results, if any. /// /// Per-RPC options /// The action to be performed /// An async enumerable of results - public async IAsyncEnumerable DoActionAsync(FlightCallOptions options, FlightAction action) + public async IAsyncEnumerable DoActionAsync(FlightAction action, FlightCallOptions? options = default) { - if (options is null) - throw new ArgumentNullException(nameof(options)); - if (action is null) throw new ArgumentNullException(nameof(action)); - var call = _client.DoAction(action, options.Headers); + var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -182,19 +158,6 @@ public async IAsyncEnumerable DoActionAsync(FlightCallOptions opti } } - /// - /// Perform the indicated action with default options, returning an iterator to the stream of results, if any. - /// - /// The action to be performed - /// An async enumerable of results - public async IAsyncEnumerable DoActionAsync(FlightAction action) - { - await foreach (var result in DoActionAsync(new FlightCallOptions(), action)) - { - yield return result; - } - } - /// /// Get the result set schema from the server for the given query. /// @@ -202,14 +165,10 @@ public async IAsyncEnumerable DoActionAsync(FlightAction action) /// The UTF8-encoded SQL query /// A transaction to associate this query with /// The SchemaResult describing the schema of the result set - public async Task GetExecuteSchemaAsync(FlightCallOptions options, string query, - Transaction? transaction = null) + public async Task GetExecuteSchemaAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { transaction ??= Transaction.NoTransaction; - if (options is null) - throw new ArgumentNullException(nameof(options)); - if (string.IsNullOrEmpty(query)) throw new ArgumentException($"Query cannot be null or empty: {nameof(query)}"); @@ -219,7 +178,7 @@ public async Task GetExecuteSchemaAsync(FlightCallOptions options, strin var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); - var call = _client.DoAction(action, options.Headers); + var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -231,7 +190,7 @@ public async Task GetExecuteSchemaAsync(FlightCallOptions options, strin }; byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); - schemaResult = await GetFlightInfoAsync(options, descriptor); + schemaResult = await GetFlightInfoAsync(descriptor, options); } return schemaResult.Schema; @@ -247,7 +206,7 @@ public async Task GetExecuteSchemaAsync(FlightCallOptions options, strin /// /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. - public async Task GetCatalogsAsync(FlightCallOptions options) + public async Task GetCatalogsAsync(FlightCallOptions? options = default) { if (options == null) { @@ -258,7 +217,7 @@ public async Task GetCatalogsAsync(FlightCallOptions options) { var command = new CommandGetCatalogs(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var catalogsInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + var catalogsInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return catalogsInfo; } catch (RpcException ex) @@ -272,18 +231,13 @@ public async Task GetCatalogsAsync(FlightCallOptions options) /// /// RPC-layer hints for this call. /// The SchemaResult describing the schema of the catalogs. - public async Task GetCatalogsSchemaAsync(FlightCallOptions options) + public async Task GetCatalogsSchemaAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var commandGetCatalogsSchema = new CommandGetCatalogs(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetCatalogsSchema.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -298,7 +252,7 @@ public async Task GetCatalogsSchemaAsync(FlightCallOptions options) /// RPC-layer hints for this call. /// The descriptor of the dataset request, whether a named dataset or a command. /// A task that represents the asynchronous operation. The task result contains the SchemaResult describing the dataset schema. - public virtual async Task GetSchemaAsync(FlightCallOptions options, FlightDescriptor descriptor) + public virtual async Task GetSchemaAsync(FlightDescriptor descriptor, FlightCallOptions? options = default) { if (descriptor is null) { @@ -307,7 +261,7 @@ public virtual async Task GetSchemaAsync(FlightCallOptions options, Flig try { - var schemaResultCall = _client.GetSchema(descriptor, options.Headers); + var schemaResultCall = _client.GetSchema(descriptor, options?.Headers); var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); return schemaResult; } @@ -317,17 +271,6 @@ public virtual async Task GetSchemaAsync(FlightCallOptions options, Flig } } - /// - /// Asynchronously retrieves schema information for a given flight descriptor. - /// - /// The descriptor of the dataset request, whether a named dataset or a command. - /// A task that represents the asynchronous operation. The task result contains the SchemaResult describing the dataset schema. - public Task GetSchemaAsync(FlightDescriptor descriptor) - { - var options = new FlightCallOptions(); - return GetSchemaAsync(options, descriptor); - } - /// /// Request a list of database schemas. /// @@ -335,14 +278,8 @@ public Task GetSchemaAsync(FlightDescriptor descriptor) /// The catalog. /// The schema filter pattern. /// The FlightInfo describing where to access the dataset. - public async Task GetDbSchemasAsync(FlightCallOptions options, string? catalog = null, - string? dbSchemaFilterPattern = null) + public async Task GetDbSchemasAsync(string? catalog = null, string? dbSchemaFilterPattern = null, FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var command = new CommandGetDbSchemas(); @@ -359,7 +296,7 @@ public async Task GetDbSchemasAsync(FlightCallOptions options, strin byte[] serializedAndPackedCommand = command.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(serializedAndPackedCommand); - var flightInfoCall = GetFlightInfoAsync(options, descriptor); + var flightInfoCall = GetFlightInfoAsync(descriptor, options); var flightInfo = await flightInfoCall.ConfigureAwait(false); return flightInfo; @@ -375,7 +312,7 @@ public async Task GetDbSchemasAsync(FlightCallOptions options, strin /// /// RPC-layer hints for this call. /// The SchemaResult describing the schema of the database schemas. - public async Task GetDbSchemasSchemaAsync(FlightCallOptions options) + public async Task GetDbSchemasSchemaAsync(FlightCallOptions? options = default) { if (options == null) { @@ -386,7 +323,7 @@ public async Task GetDbSchemasSchemaAsync(FlightCallOptions options) { var command = new CommandGetDbSchemas(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -401,19 +338,14 @@ public async Task GetDbSchemasSchemaAsync(FlightCallOptions options) /// Per-RPC options /// The flight ticket to use /// The returned RecordBatchReader - public async IAsyncEnumerable DoGetAsync(FlightCallOptions options, FlightTicket ticket) + public async IAsyncEnumerable DoGetAsync(FlightTicket ticket, FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - if (ticket == null) { throw new ArgumentNullException(nameof(ticket)); } - var call = _client.GetStream(ticket, options.Headers); + var call = _client.GetStream(ticket, options?.Headers); await foreach (var recordBatch in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { yield return recordBatch; @@ -428,7 +360,7 @@ public async IAsyncEnumerable DoGetAsync(FlightCallOptions options, /// The descriptor of the stream. /// The schema for the data to upload. /// A Task representing the asynchronous operation. The task result contains a DoPutResult struct holding a reader and a writer. - public async Task DoPutAsync(FlightCallOptions options, FlightDescriptor descriptor, Schema schema) + public async Task DoPutAsync(FlightDescriptor descriptor, Schema schema, FlightCallOptions? options = default) { if (descriptor is null) throw new ArgumentNullException(nameof(descriptor)); @@ -437,7 +369,7 @@ public async Task DoPutAsync(FlightCallOptions options, FlightDescr throw new ArgumentNullException(nameof(schema)); try { - var doPutResult = _client.StartPut(descriptor, options.Headers); + var doPutResult = _client.StartPut(descriptor, options?.Headers); var writer = doPutResult.RequestStream; var reader = doPutResult.ResponseStream; @@ -520,25 +452,13 @@ public List BuildArrowArraysFromSchema(Schema schema, int rowCount) } - /// - /// Upload data to a Flight described by the given descriptor. The caller must call Close() on the returned stream - /// once they are done writing. Uses default options. - /// - /// The descriptor of the stream. - /// The schema for the data to upload. - /// A Task representing the asynchronous operation. The task result contains a DoPutResult struct holding a reader and a writer. - public Task DoPutAsync(FlightDescriptor descriptor, Schema schema) - { - return DoPutAsync(new FlightCallOptions(), descriptor, schema); - } - /// /// Request the primary keys for a table. /// /// RPC-layer hints for this call. /// The table reference. /// The FlightInfo describing where to access the dataset. - public async Task GetPrimaryKeysAsync(FlightCallOptions options, TableRef tableRef) + public async Task GetPrimaryKeysAsync(TableRef tableRef, FlightCallOptions? options = default) { if (tableRef == null) throw new ArgumentNullException(nameof(tableRef)); @@ -551,7 +471,7 @@ public async Task GetPrimaryKeysAsync(FlightCallOptions options, Tab }; byte[] packedRequest = getPrimaryKeysRequest.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(packedRequest); - var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return flightInfo; } @@ -571,16 +491,9 @@ public async Task GetPrimaryKeysAsync(FlightCallOptions options, Tab /// True to include the schema upon return, false to not include the schema. /// The table types to include. /// The FlightInfo describing where to access the dataset. - public async Task> GetTablesAsync(FlightCallOptions options, - string? catalog = null, - string? dbSchemaFilterPattern = null, - string? tableFilterPattern = null, - bool includeSchema = false, - IEnumerable? tableTypes = null) + public async Task> + GetTablesAsync(string? catalog = null, string? dbSchemaFilterPattern = null, string? tableFilterPattern = null, bool includeSchema = false, IEnumerable? tableTypes = null, FlightCallOptions? options = default) { - if (options == null) - throw new ArgumentNullException(nameof(options)); - var command = new CommandGetTables { Catalog = catalog ?? string.Empty, @@ -595,7 +508,7 @@ public async Task> GetTablesAsync(FlightCallOptions opti } var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfoCall = GetFlightInfoAsync(options, descriptor); + var flightInfoCall = GetFlightInfoAsync(descriptor, options); var flightInfo = await flightInfoCall.ConfigureAwait(false); var flightInfos = new List { flightInfo }; @@ -609,7 +522,7 @@ public async Task> GetTablesAsync(FlightCallOptions opti /// RPC-layer hints for this call. /// The table reference. /// The FlightInfo describing where to access the dataset. - public async Task GetExportedKeysAsync(FlightCallOptions options, TableRef tableRef) + public async Task GetExportedKeysAsync(TableRef tableRef, FlightCallOptions? options = default) { if (tableRef == null) throw new ArgumentNullException(nameof(tableRef)); @@ -622,7 +535,7 @@ public async Task GetExportedKeysAsync(FlightCallOptions options, Ta }; var descriptor = FlightDescriptor.CreateCommandDescriptor(getExportedKeysRequest.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -636,7 +549,7 @@ public async Task GetExportedKeysAsync(FlightCallOptions options, Ta /// /// RPC-layer hints for this call. /// The SchemaResult describing the schema of the exported keys. - public async Task GetExportedKeysSchemaAsync(FlightCallOptions options) + public async Task GetExportedKeysSchemaAsync(FlightCallOptions? options = default) { if (options == null) { @@ -647,7 +560,7 @@ public async Task GetExportedKeysSchemaAsync(FlightCallOptions options) { var commandGetExportedKeysSchema = new CommandGetExportedKeys(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetExportedKeysSchema.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -662,7 +575,7 @@ public async Task GetExportedKeysSchemaAsync(FlightCallOptions options) /// RPC-layer hints for this call. /// The table reference. /// The FlightInfo describing where to access the dataset. - public async Task GetImportedKeysAsync(FlightCallOptions options, TableRef tableRef) + public async Task GetImportedKeysAsync(TableRef tableRef, FlightCallOptions? options = default) { if (tableRef == null) throw new ArgumentNullException(nameof(tableRef)); @@ -674,7 +587,7 @@ public async Task GetImportedKeysAsync(FlightCallOptions options, Ta Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table }; var descriptor = FlightDescriptor.CreateCommandDescriptor(getImportedKeysRequest.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -688,7 +601,7 @@ public async Task GetImportedKeysAsync(FlightCallOptions options, Ta /// /// RPC-layer hints for this call. /// The SchemaResult describing the schema of the imported keys. - public async Task GetImportedKeysSchemaAsync(FlightCallOptions options) + public async Task GetImportedKeysSchemaAsync(FlightCallOptions? options = default) { if (options == null) { @@ -699,7 +612,7 @@ public async Task GetImportedKeysSchemaAsync(FlightCallOptions options) { var commandGetImportedKeysSchema = new CommandGetImportedKeys(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetImportedKeysSchema.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -715,8 +628,7 @@ public async Task GetImportedKeysSchemaAsync(FlightCallOptions options) /// The table reference that exports the key. /// The table reference that imports the key. /// The FlightInfo describing where to access the dataset. - public async Task GetCrossReferenceAsync(FlightCallOptions options, TableRef pkTableRef, - TableRef fkTableRef) + public async Task GetCrossReferenceAsync(TableRef pkTableRef, TableRef fkTableRef, FlightCallOptions? options = default) { if (pkTableRef == null) throw new ArgumentNullException(nameof(pkTableRef)); @@ -737,7 +649,7 @@ public async Task GetCrossReferenceAsync(FlightCallOptions options, }; var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetCrossReference.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return flightInfo; } @@ -752,7 +664,7 @@ public async Task GetCrossReferenceAsync(FlightCallOptions options, /// /// RPC-layer hints for this call. /// The SchemaResult describing the schema of the cross-reference. - public async Task GetCrossReferenceSchemaAsync(FlightCallOptions options) + public async Task GetCrossReferenceSchemaAsync(FlightCallOptions? options = default) { if (options == null) { @@ -764,7 +676,7 @@ public async Task GetCrossReferenceSchemaAsync(FlightCallOptions options var commandGetCrossReferenceSchema = new CommandGetCrossReference(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetCrossReferenceSchema.PackAndSerialize()); - var schemaResultCall = GetSchemaAsync(options, descriptor); + var schemaResultCall = GetSchemaAsync(descriptor, options); var schemaResult = await schemaResultCall.ConfigureAwait(false); return schemaResult; @@ -780,7 +692,7 @@ public async Task GetCrossReferenceSchemaAsync(FlightCallOptions options /// /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. - public async Task GetTableTypesAsync(FlightCallOptions options) + public async Task GetTableTypesAsync(FlightCallOptions? options = default) { if (options == null) { @@ -791,7 +703,7 @@ public async Task GetTableTypesAsync(FlightCallOptions options) { var command = new CommandGetTableTypes(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -805,7 +717,7 @@ public async Task GetTableTypesAsync(FlightCallOptions options) /// /// RPC-layer hints for this call. /// The SchemaResult describing the schema of the table types. - public async Task GetTableTypesSchemaAsync(FlightCallOptions options) + public async Task GetTableTypesSchemaAsync(FlightCallOptions? options = default) { if (options == null) { @@ -816,7 +728,7 @@ public async Task GetTableTypesSchemaAsync(FlightCallOptions options) { var command = new CommandGetTableTypes(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -831,7 +743,7 @@ public async Task GetTableTypesSchemaAsync(FlightCallOptions options) /// RPC-layer hints for this call. /// The data type to search for as filtering. /// The FlightInfo describing where to access the dataset. - public async Task GetXdbcTypeInfoAsync(FlightCallOptions options, int dataType) + public async Task GetXdbcTypeInfoAsync(int dataType, FlightCallOptions? options = default) { if (options == null) { @@ -842,7 +754,7 @@ public async Task GetXdbcTypeInfoAsync(FlightCallOptions options, in { var command = new CommandGetXdbcTypeInfo { DataType = dataType }; var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -856,7 +768,7 @@ public async Task GetXdbcTypeInfoAsync(FlightCallOptions options, in /// /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. - public async Task GetXdbcTypeInfoAsync(FlightCallOptions options) + public async Task GetXdbcTypeInfoAsync(FlightCallOptions? options = default) { if (options == null) { @@ -867,7 +779,7 @@ public async Task GetXdbcTypeInfoAsync(FlightCallOptions options) { var command = new CommandGetXdbcTypeInfo(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -881,7 +793,7 @@ public async Task GetXdbcTypeInfoAsync(FlightCallOptions options) /// /// RPC-layer hints for this call. /// The SchemaResult describing the schema of the type info. - public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions options) + public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions? options = default) { if (options == null) { @@ -892,7 +804,7 @@ public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions options) { var command = new CommandGetXdbcTypeInfo(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var schemaResult = await GetSchemaAsync(options, descriptor).ConfigureAwait(false); + var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); return schemaResult; } catch (RpcException ex) @@ -907,7 +819,7 @@ public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions options) /// RPC-layer hints for this call. /// The SQL info required. /// The FlightInfo describing where to access the dataset. - public async Task GetSqlInfoAsync(FlightCallOptions options, List sqlInfo) + public async Task GetSqlInfoAsync(List sqlInfo, FlightCallOptions? options = default) { if (options == null) { @@ -924,7 +836,7 @@ public async Task GetSqlInfoAsync(FlightCallOptions options, List (uint)item)); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return flightInfo; } catch (RpcException ex) @@ -938,7 +850,7 @@ public async Task GetSqlInfoAsync(FlightCallOptions options, List /// RPC-layer hints for this call. /// The SchemaResult describing the schema of the SQL information. - public async Task GetSqlInfoSchemaAsync(FlightCallOptions options) + public async Task GetSqlInfoSchemaAsync(FlightCallOptions? options = default) { if (options == null) { @@ -966,16 +878,14 @@ public async Task GetSqlInfoSchemaAsync(FlightCallOptions options) /// RPC-layer hints for this call. /// The CancelFlightInfoRequest. /// A Task representing the asynchronous operation. The task result contains the CancelFlightInfoResult describing the canceled result. - public async Task CancelFlightInfoAsync(FlightCallOptions options, - FlightInfoCancelRequest request) + public async Task CancelFlightInfoAsync(FlightInfoCancelRequest request, FlightCallOptions? options = default) { - if (options == null) throw new ArgumentNullException(nameof(options)); if (request == null) throw new ArgumentNullException(nameof(request)); try { var action = new FlightAction(SqlAction.CancelFlightInfoRequest, request.PackAndSerialize()); - var call = _client.DoAction(action, options.Headers); + var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { if (Any.Parser.ParseFrom(result.Body) is Any anyResult && @@ -999,11 +909,8 @@ public async Task CancelFlightInfoAsync(FlightCallOption /// RPC-layer hints for this call. /// The FlightInfo of the query to cancel. /// A Task representing the asynchronous operation. - public async Task CancelQueryAsync(FlightCallOptions options, FlightInfo info) + public async Task CancelQueryAsync(FlightInfo info, FlightCallOptions? options = default) { - if (options == null) - throw new ArgumentNullException(nameof(options)); - if (info == null) throw new ArgumentNullException(nameof(info)); @@ -1012,7 +919,7 @@ public async Task CancelQueryAsync(FlightCallOptions opt var cancelQueryRequest = new FlightInfoCancelRequest(info); var cancelQueryAction = new FlightAction(SqlAction.CancelFlightInfoRequest, cancelQueryRequest.PackAndSerialize()); - var cancelQueryCall = _client.DoAction(cancelQueryAction, options.Headers); + var cancelQueryCall = _client.DoAction(cancelQueryAction, options?.Headers); await foreach (var result in cancelQueryCall.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -1035,18 +942,13 @@ public async Task CancelQueryAsync(FlightCallOptions opt /// /// RPC-layer hints for this call. /// A Task representing the asynchronous operation. The task result contains the Transaction object representing the new transaction. - public async Task BeginTransactionAsync(FlightCallOptions options) + public async Task BeginTransactionAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var actionBeginTransaction = new ActionBeginTransactionRequest(); var action = new FlightAction(SqlAction.BeginTransactionRequest, actionBeginTransaction.PackAndSerialize()); - var responseStream = _client.DoAction(action, options.Headers); + var responseStream = _client.DoAction(action, options?.Headers); await foreach (var result in responseStream.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { string? beginTransactionResult = result.Body.ToStringUtf8(); @@ -1068,22 +970,15 @@ public async Task BeginTransactionAsync(FlightCallOptions options) /// RPC-layer hints for this call. /// The transaction. /// A Task representing the asynchronous operation. - public AsyncServerStreamingCall CommitAsync(FlightCallOptions options, Transaction transaction) + public AsyncServerStreamingCall CommitAsync(Transaction transaction, FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - if (transaction == null) - { throw new ArgumentNullException(nameof(transaction)); - } try { var actionCommit = new FlightAction(SqlAction.CommitRequest, transaction.TransactionId); - return _client.DoAction(actionCommit, options.Headers); + return _client.DoAction(actionCommit, options?.Headers); } catch (RpcException ex) { @@ -1098,13 +993,8 @@ public AsyncServerStreamingCall CommitAsync(FlightCallOptions opti /// RPC-layer hints for this call. /// The transaction to rollback. /// A Task representing the asynchronous operation. - public AsyncServerStreamingCall RollbackAsync(FlightCallOptions options, Transaction transaction) + public AsyncServerStreamingCall RollbackAsync(Transaction transaction, FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - if (transaction == null) { throw new ArgumentNullException(nameof(transaction)); @@ -1113,7 +1003,7 @@ public AsyncServerStreamingCall RollbackAsync(FlightCallOptions op try { var actionRollback = new FlightAction(SqlAction.RollbackRequest, transaction.TransactionId); - return _client.DoAction(actionRollback, options.Headers); + return _client.DoAction(actionRollback, options?.Headers); } catch (RpcException ex) { @@ -1128,16 +1018,10 @@ public AsyncServerStreamingCall RollbackAsync(FlightCallOptions op /// The query that will be executed. /// A transaction to associate this query with. /// The created prepared statement. - public async Task PrepareAsync(FlightCallOptions options, string query, - Transaction? transaction = null) + public async Task PrepareAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { transaction ??= Transaction.NoTransaction; - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - if (string.IsNullOrEmpty(query)) { throw new ArgumentException("Query cannot be null or empty", nameof(query)); @@ -1151,7 +1035,7 @@ public async Task PrepareAsync(FlightCallOptions options, str }; var action = new FlightAction(SqlAction.CreateRequest, preparedStatementRequest.PackAndSerialize()); - var call = _client.DoAction(action, options.Headers); + var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync()) { @@ -1164,7 +1048,7 @@ public async Task PrepareAsync(FlightCallOptions options, str }; byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); - var flightInfo = await GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return new PreparedStatement(this, transaction.TransactionId.ToStringUtf8(), flightInfo.Schema, flightInfo.Schema); } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index 13c7d9b543fd0..527d7b6844eac 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -49,7 +49,7 @@ public PreparedStatement(FlightSqlClient client, string handle, Schema datasetSc /// The options used to configure the Flight call. /// A task representing the asynchronous operation, which returns the schema of the result set. /// Thrown when the schema is empty or invalid. - public async Task GetSchemaAsync(FlightCallOptions options) + public async Task GetSchemaAsync(FlightCallOptions? options = default) { EnsureStatementIsNotClosed(); @@ -60,7 +60,7 @@ public async Task GetSchemaAsync(FlightCallOptions options) PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8) }; var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var schema = await _client.GetSchemaAsync(options, descriptor).ConfigureAwait(false); + var schema = await _client.GetSchemaAsync(descriptor, options).ConfigureAwait(false); if (schema == null || !schema.FieldsList.Any()) { throw new InvalidOperationException("Schema is empty or invalid."); @@ -80,7 +80,7 @@ public async Task GetSchemaAsync(FlightCallOptions options) /// The options used to configure the Flight call. /// A task representing the asynchronous operation. /// Thrown if closing the prepared statement fails. - public async Task CloseAsync(FlightCallOptions options) + public async Task CloseAsync(FlightCallOptions? options = default) { EnsureStatementIsNotClosed(); try @@ -91,7 +91,7 @@ public async Task CloseAsync(FlightCallOptions options) }; var action = new FlightAction(SqlAction.CloseRequest, closeRequest.ToByteArray()); - await foreach (var result in _client.DoActionAsync(options, action).ConfigureAwait(false)) + await foreach (var result in _client.DoActionAsync(action, options).ConfigureAwait(false)) { // Just drain the results to complete the operation } @@ -251,8 +251,6 @@ public Status SetParameters(RecordBatch parameterBatch, CancellationToken cancel return Status.DefaultSuccess; } - - /// /// Executes the prepared statement asynchronously and retrieves the query results as . /// @@ -262,7 +260,7 @@ public Status SetParameters(RecordBatch parameterBatch, CancellationToken cancel /// A representing the asynchronous operation. The task result contains the describing the executed query results. /// Thrown if the prepared statement is closed or if there is an error during execution. /// Thrown if the operation is canceled by the . - public async Task ExecuteAsync(FlightCallOptions options, RecordBatch parameterBatch, CancellationToken cancellationToken = default) + public async Task ExecuteAsync(RecordBatch parameterBatch, FlightCallOptions? options = default, CancellationToken cancellationToken = default) { EnsureStatementIsNotClosed(); @@ -271,10 +269,10 @@ public async Task ExecuteAsync(FlightCallOptions options, RecordBatc if (parameterBatch != null) { - await BindParametersAsync(options, descriptor, parameterBatch).ConfigureAwait(false); + await BindParametersAsync(descriptor, parameterBatch, options).ConfigureAwait(false); } cancellationToken.ThrowIfCancellationRequested(); - return await _client.GetFlightInfoAsync(options, descriptor).ConfigureAwait(false); + return await _client.GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); } /// @@ -311,14 +309,14 @@ public async Task ExecuteAsync(FlightCallOptions options, RecordBatc /// Console.WriteLine($"Rows affected: {affectedRows}"); /// /// - public async Task ExecuteUpdateAsync(FlightCallOptions options, RecordBatch parameterBatch) + public async Task ExecuteUpdateAsync(RecordBatch parameterBatch, FlightCallOptions? options = default) { if (parameterBatch == null) { throw new ArgumentNullException(nameof(parameterBatch), "Parameter batch cannot be null."); } var descriptor = FlightDescriptor.CreateCommandDescriptor(_handle); - var metadata = await BindParametersAsync(options, descriptor, parameterBatch).ConfigureAwait(false); + var metadata = await BindParametersAsync(descriptor, parameterBatch, options).ConfigureAwait(false); try { @@ -355,14 +353,14 @@ private long ParseAffectedRows(ByteString metadata) /// A that represents the asynchronous operation. The task result contains the metadata from the server after binding the parameters. /// Thrown when is null. /// Thrown if the operation is canceled or if there is an error during the DoPut operation. - public async Task BindParametersAsync(FlightCallOptions options, FlightDescriptor descriptor, RecordBatch parameterBatch) + public async Task BindParametersAsync(FlightDescriptor descriptor, RecordBatch parameterBatch, FlightCallOptions? options = default) { if (parameterBatch == null) { throw new ArgumentNullException(nameof(parameterBatch), "Parameter batch cannot be null."); } - var putResult = await _client.DoPutAsync(options, descriptor, parameterBatch.Schema).ConfigureAwait(false); + var putResult = await _client.DoPutAsync(descriptor, parameterBatch.Schema, options).ConfigureAwait(false); try { diff --git a/csharp/src/Apache.Arrow/Ipc/ICompressionCodecFactory.cs b/csharp/src/Apache.Arrow/Ipc/ICompressionCodecFactory.cs index f367b15574b6e..2bb059f8c5637 100644 --- a/csharp/src/Apache.Arrow/Ipc/ICompressionCodecFactory.cs +++ b/csharp/src/Apache.Arrow/Ipc/ICompressionCodecFactory.cs @@ -43,4 +43,4 @@ ICompressionCodec CreateCodec(CompressionCodecType compressionCodecType, int? co ; #endif } -} +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 0deb4b0676f59..81f51ec8696ba 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -4,7 +4,6 @@ using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Apache.Arrow.Flight.Sql.Client; -using Apache.Arrow.Flight.Tests; using Apache.Arrow.Flight.TestWeb; using Apache.Arrow.Types; using Arrow.Flight.Protocol.Sql; @@ -38,11 +37,10 @@ public async Task CommitTransactionAsync() { // Arrange string transactionId = "sample-transaction-id"; - var options = new FlightCallOptions(); var transaction = new Transaction(transactionId); // Act - var streamCall = _flightSqlClient.CommitAsync(options, transaction); + var streamCall = _flightSqlClient.CommitAsync(transaction); var result = await streamCall.ResponseStream.ToListAsync(); // Assert @@ -54,11 +52,10 @@ public async Task CommitTransactionAsync() public async Task BeginTransactionAsync() { // Arrange - var options = new FlightCallOptions(); string expectedTransactionId = "sample-transaction-id"; // Act - var transaction = await _flightSqlClient.BeginTransactionAsync(options); + var transaction = await _flightSqlClient.BeginTransactionAsync(); // Assert Assert.NotNull(transaction); @@ -70,11 +67,10 @@ public async Task RollbackTransactionAsync() { // Arrange string transactionId = "sample-transaction-id"; - var options = new FlightCallOptions(); var transaction = new Transaction(transactionId); // Act - var streamCall = _flightSqlClient.RollbackAsync(options, transaction); + var streamCall = _flightSqlClient.RollbackAsync(transaction); var result = await streamCall.ResponseStream.ToListAsync(); // Assert @@ -91,7 +87,6 @@ public async Task PreparedStatementAsync() { // Arrange string query = "INSERT INTO users (id, name) VALUES (1, 'John Doe')"; - var options = new FlightCallOptions(); var transaction = new Transaction("sample-transaction-id"); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); @@ -101,7 +96,7 @@ public async Task PreparedStatementAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var preparedStatement = await _flightSqlClient.PrepareAsync(options, query, transaction); + var preparedStatement = await _flightSqlClient.PrepareAsync(query, transaction); // Assert Assert.NotNull(preparedStatement); @@ -114,7 +109,6 @@ public async Task ExecuteUpdateAsync() { // Arrange string query = "UPDATE test_table SET column1 = 'value' WHERE column2 = 'condition'"; - var options = new FlightCallOptions(); var transaction = new Transaction("sample-transaction-id"); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); @@ -124,7 +118,7 @@ public async Task ExecuteUpdateAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - long affectedRows = await _flightSqlClient.ExecuteUpdateAsync(options, query, transaction); + long affectedRows = await _flightSqlClient.ExecuteUpdateAsync(query, transaction); // Assert Assert.Equal(100, affectedRows); @@ -135,7 +129,6 @@ public async Task ExecuteAsync() { // Arrange string query = "SELECT * FROM test_table"; - var options = new FlightCallOptions(); var transaction = new Transaction("sample-transaction-id"); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); @@ -146,7 +139,7 @@ public async Task ExecuteAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var flightInfo = await _flightSqlClient.ExecuteAsync(options, query, transaction); + var flightInfo = await _flightSqlClient.ExecuteAsync(query, transaction); // Assert Assert.NotNull(flightInfo); @@ -164,7 +157,7 @@ public async Task GetFlightInfoAsync() _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var flightInfo = await _flightSqlClient.GetFlightInfoAsync(options, flightDescriptor); + var flightInfo = await _flightSqlClient.GetFlightInfoAsync(flightDescriptor); // Assert Assert.NotNull(flightInfo); @@ -175,7 +168,6 @@ public async Task GetExecuteSchemaAsync() { // Arrange string query = "SELECT * FROM test_table"; - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, @@ -184,7 +176,7 @@ public async Task GetExecuteSchemaAsync() // Act Schema resultSchema = - await _flightSqlClient.GetExecuteSchemaAsync(options, query, new Transaction("sample-transaction-id")); + await _flightSqlClient.GetExecuteSchemaAsync(query, new Transaction("sample-transaction-id")); // Assert Assert.NotNull(resultSchema); @@ -216,7 +208,6 @@ public async Task GetCatalogsAsync() public async Task GetSchemaAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, @@ -224,7 +215,7 @@ public async Task GetSchemaAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var result = await _flightSqlClient.GetSchemaAsync(options, flightDescriptor); + var result = await _flightSqlClient.GetSchemaAsync(flightDescriptor); // Assert Assert.NotNull(result); @@ -236,7 +227,6 @@ public async Task GetSchemaAsync() public async Task GetDbSchemasAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, @@ -246,7 +236,7 @@ public async Task GetDbSchemasAsync() string dbSchemaFilterPattern = "test-schema-pattern"; // Act - var result = await _flightSqlClient.GetDbSchemasAsync(options, catalog, dbSchemaFilterPattern); + var result = await _flightSqlClient.GetDbSchemasAsync(catalog, dbSchemaFilterPattern); // Assert Assert.NotNull(result); @@ -282,7 +272,6 @@ public async Task GetDbSchemasAsync() public async Task GetPrimaryKeysAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var tableRef = new TableRef { Catalog = "test-catalog", Table = "test-table", DbSchema = "test-schema" }; @@ -291,7 +280,7 @@ public async Task GetPrimaryKeysAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var result = await _flightSqlClient.GetPrimaryKeysAsync(options, tableRef); + var result = await _flightSqlClient.GetPrimaryKeysAsync(tableRef); // Assert Assert.NotNull(result); @@ -326,7 +315,6 @@ public async Task GetPrimaryKeysAsync() public async Task GetTablesAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, @@ -340,7 +328,7 @@ public async Task GetTablesAsync() var tableTypes = new List { "BASE TABLE" }; // Act - var result = await _flightSqlClient.GetTablesAsync(options, catalog, dbSchemaFilterPattern, tableFilterPattern, + var result = await _flightSqlClient.GetTablesAsync(catalog, dbSchemaFilterPattern, tableFilterPattern, includeSchema, tableTypes); // Assert @@ -369,7 +357,6 @@ public async Task GetTablesAsync() public async Task GetCatalogsSchemaAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, @@ -377,7 +364,7 @@ public async Task GetCatalogsSchemaAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var schema = await _flightSqlClient.GetCatalogsSchemaAsync(options); + var schema = await _flightSqlClient.GetCatalogsSchemaAsync(); // Assert Assert.NotNull(schema); @@ -422,7 +409,6 @@ public async Task GetDbSchemasSchemaAsync() public async Task DoPutAsync() { // Arrange - var options = new FlightCallOptions(); var schema = new Schema .Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) @@ -456,7 +442,7 @@ public async Task DoPutAsync() var expectedBatch = _testUtils.CreateTestBatch(0, 100); // Act - var result = await _flightSqlClient.DoPutAsync(options, flightDescriptor, expectedBatch.Schema); + var result = await _flightSqlClient.DoPutAsync(flightDescriptor, expectedBatch.Schema); // Assert Assert.NotNull(result); @@ -466,7 +452,6 @@ public async Task DoPutAsync() public async Task GetExportedKeysAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var tableRef = new TableRef { Catalog = "test-catalog", Table = "test-table", DbSchema = "test-schema" }; @@ -475,7 +460,7 @@ public async Task GetExportedKeysAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var flightInfo = await _flightSqlClient.GetExportedKeysAsync(options, tableRef); + var flightInfo = await _flightSqlClient.GetExportedKeysAsync(tableRef); // Assert Assert.NotNull(flightInfo); @@ -506,7 +491,6 @@ public async Task GetExportedKeysSchemaAsync() public async Task GetImportedKeysAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, @@ -514,8 +498,7 @@ public async Task GetImportedKeysAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var flightInfo = await _flightSqlClient.GetImportedKeysAsync(options, - new TableRef { Catalog = "test-catalog", Table = "test-table", DbSchema = "test-schema" }); + var flightInfo = await _flightSqlClient.GetImportedKeysAsync(new TableRef { Catalog = "test-catalog", Table = "test-table", DbSchema = "test-schema" }); // Assert Assert.NotNull(flightInfo); @@ -554,17 +537,15 @@ public async Task GetImportedKeysSchemaAsync() public async Task GetCrossReferenceAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, - _testWebFactory.GetAddress()); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); _flightStore.Flights.Add(flightDescriptor, flightHolder); var pkTableRef = new TableRef { Catalog = "PKCatalog", DbSchema = "PKSchema", Table = "PKTable" }; var fkTableRef = new TableRef { Catalog = "FKCatalog", DbSchema = "FKSchema", Table = "FKTable" }; // Act - var flightInfo = await _flightSqlClient.GetCrossReferenceAsync(options, pkTableRef, fkTableRef); + var flightInfo = await _flightSqlClient.GetCrossReferenceAsync(pkTableRef, fkTableRef); // Assert Assert.NotNull(flightInfo); @@ -722,7 +703,6 @@ public async Task GetSqlInfoSchemaAsync() public async Task CancelFlightInfoAsync() { // Arrange - var options = new FlightCallOptions(); var schema = new Schema .Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) @@ -732,7 +712,7 @@ public async Task CancelFlightInfoAsync() var cancelRequest = new FlightInfoCancelRequest(flightInfo); // Act - var cancelResult = await _flightSqlClient.CancelFlightInfoAsync(options, cancelRequest); + var cancelResult = await _flightSqlClient.CancelFlightInfoAsync(cancelRequest); // Assert Assert.Equal(1, cancelResult.GetCancelStatus()); @@ -742,7 +722,6 @@ public async Task CancelFlightInfoAsync() public async Task CancelQueryAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var schema = new Schema .Builder() @@ -751,11 +730,10 @@ public async Task CancelQueryAsync() var flightInfo = new FlightInfo(schema, flightDescriptor, new List(), 0, 0); // Adding the flight info to the flight store for testing - _flightStore.Flights.Add(flightDescriptor, - new FlightHolder(flightDescriptor, schema, _testWebFactory.GetAddress())); + _flightStore.Flights.Add(flightDescriptor, new FlightHolder(flightDescriptor, schema, _testWebFactory.GetAddress())); // Act - var cancelStatus = await _flightSqlClient.CancelQueryAsync(options, flightInfo); + var cancelStatus = await _flightSqlClient.CancelQueryAsync(flightInfo); // Assert Assert.Equal(1, cancelStatus.GetCancelStatus()); diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs index 6df48886e1721..dce21b736cddc 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs @@ -66,7 +66,7 @@ public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenValidInputsAreProvided { var validRecordBatch = CreateRecordBatch(_schema, new[] { 1, 2, 3 }); var result = _preparedStatement.SetParameters(validRecordBatch); - var flightInfo = await _preparedStatement.ExecuteAsync(new FlightCallOptions(), validRecordBatch); + var flightInfo = await _preparedStatement.ExecuteAsync(validRecordBatch); Assert.NotNull(flightInfo); Assert.IsType(flightInfo); @@ -76,14 +76,14 @@ public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenValidInputsAreProvided [Fact] public async Task ExecuteUpdateAsync_ShouldReturnAffectedRows_WhenParametersAreSet() { - var affectedRows = await _preparedStatement.ExecuteUpdateAsync(new FlightCallOptions(), _parameterBatch); + var affectedRows = await _preparedStatement.ExecuteUpdateAsync(_parameterBatch); Assert.True(affectedRows > 0, "Expected affected rows to be greater than 0."); } [Fact] public async Task BindParametersAsync_ShouldReturnMetadata_WhenValidInputsAreProvided() { - var metadata = await _preparedStatement.BindParametersAsync(new FlightCallOptions(), _flightDescriptor, _parameterBatch); + var metadata = await _preparedStatement.BindParametersAsync(_flightDescriptor, _parameterBatch); Assert.NotNull(metadata); Assert.True(metadata.Length > 0, "Metadata should have a length greater than 0 when valid."); } From 0e6b3c52cfec0fbb4521a9b2c2e11185985bf050 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 28 Oct 2024 18:40:18 +0200 Subject: [PATCH 39/58] refactor: add code line to match source --- csharp/src/Apache.Arrow/Ipc/ICompressionCodecFactory.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/src/Apache.Arrow/Ipc/ICompressionCodecFactory.cs b/csharp/src/Apache.Arrow/Ipc/ICompressionCodecFactory.cs index 2bb059f8c5637..f367b15574b6e 100644 --- a/csharp/src/Apache.Arrow/Ipc/ICompressionCodecFactory.cs +++ b/csharp/src/Apache.Arrow/Ipc/ICompressionCodecFactory.cs @@ -43,4 +43,4 @@ ICompressionCodec CreateCodec(CompressionCodecType compressionCodecType, int? co ; #endif } -} \ No newline at end of file +} From fbf0b927ce5628b969a06bea2e7d45e385c02542 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 30 Oct 2024 08:14:39 +0200 Subject: [PATCH 40/58] refactor: Execute method removed ParameterBatch argument --- .../Client/FlightSqlClient.cs | 40 +++++++++---------- .../PreparedStatement.cs | 23 +++++------ .../FlightSqlPreparedStatementTests.cs | 2 +- 3 files changed, 32 insertions(+), 33 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 6ef9fb324d5bf..cd4eee8b4f3ef 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -22,9 +22,9 @@ public FlightSqlClient(FlightClient client) /// /// Execute a SQL query on the server. /// - /// RPC-layer hints for this call. /// The UTF8-encoded SQL query to be executed. /// A transaction to associate this query with. + /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. public async Task ExecuteAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { @@ -67,9 +67,9 @@ public async Task ExecuteAsync(string query, Transaction? transactio /// /// Executes an update query on the server. /// - /// RPC-layer hints for this call. /// The UTF8-encoded SQL query to be executed. /// A transaction to associate this query with. Defaults to no transaction if not provided. + /// RPC-layer hints for this call. /// The number of rows affected by the operation. public async Task ExecuteUpdateAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { @@ -117,8 +117,8 @@ public async Task ExecuteUpdateAsync(string query, Transaction? transactio /// /// Asynchronously retrieves flight information for a given flight descriptor. /// - /// RPC-layer hints for this call. /// The descriptor of the dataset request, whether a named dataset or a command. + /// RPC-layer hints for this call. /// A task that represents the asynchronous operation. The task result contains the FlightInfo describing where to access the dataset. public async Task GetFlightInfoAsync(FlightDescriptor descriptor, FlightCallOptions? options = default) { @@ -142,8 +142,8 @@ public async Task GetFlightInfoAsync(FlightDescriptor descriptor, Fl /// /// Perform the indicated action, returning an iterator to the stream of results, if any. /// - /// Per-RPC options /// The action to be performed + /// Per-RPC options /// An async enumerable of results public async IAsyncEnumerable DoActionAsync(FlightAction action, FlightCallOptions? options = default) { @@ -161,9 +161,9 @@ public async IAsyncEnumerable DoActionAsync(FlightAction action, F /// /// Get the result set schema from the server for the given query. /// - /// Per-RPC options /// The UTF8-encoded SQL query /// A transaction to associate this query with + /// Per-RPC options /// The SchemaResult describing the schema of the result set public async Task GetExecuteSchemaAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { @@ -249,8 +249,8 @@ public async Task GetCatalogsSchemaAsync(FlightCallOptions? options = de /// /// Asynchronously retrieves schema information for a given flight descriptor. /// - /// RPC-layer hints for this call. /// The descriptor of the dataset request, whether a named dataset or a command. + /// RPC-layer hints for this call. /// A task that represents the asynchronous operation. The task result contains the SchemaResult describing the dataset schema. public virtual async Task GetSchemaAsync(FlightDescriptor descriptor, FlightCallOptions? options = default) { @@ -335,8 +335,8 @@ public async Task GetDbSchemasSchemaAsync(FlightCallOptions? options = d /// /// Given a flight ticket and schema, request to be sent the stream. Returns record batch stream reader. /// - /// Per-RPC options /// The flight ticket to use + /// Per-RPC options /// The returned RecordBatchReader public async IAsyncEnumerable DoGetAsync(FlightTicket ticket, FlightCallOptions? options = default) { @@ -356,9 +356,9 @@ public async IAsyncEnumerable DoGetAsync(FlightTicket ticket, Fligh /// Upload data to a Flight described by the given descriptor. The caller must call Close() on the returned stream /// once they are done writing. /// - /// RPC-layer hints for this call. /// The descriptor of the stream. /// The schema for the data to upload. + /// RPC-layer hints for this call. /// A Task representing the asynchronous operation. The task result contains a DoPutResult struct holding a reader and a writer. public async Task DoPutAsync(FlightDescriptor descriptor, Schema schema, FlightCallOptions? options = default) { @@ -455,8 +455,8 @@ public List BuildArrowArraysFromSchema(Schema schema, int rowCount) /// /// Request the primary keys for a table. /// - /// RPC-layer hints for this call. /// The table reference. + /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. public async Task GetPrimaryKeysAsync(TableRef tableRef, FlightCallOptions? options = default) { @@ -484,12 +484,12 @@ public async Task GetPrimaryKeysAsync(TableRef tableRef, FlightCallO /// /// Request a list of tables. /// - /// RPC-layer hints for this call. /// The catalog. /// The schema filter pattern. /// The table filter pattern. /// True to include the schema upon return, false to not include the schema. /// The table types to include. + /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. public async Task> GetTablesAsync(string? catalog = null, string? dbSchemaFilterPattern = null, string? tableFilterPattern = null, bool includeSchema = false, IEnumerable? tableTypes = null, FlightCallOptions? options = default) @@ -519,8 +519,8 @@ public async Task> /// /// Retrieves a description about the foreign key columns that reference the primary key columns of the given table. /// - /// RPC-layer hints for this call. /// The table reference. + /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. public async Task GetExportedKeysAsync(TableRef tableRef, FlightCallOptions? options = default) { @@ -572,8 +572,8 @@ public async Task GetExportedKeysSchemaAsync(FlightCallOptions? options /// /// Retrieves the foreign key columns for the given table. /// - /// RPC-layer hints for this call. /// The table reference. + /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. public async Task GetImportedKeysAsync(TableRef tableRef, FlightCallOptions? options = default) { @@ -624,9 +624,9 @@ public async Task GetImportedKeysSchemaAsync(FlightCallOptions? options /// /// Retrieves a description of the foreign key columns in the given foreign key table that reference the primary key or the columns representing a unique constraint of the parent table. /// - /// RPC-layer hints for this call. /// The table reference that exports the key. /// The table reference that imports the key. + /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. public async Task GetCrossReferenceAsync(TableRef pkTableRef, TableRef fkTableRef, FlightCallOptions? options = default) { @@ -740,8 +740,8 @@ public async Task GetTableTypesSchemaAsync(FlightCallOptions? options = /// /// Request the information about all the data types supported with filtering by data type. /// - /// RPC-layer hints for this call. /// The data type to search for as filtering. + /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. public async Task GetXdbcTypeInfoAsync(int dataType, FlightCallOptions? options = default) { @@ -816,8 +816,8 @@ public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions? options /// /// Request a list of SQL information. /// - /// RPC-layer hints for this call. /// The SQL info required. + /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. public async Task GetSqlInfoAsync(List sqlInfo, FlightCallOptions? options = default) { @@ -875,8 +875,8 @@ public async Task GetSqlInfoSchemaAsync(FlightCallOptions? options = def /// /// Explicitly cancel a FlightInfo. /// - /// RPC-layer hints for this call. /// The CancelFlightInfoRequest. + /// RPC-layer hints for this call. /// A Task representing the asynchronous operation. The task result contains the CancelFlightInfoResult describing the canceled result. public async Task CancelFlightInfoAsync(FlightInfoCancelRequest request, FlightCallOptions? options = default) { @@ -906,8 +906,8 @@ public async Task CancelFlightInfoAsync(FlightInfoCancel /// /// Explicitly cancel a query. /// - /// RPC-layer hints for this call. /// The FlightInfo of the query to cancel. + /// RPC-layer hints for this call. /// A Task representing the asynchronous operation. public async Task CancelQueryAsync(FlightInfo info, FlightCallOptions? options = default) { @@ -967,8 +967,8 @@ public async Task BeginTransactionAsync(FlightCallOptions? options /// Commit a transaction. /// After this, the transaction and all associated savepoints will be invalidated. /// - /// RPC-layer hints for this call. /// The transaction. + /// RPC-layer hints for this call. /// A Task representing the asynchronous operation. public AsyncServerStreamingCall CommitAsync(Transaction transaction, FlightCallOptions? options = default) { @@ -990,8 +990,8 @@ public AsyncServerStreamingCall CommitAsync(Transaction transactio /// Rollback a transaction. /// After this, the transaction and all associated savepoints will be invalidated. /// - /// RPC-layer hints for this call. /// The transaction to rollback. + /// RPC-layer hints for this call. /// A Task representing the asynchronous operation. public AsyncServerStreamingCall RollbackAsync(Transaction transaction, FlightCallOptions? options = default) { @@ -1014,9 +1014,9 @@ public AsyncServerStreamingCall RollbackAsync(Transaction transact /// /// Create a prepared statement object. /// - /// RPC-layer hints for this call. /// The query that will be executed. /// A transaction to associate this query with. + /// RPC-layer hints for this call. /// The created prepared statement. public async Task PrepareAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index 527d7b6844eac..d6bb1056bbe26 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -21,6 +21,7 @@ public class PreparedStatement : IDisposable private readonly string _handle; private Schema _datasetSchema; private Schema _parameterSchema; + private RecordBatch? _recordsBatch; private bool _isClosed; public bool IsClosed => _isClosed; public string Handle => _handle; @@ -201,9 +202,8 @@ public Status SetParameters(RecordBatch parameterBatch, CancellationToken cancel { EnsureStatementIsNotClosed(); - if (parameterBatch == null) - throw new ArgumentNullException(nameof(parameterBatch)); - + _recordsBatch = parameterBatch ?? throw new ArgumentNullException(nameof(parameterBatch)); + var channel = Channel.CreateUnbounded(); var task = Task.Run(async () => { @@ -211,10 +211,10 @@ public Status SetParameters(RecordBatch parameterBatch, CancellationToken cancel { using (var memoryStream = new MemoryStream()) { - var writer = new ArrowStreamWriter(memoryStream, parameterBatch.Schema); + var writer = new ArrowStreamWriter(memoryStream, _recordsBatch.Schema); cancellationToken.ThrowIfCancellationRequested(); - await writer.WriteRecordBatchAsync(parameterBatch, cancellationToken).ConfigureAwait(false); + await writer.WriteRecordBatchAsync(_recordsBatch, cancellationToken).ConfigureAwait(false); await writer.WriteEndAsync(cancellationToken).ConfigureAwait(false); memoryStream.Position = 0; @@ -254,22 +254,21 @@ public Status SetParameters(RecordBatch parameterBatch, CancellationToken cancel /// /// Executes the prepared statement asynchronously and retrieves the query results as . /// - /// The for the operation, which may include timeouts, headers, and other options for the call. - /// Optional containing parameters to bind before executing the statement. /// Optional to observe while waiting for the task to complete. The task will be canceled if the token is canceled. + /// Optional The for the operation, which may include timeouts, headers, and other options for the call. /// A representing the asynchronous operation. The task result contains the describing the executed query results. /// Thrown if the prepared statement is closed or if there is an error during execution. /// Thrown if the operation is canceled by the . - public async Task ExecuteAsync(RecordBatch parameterBatch, FlightCallOptions? options = default, CancellationToken cancellationToken = default) + public async Task ExecuteAsync(CancellationToken cancellationToken = default, FlightCallOptions? options = default) { EnsureStatementIsNotClosed(); var descriptor = FlightDescriptor.CreateCommandDescriptor(_handle); cancellationToken.ThrowIfCancellationRequested(); - if (parameterBatch != null) + if (_recordsBatch != null) { - await BindParametersAsync(descriptor, parameterBatch, options).ConfigureAwait(false); + await BindParametersAsync(descriptor, _recordsBatch, options).ConfigureAwait(false); } cancellationToken.ThrowIfCancellationRequested(); return await _client.GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); @@ -285,11 +284,11 @@ public async Task ExecuteAsync(RecordBatch parameterBatch, FlightCal /// /// This operation is asynchronous and can be canceled via the provided . /// - /// The for this execution, containing headers and other options. /// /// A containing the parameters to be bound to the update statement. /// This batch should match the schema expected by the prepared statement. /// + /// The for this execution, containing headers and other options. /// /// A representing the asynchronous operation. /// The task result contains the number of rows affected by the update. @@ -347,9 +346,9 @@ private long ParseAffectedRows(ByteString metadata) /// /// Binds parameters to the prepared statement by streaming the given RecordBatch to the server asynchronously. /// - /// The for the operation, which may include timeouts, headers, and other options for the call. /// The that identifies the statement or command being executed. /// The containing the parameters to bind to the prepared statement. + /// The for the operation, which may include timeouts, headers, and other options for the call. /// A that represents the asynchronous operation. The task result contains the metadata from the server after binding the parameters. /// Thrown when is null. /// Thrown if the operation is canceled or if there is an error during the DoPut operation. diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs index dce21b736cddc..b99a5ccb89023 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs @@ -66,7 +66,7 @@ public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenValidInputsAreProvided { var validRecordBatch = CreateRecordBatch(_schema, new[] { 1, 2, 3 }); var result = _preparedStatement.SetParameters(validRecordBatch); - var flightInfo = await _preparedStatement.ExecuteAsync(validRecordBatch); + var flightInfo = await _preparedStatement.ExecuteAsync(); Assert.NotNull(flightInfo); Assert.IsType(flightInfo); From f2de0f14ade3485b4b6fa0262cac5444f819de91 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 30 Oct 2024 08:54:12 +0200 Subject: [PATCH 41/58] fix: SqlList default argument --- .../src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs | 8 ++------ csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index cd4eee8b4f3ef..e26312ad6206b 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -819,18 +819,14 @@ public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions? options /// The SQL info required. /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. - public async Task GetSqlInfoAsync(List sqlInfo, FlightCallOptions? options = default) + public async Task GetSqlInfoAsync(List? sqlInfo = default, FlightCallOptions? options = default) { if (options == null) { throw new ArgumentNullException(nameof(options)); } - if (sqlInfo == null || sqlInfo.Count == 0) - { - throw new ArgumentException("SQL info list cannot be null or empty", nameof(sqlInfo)); - } - + sqlInfo ??= new List(); try { var command = new CommandGetSqlInfo(); diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index d6bb1056bbe26..39b6fcb2576f1 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -91,7 +91,7 @@ public async Task CloseAsync(FlightCallOptions? options = default) PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8) }; - var action = new FlightAction(SqlAction.CloseRequest, closeRequest.ToByteArray()); + var action = new FlightAction(SqlAction.CloseRequest, closeRequest.PackAndSerialize()); await foreach (var result in _client.DoActionAsync(action, options).ConfigureAwait(false)) { // Just drain the results to complete the operation From efb67902b633d4a2981590ceac82047c79fe0dae Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 30 Oct 2024 09:58:27 +0200 Subject: [PATCH 42/58] refactor: Execute - description: correction aligns the C# code with the expected C++ behavior, focusing on efficient query execution without unnecessary complexity from prepared statement management. --- .../Client/FlightSqlClient.cs | 49 ++++++++-------- .../FlightSqlClientTests.cs | 57 ++++++++++++++++++- 2 files changed, 80 insertions(+), 26 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index e26312ad6206b..92cf405bc58a9 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -37,26 +37,34 @@ public async Task ExecuteAsync(string query, Transaction? transactio try { - var prepareStatementRequest = - new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; - var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); - var call = _client.DoAction(action, options?.Headers); - - await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + var commandQuery = new CommandStatementQuery { Query = query }; + + if (transaction.IsValid()) { - var preparedStatementResponse = - FlightSqlUtils.ParseAndUnpack(result.Body); - var commandSqlCall = new CommandPreparedStatementQuery - { - PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle - }; - - byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); - var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); - return await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + commandQuery.TransactionId = transaction.TransactionId; } - - throw new InvalidOperationException("No results returned from the query."); + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandQuery.PackAndSerialize()); + return await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + // var prepareStatementRequest = + // new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; + // var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); + // var call = _client.DoAction(action, options?.Headers); + // + // await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + // { + // var preparedStatementResponse = + // FlightSqlUtils.ParseAndUnpack(result.Body); + // var commandSqlCall = new CommandPreparedStatementQuery + // { + // PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle + // }; + // + // byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); + // var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); + // return await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); + // } + // + // throw new InvalidOperationException("No results returned from the query."); } catch (RpcException ex) { @@ -694,11 +702,6 @@ public async Task GetCrossReferenceSchemaAsync(FlightCallOptions? option /// The FlightInfo describing where to access the dataset. public async Task GetTableTypesAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var command = new CommandGetTableTypes(); diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 81f51ec8696ba..4cdf15861bf03 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -146,11 +146,63 @@ public async Task ExecuteAsync() Assert.Single(flightInfo.Endpoints); } + [Fact] + public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenValidInputsAreProvided() + { + // Arrange + string query = "SELECT * FROM test_table"; + var transaction = new Transaction("sample-transaction-id"); + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.ExecuteAsync(query, transaction); + + // Assert + Assert.NotNull(flightInfo); + Assert.IsType(flightInfo); + } + + [Fact] + public async Task ExecuteAsync_ShouldThrowArgumentException_WhenQueryIsEmpty() + { + // Arrange + string emptyQuery = string.Empty; + var transaction = new Transaction("sample-transaction-id"); + + // Act & Assert + await Assert.ThrowsAsync(async () => + await _flightSqlClient.ExecuteAsync(emptyQuery, transaction)); + } + + [Fact] + public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenTransactionIsNoTransaction() + { + // Arrange + string query = "SELECT * FROM test_table"; + var transaction = Transaction.NoTransaction; + var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + var recordBatch = _testUtils.CreateTestBatch(0, 100); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, + _testWebFactory.GetAddress()); + _flightStore.Flights.Add(flightDescriptor, flightHolder); + + // Act + var flightInfo = await _flightSqlClient.ExecuteAsync(query, transaction); + + // Assert + Assert.NotNull(flightInfo); + Assert.IsType(flightInfo); + } + + [Fact] public async Task GetFlightInfoAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, @@ -577,7 +629,6 @@ public async Task GetCrossReferenceSchemaAsync() public async Task GetTableTypesAsync() { // Arrange - var options = new FlightCallOptions(); var expectedSchema = new Schema .Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) @@ -589,7 +640,7 @@ public async Task GetTableTypesAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var flightInfo = await _flightSqlClient.GetTableTypesAsync(options); + var flightInfo = await _flightSqlClient.GetTableTypesAsync(); var actualSchema = flightInfo.Schema; // Assert From 012921bb10654e01403aec29b0e3a842f1ee1dae Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 30 Oct 2024 10:52:25 +0200 Subject: [PATCH 43/58] refactor: adding disposing after invoking the client --- .../Client/FlightSqlClient.cs | 51 +++++-------------- .../FlightSqlClientTests.cs | 3 +- 2 files changed, 13 insertions(+), 41 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 92cf405bc58a9..6f6317ad80959 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -13,7 +13,7 @@ namespace Apache.Arrow.Flight.Sql.Client; public class FlightSqlClient { private readonly FlightClient _client; - + public FlightSqlClient(FlightClient client) { _client = client; @@ -45,26 +45,6 @@ public async Task ExecuteAsync(string query, Transaction? transactio } var descriptor = FlightDescriptor.CreateCommandDescriptor(commandQuery.PackAndSerialize()); return await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); - // var prepareStatementRequest = - // new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; - // var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); - // var call = _client.DoAction(action, options?.Headers); - // - // await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) - // { - // var preparedStatementResponse = - // FlightSqlUtils.ParseAndUnpack(result.Body); - // var commandSqlCall = new CommandPreparedStatementQuery - // { - // PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle - // }; - // - // byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); - // var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); - // return await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); - // } - // - // throw new InvalidOperationException("No results returned from the query."); } catch (RpcException ex) { @@ -137,7 +117,7 @@ public async Task GetFlightInfoAsync(FlightDescriptor descriptor, Fl try { - var flightInfoCall = _client.GetInfo(descriptor, options?.Headers); + using var flightInfoCall = _client.GetInfo(descriptor, options?.Headers); var flightInfo = await flightInfoCall.ResponseAsync.ConfigureAwait(false); return flightInfo; } @@ -158,7 +138,7 @@ public async IAsyncEnumerable DoActionAsync(FlightAction action, F if (action is null) throw new ArgumentNullException(nameof(action)); - var call = _client.DoAction(action, options?.Headers); + using var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -186,7 +166,7 @@ public async Task GetExecuteSchemaAsync(string query, Transaction? trans var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); - var call = _client.DoAction(action, options?.Headers); + using var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -269,7 +249,7 @@ public virtual async Task GetSchemaAsync(FlightDescriptor descriptor, Fl try { - var schemaResultCall = _client.GetSchema(descriptor, options?.Headers); + using var schemaResultCall = _client.GetSchema(descriptor, options?.Headers); var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); return schemaResult; } @@ -353,7 +333,7 @@ public async IAsyncEnumerable DoGetAsync(FlightTicket ticket, Fligh throw new ArgumentNullException(nameof(ticket)); } - var call = _client.GetStream(ticket, options?.Headers); + using var call = _client.GetStream(ticket, options?.Headers); await foreach (var recordBatch in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { yield return recordBatch; @@ -377,7 +357,7 @@ public async Task DoPutAsync(FlightDescriptor descriptor, Schema sc throw new ArgumentNullException(nameof(schema)); try { - var doPutResult = _client.StartPut(descriptor, options?.Headers); + using var doPutResult = _client.StartPut(descriptor, options?.Headers); var writer = doPutResult.RequestStream; var reader = doPutResult.ResponseStream; @@ -674,11 +654,6 @@ public async Task GetCrossReferenceAsync(TableRef pkTableRef, TableR /// The SchemaResult describing the schema of the cross-reference. public async Task GetCrossReferenceSchemaAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var commandGetCrossReferenceSchema = new CommandGetCrossReference(); @@ -860,7 +835,7 @@ public async Task GetSqlInfoSchemaAsync(FlightCallOptions? options = def { var command = new CommandGetSqlInfo(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var schemaResultCall = _client.GetSchema(descriptor, options.Headers); + using var schemaResultCall = _client.GetSchema(descriptor, options.Headers); var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); return schemaResult; @@ -884,7 +859,7 @@ public async Task CancelFlightInfoAsync(FlightInfoCancel try { var action = new FlightAction(SqlAction.CancelFlightInfoRequest, request.PackAndSerialize()); - var call = _client.DoAction(action, options?.Headers); + using var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { if (Any.Parser.ParseFrom(result.Body) is Any anyResult && @@ -918,7 +893,7 @@ public async Task CancelQueryAsync(FlightInfo info, Flig var cancelQueryRequest = new FlightInfoCancelRequest(info); var cancelQueryAction = new FlightAction(SqlAction.CancelFlightInfoRequest, cancelQueryRequest.PackAndSerialize()); - var cancelQueryCall = _client.DoAction(cancelQueryAction, options?.Headers); + using var cancelQueryCall = _client.DoAction(cancelQueryAction, options?.Headers); await foreach (var result in cancelQueryCall.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -947,7 +922,7 @@ public async Task BeginTransactionAsync(FlightCallOptions? options { var actionBeginTransaction = new ActionBeginTransactionRequest(); var action = new FlightAction(SqlAction.BeginTransactionRequest, actionBeginTransaction.PackAndSerialize()); - var responseStream = _client.DoAction(action, options?.Headers); + using var responseStream = _client.DoAction(action, options?.Headers); await foreach (var result in responseStream.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { string? beginTransactionResult = result.Body.ToStringUtf8(); @@ -1034,8 +1009,7 @@ public async Task PrepareAsync(string query, Transaction? tra }; var action = new FlightAction(SqlAction.CreateRequest, preparedStatementRequest.PackAndSerialize()); - var call = _client.DoAction(action, options?.Headers); - + using var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync()) { var preparedStatementResponse = @@ -1050,7 +1024,6 @@ public async Task PrepareAsync(string query, Transaction? tra var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return new PreparedStatement(this, transaction.TransactionId.ToStringUtf8(), flightInfo.Schema, flightInfo.Schema); } - throw new NullReferenceException($"{nameof(PreparedStatement)} was not able to be created"); } catch (RpcException ex) diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 4cdf15861bf03..063b3e633873f 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -609,7 +609,6 @@ public async Task GetCrossReferenceAsync() public async Task GetCrossReferenceSchemaAsync() { // Arrange - var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); var recordBatch = _testUtils.CreateTestBatch(0, 100); var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, @@ -617,7 +616,7 @@ public async Task GetCrossReferenceSchemaAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var schema = await _flightSqlClient.GetCrossReferenceSchemaAsync(options); + var schema = await _flightSqlClient.GetCrossReferenceSchemaAsync(); // Assert var expectedSchema = recordBatch.Schema; From 642619b42f1ad14f08e4d235f1910887af34615f Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 30 Oct 2024 15:26:06 +0200 Subject: [PATCH 44/58] chore: missing files updates --- .../Apache.Arrow.Flight.TestWeb.csproj | 1 - csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs | 2 +- csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj b/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj index 7a767b10d0396..2282c11c1ed39 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj @@ -11,7 +11,6 @@ - diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs index 43308ac4f8edd..b79edc4ae5466 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/FlightHolder.cs @@ -30,7 +30,7 @@ public class FlightHolder //Not thread safe, but only used in tests private readonly List _recordBatches = new List(); - + public FlightHolder(FlightDescriptor flightDescriptor, Schema schema, string location) { _flightDescriptor = flightDescriptor; diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs index c643646de0023..46c5460912d8c 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightServer.cs @@ -53,7 +53,7 @@ public override async Task DoGet(FlightTicket ticket, FlightServerRecordBatchStr { var batches = flightHolder.GetRecordBatches(); - + foreach(var batch in batches) { await responseStream.WriteAsync(batch.RecordBatch, batch.Metadata); From 2568a9a25cf284dec76f2fe203831ee4c437fd32 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 6 Nov 2024 10:11:15 +0200 Subject: [PATCH 45/58] refactor: removing extensions internal class from single file into its own --- .../Client/FlightSqlClient.cs | 25 +++++-------------- .../FlightExtensions.cs | 10 ++++++++ .../Apache.Arrow.Flight.TestWeb.csproj | 1 + 3 files changed, 17 insertions(+), 19 deletions(-) create mode 100644 csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 6f6317ad80959..617b79e3b42d5 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -357,7 +357,7 @@ public async Task DoPutAsync(FlightDescriptor descriptor, Schema sc throw new ArgumentNullException(nameof(schema)); try { - using var doPutResult = _client.StartPut(descriptor, options?.Headers); + var doPutResult = _client.StartPut(descriptor, options?.Headers); var writer = doPutResult.RequestStream; var reader = doPutResult.ResponseStream; @@ -835,7 +835,7 @@ public async Task GetSqlInfoSchemaAsync(FlightCallOptions? options = def { var command = new CommandGetSqlInfo(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - using var schemaResultCall = _client.GetSchema(descriptor, options.Headers); + var schemaResultCall = _client.GetSchema(descriptor, options.Headers); var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); return schemaResult; @@ -859,7 +859,7 @@ public async Task CancelFlightInfoAsync(FlightInfoCancel try { var action = new FlightAction(SqlAction.CancelFlightInfoRequest, request.PackAndSerialize()); - using var call = _client.DoAction(action, options?.Headers); + var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { if (Any.Parser.ParseFrom(result.Body) is Any anyResult && @@ -893,7 +893,7 @@ public async Task CancelQueryAsync(FlightInfo info, Flig var cancelQueryRequest = new FlightInfoCancelRequest(info); var cancelQueryAction = new FlightAction(SqlAction.CancelFlightInfoRequest, cancelQueryRequest.PackAndSerialize()); - using var cancelQueryCall = _client.DoAction(cancelQueryAction, options?.Headers); + var cancelQueryCall = _client.DoAction(cancelQueryAction, options?.Headers); await foreach (var result in cancelQueryCall.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -922,7 +922,7 @@ public async Task BeginTransactionAsync(FlightCallOptions? options { var actionBeginTransaction = new ActionBeginTransactionRequest(); var action = new FlightAction(SqlAction.BeginTransactionRequest, actionBeginTransaction.PackAndSerialize()); - using var responseStream = _client.DoAction(action, options?.Headers); + var responseStream = _client.DoAction(action, options?.Headers); await foreach (var result in responseStream.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { string? beginTransactionResult = result.Body.ToStringUtf8(); @@ -1009,7 +1009,7 @@ public async Task PrepareAsync(string query, Transaction? tra }; var action = new FlightAction(SqlAction.CreateRequest, preparedStatementRequest.PackAndSerialize()); - using var call = _client.DoAction(action, options?.Headers); + var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync()) { var preparedStatementResponse = @@ -1032,16 +1032,3 @@ public async Task PrepareAsync(string query, Transaction? tra } } } - -internal static class FlightDescriptorExtensions -{ - public static byte[] PackAndSerialize(this IMessage command) - { - return Any.Pack(command).Serialize().ToByteArray(); - } - - public static T ParseAndUnpack(this ByteString source) where T : IMessage, new() - { - return Any.Parser.ParseFrom(source).Unpack(); - } -} diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs new file mode 100644 index 0000000000000..a7b68591f57fd --- /dev/null +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs @@ -0,0 +1,10 @@ +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; + +namespace Apache.Arrow.Flight.Sql; + +internal static class FlightExtensions +{ + public static byte[] PackAndSerialize(this IMessage command) => Any.Pack(command).ToByteArray(); + public static T ParseAndUnpack(this ByteString source) where T : IMessage, new() => Any.Parser.ParseFrom(source).Unpack(); +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj b/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj index 2282c11c1ed39..7a767b10d0396 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj @@ -11,6 +11,7 @@ + From 97d05ed5130797d7e7d0ceaf3af64e5987506b12 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 6 Nov 2024 12:50:55 +0200 Subject: [PATCH 46/58] refactor: reduce code and concise GetExecuteSchemaAsync fixing the PrepareStatement test code and TestFlightSqlServer adding extension to handle Serialization of schema. --- .../Client/FlightSqlClient.cs | 74 +++++++++--------- .../SchemaExtensions.cs | 12 +++ .../FlightSqlClientTests.cs | 75 ++++++++++++++----- .../TestFlightSqlServer.cs | 16 +++- 4 files changed, 120 insertions(+), 57 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 617b79e3b42d5..c5fb52de6a3a7 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -159,28 +159,24 @@ public async Task GetExecuteSchemaAsync(string query, Transaction? trans if (string.IsNullOrEmpty(query)) throw new ArgumentException($"Query cannot be null or empty: {nameof(query)}"); - - FlightInfo schemaResult = null!; try { var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); using var call = _client.DoAction(action, options?.Headers); + + var preparedStatementResponse = await ReadPreparedStatementAsync(call).ConfigureAwait(false); - await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + if (preparedStatementResponse.PreparedStatementHandle.IsEmpty) + throw new InvalidOperationException("Received an empty or invalid PreparedStatementHandle."); + var commandSqlCall = new CommandPreparedStatementQuery { - var preparedStatementResponse = - FlightSqlUtils.ParseAndUnpack(result.Body); - var commandSqlCall = new CommandPreparedStatementQuery - { - PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle - }; - byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); - var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); - schemaResult = await GetFlightInfoAsync(descriptor, options); - } + PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle + }; + var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCall.PackAndSerialize()); + var schemaResult = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return schemaResult.Schema; } catch (RpcException ex) @@ -992,43 +988,53 @@ public AsyncServerStreamingCall RollbackAsync(Transaction transact /// A transaction to associate this query with. /// RPC-layer hints for this call. /// The created prepared statement. - public async Task PrepareAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) + public async Task PrepareStatementAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { - transaction ??= Transaction.NoTransaction; if (string.IsNullOrEmpty(query)) - { throw new ArgumentException("Query cannot be null or empty", nameof(query)); - } + + transaction ??= Transaction.NoTransaction; try { - var preparedStatementRequest = new ActionCreatePreparedStatementRequest + var command = new ActionCreatePreparedStatementRequest { - Query = query, TransactionId = transaction.TransactionId + Query = query }; - var action = new FlightAction(SqlAction.CreateRequest, preparedStatementRequest.PackAndSerialize()); - var call = _client.DoAction(action, options?.Headers); - await foreach (var result in call.ResponseStream.ReadAllAsync()) + if (transaction.IsValid()) { - var preparedStatementResponse = - FlightSqlUtils.ParseAndUnpack(result.Body); - - var commandSqlCall = new CommandPreparedStatementQuery - { - PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle - }; - byte[] commandSqlCallPackedAndSerialized = commandSqlCall.PackAndSerialize(); - var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCallPackedAndSerialized); - var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); - return new PreparedStatement(this, transaction.TransactionId.ToStringUtf8(), flightInfo.Schema, flightInfo.Schema); + command.TransactionId = transaction.TransactionId; } - throw new NullReferenceException($"{nameof(PreparedStatement)} was not able to be created"); + + var action = new FlightAction(SqlAction.CreateRequest, command.PackAndSerialize()); + using var call = _client.DoAction(action, options?.Headers); + var preparedStatementResponse = await ReadPreparedStatementAsync(call).ConfigureAwait(false); + + return new PreparedStatement(this, + preparedStatementResponse.PreparedStatementHandle.ToStringUtf8(), + SchemaExtensions.DeserializeSchema(preparedStatementResponse.DatasetSchema.ToByteArray()), + SchemaExtensions.DeserializeSchema(preparedStatementResponse.ParameterSchema.ToByteArray()) + ); } catch (RpcException ex) { throw new InvalidOperationException("Failed to prepare statement", ex); } } + + private static async Task ReadPreparedStatementAsync( + AsyncServerStreamingCall call) + { + await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) + { + var response = Any.Parser.ParseFrom(result.Body); + if (response.Is(ActionCreatePreparedStatementResult.Descriptor)) + { + return response.Unpack(); + } + } + throw new InvalidOperationException("Server did not return a valid prepared statement response."); + } } diff --git a/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs index 4df34ee8e1009..6c6ba0c14d9b0 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs @@ -22,4 +22,16 @@ public static Schema DeserializeSchema(byte[] serializedSchema) var reader = new ArrowStreamReader(stream); return reader.Schema; } + + /// + /// Serializes the provided schema to a byte array. + /// + public static byte[] SerializeSchema(Schema schema) + { + using var memoryStream = new MemoryStream(); + using var writer = new ArrowStreamWriter(memoryStream, schema); + writer.WriteStart(); + writer.WriteEnd(); + return memoryStream.ToArray(); + } } \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 063b3e633873f..e4b264751abeb 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -89,17 +89,60 @@ public async Task PreparedStatementAsync() string query = "INSERT INTO users (id, name) VALUES (1, 'John Doe')"; var transaction = new Transaction("sample-transaction-id"); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); - var recordBatch = _testUtils.CreateTestBatch(0, 100); - var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); - flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); + // Create a sample schema for the dataset and parameters + var schema = new Schema.Builder() + .Field(f => f.Name("id").DataType(Int32Type.Default)) + .Field(f => f.Name("name").DataType(StringType.Default)) + .Build(); + + var recordBatch = new RecordBatch(schema, new Array[] + { + new Int32Array.Builder().Append(1).Build(), + new StringArray.Builder().Append("John Doe").Build() + }, 1); + + var flightHolder = new FlightHolder(flightDescriptor, schema, _testWebFactory.GetAddress()); + flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); _flightStore.Flights.Add(flightDescriptor, flightHolder); + + var datasetSchemaBytes = SchemaExtensions.SerializeSchema(schema); + var parameterSchemaBytes = SchemaExtensions.SerializeSchema(schema); + + var preparedStatementResponse = new ActionCreatePreparedStatementResult + { + PreparedStatementHandle = ByteString.CopyFromUtf8("prepared-handle"), + DatasetSchema = ByteString.CopyFrom(datasetSchemaBytes), + ParameterSchema = ByteString.CopyFrom(parameterSchemaBytes) + }; + // Act - var preparedStatement = await _flightSqlClient.PrepareAsync(query, transaction); + var preparedStatement = await _flightSqlClient.PrepareStatementAsync(query, transaction); + var deserializedDatasetSchema = SchemaExtensions.DeserializeSchema(preparedStatementResponse.DatasetSchema.ToByteArray()); + var deserializedParameterSchema = SchemaExtensions.DeserializeSchema(preparedStatementResponse.ParameterSchema.ToByteArray()); // Assert Assert.NotNull(preparedStatement); + Assert.NotNull(deserializedDatasetSchema); + Assert.NotNull(deserializedParameterSchema); + CompareSchemas(schema, deserializedDatasetSchema); + CompareSchemas(schema, deserializedParameterSchema); + // // Arrange + // string query = "INSERT INTO users (id, name) VALUES (1, 'John Doe')"; + // var transaction = new Transaction("sample-transaction-id"); + // var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); + // var recordBatch = _testUtils.CreateTestBatch(0, 100); + // var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); + // flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); + // + // _flightStore.Flights.Add(flightDescriptor, flightHolder); + // + // // Act + // var preparedStatement = await _flightSqlClient.PrepareStatementAsync(query, transaction); + // + // // Assert + // Assert.NotNull(preparedStatement); } #endregion @@ -461,8 +504,7 @@ public async Task GetDbSchemasSchemaAsync() public async Task DoPutAsync() { // Arrange - var schema = new Schema - .Builder() + var schema = new Schema.Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) @@ -628,8 +670,7 @@ public async Task GetCrossReferenceSchemaAsync() public async Task GetTableTypesAsync() { // Arrange - var expectedSchema = new Schema - .Builder() + var expectedSchema = new Schema.Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) .Build(); var commandGetTableTypes = new CommandGetTableTypes(); @@ -652,8 +693,7 @@ public async Task GetTableTypesSchemaAsync() { // Arrange var options = new FlightCallOptions(); - var expectedSchema = new Schema - .Builder() + var expectedSchema = new Schema.Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) .Build(); var commandGetTableTypesSchema = new CommandGetTableTypes(); @@ -676,8 +716,7 @@ public async Task GetXdbcTypeInfoAsync() { // Arrange var options = new FlightCallOptions(); - var expectedSchema = new Schema - .Builder() + var expectedSchema = new Schema.Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) @@ -705,8 +744,7 @@ public async Task GetXdbcTypeInfoSchemaAsync() { // Arrange var options = new FlightCallOptions(); - var expectedSchema = new Schema - .Builder() + var expectedSchema = new Schema.Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) .Field(f => f.Name("TYPE_NAME").DataType(StringType.Default).Nullable(false)) .Field(f => f.Name("PRECISION").DataType(Int32Type.Default).Nullable(false)) @@ -734,8 +772,7 @@ public async Task GetSqlInfoSchemaAsync() // Arrange var options = new FlightCallOptions(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("sqlInfo"); - var expectedSchema = new Schema - .Builder() + var expectedSchema = new Schema.Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) .Build(); var flightHolder = new FlightHolder(flightDescriptor, expectedSchema, _testWebFactory.GetAddress()); @@ -753,8 +790,7 @@ public async Task GetSqlInfoSchemaAsync() public async Task CancelFlightInfoAsync() { // Arrange - var schema = new Schema - .Builder() + var schema = new Schema.Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) .Build(); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); @@ -773,8 +809,7 @@ public async Task CancelQueryAsync() { // Arrange var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); - var schema = new Schema - .Builder() + var schema = new Schema.Builder() .Field(f => f.Name("DATA_TYPE_ID").DataType(Int32Type.Default).Nullable(false)) .Build(); var flightInfo = new FlightInfo(schema, flightDescriptor, new List(), 0, 0); diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs index 0463ee79379f3..b2abccd5f8639 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs @@ -18,6 +18,7 @@ using System.Threading.Tasks; using Apache.Arrow.Flight.Server; using Apache.Arrow.Flight.Sql; +using Apache.Arrow.Types; using Arrow.Flight.Protocol.Sql; using Google.Protobuf; using Google.Protobuf.WellKnownTypes; @@ -59,11 +60,20 @@ await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample break; case SqlAction.CreateRequest: case SqlAction.CloseRequest: - var prepareStatementResponse = new ActionCreatePreparedStatementResult + var schema = new Schema.Builder() + .Field(f => f.Name("id").DataType(Int32Type.Default)) + .Field(f => f.Name("name").DataType(StringType.Default)) + .Build(); + var datasetSchemaBytes = SchemaExtensions.SerializeSchema(schema); + var parameterSchemaBytes = SchemaExtensions.SerializeSchema(schema); + + var preparedStatementResponse = new ActionCreatePreparedStatementResult { - PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement") + PreparedStatementHandle = ByteString.CopyFromUtf8("ample-testing-prepared-statement"), + DatasetSchema = ByteString.CopyFrom(datasetSchemaBytes), + ParameterSchema = ByteString.CopyFrom(parameterSchemaBytes) }; - byte[] packedResult = Any.Pack(prepareStatementResponse).Serialize().ToByteArray(); + byte[] packedResult = Any.Pack(preparedStatementResponse).Serialize().ToByteArray(); var flightResult = new FlightResult(packedResult); await responseStream.WriteAsync(flightResult).ConfigureAwait(false); break; From 357b4e6b11a7a865f85364071a3d1181b3cf3416 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 6 Nov 2024 13:10:13 +0200 Subject: [PATCH 47/58] refactor: PrepareAsync name --- .../Client/FlightSqlClient.cs | 3 +-- .../FlightSqlClientTests.cs | 20 ++----------------- 2 files changed, 3 insertions(+), 20 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index c5fb52de6a3a7..9f2f6c05e87e9 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -175,7 +175,6 @@ public async Task GetExecuteSchemaAsync(string query, Transaction? trans PreparedStatementHandle = preparedStatementResponse.PreparedStatementHandle }; var descriptor = FlightDescriptor.CreateCommandDescriptor(commandSqlCall.PackAndSerialize()); - var schemaResult = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return schemaResult.Schema; } @@ -988,7 +987,7 @@ public AsyncServerStreamingCall RollbackAsync(Transaction transact /// A transaction to associate this query with. /// RPC-layer hints for this call. /// The created prepared statement. - public async Task PrepareStatementAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) + public async Task PrepareAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { if (string.IsNullOrEmpty(query)) diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index e4b264751abeb..2ffd5ce30df3c 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -83,7 +83,7 @@ public async Task RollbackTransactionAsync() #region PreparedStatement [Fact] - public async Task PreparedStatementAsync() + public async Task PreparedAsync() { // Arrange string query = "INSERT INTO users (id, name) VALUES (1, 'John Doe')"; @@ -108,7 +108,6 @@ public async Task PreparedStatementAsync() var datasetSchemaBytes = SchemaExtensions.SerializeSchema(schema); var parameterSchemaBytes = SchemaExtensions.SerializeSchema(schema); - var preparedStatementResponse = new ActionCreatePreparedStatementResult { @@ -118,7 +117,7 @@ public async Task PreparedStatementAsync() }; // Act - var preparedStatement = await _flightSqlClient.PrepareStatementAsync(query, transaction); + var preparedStatement = await _flightSqlClient.PrepareAsync(query, transaction); var deserializedDatasetSchema = SchemaExtensions.DeserializeSchema(preparedStatementResponse.DatasetSchema.ToByteArray()); var deserializedParameterSchema = SchemaExtensions.DeserializeSchema(preparedStatementResponse.ParameterSchema.ToByteArray()); @@ -128,21 +127,6 @@ public async Task PreparedStatementAsync() Assert.NotNull(deserializedParameterSchema); CompareSchemas(schema, deserializedDatasetSchema); CompareSchemas(schema, deserializedParameterSchema); - // // Arrange - // string query = "INSERT INTO users (id, name) VALUES (1, 'John Doe')"; - // var transaction = new Transaction("sample-transaction-id"); - // var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); - // var recordBatch = _testUtils.CreateTestBatch(0, 100); - // var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); - // flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); - // - // _flightStore.Flights.Add(flightDescriptor, flightHolder); - // - // // Act - // var preparedStatement = await _flightSqlClient.PrepareStatementAsync(query, transaction); - // - // // Assert - // Assert.NotNull(preparedStatement); } #endregion From df8ca6019d3258c1c7bcdd82c2481cc3a317e74e Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 18 Nov 2024 18:42:03 +0200 Subject: [PATCH 48/58] fix: resolved issues which were found when the client was tested against concrete Database --- .../Client/FlightSqlClient.cs | 175 ++++-------------- .../FlightExtensions.cs | 51 +++++ .../PreparedStatement.cs | 76 ++------ .../FlightSqlClientTests.cs | 5 +- .../FlightSqlPreparedStatementTests.cs | 57 +++--- 5 files changed, 135 insertions(+), 229 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index 9f2f6c05e87e9..bf29695c76cd5 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -2,9 +2,7 @@ using System.Collections.Generic; using System.Threading.Tasks; using Apache.Arrow.Flight.Client; -using Apache.Arrow.Types; using Arrow.Flight.Protocol.Sql; -using Google.Protobuf; using Google.Protobuf.WellKnownTypes; using Grpc.Core; @@ -88,9 +86,13 @@ public async Task ExecuteUpdateAsync(string query, Transaction? transactio var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); var flightInfo = await GetFlightInfoAsync(descriptor, options); var doGetResult = DoGetAsync(flightInfo.Endpoints[0].Ticket, options); + await foreach (var recordBatch in doGetResult.ConfigureAwait(false)) { - affectedRows += recordBatch.Column(0).Length; + foreach (var rowCount in recordBatch.ExtractRowCount()) + { + affectedRows += rowCount; + } } } @@ -117,7 +119,7 @@ public async Task GetFlightInfoAsync(FlightDescriptor descriptor, Fl try { - using var flightInfoCall = _client.GetInfo(descriptor, options?.Headers); + var flightInfoCall = _client.GetInfo(descriptor, options?.Headers); var flightInfo = await flightInfoCall.ResponseAsync.ConfigureAwait(false); return flightInfo; } @@ -138,7 +140,7 @@ public async IAsyncEnumerable DoActionAsync(FlightAction action, F if (action is null) throw new ArgumentNullException(nameof(action)); - using var call = _client.DoAction(action, options?.Headers); + var call = _client.DoAction(action, options?.Headers); await foreach (var result in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { @@ -164,7 +166,7 @@ public async Task GetExecuteSchemaAsync(string query, Transaction? trans var prepareStatementRequest = new ActionCreatePreparedStatementRequest { Query = query, TransactionId = transaction.TransactionId }; var action = new FlightAction(SqlAction.CreateRequest, prepareStatementRequest.PackAndSerialize()); - using var call = _client.DoAction(action, options?.Headers); + var call = _client.DoAction(action, options?.Headers); var preparedStatementResponse = await ReadPreparedStatementAsync(call).ConfigureAwait(false); @@ -191,11 +193,6 @@ public async Task GetExecuteSchemaAsync(string query, Transaction? trans /// The FlightInfo describing where to access the dataset. public async Task GetCatalogsAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var command = new CommandGetCatalogs(); @@ -244,7 +241,7 @@ public virtual async Task GetSchemaAsync(FlightDescriptor descriptor, Fl try { - using var schemaResultCall = _client.GetSchema(descriptor, options?.Headers); + var schemaResultCall = _client.GetSchema(descriptor, options?.Headers); var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); return schemaResult; } @@ -297,11 +294,6 @@ public async Task GetDbSchemasAsync(string? catalog = null, string? /// The SchemaResult describing the schema of the database schemas. public async Task GetDbSchemasSchemaAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var command = new CommandGetDbSchemas(); @@ -328,7 +320,7 @@ public async IAsyncEnumerable DoGetAsync(FlightTicket ticket, Fligh throw new ArgumentNullException(nameof(ticket)); } - using var call = _client.GetStream(ticket, options?.Headers); + var call = _client.GetStream(ticket, options?.Headers); await foreach (var recordBatch in call.ResponseStream.ReadAllAsync().ConfigureAwait(false)) { yield return recordBatch; @@ -340,27 +332,34 @@ public async IAsyncEnumerable DoGetAsync(FlightTicket ticket, Fligh /// once they are done writing. /// /// The descriptor of the stream. - /// The schema for the data to upload. + /// The record for the data to upload. /// RPC-layer hints for this call. /// A Task representing the asynchronous operation. The task result contains a DoPutResult struct holding a reader and a writer. - public async Task DoPutAsync(FlightDescriptor descriptor, Schema schema, FlightCallOptions? options = default) + public async Task DoPutAsync(FlightDescriptor descriptor, RecordBatch recordBatch, FlightCallOptions? options = default) { if (descriptor is null) throw new ArgumentNullException(nameof(descriptor)); - if (schema is null) - throw new ArgumentNullException(nameof(schema)); + if (recordBatch is null) + throw new ArgumentNullException(nameof(recordBatch)); try { var doPutResult = _client.StartPut(descriptor, options?.Headers); var writer = doPutResult.RequestStream; var reader = doPutResult.ResponseStream; - var recordBatch = new RecordBatch(schema, BuildArrowArraysFromSchema(schema, schema.FieldsList.Count), 0); + if (recordBatch == null || recordBatch.Length == 0) + throw new InvalidOperationException("RecordBatch is empty or improperly initialized."); + await writer.WriteAsync(recordBatch).ConfigureAwait(false); await writer.CompleteAsync().ConfigureAwait(false); - return new DoPutResult(writer, reader); + if (await reader.MoveNext().ConfigureAwait(false)) + { + var putResult = reader.Current; + return new FlightPutResult(putResult.ApplicationMetadata); + } + return FlightPutResult.Empty; } catch (RpcException ex) { @@ -368,73 +367,6 @@ public async Task DoPutAsync(FlightDescriptor descriptor, Schema sc } } - public List BuildArrowArraysFromSchema(Schema schema, int rowCount) - { - var arrays = new List(); - - foreach (var field in schema.FieldsList) - { - switch (field.DataType) - { - case Int32Type _: - // Create an Int32 array - var intArrayBuilder = new Int32Array.Builder(); - for (int i = 0; i < rowCount; i++) - { - intArrayBuilder.Append(i); - } - - arrays.Add(intArrayBuilder.Build()); - break; - - case StringType: - var stringArrayBuilder = new StringArray.Builder(); - for (int i = 0; i < rowCount; i++) - { - stringArrayBuilder.Append($"Value-{i}"); - } - - arrays.Add(stringArrayBuilder.Build()); - break; - - case Int64Type: - var longArrayBuilder = new Int64Array.Builder(); - for (int i = 0; i < rowCount; i++) - { - longArrayBuilder.Append((long)i * 100); - } - - arrays.Add(longArrayBuilder.Build()); - break; - - case FloatType: - var floatArrayBuilder = new FloatArray.Builder(); - for (int i = 0; i < rowCount; i++) - { - floatArrayBuilder.Append((float)(i * 1.1)); - } - - arrays.Add(floatArrayBuilder.Build()); - break; - - case BooleanType: - var boolArrayBuilder = new BooleanArray.Builder(); - for (int i = 0; i < rowCount; i++) - { - boolArrayBuilder.Append(i % 2 == 0); - } - arrays.Add(boolArrayBuilder.Build()); - break; - - default: - throw new NotSupportedException($"Data type {field.DataType} not supported yet."); - } - } - - return arrays; - } - - /// /// Request the primary keys for a table. /// @@ -452,8 +384,8 @@ public async Task GetPrimaryKeysAsync(TableRef tableRef, FlightCallO { Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table }; - byte[] packedRequest = getPrimaryKeysRequest.PackAndSerialize(); - var descriptor = FlightDescriptor.CreateCommandDescriptor(packedRequest); + + var descriptor = FlightDescriptor.CreateCommandDescriptor(getPrimaryKeysRequest.PackAndSerialize()); var flightInfo = await GetFlightInfoAsync(descriptor, options).ConfigureAwait(false); return flightInfo; @@ -474,8 +406,7 @@ public async Task GetPrimaryKeysAsync(TableRef tableRef, FlightCallO /// The table types to include. /// RPC-layer hints for this call. /// The FlightInfo describing where to access the dataset. - public async Task> - GetTablesAsync(string? catalog = null, string? dbSchemaFilterPattern = null, string? tableFilterPattern = null, bool includeSchema = false, IEnumerable? tableTypes = null, FlightCallOptions? options = default) + public async Task> GetTablesAsync(string? catalog = null, string? dbSchemaFilterPattern = null, string? tableFilterPattern = null, bool includeSchema = false, IEnumerable? tableTypes = null, FlightCallOptions? options = default) { var command = new CommandGetTables { @@ -530,18 +461,17 @@ public async Task GetExportedKeysAsync(TableRef tableRef, FlightCall /// /// Get the exported keys schema from the server. /// + /// The table reference. /// RPC-layer hints for this call. /// The SchemaResult describing the schema of the exported keys. - public async Task GetExportedKeysSchemaAsync(FlightCallOptions? options = default) + public async Task GetExportedKeysSchemaAsync(TableRef tableRef, FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { - var commandGetExportedKeysSchema = new CommandGetExportedKeys(); + var commandGetExportedKeysSchema = new CommandGetExportedKeys + { + Catalog = tableRef.Catalog ?? string.Empty, DbSchema = tableRef.DbSchema, Table = tableRef.Table + }; var descriptor = FlightDescriptor.CreateCommandDescriptor(commandGetExportedKeysSchema.PackAndSerialize()); var schemaResult = await GetSchemaAsync(descriptor, options).ConfigureAwait(false); return schemaResult; @@ -586,11 +516,6 @@ public async Task GetImportedKeysAsync(TableRef tableRef, FlightCall /// The SchemaResult describing the schema of the imported keys. public async Task GetImportedKeysSchemaAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var commandGetImportedKeysSchema = new CommandGetImportedKeys(); @@ -692,11 +617,6 @@ public async Task GetTableTypesAsync(FlightCallOptions? options = de /// The SchemaResult describing the schema of the table types. public async Task GetTableTypesSchemaAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var command = new CommandGetTableTypes(); @@ -718,11 +638,6 @@ public async Task GetTableTypesSchemaAsync(FlightCallOptions? options = /// The FlightInfo describing where to access the dataset. public async Task GetXdbcTypeInfoAsync(int dataType, FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var command = new CommandGetXdbcTypeInfo { DataType = dataType }; @@ -743,11 +658,6 @@ public async Task GetXdbcTypeInfoAsync(int dataType, FlightCallOptio /// The FlightInfo describing where to access the dataset. public async Task GetXdbcTypeInfoAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var command = new CommandGetXdbcTypeInfo(); @@ -768,11 +678,6 @@ public async Task GetXdbcTypeInfoAsync(FlightCallOptions? options = /// The SchemaResult describing the schema of the type info. public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var command = new CommandGetXdbcTypeInfo(); @@ -794,11 +699,6 @@ public async Task GetXdbcTypeInfoSchemaAsync(FlightCallOptions? options /// The FlightInfo describing where to access the dataset. public async Task GetSqlInfoAsync(List? sqlInfo = default, FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - sqlInfo ??= new List(); try { @@ -821,16 +721,11 @@ public async Task GetSqlInfoAsync(List? sqlInfo = default, Flig /// The SchemaResult describing the schema of the SQL information. public async Task GetSqlInfoSchemaAsync(FlightCallOptions? options = default) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - try { var command = new CommandGetSqlInfo(); var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); - var schemaResultCall = _client.GetSchema(descriptor, options.Headers); + var schemaResultCall = _client.GetSchema(descriptor, options?.Headers); var schemaResult = await schemaResultCall.ResponseAsync.ConfigureAwait(false); return schemaResult; @@ -989,7 +884,6 @@ public AsyncServerStreamingCall RollbackAsync(Transaction transact /// The created prepared statement. public async Task PrepareAsync(string query, Transaction? transaction = null, FlightCallOptions? options = default) { - if (string.IsNullOrEmpty(query)) throw new ArgumentException("Query cannot be null or empty", nameof(query)); @@ -1008,9 +902,10 @@ public async Task PrepareAsync(string query, Transaction? tra } var action = new FlightAction(SqlAction.CreateRequest, command.PackAndSerialize()); - using var call = _client.DoAction(action, options?.Headers); + var call = _client.DoAction(action, options?.Headers); var preparedStatementResponse = await ReadPreparedStatementAsync(call).ConfigureAwait(false); + return new PreparedStatement(this, preparedStatementResponse.PreparedStatementHandle.ToStringUtf8(), SchemaExtensions.DeserializeSchema(preparedStatementResponse.DatasetSchema.ToByteArray()), diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs index a7b68591f57fd..74b490f2c8151 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs @@ -1,3 +1,5 @@ +using System; +using System.Collections.Generic; using Google.Protobuf; using Google.Protobuf.WellKnownTypes; @@ -7,4 +9,53 @@ internal static class FlightExtensions { public static byte[] PackAndSerialize(this IMessage command) => Any.Pack(command).ToByteArray(); public static T ParseAndUnpack(this ByteString source) where T : IMessage, new() => Any.Parser.ParseFrom(source).Unpack(); + + public static IEnumerable ExtractRowCount(this RecordBatch batch) + { + foreach (var array in batch.Arrays) + { + var values = ExtractValues(array); + foreach (var value in values) + { + yield return value as long? ?? 0; + } + } + } + + private static IEnumerable ExtractValues(IArrowArray array) + { + return array switch + { + Int32Array int32Array => ExtractPrimitiveValues(int32Array), + Int64Array int64Array => ExtractPrimitiveValues(int64Array), + FloatArray floatArray => ExtractPrimitiveValues(floatArray), + BooleanArray booleanArray => ExtractBooleanValues(booleanArray), + StringArray stringArray => ExtractStringValues(stringArray), + _ => throw new NotSupportedException($"Array type {array.GetType().Name} is not supported.") + }; + } + + private static IEnumerable ExtractPrimitiveValues(PrimitiveArray array) where T : struct, IEquatable + { + for (int i = 0; i < array.Length; i++) + { + yield return array.IsNull(i) ? null : array.Values[i]; + } + } + + private static IEnumerable ExtractBooleanValues(BooleanArray array) + { + for (int i = 0; i < array.Length; i++) + { + yield return array.IsNull(i) ? null : array.Values[i]; + } + } + + private static IEnumerable ExtractStringValues(StringArray stringArray) + { + for (int i = 0; i < stringArray.Length; i++) + { + yield return stringArray.IsNull(i) ? null : stringArray.GetString(i); + } + } } \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index 39b6fcb2576f1..6f3d88fbb0299 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -12,6 +12,7 @@ using Google.Protobuf; using Grpc.Core; using System.Threading.Channels; +using Google.Protobuf.WellKnownTypes; namespace Apache.Arrow.Flight.Sql; @@ -25,8 +26,7 @@ public class PreparedStatement : IDisposable private bool _isClosed; public bool IsClosed => _isClosed; public string Handle => _handle; - private FlightServerRecordBatchStreamReader? _parameterReader; - public FlightServerRecordBatchStreamReader? ParameterReader => _parameterReader; + public RecordBatch? ParametersBatch => _recordsBatch; /// /// Initializes a new instance of the class. @@ -198,57 +198,9 @@ public async Task ParseResponseAsync(FlightSqlClient client, /// A cancellation token for the binding operation. /// A indicating success or failure. /// Thrown if is null. - public Status SetParameters(RecordBatch parameterBatch, CancellationToken cancellationToken = default) + public void SetParameters(RecordBatch parameterBatch) { - EnsureStatementIsNotClosed(); - _recordsBatch = parameterBatch ?? throw new ArgumentNullException(nameof(parameterBatch)); - - var channel = Channel.CreateUnbounded(); - var task = Task.Run(async () => - { - try - { - using (var memoryStream = new MemoryStream()) - { - var writer = new ArrowStreamWriter(memoryStream, _recordsBatch.Schema); - - cancellationToken.ThrowIfCancellationRequested(); - await writer.WriteRecordBatchAsync(_recordsBatch, cancellationToken).ConfigureAwait(false); - await writer.WriteEndAsync(cancellationToken).ConfigureAwait(false); - - memoryStream.Position = 0; - - cancellationToken.ThrowIfCancellationRequested(); - - var flightData = new FlightData( - FlightDescriptor.CreateCommandDescriptor(_handle), - ByteString.CopyFrom(memoryStream.ToArray()), - ByteString.Empty, - ByteString.Empty - ); - await channel.Writer.WriteAsync(flightData, cancellationToken).ConfigureAwait(false); - } - - channel.Writer.Complete(); - } - catch (OperationCanceledException) - { - channel.Writer.TryComplete(new OperationCanceledException("Task was canceled")); - } - catch (Exception ex) - { - channel.Writer.TryComplete(ex); - } - }, cancellationToken); - - _parameterReader = new FlightServerRecordBatchStreamReader(new ChannelReaderStreamAdapter(channel.Reader)); - if (task.IsCanceled || cancellationToken.IsCancellationRequested) - { - return Status.DefaultCancelled; - } - - return Status.DefaultSuccess; } /// @@ -263,7 +215,12 @@ public async Task ExecuteAsync(CancellationToken cancellationToken = { EnsureStatementIsNotClosed(); - var descriptor = FlightDescriptor.CreateCommandDescriptor(_handle); + var command = new CommandPreparedStatementQuery + { + PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8), + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); cancellationToken.ThrowIfCancellationRequested(); if (_recordsBatch != null) @@ -314,7 +271,13 @@ public async Task ExecuteUpdateAsync(RecordBatch parameterBatch, FlightCal { throw new ArgumentNullException(nameof(parameterBatch), "Parameter batch cannot be null."); } - var descriptor = FlightDescriptor.CreateCommandDescriptor(_handle); + + var command = new CommandPreparedStatementQuery + { + PreparedStatementHandle = ByteString.CopyFrom(_handle, Encoding.UTF8), + }; + + var descriptor = FlightDescriptor.CreateCommandDescriptor(command.PackAndSerialize()); var metadata = await BindParametersAsync(descriptor, parameterBatch, options).ConfigureAwait(false); try @@ -358,13 +321,10 @@ public async Task BindParametersAsync(FlightDescriptor descriptor, R { throw new ArgumentNullException(nameof(parameterBatch), "Parameter batch cannot be null."); } - - var putResult = await _client.DoPutAsync(descriptor, parameterBatch.Schema, options).ConfigureAwait(false); - + var putResult = await _client.DoPutAsync(descriptor, parameterBatch, options).ConfigureAwait(false); try { - var metadata = await putResult.ReadMetadataAsync().ConfigureAwait(false); - await putResult.CompleteAsync().ConfigureAwait(false); + var metadata = putResult.ApplicationMetadata; return metadata; } catch (OperationCanceledException) diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 2ffd5ce30df3c..454a009f7ecf2 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -520,7 +520,7 @@ public async Task DoPutAsync() var expectedBatch = _testUtils.CreateTestBatch(0, 100); // Act - var result = await _flightSqlClient.DoPutAsync(flightDescriptor, expectedBatch.Schema); + var result = await _flightSqlClient.DoPutAsync(flightDescriptor, expectedBatch); // Assert Assert.NotNull(result); @@ -557,7 +557,8 @@ public async Task GetExportedKeysSchemaAsync() _flightStore.Flights.Add(flightDescriptor, flightHolder); // Act - var schema = await _flightSqlClient.GetExportedKeysSchemaAsync(options); + var tableRef = new TableRef { Catalog = "test-catalog", Table = "test-table", DbSchema = "test-schema" }; + var schema = await _flightSqlClient.GetExportedKeysSchemaAsync(tableRef); // Assert Assert.NotNull(schema); diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs index b99a5ccb89023..d360c25cc5ca7 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs @@ -53,24 +53,25 @@ private RecordBatch CreateParameterBatch() new Int32Array.Builder().AppendRange(new[] { 32, 255, 1 }).Build() }, 3); } - - [Fact] - public async Task GetSchemaAsync_ShouldThrowInvalidOperationException_WhenStatementIsClosed() - { - await _preparedStatement.CloseAsync(new FlightCallOptions()); - await Assert.ThrowsAsync(() => _preparedStatement.GetSchemaAsync(new FlightCallOptions())); - } - + + [Fact] public async Task ExecuteAsync_ShouldReturnFlightInfo_WhenValidInputsAreProvided() { - var validRecordBatch = CreateRecordBatch(_schema, new[] { 1, 2, 3 }); - var result = _preparedStatement.SetParameters(validRecordBatch); + var validRecordBatch = CreateParameterBatch(); + _preparedStatement.SetParameters(validRecordBatch); var flightInfo = await _preparedStatement.ExecuteAsync(); Assert.NotNull(flightInfo); Assert.IsType(flightInfo); - Assert.Equal(Status.DefaultSuccess, result); + } + + [Fact] + public async Task GetSchemaAsync_ShouldThrowInvalidOperationException_WhenStatementIsClosed() + { + await _preparedStatement.CloseAsync(new FlightCallOptions()); + await Assert.ThrowsAsync(() => + _preparedStatement.GetSchemaAsync(new FlightCallOptions())); } [Fact] @@ -95,33 +96,28 @@ public async Task TestSetParameters(RecordBatch parameterBatch, Schema parameter var preparedStatement = new PreparedStatement(_flightSqlClient, "TestHandle", _schema, parameterSchema); if (expectedException != null) { - var exception = await Record.ExceptionAsync(() => Task.Run(() => preparedStatement.SetParameters(parameterBatch))); + var exception = + await Record.ExceptionAsync(() => Task.Run(() => preparedStatement.SetParameters(parameterBatch))); Assert.NotNull(exception); Assert.IsType(expectedException, exception); } - else - { - var result = await Task.Run(() => preparedStatement.SetParameters(parameterBatch)); - Assert.NotNull(preparedStatement.ParameterReader); - Assert.Equal(Status.DefaultSuccess, result); - } } [Fact] public async Task TestSetParameters_Cancelled() { - var validRecordBatch = CreateRecordBatch(_schema, new[] { 1, 2, 3 }); + var validRecordBatch = CreateRecordBatch([1, 2, 3]); var cts = new CancellationTokenSource(); await cts.CancelAsync(); - var result = _preparedStatement.SetParameters(validRecordBatch, cts.Token); - Assert.Equal(Status.DefaultCancelled, result); + _preparedStatement.SetParameters(validRecordBatch); } [Fact] public async Task TestCloseAsync() { await _preparedStatement.CloseAsync(new FlightCallOptions()); - Assert.True(_preparedStatement.IsClosed, "PreparedStatement should be marked as closed after calling CloseAsync."); + Assert.True(_preparedStatement.IsClosed, + "PreparedStatement should be marked as closed after calling CloseAsync."); } [Fact] @@ -168,11 +164,13 @@ public async Task ParseResponseAsync_ShouldReturnPreparedStatement_WhenValidData [Theory] [InlineData(null)] [InlineData("")] - public async Task ParseResponseAsync_ShouldThrowException_WhenPreparedStatementHandleIsNullOrEmpty(string handle) + public async Task ParseResponseAsync_ShouldThrowException_WhenPreparedStatementHandleIsNullOrEmpty( + string handle) { ActionCreatePreparedStatementResult actionResult = string.IsNullOrEmpty(handle) ? new ActionCreatePreparedStatementResult() - : new ActionCreatePreparedStatementResult { PreparedStatementHandle = ByteString.CopyFrom(handle, Encoding.UTF8) }; + : new ActionCreatePreparedStatementResult + { PreparedStatementHandle = ByteString.CopyFrom(handle, Encoding.UTF8) }; var flightData = new FlightData(_flightDescriptor, ByteString.CopyFrom(actionResult.ToByteArray())); var results = GetAsyncEnumerable(new List { flightData }); @@ -193,9 +191,10 @@ private async IAsyncEnumerable GetAsyncEnumerable(IEnumerable enumerabl public static IEnumerable GetTestData() { var schema = new Schema.Builder().Field(f => f.Name("field1").DataType(Int32Type.Default)).Build(); - var validRecordBatch = CreateRecordBatch(schema, new[] { 1, 2, 3 }); - var invalidSchema = new Schema.Builder().Field(f => f.Name("invalid_field").DataType(Int32Type.Default)).Build(); - var invalidRecordBatch = CreateRecordBatch(invalidSchema, new[] { 4, 5, 6 }); + var validRecordBatch = CreateRecordBatch([1, 2, 3]); + var invalidSchema = new Schema.Builder().Field(f => f.Name("invalid_field").DataType(Int32Type.Default)) + .Build(); + var invalidRecordBatch = CreateRecordBatch([4, 5, 6]); return new List { @@ -204,10 +203,10 @@ public static IEnumerable GetTestData() }; } - public static RecordBatch CreateRecordBatch(Schema schema, int[] values) + public static RecordBatch CreateRecordBatch(int[] values) { var int32Array = new Int32Array.Builder().AppendRange(values).Build(); return new RecordBatch.Builder().Append("field1", true, int32Array).Build(); } } -} +} \ No newline at end of file From 9caa6f2fda666c7db563405613ef6d86e2726476 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 18 Nov 2024 19:01:43 +0200 Subject: [PATCH 49/58] chore: appended the build reference for nuget. --- dev/release/post-08-csharp.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/release/post-08-csharp.sh b/dev/release/post-08-csharp.sh index 8c86b36774887..75bf2b963aa8e 100755 --- a/dev/release/post-08-csharp.sh +++ b/dev/release/post-08-csharp.sh @@ -39,6 +39,7 @@ base_names=() base_names+=(Apache.Arrow.${version}) base_names+=(Apache.Arrow.Flight.${version}) base_names+=(Apache.Arrow.Flight.AspNetCore.${version}) +base_names+=(Apache.Arrow.Flight.Sql.${version}) base_names+=(Apache.Arrow.Compression.${version}) for base_name in ${base_names[@]}; do for extension in nupkg snupkg; do From 8a14561832302a45fa471d5f438da77865468604 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 20 Nov 2024 08:49:28 +0200 Subject: [PATCH 50/58] chore: adding official Apache header to all the source files --- .../ChannelReaderStreamAdapter.cs | 15 +++++++++++++++ .../Client/FlightSqlClient.cs | 15 +++++++++++++++ csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs | 15 +++++++++++++++ .../Apache.Arrow.Flight.Sql/FlightCallOptions.cs | 15 +++++++++++++++ .../Apache.Arrow.Flight.Sql/FlightExtensions.cs | 15 +++++++++++++++ .../Apache.Arrow.Flight.Sql/PreparedStatement.cs | 15 +++++++++++++++ .../RecordBatchExtensions.cs | 15 +++++++++++++++ .../Apache.Arrow.Flight.Sql/SchemaExtensions.cs | 15 +++++++++++++++ csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs | 15 +++++++++++++++ csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs | 15 +++++++++++++++ .../FlightSqlClientTests.cs | 15 +++++++++++++++ .../FlightSqlPreparedStatementTests.cs | 15 +++++++++++++++ .../FlightSqlTestUtils.cs | 15 +++++++++++++++ .../test/Apache.Arrow.Flight.Sql.Tests/Startup.cs | 15 +++++++++++++++ .../TestFlightSqlWebFactory.cs | 15 +++++++++++++++ 15 files changed, 225 insertions(+) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs b/csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs index bbcbe17d32685..14cf03ca40771 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.Threading; using System.Threading.Channels; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs index bf29695c76cd5..55c23f66f356b 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Client/FlightSqlClient.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.Collections.Generic; using System.Threading.Tasks; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs b/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs index 9c5bdc0c099a9..646ed38647d44 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/DoPutResult.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System.Threading.Tasks; using Apache.Arrow.Flight.Client; using Grpc.Core; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs index 461bee550a97d..d5b275b816a4f 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightCallOptions.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.Buffers; using System.Threading; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs index 74b490f2c8151..1e6bb5f924204 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.Collections.Generic; using Google.Protobuf; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index 6f3d88fbb0299..4d2c304a31688 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.Collections.Generic; using System.IO; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/RecordBatchExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/RecordBatchExtensions.cs index 67f889565f318..5aecc084fd5f4 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/RecordBatchExtensions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/RecordBatchExtensions.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.Collections.Generic; using System.IO; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs index 6c6ba0c14d9b0..146293ec118a3 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/SchemaExtensions.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.IO; using Apache.Arrow.Ipc; diff --git a/csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs b/csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs index 704e9fedcd44d..b4026ab348d0d 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/TableRef.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + namespace Apache.Arrow.Flight.Sql; public class TableRef diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs b/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs index 32c882b94c0bc..a85ea9ca6a77d 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/Transaction.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using Google.Protobuf; namespace Apache.Arrow.Flight.Sql; diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index 454a009f7ecf2..bac686da352f0 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.Collections.Generic; using System.Linq; diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs index d360c25cc5ca7..637e632640c9b 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlPreparedStatementTests.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.Collections.Generic; using System.Text; diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs index 8094a0b491bd3..e0f22d74bbaba 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestUtils.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System.Linq; using Apache.Arrow.Flight.Tests; using Apache.Arrow.Flight.TestWeb; diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs index fedb7de114498..b99418ce99611 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Startup.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using Apache.Arrow.Flight.TestWeb; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs index 714d801f59fde..594c5d884b30d 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/TestFlightSqlWebFactory.cs @@ -1,3 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + using System; using System.Linq; using Apache.Arrow.Flight.TestWeb; From 842e1abff10259fe3907fd38854264d8d5c6272d Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 20 Nov 2024 08:52:32 +0200 Subject: [PATCH 51/58] chore: appended not existing project which was added after the project was forked --- csharp/Apache.Arrow.sln | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csharp/Apache.Arrow.sln b/csharp/Apache.Arrow.sln index f08302f1d944f..0dd6853a1c38a 100644 --- a/csharp/Apache.Arrow.sln +++ b/csharp/Apache.Arrow.sln @@ -27,6 +27,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.Sql", " EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Apache.Arrow.Flight.IntegrationTest", "test\Apache.Arrow.Flight.IntegrationTest\Apache.Arrow.Flight.IntegrationTest.csproj", "{7E66CBB4-D921-41E7-A98A-7C6DEA521696}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Apache.Arrow.IntegrationTest", "test\Apache.Arrow.IntegrationTest\Apache.Arrow.IntegrationTest.csproj", "{E8264B7F-B680-4A55-939B-85DB628164BB}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU From 9f166c267ae79600c0667e4b74ec24b0180c1f48 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 20 Nov 2024 08:54:16 +0200 Subject: [PATCH 52/58] chore: reverted version to existing one 2.66.0 -> 2.67.0 --- .../src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj b/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj index bec7dbdf54f7f..2bf25ee756059 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj +++ b/csharp/src/Apache.Arrow.Flight.Sql/Apache.Arrow.Flight.Sql.csproj @@ -5,7 +5,7 @@ - + From be6527187c362f362b6fa913f4a344a8dbc2f9d7 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Wed, 20 Nov 2024 09:00:21 +0200 Subject: [PATCH 53/58] chore: removed empty line --- cpp/src/arrow/flight/sql/client.h | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index 483bdadbc3878..4d9793a9e27d1 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -443,7 +443,6 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { PreparedStatement(FlightSqlClient* client, std::string handle, std::shared_ptr dataset_schema, std::shared_ptr parameter_schema); - /// \brief Default destructor for the PreparedStatement class. /// The destructor will call the Close method from the class in order, /// to send a request to close the PreparedStatement. From 2b615d640f615d46755f523d76675dfb3361910b Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 24 Nov 2024 12:20:03 +0200 Subject: [PATCH 54/58] fix: conversion from long to int and default value --- .../FlightExtensions.cs | 7 +- .../FlightSqlClientTests.cs | 15 +- .../FlightSqlTestExtensions.cs | 240 ++++++++++++++++++ .../TestFlightSqlServer.cs | 2 +- 4 files changed, 260 insertions(+), 4 deletions(-) diff --git a/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs b/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs index 1e6bb5f924204..1b9d3e84f25d0 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/FlightExtensions.cs @@ -32,7 +32,12 @@ public static IEnumerable ExtractRowCount(this RecordBatch batch) var values = ExtractValues(array); foreach (var value in values) { - yield return value as long? ?? 0; + yield return value switch + { + long l => l, + int i => i != 0 ? i : 0, + _ => 0L + }; } } } diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs index bac686da352f0..92311171bc6bd 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlClientTests.cs @@ -153,7 +153,18 @@ public async Task ExecuteUpdateAsync() string query = "UPDATE test_table SET column1 = 'value' WHERE column2 = 'condition'"; var transaction = new Transaction("sample-transaction-id"); var flightDescriptor = FlightDescriptor.CreateCommandDescriptor("test"); - var recordBatch = _testUtils.CreateTestBatch(0, 100); + + var schema = new Schema.Builder() + .Field(f => f.Name("id").DataType(Int32Type.Default)) + .Field(f => f.Name("name").DataType(StringType.Default)) + .Build(); + + var recordBatch = new RecordBatch(schema, new Array[] + { + new Int32Array.Builder().Append(1).Build(), + new StringArray.Builder().Append("John Doe").Build() + }, 1); + var flightHolder = new FlightHolder(flightDescriptor, recordBatch.Schema, _testWebFactory.GetAddress()); flightHolder.AddBatch(new RecordBatchWithMetadata(recordBatch)); @@ -163,7 +174,7 @@ public async Task ExecuteUpdateAsync() long affectedRows = await _flightSqlClient.ExecuteUpdateAsync(query, transaction); // Assert - Assert.Equal(100, affectedRows); + Assert.Equal(1, affectedRows); } [Fact] diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs index 031495fffdcc7..c1cd8f2bded0d 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/FlightSqlTestExtensions.cs @@ -13,8 +13,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +using System; +using System.Collections.Generic; +using System.Linq; +using Apache.Arrow.Memory; +using Apache.Arrow.Types; using Google.Protobuf; using Google.Protobuf.WellKnownTypes; +using Type = System.Type; namespace Apache.Arrow.Flight.Sql.Tests; @@ -25,3 +31,237 @@ public static ByteString PackAndSerialize(this IMessage command) return Any.Pack(command).Serialize(); } } + +internal static class TestSchemaExtensions +{ + public static void PrintSchema(this RecordBatch recordBatchResult) + { + // Display column headers + foreach (var field in recordBatchResult.Schema.FieldsList) + { + Console.Write($"{field.Name}\t"); + } + + Console.WriteLine(); + + int rowCount = recordBatchResult.Length; + + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) + { + foreach (var array in recordBatchResult.Arrays) + { + // Retrieve value based on array type + if (array is Int32Array intArray) + { + Console.Write($"{intArray.GetValue(rowIndex)}\t"); + } + else if (array is StringArray stringArray) + { + Console.Write($"{stringArray.GetString(rowIndex)}\t"); + } + else if (array is Int64Array longArray) + { + Console.Write($"{longArray.GetValue(rowIndex)}\t"); + } + else if (array is FloatArray floatArray) + { + Console.Write($"{floatArray.GetValue(rowIndex)}\t"); + } + else if (array is BooleanArray boolArray) + { + Console.Write($"{boolArray.GetValue(rowIndex)}\t"); + } + else + { + Console.Write("N/A\t"); // Fallback for unsupported types + } + } + + Console.WriteLine(); // Move to the next row + } + } + + public static RecordBatch CreateRecordBatch(int[] values) + { + var paramsList = new List(); + var schema = new Schema.Builder(); + for (var index = 0; index < values.Length; index++) + { + var val = values[index]; + var builder = new Int32Array.Builder(); + builder.Append(val); + var paramsArray = builder.Build(); + paramsList.Add(paramsArray); + schema.Field(f => f.Name($"param{index}").DataType(Int32Type.Default).Nullable(false)); + } + + return new RecordBatch(schema.Build(), paramsList, values.Length); + } + + public static void PrintSchema(this Schema schema) + { + Console.WriteLine("Schema Fields:"); + Console.WriteLine("{0,-20} {1,-20} {2,-20}", "Field Name", "Field Type", "Is Nullable"); + Console.WriteLine(new string('-', 60)); + + foreach (var field in schema.FieldsLookup) + { + string fieldName = field.First().Name; + string fieldType = field.First().DataType.TypeId.ToString(); + string isNullable = field.First().IsNullable ? "Yes" : "No"; + + Console.WriteLine("{0,-20} {1,-20} {2,-20}", fieldName, fieldType, isNullable); + } + } + + public static string GetStringValue(IArrowArray array, int index) + { + return array switch + { + StringArray stringArray => stringArray.GetString(index), + Int32Array intArray => intArray.GetValue(index).ToString(), + Int64Array longArray => longArray.GetValue(index).ToString(), + BooleanArray boolArray => boolArray.GetValue(index).Value ? "true" : "false", + _ => "Unsupported Type" + }; + } + + public static void PrintRecordBatch(RecordBatch recordBatch) + { + int rowCount = recordBatch.Length; + for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) + { + string catalogName = GetStringValue(recordBatch.Column(0), rowIndex); + string schemaName = GetStringValue(recordBatch.Column(1), rowIndex); + string tableName = GetStringValue(recordBatch.Column(2), rowIndex); + string tableType = GetStringValue(recordBatch.Column(3), rowIndex); + + Console.WriteLine("{0,-20} {1,-20} {2,-20} {3,-20}", catalogName, schemaName, tableName, tableType); + } + } + + public static RecordBatch CreateRecordBatch(int[] ids, string[] values) + { + var idArrayBuilder = new Int32Array.Builder(); + var valueArrayBuilder = new StringArray.Builder(); + + for (int i = 0; i < ids.Length; i++) + { + idArrayBuilder.Append(ids[i]); + valueArrayBuilder.Append(values[i]); + } + + var schema = new Schema.Builder() + .Field(f => f.Name("Id").DataType(Int32Type.Default).Nullable(false)) + .Field(f => f.Name("Value").DataType(StringType.Default).Nullable(false)) + .Build(); + + return new RecordBatch(schema, [idArrayBuilder.Build(), valueArrayBuilder.Build()], ids.Length); + } + + public static RecordBatch CreateRecordBatch(T[] items) + { + if (items is null || items.Length == 0) + { + throw new ArgumentException("Items array cannot be null or empty."); + } + + var schema = BuildSchema(typeof(T)); + + var arrays = new List(); + foreach (var field in schema.FieldsList) + { + var property = typeof(T).GetProperty(field.Name); + if (property is null) + { + throw new InvalidOperationException($"Property {field.Name} not found in type {typeof(T).Name}."); + } + + // extract values and build the array + var values = items.Select(item => property.GetValue(item, null)).ToArray(); + var array = BuildArrowArray(field.DataType, values); + arrays.Add(array); + } + return new RecordBatch(schema, arrays, items.Length); + } + private static Schema BuildSchema(Type type) + { + var builder = new Schema.Builder(); + + foreach (var property in type.GetProperties()) + { + var fieldType = InferArrowType(property.PropertyType); + builder.Field(f => f.Name(property.Name).DataType(fieldType).Nullable(true)); + } + + return builder.Build(); + } + + private static IArrowType InferArrowType(Type type) + { + return type switch + { + { } t when t == typeof(string) => StringType.Default, + { } t when t == typeof(int) => Int32Type.Default, + { } t when t == typeof(float) => FloatType.Default, + { } t when t == typeof(bool) => BooleanType.Default, + { } t when t == typeof(long) => Int64Type.Default, + _ => throw new NotSupportedException($"Unsupported type: {type}") + }; + } + + private static IArrowArray BuildArrowArray(IArrowType dataType, object[] values, MemoryAllocator allocator = default) + { + allocator ??= MemoryAllocator.Default.Value; + + return dataType switch + { + StringType => BuildStringArray(values), + Int32Type => BuildArray(values, allocator), + FloatType => BuildArray(values, allocator), + BooleanType => BuildArray(values, allocator), + Int64Type => BuildArray(values, allocator), + _ => throw new NotSupportedException($"Unsupported Arrow type: {dataType}") + }; + } + + private static IArrowArray BuildStringArray(object[] values) + { + var builder = new StringArray.Builder(); + + foreach (var value in values) + { + if (value is null) + { + builder.AppendNull(); + } + else + { + builder.Append(value.ToString()); + } + } + + return builder.Build(); + } + + private static IArrowArray BuildArray(object[] values, MemoryAllocator allocator) + where TArray : IArrowArray + where TBuilder : IArrowArrayBuilder, new() + { + var builder = new TBuilder(); + + foreach (var value in values) + { + if (value == null) + { + builder.AppendNull(); + } + else + { + builder.Append((T)value); + } + } + + return builder.Build(allocator); + } +} \ No newline at end of file diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs index b2abccd5f8639..a7aaad4fb2d84 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/TestFlightSqlServer.cs @@ -69,7 +69,7 @@ await responseStream.WriteAsync(new FlightResult(ByteString.CopyFromUtf8("sample var preparedStatementResponse = new ActionCreatePreparedStatementResult { - PreparedStatementHandle = ByteString.CopyFromUtf8("ample-testing-prepared-statement"), + PreparedStatementHandle = ByteString.CopyFromUtf8("sample-testing-prepared-statement"), DatasetSchema = ByteString.CopyFrom(datasetSchemaBytes), ParameterSchema = ByteString.CopyFrom(parameterSchemaBytes) }; From 96768cd0d449b117a8fa0fcb08147dcc7256e1c4 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 24 Nov 2024 12:26:01 +0200 Subject: [PATCH 55/58] chore: clear out the empty line --- cpp/src/arrow/flight/sql/client.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index 4d9793a9e27d1..9d3f0004ada9a 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -525,4 +525,4 @@ class ARROW_FLIGHT_SQL_EXPORT Transaction { } // namespace sql } // namespace flight -} // namespace arrow +} // namespace arrow \ No newline at end of file From 02a9cb138674a766286da889e81e6cbd1cbdb014 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Sun, 24 Nov 2024 12:30:45 +0200 Subject: [PATCH 56/58] chore: remove redundant change --- cpp/src/arrow/flight/sql/CMakeLists.txt | 879 ++++++++++++++++++++---- 1 file changed, 740 insertions(+), 139 deletions(-) diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index b32f731496749..2e82be0c68229 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -15,147 +15,748 @@ # specific language governing permissions and limitations # under the License. -add_custom_target(arrow_flight_sql) - -arrow_install_all_headers("arrow/flight/sql") - -set(FLIGHT_SQL_PROTO_PATH "${ARROW_SOURCE_DIR}/../format") -set(FLIGHT_SQL_PROTO ${ARROW_SOURCE_DIR}/../format/FlightSql.proto) - -set(FLIGHT_SQL_GENERATED_PROTO_FILES "${CMAKE_CURRENT_BINARY_DIR}/FlightSql.pb.cc" - "${CMAKE_CURRENT_BINARY_DIR}/FlightSql.pb.h") - -set(PROTO_DEPENDS ${FLIGHT_SQL_PROTO} ${ARROW_PROTOBUF_LIBPROTOBUF}) - -set(FLIGHT_SQL_PROTOC_COMMAND - ${ARROW_PROTOBUF_PROTOC} "-I${FLIGHT_SQL_PROTO_PATH}" - "--cpp_out=dllexport_decl=ARROW_FLIGHT_SQL_EXPORT:${CMAKE_CURRENT_BINARY_DIR}") -if(Protobuf_VERSION VERSION_LESS 3.15) - list(APPEND FLIGHT_SQL_PROTOC_COMMAND "--experimental_allow_proto3_optional") -endif() -list(APPEND FLIGHT_SQL_PROTOC_COMMAND "${FLIGHT_SQL_PROTO}") - -add_custom_command(OUTPUT ${FLIGHT_SQL_GENERATED_PROTO_FILES} - COMMAND ${FLIGHT_SQL_PROTOC_COMMAND} - DEPENDS ${PROTO_DEPENDS}) - -set_source_files_properties(${FLIGHT_SQL_GENERATED_PROTO_FILES} PROPERTIES GENERATED TRUE) -add_custom_target(flight_sql_protobuf_gen ALL DEPENDS ${FLIGHT_SQL_GENERATED_PROTO_FILES}) - -set(ARROW_FLIGHT_SQL_SRCS - server.cc - sql_info_internal.cc - column_metadata.cc - client.cc - protocol_internal.cc - server_session_middleware.cc) - -add_arrow_lib(arrow_flight_sql - CMAKE_PACKAGE_NAME - ArrowFlightSql - PKG_CONFIG_NAME - arrow-flight-sql - OUTPUTS - ARROW_FLIGHT_SQL_LIBRARIES - SOURCES - ${ARROW_FLIGHT_SQL_SRCS} - DEPENDENCIES - flight_sql_protobuf_gen - SHARED_LINK_FLAGS - ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt - SHARED_LINK_LIBS - arrow_flight_shared - SHARED_INSTALL_INTERFACE_LIBS - ArrowFlight::arrow_flight_shared - STATIC_LINK_LIBS - arrow_flight_static - STATIC_INSTALL_INTERFACE_LIBS - ArrowFlight::arrow_flight_static - PRIVATE_INCLUDES - "${Protobuf_INCLUDE_DIRS}") - -if(ARROW_BUILD_STATIC AND WIN32) - target_compile_definitions(arrow_flight_sql_static PUBLIC ARROW_FLIGHT_SQL_STATIC) -endif() - -if(MSVC) - # Suppress warnings caused by Protobuf (casts) - set_source_files_properties(protocol_internal.cc PROPERTIES COMPILE_FLAGS "/wd4267") -endif() -foreach(LIB_TARGET ${ARROW_FLIGHT_SQL_LIBRARIES}) - target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_FLIGHT_SQL_EXPORTING) -endforeach() - -if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static" AND ARROW_BUILD_STATIC) - set(ARROW_FLIGHT_SQL_TEST_LINK_LIBS arrow_flight_sql_static) +cmake_minimum_required(VERSION 3.16) +message(STATUS "Building using CMake version: ${CMAKE_VERSION}") + +# https://www.cmake.org/cmake/help/latest/policy/CMP0025.html +# +# Compiler id for Apple Clang is now AppleClang. +cmake_policy(SET CMP0025 NEW) + +# https://cmake.org/cmake/help/latest/policy/CMP0042.html +# +# Enable MACOSX_RPATH by default. @rpath in a target's install name is +# a more flexible and powerful mechanism than @executable_path or +# @loader_path for locating shared libraries. +cmake_policy(SET CMP0042 NEW) + +# https://www.cmake.org/cmake/help/latest/policy/CMP0054.html +# +# Only interpret if() arguments as variables or keywords when unquoted. +cmake_policy(SET CMP0054 NEW) + +# https://www.cmake.org/cmake/help/latest/policy/CMP0057.html +# +# Support new if() IN_LIST operator. +cmake_policy(SET CMP0057 NEW) + +# https://www.cmake.org/cmake/help/latest/policy/CMP0063.html +# +# Adapted from Apache Kudu: https://github.com/apache/kudu/commit/bd549e13743a51013585 +# Honor visibility properties for all target types. +cmake_policy(SET CMP0063 NEW) + +# https://cmake.org/cmake/help/latest/policy/CMP0068.html +# +# RPATH settings on macOS do not affect install_name. +cmake_policy(SET CMP0068 NEW) + +# https://cmake.org/cmake/help/latest/policy/CMP0074.html +# +# find_package() uses _ROOT variables. +cmake_policy(SET CMP0074 NEW) + +# https://cmake.org/cmake/help/latest/policy/CMP0091.html +# +# MSVC runtime library flags are selected by an abstraction. +cmake_policy(SET CMP0091 NEW) + +# https://cmake.org/cmake/help/latest/policy/CMP0135.html +# +# CMP0135 is for solving re-building and re-downloading. +# We don't have a real problem with the OLD behavior for now +# but we use the NEW behavior explicitly to suppress CMP0135 +# warnings. +if(POLICY CMP0135) + cmake_policy(SET CMP0135 NEW) +endif() + +# https://cmake.org/cmake/help/latest/policy/CMP0170.html +# +# CMP0170 is for enforcing dependency populations by users with +# FETCHCONTENT_FULLY_DISCONNECTED=ON. +if(POLICY CMP0170) + cmake_policy(SET CMP0170 NEW) +endif() + +set(ARROW_VERSION "19.0.0-SNAPSHOT") + +string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ARROW_BASE_VERSION "${ARROW_VERSION}") + +# if no build type is specified, default to release builds +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE + Release + CACHE STRING "Choose the type of build.") +endif() +string(TOLOWER ${CMAKE_BUILD_TYPE} LOWERCASE_BUILD_TYPE) +string(TOUPPER ${CMAKE_BUILD_TYPE} UPPERCASE_BUILD_TYPE) + +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules") + +# this must be included before the project() command, because of the way +# vcpkg (ab)uses CMAKE_TOOLCHAIN_FILE to inject its logic into CMake +if(ARROW_DEPENDENCY_SOURCE STREQUAL "VCPKG") + include(Usevcpkg) +endif() + +project(arrow VERSION "${ARROW_BASE_VERSION}") + +set(ARROW_VERSION_MAJOR "${arrow_VERSION_MAJOR}") +set(ARROW_VERSION_MINOR "${arrow_VERSION_MINOR}") +set(ARROW_VERSION_PATCH "${arrow_VERSION_PATCH}") +if(ARROW_VERSION_MAJOR STREQUAL "" + OR ARROW_VERSION_MINOR STREQUAL "" + OR ARROW_VERSION_PATCH STREQUAL "") + message(FATAL_ERROR "Failed to determine Arrow version from '${ARROW_VERSION}'") +endif() + +# The SO version is also the ABI version +if(ARROW_VERSION_MAJOR STREQUAL "0") + # Arrow 0.x.y => SO version is "x", full SO version is "x.y.0" + set(ARROW_SO_VERSION "${ARROW_VERSION_MINOR}") + set(ARROW_FULL_SO_VERSION "${ARROW_SO_VERSION}.${ARROW_VERSION_PATCH}.0") +else() + # Arrow 1.x.y => SO version is "10x", full SO version is "10x.y.0" + math(EXPR ARROW_SO_VERSION "${ARROW_VERSION_MAJOR} * 100 + ${ARROW_VERSION_MINOR}") + set(ARROW_FULL_SO_VERSION "${ARROW_SO_VERSION}.${ARROW_VERSION_PATCH}.0") +endif() + +message(STATUS "Arrow version: " + "${ARROW_VERSION_MAJOR}.${ARROW_VERSION_MINOR}.${ARROW_VERSION_PATCH} " + "(full: '${ARROW_VERSION}')") +message(STATUS "Arrow SO version: ${ARROW_SO_VERSION} (full: ${ARROW_FULL_SO_VERSION})") + +set(ARROW_SOURCE_DIR ${PROJECT_SOURCE_DIR}) +set(ARROW_BINARY_DIR ${PROJECT_BINARY_DIR}) + +include(CMakePackageConfigHelpers) +include(CMakeParseArguments) +include(ExternalProject) +include(FindPackageHandleStandardArgs) + +include(GNUInstallDirs) +if(IS_ABSOLUTE "${CMAKE_INSTALL_BINDIR}") + set(ARROW_PKG_CONFIG_BINDIR "${CMAKE_INSTALL_BINDIR}") +else() + set(ARROW_PKG_CONFIG_BINDIR "\${prefix}/${CMAKE_INSTALL_BINDIR}") +endif() +if(IS_ABSOLUTE "${CMAKE_INSTALL_INCLUDEDIR}") + set(ARROW_PKG_CONFIG_INCLUDEDIR "${CMAKE_INSTALL_INCLUDEDIR}") +else() + set(ARROW_PKG_CONFIG_INCLUDEDIR "\${prefix}/${CMAKE_INSTALL_INCLUDEDIR}") +endif() +if(IS_ABSOLUTE "${CMAKE_INSTALL_LIBDIR}") + set(ARROW_PKG_CONFIG_LIBDIR "${CMAKE_INSTALL_LIBDIR}") +else() + set(ARROW_PKG_CONFIG_LIBDIR "\${prefix}/${CMAKE_INSTALL_LIBDIR}") +endif() +set(ARROW_GDB_DIR "${CMAKE_INSTALL_DATADIR}/${PROJECT_NAME}/gdb") +set(ARROW_FULL_GDB_DIR "${CMAKE_INSTALL_FULL_DATADIR}/${PROJECT_NAME}/gdb") +set(ARROW_GDB_AUTO_LOAD_DIR "${CMAKE_INSTALL_DATADIR}/gdb/auto-load") +set(ARROW_CMAKE_DIR "${CMAKE_INSTALL_LIBDIR}/cmake") +set(ARROW_DOC_DIR "share/doc/${PROJECT_NAME}") + +set(BUILD_SUPPORT_DIR "${CMAKE_SOURCE_DIR}/build-support") + +set(ARROW_LLVM_VERSIONS + "19.1" + "18.1" + "17.0" + "16.0" + "15.0" + "14.0" + "13.0" + "12.0" + "11.1" + "11.0" + "10" + "9" + "8" + "7") + +file(READ ${CMAKE_CURRENT_SOURCE_DIR}/../.env ARROW_ENV) +string(REGEX MATCH "CLANG_TOOLS=[^\n]+" ARROW_ENV_CLANG_TOOLS_VERSION "${ARROW_ENV}") +string(REGEX REPLACE "^CLANG_TOOLS=" "" ARROW_CLANG_TOOLS_VERSION + "${ARROW_ENV_CLANG_TOOLS_VERSION}") +string(REGEX REPLACE "^([0-9]+)(\\..+)?" "\\1" ARROW_CLANG_TOOLS_VERSION_MAJOR + "${ARROW_CLANG_TOOLS_VERSION}") + +if(WIN32 AND NOT MINGW) + # This is used to handle builds using e.g. clang in an MSVC setting. + set(MSVC_TOOLCHAIN TRUE) +else() + set(MSVC_TOOLCHAIN FALSE) +endif() + +find_package(ClangTools) +find_package(InferTools) +if("$ENV{CMAKE_EXPORT_COMPILE_COMMANDS}" STREQUAL "1" + OR CLANG_TIDY_FOUND + OR INFER_FOUND) + # Generate a Clang compile_commands.json "compilation database" file for use + # with various development tools, such as Vim's YouCompleteMe plugin. + # See http://clang.llvm.org/docs/JSONCompilationDatabase.html + set(CMAKE_EXPORT_COMPILE_COMMANDS 1) +endif() + +# Needed for linting targets, etc. +# Use the first Python installation on PATH, not the newest one +set(Python3_FIND_STRATEGY "LOCATION") +# On Windows, use registry last, not first +set(Python3_FIND_REGISTRY "LAST") +# On macOS, use framework last, not first +set(Python3_FIND_FRAMEWORK "LAST") + +find_package(Python3) +set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE}) + +# ---------------------------------------------------------------------- +# cmake options +include(DefineOptions) + +if(ARROW_BUILD_SHARED AND NOT ARROW_POSITION_INDEPENDENT_CODE) + message(WARNING "Can't disable position-independent code to build shared libraries, enabling" + ) + set(ARROW_POSITION_INDEPENDENT_CODE ON) +endif() + +if(ARROW_USE_SCCACHE + AND NOT CMAKE_C_COMPILER_LAUNCHER + AND NOT CMAKE_CXX_COMPILER_LAUNCHER) + + find_program(SCCACHE_FOUND sccache) + + if(NOT SCCACHE_FOUND AND DEFINED ENV{SCCACHE_PATH}) + # cmake has problems finding sccache from within mingw + message(STATUS "Did not find sccache, using envvar fallback.") + set(SCCACHE_FOUND $ENV{SCCACHE_PATH}) + endif() + + # Only use sccache if a storage backend is configured + if(SCCACHE_FOUND + AND (DEFINED ENV{SCCACHE_AZURE_BLOB_CONTAINER} + OR DEFINED ENV{SCCACHE_BUCKET} + OR DEFINED ENV{SCCACHE_DIR} + OR DEFINED ENV{SCCACHE_GCS_BUCKET} + OR DEFINED ENV{SCCACHE_MEMCACHED} + OR DEFINED ENV{SCCACHE_REDIS} + )) + message(STATUS "Using sccache: ${SCCACHE_FOUND}") + set(CMAKE_C_COMPILER_LAUNCHER ${SCCACHE_FOUND}) + set(CMAKE_CXX_COMPILER_LAUNCHER ${SCCACHE_FOUND}) + endif() +endif() + +if(ARROW_USE_CCACHE + AND NOT CMAKE_C_COMPILER_LAUNCHER + AND NOT CMAKE_CXX_COMPILER_LAUNCHER) + + find_program(CCACHE_FOUND ccache) + + if(CCACHE_FOUND) + message(STATUS "Using ccache: ${CCACHE_FOUND}") + set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_FOUND}) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_FOUND}) + # ARROW-3985: let ccache preserve C++ comments, because some of them may be + # meaningful to the compiler + set(ENV{CCACHE_COMMENTS} "1") + endif() +endif() + +if(ARROW_OPTIONAL_INSTALL) + set(INSTALL_IS_OPTIONAL OPTIONAL) +endif() + +# +# "make lint" target +# +if(NOT ARROW_VERBOSE_LINT) + set(ARROW_LINT_QUIET "--quiet") +endif() + +if(NOT LINT_EXCLUSIONS_FILE) + # source files matching a glob from a line in this file + # will be excluded from linting (cpplint, clang-tidy, clang-format) + set(LINT_EXCLUSIONS_FILE ${BUILD_SUPPORT_DIR}/lint_exclusions.txt) +endif() + +find_program(CPPLINT_BIN + NAMES cpplint cpplint.py + HINTS ${BUILD_SUPPORT_DIR}) +message(STATUS "Found cpplint executable at ${CPPLINT_BIN}") + +set(COMMON_LINT_OPTIONS + --exclude_globs + ${LINT_EXCLUSIONS_FILE} + --source_dir + ${CMAKE_CURRENT_SOURCE_DIR}/src + --source_dir + ${CMAKE_CURRENT_SOURCE_DIR}/examples + --source_dir + ${CMAKE_CURRENT_SOURCE_DIR}/tools) + +add_custom_target(lint + ${PYTHON_EXECUTABLE} + ${BUILD_SUPPORT_DIR}/run_cpplint.py + --cpplint_binary + ${CPPLINT_BIN} + ${COMMON_LINT_OPTIONS} + ${ARROW_LINT_QUIET} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..) + +# +# "make format" and "make check-format" targets +# +if(${CLANG_FORMAT_FOUND}) + # runs clang format and updates files in place. + add_custom_target(format + ${PYTHON_EXECUTABLE} + ${BUILD_SUPPORT_DIR}/run_clang_format.py + --clang_format_binary + ${CLANG_FORMAT_BIN} + ${COMMON_LINT_OPTIONS} + --fix + ${ARROW_LINT_QUIET}) + + # runs clang format and exits with a non-zero exit code if any files need to be reformatted + add_custom_target(check-format + ${PYTHON_EXECUTABLE} + ${BUILD_SUPPORT_DIR}/run_clang_format.py + --clang_format_binary + ${CLANG_FORMAT_BIN} + ${COMMON_LINT_OPTIONS} + ${ARROW_LINT_QUIET}) +endif() + +add_custom_target(lint_cpp_cli ${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/lint_cpp_cli.py + ${CMAKE_CURRENT_SOURCE_DIR}/src) + +if(ARROW_LINT_ONLY) + message("ARROW_LINT_ONLY was specified, this is only a partial build directory") + return() +endif() + +# +# "make clang-tidy" and "make check-clang-tidy" targets +# +if(${CLANG_TIDY_FOUND}) + # TODO check to make sure .clang-tidy is being respected + + # runs clang-tidy and attempts to fix any warning automatically + add_custom_target(clang-tidy + ${PYTHON_EXECUTABLE} + ${BUILD_SUPPORT_DIR}/run_clang_tidy.py + --clang_tidy_binary + ${CLANG_TIDY_BIN} + --compile_commands + ${CMAKE_BINARY_DIR}/compile_commands.json + ${COMMON_LINT_OPTIONS} + --fix + ${ARROW_LINT_QUIET}) + + # runs clang-tidy and exits with a non-zero exit code if any errors are found. + add_custom_target(check-clang-tidy + ${PYTHON_EXECUTABLE} + ${BUILD_SUPPORT_DIR}/run_clang_tidy.py + --clang_tidy_binary + ${CLANG_TIDY_BIN} + --compile_commands + ${CMAKE_BINARY_DIR}/compile_commands.json + ${COMMON_LINT_OPTIONS} + ${ARROW_LINT_QUIET}) +endif() + +if(UNIX) + add_custom_target(iwyu + ${CMAKE_COMMAND} + -E + env + "PYTHON=${PYTHON_EXECUTABLE}" + ${BUILD_SUPPORT_DIR}/iwyu/iwyu.sh) + add_custom_target(iwyu-all + ${CMAKE_COMMAND} + -E + env + "PYTHON=${PYTHON_EXECUTABLE}" + ${BUILD_SUPPORT_DIR}/iwyu/iwyu.sh + all) +endif(UNIX) + +# datetime code used by iOS requires zlib support +if(IOS) + set(ARROW_WITH_ZLIB ON) +endif() + +if(NOT ARROW_BUILD_TESTS) + set(NO_TESTS 1) +else() + add_custom_target(all-tests) + add_custom_target(unittest + ctest + -j4 + -L + unittest + --output-on-failure) + add_dependencies(unittest all-tests) +endif() + +if(ARROW_ENABLE_TIMING_TESTS) + add_definitions(-DARROW_WITH_TIMING_TESTS) +endif() + +if(NOT ARROW_BUILD_BENCHMARKS) + set(NO_BENCHMARKS 1) else() - set(ARROW_FLIGHT_SQL_TEST_LINK_LIBS arrow_flight_sql_shared) -endif() -list(APPEND ARROW_FLIGHT_SQL_TEST_LINK_LIBS ${ARROW_FLIGHT_TEST_LINK_LIBS}) - -# Build test server for unit tests -if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) - find_package(SQLite3Alt REQUIRED) - - set(ARROW_FLIGHT_SQL_TEST_SERVER_SRCS - example/sqlite_sql_info.cc - example/sqlite_type_info.cc - example/sqlite_statement.cc - example/sqlite_statement_batch_reader.cc - example/sqlite_server.cc - example/sqlite_tables_schema_batch_reader.cc) - - set(ARROW_FLIGHT_SQL_TEST_SRCS server_test.cc - server_session_middleware_internals_test.cc) - - set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES}) - set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc) - - if(ARROW_COMPUTE - AND ARROW_PARQUET - AND ARROW_SUBSTRAIT) - list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS ${ARROW_FLIGHT_SQL_ACERO_SRCS} acero_test.cc) - if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") - list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_static) - else() - list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_shared) - endif() - - if(ARROW_BUILD_EXAMPLES) - add_executable(acero-flight-sql-server ${ARROW_FLIGHT_SQL_ACERO_SRCS} - example/acero_main.cc) - target_link_libraries(acero-flight-sql-server - PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} - ${ARROW_FLIGHT_SQL_TEST_LIBS} ${GFLAGS_LIBRARIES}) - endif() + add_custom_target(all-benchmarks) + add_custom_target(benchmark ctest -L benchmark) + add_dependencies(benchmark all-benchmarks) + if(ARROW_BUILD_BENCHMARKS_REFERENCE) + add_definitions(-DARROW_WITH_BENCHMARKS_REFERENCE) + endif() +endif() + +if(NOT ARROW_BUILD_EXAMPLES) + set(NO_EXAMPLES 1) +endif() + +if(ARROW_FUZZING) + # Fuzzing builds enable ASAN without setting our home-grown option for it. + add_definitions(-DADDRESS_SANITIZER) +endif() + +if(ARROW_LARGE_MEMORY_TESTS) + add_definitions(-DARROW_LARGE_MEMORY_TESTS) +endif() + +if(ARROW_TEST_MEMCHECK) + add_definitions(-DARROW_VALGRIND) +endif() + +if(ARROW_USE_UBSAN) + add_definitions(-DARROW_UBSAN) +endif() + +# +# Compiler flags +# + +if(ARROW_EXTRA_ERROR_CONTEXT) + add_definitions(-DARROW_EXTRA_ERROR_CONTEXT) +endif() + +include(SetupCxxFlags) + +# +# Linker flags +# + +# Localize thirdparty symbols using a linker version script. This hides them +# from the client application. The OS X linker does not support the +# version-script option. +if(CMAKE_VERSION VERSION_LESS 3.18) + if(APPLE OR WIN32) + set(CXX_LINKER_SUPPORTS_VERSION_SCRIPT FALSE) + else() + set(CXX_LINKER_SUPPORTS_VERSION_SCRIPT TRUE) endif() +else() + include(CheckLinkerFlag) + check_linker_flag(CXX + "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/src/arrow/symbols.map" + CXX_LINKER_SUPPORTS_VERSION_SCRIPT) +endif() + +# +# Build output directory +# - add_arrow_test(flight_sql_test - SOURCES - ${ARROW_FLIGHT_SQL_TEST_SRCS} - ${ARROW_FLIGHT_SQL_TEST_SERVER_SRCS} - STATIC_LINK_LIBS - ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} - ${ARROW_FLIGHT_SQL_TEST_LIBS} - EXTRA_INCLUDES - "${CMAKE_CURRENT_BINARY_DIR}/../" - LABELS - "arrow_flight_sql") - - add_executable(flight-sql-test-server test_server_cli.cc - ${ARROW_FLIGHT_SQL_TEST_SERVER_SRCS}) - target_link_libraries(flight-sql-test-server - PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} - ${SQLite3_LIBRARIES}) - - add_executable(flight-sql-test-app test_app_cli.cc) - target_link_libraries(flight-sql-test-app PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} - ${GFLAGS_LIBRARIES}) - - if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static" AND ARROW_BUILD_STATIC) - foreach(TEST_TARGET arrow-flight-sql-test flight-sql-test-server flight-sql-test-app) - target_compile_definitions(${TEST_TARGET} PUBLIC ARROW_FLIGHT_STATIC - ARROW_FLIGHT_SQL_STATIC) - endforeach() +# set compile output directory +string(TOLOWER ${CMAKE_BUILD_TYPE} BUILD_SUBDIR_NAME) + +# If build in-source, create the latest symlink. If build out-of-source, which is +# preferred, simply output the binaries in the build folder +if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_CURRENT_BINARY_DIR}) + set(BUILD_OUTPUT_ROOT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/build/${BUILD_SUBDIR_NAME}/") + # Link build/latest to the current build directory, to avoid developers + # accidentally running the latest debug build when in fact they're building + # release builds. + file(MAKE_DIRECTORY ${BUILD_OUTPUT_ROOT_DIRECTORY}) + if(NOT APPLE) + set(MORE_ARGS "-T") endif() + execute_process(COMMAND ln ${MORE_ARGS} -sf ${BUILD_OUTPUT_ROOT_DIRECTORY} + ${CMAKE_CURRENT_BINARY_DIR}/build/latest) +else() + set(BUILD_OUTPUT_ROOT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${BUILD_SUBDIR_NAME}/") +endif() + +# where to put generated archives (.a files) +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}") +set(ARCHIVE_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}") + +# where to put generated libraries (.so files) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}") +set(LIBRARY_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}") + +# where to put generated binaries +set(EXECUTABLE_OUTPUT_PATH "${BUILD_OUTPUT_ROOT_DIRECTORY}") + +if(CMAKE_GENERATOR STREQUAL Xcode) + # Xcode projects support multi-configuration builds. This forces a single output directory + # when building with Xcode that is consistent with single-configuration Makefile driven build. + set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_${UPPERCASE_BUILD_TYPE} + "${BUILD_OUTPUT_ROOT_DIRECTORY}") + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_${UPPERCASE_BUILD_TYPE} + "${BUILD_OUTPUT_ROOT_DIRECTORY}") + set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${UPPERCASE_BUILD_TYPE} + "${BUILD_OUTPUT_ROOT_DIRECTORY}") +endif() + +# +# Dependencies +# + +include(BuildUtils) +enable_testing() + +# For arrow.pc. Cflags.private, Libs.private and Requires.private are +# used when "pkg-config --cflags --libs --static arrow" is used. +set(ARROW_PC_CFLAGS "") +set(ARROW_PC_CFLAGS_PRIVATE " -DARROW_STATIC") +set(ARROW_PC_LIBS_PRIVATE "") +set(ARROW_PC_REQUIRES_PRIVATE "") + +# For arrow-flight.pc. +set(ARROW_FLIGHT_PC_REQUIRES_PRIVATE "") + +# For arrow-testing.pc. +set(ARROW_TESTING_PC_CFLAGS "") +set(ARROW_TESTING_PC_CFLAGS_PRIVATE " -DARROW_TESTING_STATIC") +set(ARROW_TESTING_PC_LIBS "") +set(ARROW_TESTING_PC_REQUIRES "") + +# For parquet.pc. +set(PARQUET_PC_CFLAGS "") +set(PARQUET_PC_CFLAGS_PRIVATE " -DPARQUET_STATIC") +set(PARQUET_PC_REQUIRES "") +set(PARQUET_PC_REQUIRES_PRIVATE "") + +include(ThirdpartyToolchain) + +# Add common flags +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_COMMON_FLAGS}") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARROW_CXXFLAGS}") + +# For any C code, use the same flags. These flags don't contain +# C++ specific flags. +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${CXX_COMMON_FLAGS} ${ARROW_CXXFLAGS}") + +# Remove --std=c++17 to avoid errors from C compilers +string(REPLACE "-std=c++17" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) + +# Add C++-only flags, like -std=c++17 +set(CMAKE_CXX_FLAGS "${CXX_ONLY_FLAGS} ${CMAKE_CXX_FLAGS}") + +# ASAN / TSAN / UBSAN +if(ARROW_FUZZING) + set(ARROW_USE_COVERAGE ON) +endif() +include(san-config) + +# Code coverage +if("${ARROW_GENERATE_COVERAGE}") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} --coverage -DCOVERAGE_BUILD") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --coverage -DCOVERAGE_BUILD") endif() + +# CMAKE_CXX_FLAGS now fully assembled +message(STATUS "CMAKE_C_FLAGS: ${CMAKE_C_FLAGS}") +message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") +message(STATUS "CMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}: ${CMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}}" +) +message(STATUS "CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}: ${CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}}" +) + +include_directories(${CMAKE_CURRENT_BINARY_DIR}/src) +include_directories(src) + +# Compiled flatbuffers files +include_directories(src/generated) + +# +# Visibility +# +if(PARQUET_BUILD_SHARED) + set_target_properties(arrow_shared + PROPERTIES C_VISIBILITY_PRESET hidden + CXX_VISIBILITY_PRESET hidden + VISIBILITY_INLINES_HIDDEN 1) +endif() + +# +# "make ctags" target +# +if(UNIX) + add_custom_target(ctags ctags -R --languages=c++,c) +endif(UNIX) + +# +# "make etags" target +# +if(UNIX) + add_custom_target(tags + etags + --members + --declarations + `find + ${CMAKE_CURRENT_SOURCE_DIR}/src + -name + \\*.cc + -or + -name + \\*.hh + -or + -name + \\*.cpp + -or + -name + \\*.h + -or + -name + \\*.c + -or + -name + \\*.f`) + add_custom_target(etags DEPENDS tags) +endif(UNIX) + +# +# "make cscope" target +# +if(UNIX) + add_custom_target(cscope + find + ${CMAKE_CURRENT_SOURCE_DIR} + (-name + \\*.cc + -or + -name + \\*.hh + -or + -name + \\*.cpp + -or + -name + \\*.h + -or + -name + \\*.c + -or + -name + \\*.f) + -exec + echo + \"{}\" + \; + > + cscope.files + && + cscope + -q + -b + VERBATIM) +endif(UNIX) + +# +# "make infer" target +# + +if(${INFER_FOUND}) + # runs infer capture + add_custom_target(infer ${BUILD_SUPPORT_DIR}/run-infer.sh ${INFER_BIN} + ${CMAKE_BINARY_DIR}/compile_commands.json 1) + # runs infer analyze + add_custom_target(infer-analyze ${BUILD_SUPPORT_DIR}/run-infer.sh ${INFER_BIN} + ${CMAKE_BINARY_DIR}/compile_commands.json 2) + # runs infer report + add_custom_target(infer-report ${BUILD_SUPPORT_DIR}/run-infer.sh ${INFER_BIN} + ${CMAKE_BINARY_DIR}/compile_commands.json 3) +endif() + +# +# Link targets +# + +if("${ARROW_TEST_LINKAGE}" STREQUAL "shared") + if(ARROW_BUILD_TESTS AND NOT ARROW_BUILD_SHARED) + message(FATAL_ERROR "If using ARROW_TEST_LINKAGE=shared, must also \ +pass ARROW_BUILD_SHARED=on") + endif() + # Use shared linking for unit tests if it's available + set(ARROW_TEST_LINK_LIBS arrow_testing_shared) + set(ARROW_EXAMPLE_LINK_LIBS arrow_shared) +else() + if(ARROW_BUILD_TESTS AND NOT ARROW_BUILD_STATIC) + message(FATAL_ERROR "If using static linkage for unit tests, must also \ +pass ARROW_BUILD_STATIC=on") + endif() + set(ARROW_TEST_LINK_LIBS arrow_testing_static) + set(ARROW_EXAMPLE_LINK_LIBS arrow_static) +endif() +# arrow::flatbuffers isn't needed for all tests but we specify it as +# the first link library. It's for prioritizing bundled FlatBuffers +# than system FlatBuffers. +list(PREPEND ARROW_TEST_LINK_LIBS arrow::flatbuffers) +list(APPEND ARROW_TEST_LINK_LIBS ${ARROW_GTEST_GMOCK} ${ARROW_GTEST_GTEST_MAIN}) + +if(ARROW_BUILD_BENCHMARKS) + set(ARROW_BENCHMARK_LINK_LIBS benchmark::benchmark_main ${ARROW_TEST_LINK_LIBS}) + if(WIN32) + list(APPEND ARROW_BENCHMARK_LINK_LIBS shlwapi) + endif() +endif() + +# +# Subdirectories +# + +add_subdirectory(src/arrow) + +if(ARROW_PARQUET) + add_subdirectory(src/parquet) + add_subdirectory(tools/parquet) + if(PARQUET_BUILD_EXAMPLES) + add_subdirectory(examples/parquet) + endif() +endif() + +if(ARROW_GANDIVA) + add_subdirectory(src/gandiva) +endif() + +if(ARROW_SKYHOOK) + add_subdirectory(src/skyhook) +endif() + +if(ARROW_BUILD_EXAMPLES) + add_custom_target(runexample ctest -L example) + add_subdirectory(examples/arrow) +endif() + +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/../LICENSE.txt + ${CMAKE_CURRENT_SOURCE_DIR}/../NOTICE.txt + ${CMAKE_CURRENT_SOURCE_DIR}/README.md DESTINATION "${ARROW_DOC_DIR}") + +install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/gdb_arrow.py DESTINATION "${ARROW_GDB_DIR}") + +# +# Validate and print out Arrow configuration options +# + +validate_config() +config_summary_message() +if(${ARROW_BUILD_CONFIG_SUMMARY_JSON}) + config_summary_json() +endif() \ No newline at end of file From d4afd4e13f7b745895048ec02f6f410fc348aeea Mon Sep 17 00:00:00 2001 From: HackPoint Date: Thu, 28 Nov 2024 12:32:27 +0200 Subject: [PATCH 57/58] chore: revert changes --- cpp/CMakeLists.txt | 2 +- cpp/src/arrow/flight/sql/CMakeLists.txt | 879 ++++-------------------- cpp/src/arrow/flight/sql/client.h | 3 +- 3 files changed, 142 insertions(+), 742 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 2e82be0c68229..97cbb74d1ffda 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -759,4 +759,4 @@ validate_config() config_summary_message() if(${ARROW_BUILD_CONFIG_SUMMARY_JSON}) config_summary_json() -endif() \ No newline at end of file +endif() diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index 2e82be0c68229..b32f731496749 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -15,748 +15,147 @@ # specific language governing permissions and limitations # under the License. -cmake_minimum_required(VERSION 3.16) -message(STATUS "Building using CMake version: ${CMAKE_VERSION}") - -# https://www.cmake.org/cmake/help/latest/policy/CMP0025.html -# -# Compiler id for Apple Clang is now AppleClang. -cmake_policy(SET CMP0025 NEW) - -# https://cmake.org/cmake/help/latest/policy/CMP0042.html -# -# Enable MACOSX_RPATH by default. @rpath in a target's install name is -# a more flexible and powerful mechanism than @executable_path or -# @loader_path for locating shared libraries. -cmake_policy(SET CMP0042 NEW) - -# https://www.cmake.org/cmake/help/latest/policy/CMP0054.html -# -# Only interpret if() arguments as variables or keywords when unquoted. -cmake_policy(SET CMP0054 NEW) - -# https://www.cmake.org/cmake/help/latest/policy/CMP0057.html -# -# Support new if() IN_LIST operator. -cmake_policy(SET CMP0057 NEW) - -# https://www.cmake.org/cmake/help/latest/policy/CMP0063.html -# -# Adapted from Apache Kudu: https://github.com/apache/kudu/commit/bd549e13743a51013585 -# Honor visibility properties for all target types. -cmake_policy(SET CMP0063 NEW) - -# https://cmake.org/cmake/help/latest/policy/CMP0068.html -# -# RPATH settings on macOS do not affect install_name. -cmake_policy(SET CMP0068 NEW) - -# https://cmake.org/cmake/help/latest/policy/CMP0074.html -# -# find_package() uses _ROOT variables. -cmake_policy(SET CMP0074 NEW) - -# https://cmake.org/cmake/help/latest/policy/CMP0091.html -# -# MSVC runtime library flags are selected by an abstraction. -cmake_policy(SET CMP0091 NEW) - -# https://cmake.org/cmake/help/latest/policy/CMP0135.html -# -# CMP0135 is for solving re-building and re-downloading. -# We don't have a real problem with the OLD behavior for now -# but we use the NEW behavior explicitly to suppress CMP0135 -# warnings. -if(POLICY CMP0135) - cmake_policy(SET CMP0135 NEW) -endif() - -# https://cmake.org/cmake/help/latest/policy/CMP0170.html -# -# CMP0170 is for enforcing dependency populations by users with -# FETCHCONTENT_FULLY_DISCONNECTED=ON. -if(POLICY CMP0170) - cmake_policy(SET CMP0170 NEW) -endif() - -set(ARROW_VERSION "19.0.0-SNAPSHOT") - -string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ARROW_BASE_VERSION "${ARROW_VERSION}") - -# if no build type is specified, default to release builds -if(NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE - Release - CACHE STRING "Choose the type of build.") -endif() -string(TOLOWER ${CMAKE_BUILD_TYPE} LOWERCASE_BUILD_TYPE) -string(TOUPPER ${CMAKE_BUILD_TYPE} UPPERCASE_BUILD_TYPE) - -list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake_modules") - -# this must be included before the project() command, because of the way -# vcpkg (ab)uses CMAKE_TOOLCHAIN_FILE to inject its logic into CMake -if(ARROW_DEPENDENCY_SOURCE STREQUAL "VCPKG") - include(Usevcpkg) -endif() - -project(arrow VERSION "${ARROW_BASE_VERSION}") - -set(ARROW_VERSION_MAJOR "${arrow_VERSION_MAJOR}") -set(ARROW_VERSION_MINOR "${arrow_VERSION_MINOR}") -set(ARROW_VERSION_PATCH "${arrow_VERSION_PATCH}") -if(ARROW_VERSION_MAJOR STREQUAL "" - OR ARROW_VERSION_MINOR STREQUAL "" - OR ARROW_VERSION_PATCH STREQUAL "") - message(FATAL_ERROR "Failed to determine Arrow version from '${ARROW_VERSION}'") -endif() - -# The SO version is also the ABI version -if(ARROW_VERSION_MAJOR STREQUAL "0") - # Arrow 0.x.y => SO version is "x", full SO version is "x.y.0" - set(ARROW_SO_VERSION "${ARROW_VERSION_MINOR}") - set(ARROW_FULL_SO_VERSION "${ARROW_SO_VERSION}.${ARROW_VERSION_PATCH}.0") -else() - # Arrow 1.x.y => SO version is "10x", full SO version is "10x.y.0" - math(EXPR ARROW_SO_VERSION "${ARROW_VERSION_MAJOR} * 100 + ${ARROW_VERSION_MINOR}") - set(ARROW_FULL_SO_VERSION "${ARROW_SO_VERSION}.${ARROW_VERSION_PATCH}.0") -endif() - -message(STATUS "Arrow version: " - "${ARROW_VERSION_MAJOR}.${ARROW_VERSION_MINOR}.${ARROW_VERSION_PATCH} " - "(full: '${ARROW_VERSION}')") -message(STATUS "Arrow SO version: ${ARROW_SO_VERSION} (full: ${ARROW_FULL_SO_VERSION})") - -set(ARROW_SOURCE_DIR ${PROJECT_SOURCE_DIR}) -set(ARROW_BINARY_DIR ${PROJECT_BINARY_DIR}) - -include(CMakePackageConfigHelpers) -include(CMakeParseArguments) -include(ExternalProject) -include(FindPackageHandleStandardArgs) - -include(GNUInstallDirs) -if(IS_ABSOLUTE "${CMAKE_INSTALL_BINDIR}") - set(ARROW_PKG_CONFIG_BINDIR "${CMAKE_INSTALL_BINDIR}") -else() - set(ARROW_PKG_CONFIG_BINDIR "\${prefix}/${CMAKE_INSTALL_BINDIR}") -endif() -if(IS_ABSOLUTE "${CMAKE_INSTALL_INCLUDEDIR}") - set(ARROW_PKG_CONFIG_INCLUDEDIR "${CMAKE_INSTALL_INCLUDEDIR}") -else() - set(ARROW_PKG_CONFIG_INCLUDEDIR "\${prefix}/${CMAKE_INSTALL_INCLUDEDIR}") -endif() -if(IS_ABSOLUTE "${CMAKE_INSTALL_LIBDIR}") - set(ARROW_PKG_CONFIG_LIBDIR "${CMAKE_INSTALL_LIBDIR}") -else() - set(ARROW_PKG_CONFIG_LIBDIR "\${prefix}/${CMAKE_INSTALL_LIBDIR}") -endif() -set(ARROW_GDB_DIR "${CMAKE_INSTALL_DATADIR}/${PROJECT_NAME}/gdb") -set(ARROW_FULL_GDB_DIR "${CMAKE_INSTALL_FULL_DATADIR}/${PROJECT_NAME}/gdb") -set(ARROW_GDB_AUTO_LOAD_DIR "${CMAKE_INSTALL_DATADIR}/gdb/auto-load") -set(ARROW_CMAKE_DIR "${CMAKE_INSTALL_LIBDIR}/cmake") -set(ARROW_DOC_DIR "share/doc/${PROJECT_NAME}") - -set(BUILD_SUPPORT_DIR "${CMAKE_SOURCE_DIR}/build-support") - -set(ARROW_LLVM_VERSIONS - "19.1" - "18.1" - "17.0" - "16.0" - "15.0" - "14.0" - "13.0" - "12.0" - "11.1" - "11.0" - "10" - "9" - "8" - "7") - -file(READ ${CMAKE_CURRENT_SOURCE_DIR}/../.env ARROW_ENV) -string(REGEX MATCH "CLANG_TOOLS=[^\n]+" ARROW_ENV_CLANG_TOOLS_VERSION "${ARROW_ENV}") -string(REGEX REPLACE "^CLANG_TOOLS=" "" ARROW_CLANG_TOOLS_VERSION - "${ARROW_ENV_CLANG_TOOLS_VERSION}") -string(REGEX REPLACE "^([0-9]+)(\\..+)?" "\\1" ARROW_CLANG_TOOLS_VERSION_MAJOR - "${ARROW_CLANG_TOOLS_VERSION}") - -if(WIN32 AND NOT MINGW) - # This is used to handle builds using e.g. clang in an MSVC setting. - set(MSVC_TOOLCHAIN TRUE) -else() - set(MSVC_TOOLCHAIN FALSE) -endif() - -find_package(ClangTools) -find_package(InferTools) -if("$ENV{CMAKE_EXPORT_COMPILE_COMMANDS}" STREQUAL "1" - OR CLANG_TIDY_FOUND - OR INFER_FOUND) - # Generate a Clang compile_commands.json "compilation database" file for use - # with various development tools, such as Vim's YouCompleteMe plugin. - # See http://clang.llvm.org/docs/JSONCompilationDatabase.html - set(CMAKE_EXPORT_COMPILE_COMMANDS 1) -endif() - -# Needed for linting targets, etc. -# Use the first Python installation on PATH, not the newest one -set(Python3_FIND_STRATEGY "LOCATION") -# On Windows, use registry last, not first -set(Python3_FIND_REGISTRY "LAST") -# On macOS, use framework last, not first -set(Python3_FIND_FRAMEWORK "LAST") - -find_package(Python3) -set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE}) - -# ---------------------------------------------------------------------- -# cmake options -include(DefineOptions) - -if(ARROW_BUILD_SHARED AND NOT ARROW_POSITION_INDEPENDENT_CODE) - message(WARNING "Can't disable position-independent code to build shared libraries, enabling" - ) - set(ARROW_POSITION_INDEPENDENT_CODE ON) -endif() - -if(ARROW_USE_SCCACHE - AND NOT CMAKE_C_COMPILER_LAUNCHER - AND NOT CMAKE_CXX_COMPILER_LAUNCHER) - - find_program(SCCACHE_FOUND sccache) - - if(NOT SCCACHE_FOUND AND DEFINED ENV{SCCACHE_PATH}) - # cmake has problems finding sccache from within mingw - message(STATUS "Did not find sccache, using envvar fallback.") - set(SCCACHE_FOUND $ENV{SCCACHE_PATH}) - endif() - - # Only use sccache if a storage backend is configured - if(SCCACHE_FOUND - AND (DEFINED ENV{SCCACHE_AZURE_BLOB_CONTAINER} - OR DEFINED ENV{SCCACHE_BUCKET} - OR DEFINED ENV{SCCACHE_DIR} - OR DEFINED ENV{SCCACHE_GCS_BUCKET} - OR DEFINED ENV{SCCACHE_MEMCACHED} - OR DEFINED ENV{SCCACHE_REDIS} - )) - message(STATUS "Using sccache: ${SCCACHE_FOUND}") - set(CMAKE_C_COMPILER_LAUNCHER ${SCCACHE_FOUND}) - set(CMAKE_CXX_COMPILER_LAUNCHER ${SCCACHE_FOUND}) - endif() -endif() - -if(ARROW_USE_CCACHE - AND NOT CMAKE_C_COMPILER_LAUNCHER - AND NOT CMAKE_CXX_COMPILER_LAUNCHER) - - find_program(CCACHE_FOUND ccache) - - if(CCACHE_FOUND) - message(STATUS "Using ccache: ${CCACHE_FOUND}") - set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_FOUND}) - set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_FOUND}) - # ARROW-3985: let ccache preserve C++ comments, because some of them may be - # meaningful to the compiler - set(ENV{CCACHE_COMMENTS} "1") - endif() -endif() - -if(ARROW_OPTIONAL_INSTALL) - set(INSTALL_IS_OPTIONAL OPTIONAL) -endif() - -# -# "make lint" target -# -if(NOT ARROW_VERBOSE_LINT) - set(ARROW_LINT_QUIET "--quiet") -endif() - -if(NOT LINT_EXCLUSIONS_FILE) - # source files matching a glob from a line in this file - # will be excluded from linting (cpplint, clang-tidy, clang-format) - set(LINT_EXCLUSIONS_FILE ${BUILD_SUPPORT_DIR}/lint_exclusions.txt) -endif() - -find_program(CPPLINT_BIN - NAMES cpplint cpplint.py - HINTS ${BUILD_SUPPORT_DIR}) -message(STATUS "Found cpplint executable at ${CPPLINT_BIN}") - -set(COMMON_LINT_OPTIONS - --exclude_globs - ${LINT_EXCLUSIONS_FILE} - --source_dir - ${CMAKE_CURRENT_SOURCE_DIR}/src - --source_dir - ${CMAKE_CURRENT_SOURCE_DIR}/examples - --source_dir - ${CMAKE_CURRENT_SOURCE_DIR}/tools) - -add_custom_target(lint - ${PYTHON_EXECUTABLE} - ${BUILD_SUPPORT_DIR}/run_cpplint.py - --cpplint_binary - ${CPPLINT_BIN} - ${COMMON_LINT_OPTIONS} - ${ARROW_LINT_QUIET} - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/..) - -# -# "make format" and "make check-format" targets -# -if(${CLANG_FORMAT_FOUND}) - # runs clang format and updates files in place. - add_custom_target(format - ${PYTHON_EXECUTABLE} - ${BUILD_SUPPORT_DIR}/run_clang_format.py - --clang_format_binary - ${CLANG_FORMAT_BIN} - ${COMMON_LINT_OPTIONS} - --fix - ${ARROW_LINT_QUIET}) - - # runs clang format and exits with a non-zero exit code if any files need to be reformatted - add_custom_target(check-format - ${PYTHON_EXECUTABLE} - ${BUILD_SUPPORT_DIR}/run_clang_format.py - --clang_format_binary - ${CLANG_FORMAT_BIN} - ${COMMON_LINT_OPTIONS} - ${ARROW_LINT_QUIET}) -endif() - -add_custom_target(lint_cpp_cli ${PYTHON_EXECUTABLE} ${BUILD_SUPPORT_DIR}/lint_cpp_cli.py - ${CMAKE_CURRENT_SOURCE_DIR}/src) - -if(ARROW_LINT_ONLY) - message("ARROW_LINT_ONLY was specified, this is only a partial build directory") - return() -endif() - -# -# "make clang-tidy" and "make check-clang-tidy" targets -# -if(${CLANG_TIDY_FOUND}) - # TODO check to make sure .clang-tidy is being respected - - # runs clang-tidy and attempts to fix any warning automatically - add_custom_target(clang-tidy - ${PYTHON_EXECUTABLE} - ${BUILD_SUPPORT_DIR}/run_clang_tidy.py - --clang_tidy_binary - ${CLANG_TIDY_BIN} - --compile_commands - ${CMAKE_BINARY_DIR}/compile_commands.json - ${COMMON_LINT_OPTIONS} - --fix - ${ARROW_LINT_QUIET}) - - # runs clang-tidy and exits with a non-zero exit code if any errors are found. - add_custom_target(check-clang-tidy - ${PYTHON_EXECUTABLE} - ${BUILD_SUPPORT_DIR}/run_clang_tidy.py - --clang_tidy_binary - ${CLANG_TIDY_BIN} - --compile_commands - ${CMAKE_BINARY_DIR}/compile_commands.json - ${COMMON_LINT_OPTIONS} - ${ARROW_LINT_QUIET}) -endif() - -if(UNIX) - add_custom_target(iwyu - ${CMAKE_COMMAND} - -E - env - "PYTHON=${PYTHON_EXECUTABLE}" - ${BUILD_SUPPORT_DIR}/iwyu/iwyu.sh) - add_custom_target(iwyu-all - ${CMAKE_COMMAND} - -E - env - "PYTHON=${PYTHON_EXECUTABLE}" - ${BUILD_SUPPORT_DIR}/iwyu/iwyu.sh - all) -endif(UNIX) - -# datetime code used by iOS requires zlib support -if(IOS) - set(ARROW_WITH_ZLIB ON) -endif() - -if(NOT ARROW_BUILD_TESTS) - set(NO_TESTS 1) -else() - add_custom_target(all-tests) - add_custom_target(unittest - ctest - -j4 - -L - unittest - --output-on-failure) - add_dependencies(unittest all-tests) -endif() - -if(ARROW_ENABLE_TIMING_TESTS) - add_definitions(-DARROW_WITH_TIMING_TESTS) -endif() - -if(NOT ARROW_BUILD_BENCHMARKS) - set(NO_BENCHMARKS 1) +add_custom_target(arrow_flight_sql) + +arrow_install_all_headers("arrow/flight/sql") + +set(FLIGHT_SQL_PROTO_PATH "${ARROW_SOURCE_DIR}/../format") +set(FLIGHT_SQL_PROTO ${ARROW_SOURCE_DIR}/../format/FlightSql.proto) + +set(FLIGHT_SQL_GENERATED_PROTO_FILES "${CMAKE_CURRENT_BINARY_DIR}/FlightSql.pb.cc" + "${CMAKE_CURRENT_BINARY_DIR}/FlightSql.pb.h") + +set(PROTO_DEPENDS ${FLIGHT_SQL_PROTO} ${ARROW_PROTOBUF_LIBPROTOBUF}) + +set(FLIGHT_SQL_PROTOC_COMMAND + ${ARROW_PROTOBUF_PROTOC} "-I${FLIGHT_SQL_PROTO_PATH}" + "--cpp_out=dllexport_decl=ARROW_FLIGHT_SQL_EXPORT:${CMAKE_CURRENT_BINARY_DIR}") +if(Protobuf_VERSION VERSION_LESS 3.15) + list(APPEND FLIGHT_SQL_PROTOC_COMMAND "--experimental_allow_proto3_optional") +endif() +list(APPEND FLIGHT_SQL_PROTOC_COMMAND "${FLIGHT_SQL_PROTO}") + +add_custom_command(OUTPUT ${FLIGHT_SQL_GENERATED_PROTO_FILES} + COMMAND ${FLIGHT_SQL_PROTOC_COMMAND} + DEPENDS ${PROTO_DEPENDS}) + +set_source_files_properties(${FLIGHT_SQL_GENERATED_PROTO_FILES} PROPERTIES GENERATED TRUE) +add_custom_target(flight_sql_protobuf_gen ALL DEPENDS ${FLIGHT_SQL_GENERATED_PROTO_FILES}) + +set(ARROW_FLIGHT_SQL_SRCS + server.cc + sql_info_internal.cc + column_metadata.cc + client.cc + protocol_internal.cc + server_session_middleware.cc) + +add_arrow_lib(arrow_flight_sql + CMAKE_PACKAGE_NAME + ArrowFlightSql + PKG_CONFIG_NAME + arrow-flight-sql + OUTPUTS + ARROW_FLIGHT_SQL_LIBRARIES + SOURCES + ${ARROW_FLIGHT_SQL_SRCS} + DEPENDENCIES + flight_sql_protobuf_gen + SHARED_LINK_FLAGS + ${ARROW_VERSION_SCRIPT_FLAGS} # Defined in cpp/arrow/CMakeLists.txt + SHARED_LINK_LIBS + arrow_flight_shared + SHARED_INSTALL_INTERFACE_LIBS + ArrowFlight::arrow_flight_shared + STATIC_LINK_LIBS + arrow_flight_static + STATIC_INSTALL_INTERFACE_LIBS + ArrowFlight::arrow_flight_static + PRIVATE_INCLUDES + "${Protobuf_INCLUDE_DIRS}") + +if(ARROW_BUILD_STATIC AND WIN32) + target_compile_definitions(arrow_flight_sql_static PUBLIC ARROW_FLIGHT_SQL_STATIC) +endif() + +if(MSVC) + # Suppress warnings caused by Protobuf (casts) + set_source_files_properties(protocol_internal.cc PROPERTIES COMPILE_FLAGS "/wd4267") +endif() +foreach(LIB_TARGET ${ARROW_FLIGHT_SQL_LIBRARIES}) + target_compile_definitions(${LIB_TARGET} PRIVATE ARROW_FLIGHT_SQL_EXPORTING) +endforeach() + +if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static" AND ARROW_BUILD_STATIC) + set(ARROW_FLIGHT_SQL_TEST_LINK_LIBS arrow_flight_sql_static) else() - add_custom_target(all-benchmarks) - add_custom_target(benchmark ctest -L benchmark) - add_dependencies(benchmark all-benchmarks) - if(ARROW_BUILD_BENCHMARKS_REFERENCE) - add_definitions(-DARROW_WITH_BENCHMARKS_REFERENCE) - endif() -endif() - -if(NOT ARROW_BUILD_EXAMPLES) - set(NO_EXAMPLES 1) -endif() - -if(ARROW_FUZZING) - # Fuzzing builds enable ASAN without setting our home-grown option for it. - add_definitions(-DADDRESS_SANITIZER) -endif() - -if(ARROW_LARGE_MEMORY_TESTS) - add_definitions(-DARROW_LARGE_MEMORY_TESTS) -endif() - -if(ARROW_TEST_MEMCHECK) - add_definitions(-DARROW_VALGRIND) -endif() - -if(ARROW_USE_UBSAN) - add_definitions(-DARROW_UBSAN) -endif() - -# -# Compiler flags -# - -if(ARROW_EXTRA_ERROR_CONTEXT) - add_definitions(-DARROW_EXTRA_ERROR_CONTEXT) -endif() - -include(SetupCxxFlags) - -# -# Linker flags -# - -# Localize thirdparty symbols using a linker version script. This hides them -# from the client application. The OS X linker does not support the -# version-script option. -if(CMAKE_VERSION VERSION_LESS 3.18) - if(APPLE OR WIN32) - set(CXX_LINKER_SUPPORTS_VERSION_SCRIPT FALSE) - else() - set(CXX_LINKER_SUPPORTS_VERSION_SCRIPT TRUE) + set(ARROW_FLIGHT_SQL_TEST_LINK_LIBS arrow_flight_sql_shared) +endif() +list(APPEND ARROW_FLIGHT_SQL_TEST_LINK_LIBS ${ARROW_FLIGHT_TEST_LINK_LIBS}) + +# Build test server for unit tests +if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) + find_package(SQLite3Alt REQUIRED) + + set(ARROW_FLIGHT_SQL_TEST_SERVER_SRCS + example/sqlite_sql_info.cc + example/sqlite_type_info.cc + example/sqlite_statement.cc + example/sqlite_statement_batch_reader.cc + example/sqlite_server.cc + example/sqlite_tables_schema_batch_reader.cc) + + set(ARROW_FLIGHT_SQL_TEST_SRCS server_test.cc + server_session_middleware_internals_test.cc) + + set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES}) + set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc) + + if(ARROW_COMPUTE + AND ARROW_PARQUET + AND ARROW_SUBSTRAIT) + list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS ${ARROW_FLIGHT_SQL_ACERO_SRCS} acero_test.cc) + if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_static) + else() + list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_shared) + endif() + + if(ARROW_BUILD_EXAMPLES) + add_executable(acero-flight-sql-server ${ARROW_FLIGHT_SQL_ACERO_SRCS} + example/acero_main.cc) + target_link_libraries(acero-flight-sql-server + PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} + ${ARROW_FLIGHT_SQL_TEST_LIBS} ${GFLAGS_LIBRARIES}) + endif() endif() -else() - include(CheckLinkerFlag) - check_linker_flag(CXX - "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/src/arrow/symbols.map" - CXX_LINKER_SUPPORTS_VERSION_SCRIPT) -endif() - -# -# Build output directory -# -# set compile output directory -string(TOLOWER ${CMAKE_BUILD_TYPE} BUILD_SUBDIR_NAME) - -# If build in-source, create the latest symlink. If build out-of-source, which is -# preferred, simply output the binaries in the build folder -if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_CURRENT_BINARY_DIR}) - set(BUILD_OUTPUT_ROOT_DIRECTORY - "${CMAKE_CURRENT_BINARY_DIR}/build/${BUILD_SUBDIR_NAME}/") - # Link build/latest to the current build directory, to avoid developers - # accidentally running the latest debug build when in fact they're building - # release builds. - file(MAKE_DIRECTORY ${BUILD_OUTPUT_ROOT_DIRECTORY}) - if(NOT APPLE) - set(MORE_ARGS "-T") + add_arrow_test(flight_sql_test + SOURCES + ${ARROW_FLIGHT_SQL_TEST_SRCS} + ${ARROW_FLIGHT_SQL_TEST_SERVER_SRCS} + STATIC_LINK_LIBS + ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} + ${ARROW_FLIGHT_SQL_TEST_LIBS} + EXTRA_INCLUDES + "${CMAKE_CURRENT_BINARY_DIR}/../" + LABELS + "arrow_flight_sql") + + add_executable(flight-sql-test-server test_server_cli.cc + ${ARROW_FLIGHT_SQL_TEST_SERVER_SRCS}) + target_link_libraries(flight-sql-test-server + PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} + ${SQLite3_LIBRARIES}) + + add_executable(flight-sql-test-app test_app_cli.cc) + target_link_libraries(flight-sql-test-app PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} + ${GFLAGS_LIBRARIES}) + + if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static" AND ARROW_BUILD_STATIC) + foreach(TEST_TARGET arrow-flight-sql-test flight-sql-test-server flight-sql-test-app) + target_compile_definitions(${TEST_TARGET} PUBLIC ARROW_FLIGHT_STATIC + ARROW_FLIGHT_SQL_STATIC) + endforeach() endif() - execute_process(COMMAND ln ${MORE_ARGS} -sf ${BUILD_OUTPUT_ROOT_DIRECTORY} - ${CMAKE_CURRENT_BINARY_DIR}/build/latest) -else() - set(BUILD_OUTPUT_ROOT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${BUILD_SUBDIR_NAME}/") -endif() - -# where to put generated archives (.a files) -set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}") -set(ARCHIVE_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}") - -# where to put generated libraries (.so files) -set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}") -set(LIBRARY_OUTPUT_DIRECTORY "${BUILD_OUTPUT_ROOT_DIRECTORY}") - -# where to put generated binaries -set(EXECUTABLE_OUTPUT_PATH "${BUILD_OUTPUT_ROOT_DIRECTORY}") - -if(CMAKE_GENERATOR STREQUAL Xcode) - # Xcode projects support multi-configuration builds. This forces a single output directory - # when building with Xcode that is consistent with single-configuration Makefile driven build. - set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY_${UPPERCASE_BUILD_TYPE} - "${BUILD_OUTPUT_ROOT_DIRECTORY}") - set(CMAKE_LIBRARY_OUTPUT_DIRECTORY_${UPPERCASE_BUILD_TYPE} - "${BUILD_OUTPUT_ROOT_DIRECTORY}") - set(CMAKE_RUNTIME_OUTPUT_DIRECTORY_${UPPERCASE_BUILD_TYPE} - "${BUILD_OUTPUT_ROOT_DIRECTORY}") -endif() - -# -# Dependencies -# - -include(BuildUtils) -enable_testing() - -# For arrow.pc. Cflags.private, Libs.private and Requires.private are -# used when "pkg-config --cflags --libs --static arrow" is used. -set(ARROW_PC_CFLAGS "") -set(ARROW_PC_CFLAGS_PRIVATE " -DARROW_STATIC") -set(ARROW_PC_LIBS_PRIVATE "") -set(ARROW_PC_REQUIRES_PRIVATE "") - -# For arrow-flight.pc. -set(ARROW_FLIGHT_PC_REQUIRES_PRIVATE "") - -# For arrow-testing.pc. -set(ARROW_TESTING_PC_CFLAGS "") -set(ARROW_TESTING_PC_CFLAGS_PRIVATE " -DARROW_TESTING_STATIC") -set(ARROW_TESTING_PC_LIBS "") -set(ARROW_TESTING_PC_REQUIRES "") - -# For parquet.pc. -set(PARQUET_PC_CFLAGS "") -set(PARQUET_PC_CFLAGS_PRIVATE " -DPARQUET_STATIC") -set(PARQUET_PC_REQUIRES "") -set(PARQUET_PC_REQUIRES_PRIVATE "") - -include(ThirdpartyToolchain) - -# Add common flags -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_COMMON_FLAGS}") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARROW_CXXFLAGS}") - -# For any C code, use the same flags. These flags don't contain -# C++ specific flags. -set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${CXX_COMMON_FLAGS} ${ARROW_CXXFLAGS}") - -# Remove --std=c++17 to avoid errors from C compilers -string(REPLACE "-std=c++17" "" CMAKE_C_FLAGS ${CMAKE_C_FLAGS}) - -# Add C++-only flags, like -std=c++17 -set(CMAKE_CXX_FLAGS "${CXX_ONLY_FLAGS} ${CMAKE_CXX_FLAGS}") - -# ASAN / TSAN / UBSAN -if(ARROW_FUZZING) - set(ARROW_USE_COVERAGE ON) -endif() -include(san-config) - -# Code coverage -if("${ARROW_GENERATE_COVERAGE}") - set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} --coverage -DCOVERAGE_BUILD") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --coverage -DCOVERAGE_BUILD") endif() - -# CMAKE_CXX_FLAGS now fully assembled -message(STATUS "CMAKE_C_FLAGS: ${CMAKE_C_FLAGS}") -message(STATUS "CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}") -message(STATUS "CMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}: ${CMAKE_C_FLAGS_${UPPERCASE_BUILD_TYPE}}" -) -message(STATUS "CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}: ${CMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}}" -) - -include_directories(${CMAKE_CURRENT_BINARY_DIR}/src) -include_directories(src) - -# Compiled flatbuffers files -include_directories(src/generated) - -# -# Visibility -# -if(PARQUET_BUILD_SHARED) - set_target_properties(arrow_shared - PROPERTIES C_VISIBILITY_PRESET hidden - CXX_VISIBILITY_PRESET hidden - VISIBILITY_INLINES_HIDDEN 1) -endif() - -# -# "make ctags" target -# -if(UNIX) - add_custom_target(ctags ctags -R --languages=c++,c) -endif(UNIX) - -# -# "make etags" target -# -if(UNIX) - add_custom_target(tags - etags - --members - --declarations - `find - ${CMAKE_CURRENT_SOURCE_DIR}/src - -name - \\*.cc - -or - -name - \\*.hh - -or - -name - \\*.cpp - -or - -name - \\*.h - -or - -name - \\*.c - -or - -name - \\*.f`) - add_custom_target(etags DEPENDS tags) -endif(UNIX) - -# -# "make cscope" target -# -if(UNIX) - add_custom_target(cscope - find - ${CMAKE_CURRENT_SOURCE_DIR} - (-name - \\*.cc - -or - -name - \\*.hh - -or - -name - \\*.cpp - -or - -name - \\*.h - -or - -name - \\*.c - -or - -name - \\*.f) - -exec - echo - \"{}\" - \; - > - cscope.files - && - cscope - -q - -b - VERBATIM) -endif(UNIX) - -# -# "make infer" target -# - -if(${INFER_FOUND}) - # runs infer capture - add_custom_target(infer ${BUILD_SUPPORT_DIR}/run-infer.sh ${INFER_BIN} - ${CMAKE_BINARY_DIR}/compile_commands.json 1) - # runs infer analyze - add_custom_target(infer-analyze ${BUILD_SUPPORT_DIR}/run-infer.sh ${INFER_BIN} - ${CMAKE_BINARY_DIR}/compile_commands.json 2) - # runs infer report - add_custom_target(infer-report ${BUILD_SUPPORT_DIR}/run-infer.sh ${INFER_BIN} - ${CMAKE_BINARY_DIR}/compile_commands.json 3) -endif() - -# -# Link targets -# - -if("${ARROW_TEST_LINKAGE}" STREQUAL "shared") - if(ARROW_BUILD_TESTS AND NOT ARROW_BUILD_SHARED) - message(FATAL_ERROR "If using ARROW_TEST_LINKAGE=shared, must also \ -pass ARROW_BUILD_SHARED=on") - endif() - # Use shared linking for unit tests if it's available - set(ARROW_TEST_LINK_LIBS arrow_testing_shared) - set(ARROW_EXAMPLE_LINK_LIBS arrow_shared) -else() - if(ARROW_BUILD_TESTS AND NOT ARROW_BUILD_STATIC) - message(FATAL_ERROR "If using static linkage for unit tests, must also \ -pass ARROW_BUILD_STATIC=on") - endif() - set(ARROW_TEST_LINK_LIBS arrow_testing_static) - set(ARROW_EXAMPLE_LINK_LIBS arrow_static) -endif() -# arrow::flatbuffers isn't needed for all tests but we specify it as -# the first link library. It's for prioritizing bundled FlatBuffers -# than system FlatBuffers. -list(PREPEND ARROW_TEST_LINK_LIBS arrow::flatbuffers) -list(APPEND ARROW_TEST_LINK_LIBS ${ARROW_GTEST_GMOCK} ${ARROW_GTEST_GTEST_MAIN}) - -if(ARROW_BUILD_BENCHMARKS) - set(ARROW_BENCHMARK_LINK_LIBS benchmark::benchmark_main ${ARROW_TEST_LINK_LIBS}) - if(WIN32) - list(APPEND ARROW_BENCHMARK_LINK_LIBS shlwapi) - endif() -endif() - -# -# Subdirectories -# - -add_subdirectory(src/arrow) - -if(ARROW_PARQUET) - add_subdirectory(src/parquet) - add_subdirectory(tools/parquet) - if(PARQUET_BUILD_EXAMPLES) - add_subdirectory(examples/parquet) - endif() -endif() - -if(ARROW_GANDIVA) - add_subdirectory(src/gandiva) -endif() - -if(ARROW_SKYHOOK) - add_subdirectory(src/skyhook) -endif() - -if(ARROW_BUILD_EXAMPLES) - add_custom_target(runexample ctest -L example) - add_subdirectory(examples/arrow) -endif() - -install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/../LICENSE.txt - ${CMAKE_CURRENT_SOURCE_DIR}/../NOTICE.txt - ${CMAKE_CURRENT_SOURCE_DIR}/README.md DESTINATION "${ARROW_DOC_DIR}") - -install(FILES ${CMAKE_CURRENT_SOURCE_DIR}/gdb_arrow.py DESTINATION "${ARROW_GDB_DIR}") - -# -# Validate and print out Arrow configuration options -# - -validate_config() -config_summary_message() -if(${ARROW_BUILD_CONFIG_SUMMARY_JSON}) - config_summary_json() -endif() \ No newline at end of file diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index 9d3f0004ada9a..9541432114f9c 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -443,6 +443,7 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { PreparedStatement(FlightSqlClient* client, std::string handle, std::shared_ptr dataset_schema, std::shared_ptr parameter_schema); + /// \brief Default destructor for the PreparedStatement class. /// The destructor will call the Close method from the class in order, /// to send a request to close the PreparedStatement. @@ -525,4 +526,4 @@ class ARROW_FLIGHT_SQL_EXPORT Transaction { } // namespace sql } // namespace flight -} // namespace arrow \ No newline at end of file +} // namespace arrow From e1241510a1e11e99546b6ddf9d83a8ce5e51c0e8 Mon Sep 17 00:00:00 2001 From: HackPoint Date: Mon, 23 Dec 2024 13:45:44 +0200 Subject: [PATCH 58/58] chore: clean up redundant namespaces --- .../ChannelReaderStreamAdapter.cs | 54 ------------------- .../PreparedStatement.cs | 5 -- 2 files changed, 59 deletions(-) delete mode 100644 csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs diff --git a/csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs b/csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs deleted file mode 100644 index 14cf03ca40771..0000000000000 --- a/csharp/src/Apache.Arrow.Flight.Sql/ChannelReaderStreamAdapter.cs +++ /dev/null @@ -1,54 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -using System; -using System.Threading; -using System.Threading.Channels; -using System.Threading.Tasks; -using Grpc.Core; - -namespace Apache.Arrow.Flight.Sql; - -internal class ChannelReaderStreamAdapter : IAsyncStreamReader -{ - private readonly ChannelReader _channelReader; - - public ChannelReaderStreamAdapter(ChannelReader channelReader) - { - _channelReader = channelReader ?? throw new ArgumentNullException(nameof(channelReader)); - Current = default!; - } - - public T Current { get; private set; } - - public async Task MoveNext(CancellationToken cancellationToken) - { - if (await _channelReader.WaitToReadAsync(cancellationToken).ConfigureAwait(false)) - { - if (_channelReader.TryRead(out var item)) - { - Current = item; - return true; - } - } - - return false; - } - - public void Dispose() - { - // No additional cleanup is required here since we are using a channel - } -} \ No newline at end of file diff --git a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs index 4d2c304a31688..0399e1636762a 100644 --- a/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs +++ b/csharp/src/Apache.Arrow.Flight.Sql/PreparedStatement.cs @@ -15,19 +15,14 @@ using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; -using Apache.Arrow.Flight.Server; using Apache.Arrow.Flight.Sql.Client; -using Apache.Arrow.Ipc; using Arrow.Flight.Protocol.Sql; using Google.Protobuf; using Grpc.Core; -using System.Threading.Channels; -using Google.Protobuf.WellKnownTypes; namespace Apache.Arrow.Flight.Sql;