From 13deb30df68c7012e3955e67dcf946f05b16c7ee Mon Sep 17 00:00:00 2001 From: Marc Handalian Date: Mon, 10 Feb 2025 15:51:00 -0800 Subject: [PATCH] Trying a new way by pushing/polling from DF Signed-off-by: Marc Handalian --- .../opensearch/arrow/spi/StreamProducer.java | 4 + libs/datafusion/jni/src/lib.rs | 649 +++++++++++++++++- .../DataFrameStreamProducer.java | 6 - .../org.opensearch.datafusion/DataFusion.java | 3 - .../search/query/StreamSearchPhase.java | 13 +- .../stream/collector/ArrowCollector.java | 93 ++- .../collector/ArrowStreamingCollector.java | 195 ++++++ .../collector/DataFusionAggregator.java | 100 +++ .../collector/PushStreamingCollector.java | 193 ++++++ .../stream/collector/StreamingCollector.java | 135 ++++ 10 files changed, 1304 insertions(+), 87 deletions(-) create mode 100644 server/src/main/java/org/opensearch/search/stream/collector/ArrowStreamingCollector.java create mode 100644 server/src/main/java/org/opensearch/search/stream/collector/DataFusionAggregator.java create mode 100644 server/src/main/java/org/opensearch/search/stream/collector/PushStreamingCollector.java create mode 100644 server/src/main/java/org/opensearch/search/stream/collector/StreamingCollector.java diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java index cbf3e35b0508e..63b998fb12b3b 100644 --- a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java +++ b/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java @@ -100,6 +100,10 @@ public interface StreamProducer extends Closeable { */ BatchedJob createJob(BufferAllocator allocator); + /** + * + * @return + */ default Set partitions() { return Collections.emptySet(); } diff --git a/libs/datafusion/jni/src/lib.rs b/libs/datafusion/jni/src/lib.rs index 6d16d3824d803..84a500073a7e5 100644 --- a/libs/datafusion/jni/src/lib.rs +++ b/libs/datafusion/jni/src/lib.rs @@ -5,15 +5,23 @@ use std::time::Duration; use arrow::array::ffi::{FFI_ArrowArray, FFI_ArrowSchema}; use arrow::array::{Array, RecordBatch, StructArray}; +use arrow::datatypes::Schema; use arrow::ffi::{self}; use arrow::ipc::writer::FileWriter; use bytes::Bytes; -use datafusion::execution::SendableRecordBatchStream; +use datafusion::catalog::Session; +use datafusion::error::DataFusionError; +use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream}; +use datafusion::functions_aggregate::count::count; +use datafusion::functions_aggregate::sum::sum; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::metrics::MetricsSet; use datafusion::prelude::{col, DataFrame, SessionConfig, SessionContext}; use futures::stream::TryStreamExt; use jni::objects::{JByteArray, JClass, JObject, JString}; use jni::sys::{jint, jlong}; -use jni::JNIEnv; +use jni::{AttachGuard, JNIEnv}; + use std::io::BufWriter; use tokio::runtime::Runtime; mod provider; @@ -31,10 +39,13 @@ pub extern "system" fn Java_org_opensearch_datafusion_DataFusion_load( ) { let context = unsafe { &mut *(ctx as *mut SessionContext) }; let runtime = unsafe { &mut *(runtime as *mut Runtime) }; - let term_str: String = format!("`{}`", env.get_string(&term) - .expect("Invalid term string") - .to_string_lossy() - .into_owned()); + let term_str: String = format!( + "`{}`", + env.get_string(&term) + .expect("Invalid term string") + .to_string_lossy() + .into_owned() + ); // Take ownership of FFI structs immediately outside the async block let array = unsafe { ffi::FFI_ArrowArray::from_raw(array_ptr as *mut _) }; let schema = unsafe { ffi::FFI_ArrowSchema::from_raw(schema_ptr as *mut _) }; @@ -62,7 +73,15 @@ pub extern "system" fn Java_org_opensearch_datafusion_DataFusion_load( ], ) }) - .and_then(|df| df.sort(vec![col("ord").sort(false, true)])); + .and_then(|agg_df| { + agg_df.sort(vec![col("count").sort(true, false)]) + }) + .and_then(|sorted| { + sorted.limit(0, Some(500)) + }) + .and_then(|limited| { + limited.sort(vec![col("ord").sort(false, true)]) + }); df }; @@ -186,7 +205,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_SessionContext_destroySess pub extern "system" fn Java_org_opensearch_datafusion_SessionContext_createSessionContext( _env: JNIEnv, _class: JClass, - size: jint + size: jint, ) -> jlong { let config = SessionConfig::new().with_batch_size(size.try_into().unwrap()); let context = SessionContext::new_with_config(config); @@ -285,7 +304,8 @@ pub extern "system" fn Java_org_opensearch_datafusion_RecordBatchStream_next( let runtime = unsafe { &mut *(runtime as *mut Runtime) }; let stream = unsafe { &mut *(stream as *mut SendableRecordBatchStream) }; runtime.block_on(async { - let next: Result, datafusion::error::DataFusionError> = stream.try_next().await; + let next: Result, datafusion::error::DataFusionError> = + stream.try_next().await; match next { Ok(Some(batch)) => { // Convert to struct array for compatibility with FFI @@ -309,7 +329,7 @@ pub extern "system" fn Java_org_opensearch_datafusion_RecordBatchStream_next( pub extern "system" fn Java_org_opensearch_datafusion_RecordBatchStream_destroy( mut env: JNIEnv, _class: JClass, - pointer: jlong + pointer: jlong, ) { let _ = unsafe { Box::from_raw(pointer as *mut SendableRecordBatchStream) }; } @@ -344,23 +364,616 @@ pub extern "system" fn Java_org_opensearch_datafusion_DataFrame_destroyDataFrame let _ = unsafe { Box::from_raw(pointer as *mut DataFrame) }; } +pub struct DataFusionAggregator { + context: SessionContext, + current_aggregation: Option, + term_column: String, +} + +impl DataFusionAggregator { + pub fn new(context: SessionContext, term_column: String) -> Self { + DataFusionAggregator { + context, + current_aggregation: None, + term_column, + } + } + + pub async fn push_batch(&mut self, batch: RecordBatch) -> Result<(), DataFusionError> { + // agg and collect the new batch immediately + let aggregated = self.context.read_batch(batch)? + .filter(col(&self.term_column).is_not_null())? + .aggregate( + vec![col(&self.term_column).alias("ord")], + vec![count(col(&self.term_column)).alias("count")], + )?.collect().await?; + + let incoming_frame = self.context.read_batches(aggregated).unwrap(); + + // Merge with existing aggregation if we have one + self.current_aggregation = match self.current_aggregation.take() { + Some(existing) => { + Some(existing + .union(incoming_frame)? + .aggregate( + vec![col("ord")], + vec![sum(col("count")).alias("count")] + )? + ) + }, + None => Some(incoming_frame), + }; + Ok(()) + } + + pub fn take_results(&mut self) -> Option { + self.current_aggregation.take() + } +} + +// JNI bindings +#[no_mangle] +pub extern "system" fn Java_org_opensearch_search_stream_collector_DataFusionAggregator_create( + mut env: JNIEnv, + _class: JClass, + ctx: jlong, + term: JString, +) -> jlong { + let context = unsafe { &*(ctx as *const SessionContext) }; + let term_str = format!( + "`{}`", + env.get_string(&term) + .expect("Invalid term string") + .to_string_lossy() + .into_owned() + ); + + let aggregator = DataFusionAggregator::new(context.clone(), term_str); + Box::into_raw(Box::new(aggregator)) as jlong +} + #[no_mangle] -pub extern "system" fn Java_org_opensearch_datafusion_DataFrame_union( +pub extern "system" fn Java_org_opensearch_search_stream_collector_DataFusionAggregator_pushBatch( mut env: JNIEnv, _class: JClass, runtime: jlong, - df1: jlong, - df2: jlong, + agg_ptr: jlong, + array_ptr: jlong, + schema_ptr: jlong, callback: JObject, ) { let runtime = unsafe { &mut *(runtime as *mut Runtime) }; - let dataframe1 = unsafe { &mut *(df1 as *mut DataFrame) }; - let dataframe2 = unsafe { &mut *(df2 as *mut DataFrame) }; + let aggregator = unsafe { &mut *(agg_ptr as *mut DataFusionAggregator) }; + + let array = unsafe { ffi::FFI_ArrowArray::from_raw(array_ptr as *mut _) }; + let schema = unsafe { ffi::FFI_ArrowSchema::from_raw(schema_ptr as *mut _) }; runtime.block_on(async { - let result = dataframe1.clone() - .union(dataframe2.clone()) - .map(|df| Box::into_raw(Box::new(df))); + let result = unsafe { + let data = arrow::ffi::from_ffi(array, &schema).unwrap(); + let arrow_array = arrow::array::make_array(data); + let struct_array = arrow_array + .as_any() + .downcast_ref::() + .unwrap(); + let record_batch = RecordBatch::try_from(struct_array).unwrap(); + + aggregator + .push_batch(record_batch) + .await + .map(|_| Box::into_raw(Box::new(1i32))) + }; + set_object_result(&mut env, callback, result); }); +} + +#[no_mangle] +pub extern "system" fn Java_org_opensearch_search_stream_collector_DataFusionAggregator_getResults( + mut env: JNIEnv, + _class: JClass, + runtime: jlong, + agg_ptr: jlong, + limit: jint, + callback: JObject, +) { + let runtime = unsafe { &mut *(runtime as *mut Runtime) }; + let aggregator = unsafe { &mut *(agg_ptr as *mut DataFusionAggregator) }; + + let result = match aggregator.take_results() { + Some(mut df) => { + if limit > 0 { + // Apply limit and THEN sort by ord, this ensures we do lookups only on the top ords. + match df.limit(0, Some(limit as usize)).and_then(|limited| { + runtime.block_on(async { + println!("Limited DataFrame:"); + limited.clone().show().await.unwrap(); + }); + limited.sort(vec![col("ord").sort(false, true)]) + }) { + Ok(limited_df) => Ok(Box::into_raw(Box::new(limited_df))), + Err(e) => Err(e) + } + } else { + Ok(Box::into_raw(Box::new(df))) + } + }, + None => Ok(std::ptr::null_mut()) + }; + + set_object_result::(&mut env, callback, result); +} + +#[no_mangle] +pub extern "system" fn Java_org_opensearch_search_stream_collector_DataFusionAggregator_destroy( + _env: JNIEnv, + _class: JClass, + pointer: jlong, +) { + let _ = unsafe { Box::from_raw(pointer as *mut DataFusionAggregator) }; +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int64Array}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use std::sync::Arc; + + fn create_test_batch(values: Vec) -> RecordBatch { + let schema = Schema::new(vec![Field::new("category", DataType::Int64, false)]); + let array = Int64Array::from(values); + + RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(array)] + ).unwrap() + } + + #[tokio::test] + async fn test_aggregator_single_batch() { + let ctx = SessionContext::new(); + let mut aggregator = DataFusionAggregator::new(ctx, "`category`".to_string()); + + // Create a batch with [1, 1, 2, 2, 2] + let batch = create_test_batch(vec![1, 1, 2, 2, 2]); + + // Push batch and check results + aggregator.push_batch(batch).await.unwrap(); + + // Get results and verify + let batches = aggregator.take_results().unwrap().collect().await.unwrap(); + println!("BATCHES ARE {:?}", batches); + assert_eq!(batches.len(), 1); + + let batch = &batches[0]; + assert_eq!(batch.num_rows(), 2); // Should have two rows (for values 1 and 2) + + // Verify counts + let counts = batch.column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(counts.value(0), 2); // Count for value 1 + assert_eq!(counts.value(1), 3); // Count for value 2 + } + + #[tokio::test] + async fn test_aggregator_multiple_batches() { + let ctx = SessionContext::new(); + let mut aggregator = DataFusionAggregator::new(ctx, "`category`".to_string()); + + // Push multiple batches + aggregator.push_batch(create_test_batch(vec![1, 1])).await.unwrap(); + aggregator.push_batch(create_test_batch(vec![2, 2])).await.unwrap(); + aggregator.push_batch(create_test_batch(vec![1, 2])).await.unwrap(); + + // Get results and verify + let batches = aggregator.take_results().unwrap().collect().await.unwrap(); + let batch = &batches[0]; + assert_eq!(batch.num_rows(), 2); // Should have two rows + + // Verify counts + let counts = batch.column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(counts.value(0), 3); // Total count for value 1 + assert_eq!(counts.value(1), 3); // Total count for value 2 + } + + #[tokio::test] + async fn test_aggregator_empty_batch() { + let ctx = SessionContext::new(); + let mut aggregator = DataFusionAggregator::new(ctx, "`category`".to_string()); + + // Push empty batch + let batch = create_test_batch(vec![]); + aggregator.push_batch(batch).await.unwrap(); + + // Get results and verify + let batches = aggregator.take_results().unwrap().collect().await.unwrap(); + assert!(batches.is_empty()); + } + + #[tokio::test] + async fn test_aggregator_limit() { + let ctx = SessionContext::new(); + let mut aggregator = DataFusionAggregator::new(ctx, "`category`".to_string()); + + // Create batch with multiple values + let batch = create_test_batch(vec![1, 1, 2, 2, 3, 3, 4, 4]); + aggregator.push_batch(batch).await.unwrap(); + + // Get results with limit + let batches = aggregator.take_results().unwrap().collect().await.unwrap(); + assert_eq!(batches[0].num_rows(), 2); // Should only have 2 rows due to limit + } + + #[tokio::test] + async fn test_aggregator_ordering() { + let ctx = SessionContext::new(); + let mut aggregator = DataFusionAggregator::new(ctx, "`category`".to_string()); + + // Push values in random order + aggregator.push_batch(create_test_batch(vec![3, 1])).await.unwrap(); + aggregator.push_batch(create_test_batch(vec![2, 4])).await.unwrap(); + + // Get results and verify ordering + let batches = aggregator.take_results().unwrap().collect().await.unwrap(); + + let batch = &batches[0]; + let ords = batch.column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // Verify values are in ascending order + let mut prev = ords.value(0); + for i in 1..batch.num_rows() { + let curr = ords.value(i); + assert!(curr > prev); + prev = curr; + } + } + + #[tokio::test] + async fn test_aggregator_null_handling() { + let ctx = SessionContext::new(); + let mut aggregator = DataFusionAggregator::new(ctx, "`category`".to_string()); + + // Create a batch with some null values + let schema = Schema::new(vec![Field::new("category", DataType::Int64, true)]); + let array = Int64Array::from(vec![Some(1), None, Some(2), None, Some(1)]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(array)] + ).unwrap(); + + aggregator.push_batch(batch).await.unwrap(); + + // Get results and verify + let result = aggregator.take_results().unwrap().collect().await.unwrap(); + let batch = &result[0]; + + // Verify counts (nulls should be filtered out) + let counts = batch.column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(counts.value(0), 2); // Count for value 1 + assert_eq!(counts.value(1), 1); // Count for value 2 + } +} + +use std::sync::Arc; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::fmt; +use std::any::Any; + +use async_trait::async_trait; +use futures::Stream; + +use arrow::datatypes::{SchemaRef}; +use datafusion::physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties +}; +use datafusion::execution::{ + TaskContext, + context::{SessionState}, +}; +use datafusion::datasource::{TableProvider, TableType}; +use datafusion::logical_expr::Expr; + +use jni::JavaVM; + +use jni::objects::GlobalRef; + + +struct CollectorStream { + collector_ref: GlobalRef, + schema: SchemaRef, + finished: bool, + metrics: MetricsSet, + jvm: Arc, // Store JavaVM instead of AttachGuard +} + +impl Stream for CollectorStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if this.finished { + return Poll::Ready(None); + } + + let mut env_guard = match this.jvm.attach_current_thread() { + Ok(guard) => guard, + Err(e) => return Poll::Ready(Some(Err(DataFusionError::External(Box::new(e))))) + }; + + let result = env_guard.call_method( + this.collector_ref.as_obj(), + "getNextBatch", + "()LBatchPointers;", + &[] + ); + + match result { + Ok(output) => { + match output.l() { + Ok(obj) => { + if obj.is_null() { + this.finished = true; + Poll::Ready(None) + } else { + let schema_ptr = env_guard.get_field(&obj, "schemaPtr", "J") + .map_err(|e| DataFusionError::External(Box::new(e)))? + .j() + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let array_ptr = env_guard.get_field(&obj, "arrayPtr", "J") + .map_err(|e| DataFusionError::External(Box::new(e)))? + .j() + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + match convert_to_record_batch(schema_ptr, array_ptr, &this.jvm) { + Ok(batch) => Poll::Ready(Some(Ok(batch))), + Err(e) => Poll::Ready(Some(Err(e))) + } + } + }, + Err(e) => Poll::Ready(Some(Err(DataFusionError::External(Box::new(e))))) + } + }, + Err(e) => Poll::Ready(Some(Err(DataFusionError::External(Box::new(e))))) + } + } +} + +#[derive(Debug)] +struct CollectorScan { + collector_ref: GlobalRef, + jvm: Arc, // Wrap JavaVM in Arc here too + projected_schema: SchemaRef, + metrics: MetricsSet, + properties: PlanProperties, +} + +impl CollectorScan { + fn new(collector_ref: GlobalRef, jvm: Arc, schema: SchemaRef) -> Self { + CollectorScan { + collector_ref, + jvm, + projected_schema: schema.clone(), + metrics: MetricsSet::new(), + properties: PlanProperties::new( + EquivalenceProperties::new(schema.clone()), + Partitioning::RoundRobinBatch(5), + datafusion::physical_plan::execution_plan::EmissionType::Incremental, + datafusion::physical_plan::execution_plan::Boundedness::Bounded) + } + } +} + +impl DisplayAs for CollectorScan { + fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default => { + write!(f, "CollectorScan") + } + DisplayFormatType::Verbose => todo!(), + } + } +} + +impl<'a> RecordBatchStream for CollectorStream { + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +unsafe impl<'a> Send for CollectorStream {} + +impl ExecutionPlan for CollectorScan { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.projected_schema.clone() + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> Result, DataFusionError> { + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + // let env_guard = self.jvm.attach_current_thread() + // .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let stream = CollectorStream { + collector_ref: self.collector_ref.clone(), + schema: self.projected_schema.clone(), + finished: false, + metrics: self.metrics.clone(), + jvm: self.jvm.clone() + }; + + // Just use one Box with pin + Ok(Box::pin(stream)) + } + + fn name(&self) -> &str { + "CollectorScan" + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } +} + +#[derive(Debug)] +pub struct CollectorTable { + collector_ref: GlobalRef, + schema: SchemaRef, + jvm: Arc, // Wrap JavaVM in Arc +} + +impl CollectorTable { + pub fn new(collector_ref: GlobalRef, schema: SchemaRef, jvm: JavaVM) -> Self { + Self { + collector_ref, + schema, + jvm: Arc::new(jvm), // Wrap in Arc when creating + } + } +} + +#[async_trait] +impl TableProvider for CollectorTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result, DataFusionError> { + Ok(Arc::new(CollectorScan::new( + self.collector_ref.clone(), + self.jvm.clone(), + project_schema(&self.schema, projection), + ))) + } +} + +fn convert_to_record_batch(array_ptr: i64, schema_ptr: i64, jvm: &JavaVM) -> Result { + let mut env = jvm.attach_current_thread() + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + let array = unsafe { ffi::FFI_ArrowArray::from_raw(array_ptr as *mut _) }; + let schema = unsafe { ffi::FFI_ArrowSchema::from_raw(schema_ptr as *mut _) }; + + unsafe { + let data = arrow::ffi::from_ffi(array, &schema)?; + let arrow_array = arrow::array::make_array(data); + let struct_array = arrow_array + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Failed to convert to StructArray".to_string()))?; + + RecordBatch::try_from(struct_array.clone()) + .map_err(|e| DataFusionError::Internal(format!("Failed to convert to RecordBatch: {}", e))) + } +} + +fn project_schema(schema: &SchemaRef, projection: Option<&Vec>) -> SchemaRef { + match projection { + Some(proj) => Arc::new(schema.project(proj).unwrap()), + None => schema.clone(), + } +} + +#[no_mangle] +pub extern "system" fn Java_org_opensearch_search_stream_collector_StreamingCollector_create( + mut env: JNIEnv, + _class: JClass, + ctx: jlong, + collector: JObject, + limit: jint, + term: JString, + schema_ptr: jlong, +) -> jlong { + let context = unsafe { &*(ctx as *const SessionContext) }; + let schema = unsafe { &*(schema_ptr as *const Schema) }; + let jvm = env.get_java_vm().unwrap(); + + let term_str: String = format!( + "`{}`", + env.get_string(&term) + .expect("Invalid term string") + .to_string_lossy() + .into_owned() + ); + + let collector_ref = match env.new_global_ref(collector) { + Ok(global) => global, + Err(_) => return -1, + }; + + let table = Arc::new(CollectorTable::new( + collector_ref, + Arc::new(schema.clone()), + jvm, + )); + + let df = match context.read_table(table) { + Ok(df) => df + .aggregate( + vec![col(&term_str).alias("ord")], + vec![count(col("ord")).alias("count")] + ) + .and_then(|agg_df| { + agg_df.sort(vec![col("count").sort(true, false)]) + }) + .and_then(|sorted| { + sorted.limit(0, Some(limit as usize)) + }) + .and_then(|limited| { + limited.sort(vec![col("ord").sort(false, true)]) + }), + Err(_) => return -1, + }; + + match df { + Ok(plan) => Box::into_raw(Box::new(plan)) as jlong, + Err(_) => -1, + } } \ No newline at end of file diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java index 3f5043abde1b3..f03d3532f55b4 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFrameStreamProducer.java @@ -10,10 +10,6 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.arrow.spi.StreamProducer; @@ -21,8 +17,6 @@ import org.opensearch.common.unit.TimeValue; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; diff --git a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java index 497d68aeda761..0dd71798fb2d3 100644 --- a/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java +++ b/libs/datafusion/src/main/java/org.opensearch.datafusion/DataFusion.java @@ -10,15 +10,12 @@ import org.apache.arrow.c.ArrowArray; import org.apache.arrow.c.ArrowSchema; -import org.apache.arrow.c.CDataDictionaryProvider; import org.apache.arrow.c.Data; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.arrow.spi.StreamTicket; import java.util.concurrent.CompletableFuture; import java.util.function.BiConsumer; diff --git a/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java b/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java index 654b39c34c21e..814fe7be3f47b 100644 --- a/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java +++ b/server/src/main/java/org/opensearch/search/query/StreamSearchPhase.java @@ -43,6 +43,7 @@ import org.opensearch.search.stream.StreamSearchResult; import org.opensearch.search.stream.collector.ArrowCollector; import org.opensearch.search.stream.collector.ArrowFieldAdaptor; +import org.opensearch.search.stream.collector.PushStreamingCollector; import java.io.IOException; import java.util.ArrayList; @@ -206,16 +207,8 @@ public void run(VectorSchemaRoot root, StreamProducer.FlushSignal flushSignal) { new ArrayList<>(fieldVectors.values()), 0 ); - final ArrowCollector arrowDocIdCollector = new ArrowCollector(finalCollector, provider, collectionRoot, root, allocator, arrowFieldAdaptors, 10_000_000, flushSignal, searchContext.shardTarget().getShardId()); - -// final StreamingAggregator arrowDocIdCollector = new StreamingAggregator( -// (Aggregator) QueryCollectorContext.createQueryCollector(collectors), -// searchContext, -// root, -// 1_000_000, -// flushSignal, -// searchContext.shardTarget().getShardId() -// ); + final PushStreamingCollector arrowDocIdCollector = new PushStreamingCollector(finalCollector, provider, collectionRoot, root, allocator, arrowFieldAdaptors, 1_000_000, flushSignal, searchContext.shardTarget().getShardId()); +// final ArrowCollector arrowDocIdCollector = new ArrowCollector(finalCollector, provider, collectionRoot, root, allocator, arrowFieldAdaptors, 1_000_000, flushSignal, searchContext.shardTarget().getShardId()); try { searcher.addQueryCancellation(() -> { if (isCancelled[0] == true) { diff --git a/server/src/main/java/org/opensearch/search/stream/collector/ArrowCollector.java b/server/src/main/java/org/opensearch/search/stream/collector/ArrowCollector.java index fc1922e91f859..13b99a4aca320 100644 --- a/server/src/main/java/org/opensearch/search/stream/collector/ArrowCollector.java +++ b/server/src/main/java/org/opensearch/search/stream/collector/ArrowCollector.java @@ -20,14 +20,10 @@ import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.dictionary.Dictionary; -import org.apache.arrow.vector.dictionary.DictionaryEncoder; import org.apache.arrow.vector.dictionary.DictionaryProvider; -import org.apache.arrow.vector.util.TransferPair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.search.Collector; import org.apache.lucene.search.FilterCollector; @@ -41,16 +37,13 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.core.index.shard.ShardId; import org.opensearch.datafusion.DataFrame; -import org.opensearch.datafusion.DataFusion; import org.opensearch.datafusion.RecordBatchStream; +import org.opensearch.datafusion.SessionContext; import java.io.IOException; import java.security.AccessController; import java.security.PrivilegedAction; -import java.util.Arrays; -import java.util.HashMap; import java.util.List; -import java.util.Map; /** * Arrow collector for OpenSearch fields values @@ -99,9 +92,19 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept final int maxOrd = (int) dv.getValueCount(); // vector to hold ordinals - FieldVector vector = collectionRoot.getVector(arrowFieldAdaptor.fieldName); + final FieldVector vector = collectionRoot.getVector(arrowFieldAdaptor.fieldName); vector.setInitialCapacity(maxOrd); + StreamingCollector streamingCollector = new StreamingCollector( + new SessionContext(), + allocator, + collectionRoot, + batchSize, + TimeValue.timeValueMillis(1000 * 120), + term, + 500 + ); + final int[] currentRow = {0}; return new LeafCollector() { @@ -120,65 +123,55 @@ public void collect(int docId) throws IOException { } private void flushDocs() throws IOException { - - // set ord value count vector.setValueCount(currentRow[0]); collectionRoot.setRowCount(currentRow[0]); - logger.info("Starting flush of {} docs", currentRow[0]); - // DATAFUSION HANDOFF - AccessController.doPrivileged((PrivilegedAction) () -> { - try { - DataFrame dataFrame = DataFusion.from_vsr(allocator, collectionRoot, provider, term).get(); - RecordBatchStream recordBatchStream = dataFrame.getStream(allocator).get(); + // Offer batch to streaming collector + streamingCollector.offerBatch(collectionRoot); + + // Clear for next batch + collectionRoot.clear(); + currentRow[0] = 0; + } + + @Override + public void finish() throws IOException { + if (currentRow[0] > 0) { + flushDocs(); + } + + try { + // Get final results and process into bucketRoot + DataFrame results = streamingCollector.getResults(); + if (results != null) { + RecordBatchStream recordBatchStream = results.getStream(allocator).get(); VectorSchemaRoot root = recordBatchStream.getVectorSchemaRoot(); VarCharVector ordVector = (VarCharVector) bucketRoot.getVector("ord"); BigIntVector countVector = (BigIntVector) bucketRoot.getVector("count"); int row = 0; + while (recordBatchStream.loadNextBatch().join()) { UInt8Vector dfVector = (UInt8Vector) root.getVector("ord"); FieldVector cv = root.getVector("count"); - logger.info("DF VECTOR {}", dfVector); - logger.info("COUNT VECTOR {}", cv); - // Create transfer pair for the count vector -// TransferPair countTransfer = cv.makeTransferPair(countVector); - - // Transfer the counts -// countTransfer.transfer(); - // for each row + for (int i = 0; i < dfVector.getValueCount(); i++) { - // look up ord value BytesRef bytesRef = dv.lookupOrd(dfVector.get(i)); ordVector.setSafe(row, bytesRef.bytes, 0, bytesRef.length); countVector.setSafe(row, ((BigIntVector) cv).get(i)); row++; } - ordVector.setValueCount(row); - countVector.setValueCount(row); - bucketRoot.setRowCount(row); - flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000 * 120)); - row = 0; - ordVector.clear(); - countVector.clear();; - bucketRoot.setRowCount(0); } -// flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000 * 120)); - recordBatchStream.close(); // clean up DF pointers - collectionRoot.clear(); - dataFrame.close(); - } catch (Exception e) { - logger.error("eh", e); - throw new RuntimeException(e); + ordVector.setValueCount(row); + countVector.setValueCount(row); + bucketRoot.setRowCount(row); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000 * 120)); + recordBatchStream.close(); + results.close(); } - currentRow[0] = 0; - return null; - }); - } - - @Override - public void finish() throws IOException { - if (currentRow[0] > 0) { - flushDocs(); + } catch (Exception e) { + logger.error("Fack", e); + } finally { + streamingCollector.close(); } } diff --git a/server/src/main/java/org/opensearch/search/stream/collector/ArrowStreamingCollector.java b/server/src/main/java/org/opensearch/search/stream/collector/ArrowStreamingCollector.java new file mode 100644 index 0000000000000..944f038db1dbb --- /dev/null +++ b/server/src/main/java/org/opensearch/search/stream/collector/ArrowStreamingCollector.java @@ -0,0 +1,195 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.stream.collector;/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.FilterCollector; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.BytesRef; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.datafusion.DataFrame; +import org.opensearch.datafusion.DataFusion; +import org.opensearch.datafusion.RecordBatchStream; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.List; + +/** + * Arrow collector for OpenSearch fields values + */ +@ExperimentalApi +public class ArrowStreamingCollector extends FilterCollector { + + private final VectorSchemaRoot collectionRoot; + private final DictionaryProvider provider; + private final BufferAllocator allocator; + List fields; + private final VectorSchemaRoot bucketRoot; + private final StreamProducer.FlushSignal flushSignal; + public static Logger logger = LogManager.getLogger(ArrowStreamingCollector.class); + // Pre-allocate reusable buffers + private byte[] terms; + private int batchSize; + + public ArrowStreamingCollector( + Collector in, + DictionaryProvider provider, + VectorSchemaRoot collectionRoot, + VectorSchemaRoot root, + BufferAllocator allocator, + List fields, + int batchSize, + StreamProducer.FlushSignal flushSignal, + ShardId shardId + ) { + super(in); + this.provider = provider; + this.allocator = allocator; + this.fields = fields; + this.bucketRoot = root; + this.collectionRoot = collectionRoot; + this.flushSignal = flushSignal; + this.batchSize = batchSize; + // Pre-allocate arrays + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + ArrowFieldAdaptor arrowFieldAdaptor = fields.get(0); + String term = arrowFieldAdaptor.fieldName; + SortedSetDocValues dv = ((ArrowFieldAdaptor.SortedDocValuesType) arrowFieldAdaptor.getDocValues(context.reader())).getSortedDocValues(); + final int maxOrd = (int) dv.getValueCount(); + + // vector to hold ordinals + FieldVector vector = collectionRoot.getVector(arrowFieldAdaptor.fieldName); + vector.setInitialCapacity(maxOrd); + + final int[] currentRow = {0}; + return new LeafCollector() { + + @Override + public void collect(int docId) throws IOException { + if (currentRow[0] >= batchSize) { + flushDocs(); + } + // dump all the ords into an arrow vector, df will aggregate on these and then decode + dv.advance(docId); + for (int i = 0; i < dv.docValueCount(); i++) { + long ord = dv.nextOrd(); + ((UInt8Vector) vector).setSafe(currentRow[0], ord); + } + currentRow[0] += dv.docValueCount(); + } + + private void flushDocs() throws IOException { + + // set ord value count + vector.setValueCount(currentRow[0]); + collectionRoot.setRowCount(currentRow[0]); + logger.info("Starting flush of {} docs", currentRow[0]); + + // DATAFUSION HANDOFF + AccessController.doPrivileged((PrivilegedAction) () -> { + try { + DataFrame dataFrame = DataFusion.from_vsr(allocator, collectionRoot, provider, term).get(); + RecordBatchStream recordBatchStream = dataFrame.getStream(allocator).get(); + VectorSchemaRoot root = recordBatchStream.getVectorSchemaRoot(); + VarCharVector ordVector = (VarCharVector) bucketRoot.getVector("ord"); + BigIntVector countVector = (BigIntVector) bucketRoot.getVector("count"); + int row = 0; + while (recordBatchStream.loadNextBatch().join()) { + UInt8Vector dfVector = (UInt8Vector) root.getVector("ord"); + FieldVector cv = root.getVector("count"); + logger.info("DF VECTOR {}", dfVector); + logger.info("COUNT VECTOR {}", cv); + // Create transfer pair for the count vector +// TransferPair countTransfer = cv.makeTransferPair(countVector); + + // Transfer the counts +// countTransfer.transfer(); + // for each row + for (int i = 0; i < dfVector.getValueCount(); i++) { + // look up ord value + BytesRef bytesRef = dv.lookupOrd(dfVector.get(i)); + ordVector.setSafe(row, bytesRef.bytes, 0, bytesRef.length); + countVector.setSafe(row, ((BigIntVector) cv).get(i)); + row++; + } + ordVector.setValueCount(row); + countVector.setValueCount(row); + bucketRoot.setRowCount(row); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000 * 120)); + row = 0; + ordVector.clear(); + countVector.clear();; + bucketRoot.setRowCount(0); + } +// flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000 * 120)); + recordBatchStream.close(); // clean up DF pointers + collectionRoot.clear(); + dataFrame.close(); + } catch (Exception e) { + logger.error("eh", e); + throw new RuntimeException(e); + } + currentRow[0] = 0; + return null; + }); + } + + @Override + public void finish() throws IOException { + if (currentRow[0] > 0) { + flushDocs(); + } + } + + @Override + public void setScorer(Scorable scorable) throws IOException { + } + }; + } + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; + } + + @Override + public void setWeight(Weight weight) { + if (this.in != null) { + this.in.setWeight(weight); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/stream/collector/DataFusionAggregator.java b/server/src/main/java/org/opensearch/search/stream/collector/DataFusionAggregator.java new file mode 100644 index 0000000000000..dc60d1aad15fb --- /dev/null +++ b/server/src/main/java/org/opensearch/search/stream/collector/DataFusionAggregator.java @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.stream.collector; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.CDataDictionaryProvider; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.opensearch.datafusion.DataFrame; +import org.opensearch.datafusion.ObjectResultCallback; +import org.opensearch.datafusion.SessionContext; + +import java.util.concurrent.CompletableFuture; + +public class DataFusionAggregator implements AutoCloseable { + static { + System.loadLibrary("datafusion_jni"); + } + private final SessionContext context; + private final long ptr; + private final DictionaryProvider dictionaryProvider; + + public DataFusionAggregator(String term, int batchSize) { + this.context = new SessionContext(batchSize); + this.ptr = create(context.getPointer(), term); + this.dictionaryProvider = new CDataDictionaryProvider(); + } + + // use default DF batch size (8192) + public DataFusionAggregator(String term) { + this.context = new SessionContext(); + this.ptr = create(context.getPointer(), term); + this.dictionaryProvider = new CDataDictionaryProvider(); + } + + public CompletableFuture pushBatch(BufferAllocator allocator, VectorSchemaRoot root) { + CompletableFuture result = new CompletableFuture<>(); + + try { + ArrowArray array = ArrowArray.allocateNew(allocator); + ArrowSchema schema = ArrowSchema.allocateNew(allocator); + + Data.exportVectorSchemaRoot(allocator, root, dictionaryProvider, array, schema); + + pushBatch( + context.getRuntime(), + ptr, + array.memoryAddress(), + schema.memoryAddress(), + (String errString, long ptr) -> { + if (errString != null && !errString.isEmpty()) { + result.completeExceptionally(new RuntimeException(errString)); + } else { + result.complete(null); + } + }); + } catch (Exception e) { + result.completeExceptionally(e); + } + return result; + } + + public CompletableFuture getResults(int limit) { + CompletableFuture result = new CompletableFuture<>(); + getResults( + context.getRuntime(), + ptr, + limit, + (String errString, long ptr) -> { + if (errString != null && !errString.isEmpty()) { + result.completeExceptionally(new RuntimeException(errString)); + } else if (ptr == 0) { + result.complete(null); + } else { + result.complete(new DataFrame(context, ptr)); + } + }); + return result; + } + + @Override + public void close() { + destroy(ptr); + } + + private static native long create(long ctx, String term); + private static native void pushBatch(long runtime, long ptr, long schema, long array, ObjectResultCallback callback); + private static native void getResults(long runtime, long ptr, int limit, ObjectResultCallback callback); + private static native void destroy(long ptr); + +} diff --git a/server/src/main/java/org/opensearch/search/stream/collector/PushStreamingCollector.java b/server/src/main/java/org/opensearch/search/stream/collector/PushStreamingCollector.java new file mode 100644 index 0000000000000..d5797c2934e67 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/stream/collector/PushStreamingCollector.java @@ -0,0 +1,193 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.stream.collector;/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.search.Collector; +import org.apache.lucene.search.FilterCollector; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.Scorable; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.BytesRef; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.datafusion.DataFrame; +import org.opensearch.datafusion.RecordBatchStream; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.List; + +/** + * Arrow collector for OpenSearch fields values + */ +@ExperimentalApi +public class PushStreamingCollector extends FilterCollector { + + private final VectorSchemaRoot collectionRoot; + private final DictionaryProvider provider; + private final BufferAllocator allocator; + List fields; + private final VectorSchemaRoot bucketRoot; + private final StreamProducer.FlushSignal flushSignal; + public static Logger logger = LogManager.getLogger(PushStreamingCollector.class); + // Pre-allocate reusable buffers + private byte[] terms; + private int batchSize; + + public PushStreamingCollector( + Collector in, + DictionaryProvider provider, + VectorSchemaRoot collectionRoot, + VectorSchemaRoot root, + BufferAllocator allocator, + List fields, + int batchSize, + StreamProducer.FlushSignal flushSignal, + ShardId shardId + ) { + super(in); + this.provider = provider; + this.allocator = allocator; + this.fields = fields; + this.bucketRoot = root; + this.collectionRoot = collectionRoot; + this.flushSignal = flushSignal; + this.batchSize = batchSize; + // Pre-allocate arrays + } + + @Override + public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { + ArrowFieldAdaptor arrowFieldAdaptor = fields.get(0); + String term = arrowFieldAdaptor.fieldName; + SortedSetDocValues dv = ((ArrowFieldAdaptor.SortedDocValuesType) arrowFieldAdaptor.getDocValues(context.reader())).getSortedDocValues(); + final int maxOrd = (int) dv.getValueCount(); + + // vector to hold ordinals + FieldVector vector = collectionRoot.getVector(arrowFieldAdaptor.fieldName); + vector.setInitialCapacity(maxOrd); + DataFusionAggregator aggregator = new DataFusionAggregator(term); + + final int[] currentRow = {0}; + return new LeafCollector() { + + @Override + public void collect(int docId) throws IOException { + if (currentRow[0] >= batchSize) { + pushBatch(); + } + // dump all the ords into an arrow vector, df will aggregate on these and then decode + dv.advance(docId); + for (int i = 0; i < dv.docValueCount(); i++) { + long ord = dv.nextOrd(); + ((UInt8Vector) vector).setSafe(currentRow[0], ord); + } + currentRow[0] += dv.docValueCount(); + } + + private void pushBatch() throws IOException { + vector.setValueCount(currentRow[0]); + collectionRoot.setRowCount(currentRow[0]); + logger.info("Starting flush of {} docs", currentRow[0]); + + AccessController.doPrivileged((PrivilegedAction) () -> { + try { + // Push batch to streaming aggregator + aggregator.pushBatch(allocator, collectionRoot).get(); + collectionRoot.clear(); + currentRow[0] = 0; + } catch (Exception e) { + throw new RuntimeException(e); + } + return null; + }); + } + + @Override + public void finish() throws IOException { + if (currentRow[0] > 0) { + pushBatch(); + } + + try { + // Get final results and process into bucketRoot + DataFrame results = aggregator.getResults(500).get(); + if (results != null) { + RecordBatchStream recordBatchStream = results.getStream(allocator).get(); + VectorSchemaRoot root = recordBatchStream.getVectorSchemaRoot(); + VarCharVector ordVector = (VarCharVector) bucketRoot.getVector("ord"); + BigIntVector countVector = (BigIntVector) bucketRoot.getVector("count"); + int row = 0; + + while (recordBatchStream.loadNextBatch().join()) { + UInt8Vector dfVector = (UInt8Vector) root.getVector("ord"); + FieldVector cv = root.getVector("count"); + + + for (int i = 0; i < dfVector.getValueCount(); i++) { + BytesRef bytesRef = dv.lookupOrd(dfVector.get(i)); + ordVector.setSafe(row, bytesRef.bytes, 0, bytesRef.length); + countVector.setSafe(row, ((BigIntVector) cv).get(i)); + row++; + } + } + logger.info(ordVector.get(0)); + ordVector.setValueCount(row); + countVector.setValueCount(row); + bucketRoot.setRowCount(row); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000 * 120)); + recordBatchStream.close(); + results.close(); + } + + aggregator.close(); + } catch (Exception e) { + logger.error("Fack", e); + } + } + + @Override + public void setScorer(Scorable scorable) throws IOException { + } + }; + } + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; + } + + @Override + public void setWeight(Weight weight) { + if (this.in != null) { + this.in.setWeight(weight); + } + } +} diff --git a/server/src/main/java/org/opensearch/search/stream/collector/StreamingCollector.java b/server/src/main/java/org/opensearch/search/stream/collector/StreamingCollector.java new file mode 100644 index 0000000000000..79fd30ec6526c --- /dev/null +++ b/server/src/main/java/org/opensearch/search/stream/collector/StreamingCollector.java @@ -0,0 +1,135 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.stream.collector; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.CDataDictionaryProvider; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.datafusion.DataFrame; +import org.opensearch.datafusion.SessionContext; + +import java.util.List; + +public class StreamingCollector implements AutoCloseable { + private final SessionContext context; + private final BufferAllocator allocator; + private VectorSchemaRoot currentRoot; + private final Object lock = new Object(); + private boolean finished = false; + private int batchSize; + private TimeValue waitTimeout; + private DataFrame dataFrame; + + public StreamingCollector( + SessionContext context, + BufferAllocator allocator, + VectorSchemaRoot root, + int batchSize, + TimeValue waitTimeout, + String term, + int limit) { + this.context = context; + this.allocator = allocator; + this.currentRoot = root; + this.batchSize = batchSize; + this.waitTimeout = waitTimeout; + ArrowSchema schema = ArrowSchema.allocateNew(allocator); + try { + Data.exportSchema(allocator, root.getSchema(), new CDataDictionaryProvider(), schema); + this.dataFrame = new DataFrame(context, create(context.getPointer(), this, limit, term, schema.memoryAddress())); + } finally { + schema.close(); + } + } + + public DataFrame getResults() { + synchronized (lock) { + finish(); + return dataFrame; + } + } + + // Inner class to hold both pointers + public static class BatchPointers { + public long schemaPtr; + public long arrayPtr; + + public BatchPointers(long schemaPtr, long arrayPtr) { + this.schemaPtr = schemaPtr; + this.arrayPtr = arrayPtr; + } + } + + // Called by DataFusion through JNI + public BatchPointers getNextBatch() { + synchronized (lock) { + if (finished && currentRoot.getRowCount() == 0) { + return null; + } + + if (currentRoot.getRowCount() > 0) { + VectorSchemaRoot batchToReturn = currentRoot; + ArrowArray array = ArrowArray.allocateNew(allocator); + ArrowSchema schema = ArrowSchema.allocateNew(allocator); + + try { + Data.exportVectorSchemaRoot(allocator, batchToReturn, new CDataDictionaryProvider(), array, schema); + return new BatchPointers(schema.memoryAddress(), array.memoryAddress()); + } catch (Exception e) { + // Handle any export errors + array.close(); + schema.close(); + return null; + } + } + + try { + // Wait for more data with timeout + lock.wait(waitTimeout.getMillis()); + return getNextBatch(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + } + } + + // Called by ArrowCollector when a batch is ready + public void offerBatch(VectorSchemaRoot batch) { + synchronized (lock) { + this.currentRoot = batch; + lock.notify(); + } + } + + public void finish() { + synchronized (lock) { + finished = true; + lock.notify(); + } + } + + @Override + public void close() { + if (currentRoot != null) { + currentRoot.close(); + } + try { + dataFrame.close(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static native long create(long ctx, Object collector, int limit, String term, long schema); +}