diff --git a/lakesoul-spark/pom.xml b/lakesoul-spark/pom.xml index 241300a67..f880dd64d 100644 --- a/lakesoul-spark/pom.xml +++ b/lakesoul-spark/pom.xml @@ -142,6 +142,14 @@ SPDX-License-Identifier: Apache-2.0 + + + io.jhdf + jhdf + 0.6.10 + + + org.apache.spark @@ -322,55 +330,55 @@ SPDX-License-Identifier: Apache-2.0 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - + + + + + @@ -646,4 +654,4 @@ SPDX-License-Identifier: Apache-2.0 - + \ No newline at end of file diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetColumnarOutputWriter.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetColumnarOutputWriter.scala index 5c3de0663..c3b7e2841 100644 --- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetColumnarOutputWriter.scala +++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetColumnarOutputWriter.scala @@ -20,7 +20,6 @@ import scala.collection.JavaConverters.{asScalaBufferConverter, seqAsJavaListCon class NativeParquetCompactionColumnarOutputWriter(path: String, dataSchema: StructType, timeZoneId: String, context: TaskAttemptContext) extends NativeParquetOutputWriter(path, dataSchema, timeZoneId, context) { - override def write(row: InternalRow): Unit = { if (!row.isInstanceOf[ArrowFakeRow]) { throw new IllegalStateException( diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetOutputWriter.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetOutputWriter.scala index c4d599ef8..c55275424 100644 --- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetOutputWriter.scala +++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetOutputWriter.scala @@ -28,6 +28,8 @@ class NativeParquetOutputWriter(val path: String, dataSchema: StructType, timeZo protected val nativeIOWriter: NativeIOWriter = new NativeIOWriter(arrowSchema) + + GlutenUtils.setArrowAllocator(nativeIOWriter) nativeIOWriter.setRowGroupRowNumber(NATIVE_IO_WRITE_MAX_ROW_GROUP_SIZE) nativeIOWriter.addFile(path) diff --git a/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala b/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala index acf9d6282..50002ad5a 100644 --- a/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala +++ b/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala @@ -4,6 +4,8 @@ package org.apache.spark.sql.lakesoul.commands +import com.dmetasoul.lakesoul.meta.LakeSoulOptions +import com.dmetasoul.lakesoul.tables.LakeSoulTable import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SparkSession} import org.apache.spark.sql.lakesoul.test.{LakeSoulSQLCommandTest, LakeSoulTestBeforeAndAfterEach, LakeSoulTestSparkSession, LakeSoulTestUtils} import org.apache.spark.util.Utils @@ -15,6 +17,19 @@ import org.apache.spark.sql.lakesoul.sources.LakeSoulSQLConf import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession} import org.junit.runner.RunWith import org.scalatestplus.junit.JUnitRunner +import io.jhdf.HdfFile +import io.jhdf.api.Dataset +import org.apache.commons.lang3.ArrayUtils +import org.apache.spark.sql.types.{ArrayType, ByteType, FloatType, IntegerType, LongType, StructField, StructType} +import org.apache.commons.lang3.ArrayUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions.{collect_list, sum, udf} + +import java.nio.file.Paths +import scala.collection.mutable.ListBuffer +import scala.concurrent.duration.DurationLong +import scala.math.{pow, sqrt} @RunWith(classOf[JUnitRunner]) class MergeIntoSQLSuite extends QueryTest @@ -63,6 +78,208 @@ class MergeIntoSQLSuite extends QueryTest } } +// test("test lsh"){ +// val filepath = "/Users/beidu/Documents/dataset/glove-200-angular.hdf5" +// val trainPath = "/Users/beidu/Documents/LakeSoul/train" +// val testPath = "/Users/beidu/Documents/LakeSoul/test" +// +// println(filepath) +// val spark = SparkSession.builder +// .appName("Array to RDD") +// .master("local[*]") +// .getOrCreate() +// try{ +// val hdfFile = new HdfFile(Paths.get(filepath)) +// +// val trainDataset = hdfFile.getDatasetByPath("train") +// val testDataset = hdfFile.getDatasetByPath("test") +// val neighborDataset = hdfFile.getDatasetByPath("neighbors") +// val trainData = trainDataset.getData() +// val testData = testDataset.getData() +// val neighborData = neighborDataset.getData() +// println(trainData) +// var float2DDataNeighbor: Array[Array[Int]] = null +// neighborData match { +// case data:Array[Array[Int]] => +// float2DDataNeighbor = data +// case _ => +// println("not") +// } +// // the smaller the Hamming distance,the greater the similarity +// val calculateHammingDistanceUDF = udf((trainLSH: Seq[Long], testLSH: Seq[Long]) => { +// require(trainLSH.length == testLSH.length, "The input sequences must have the same length") +// trainLSH.zip(testLSH).map { case (train, test) => +// java.lang.Long.bitCount(train ^ test) +// }.sum +// }) +// // the smaller the Euclidean distance,the greater the similarity +// val calculateEuclideanDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => { +// require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length") +// sqrt(trainEmbedding.zip(testEmbedding).map{case (train,test) => +// pow(train - test,2) }.sum) +// }) +// //the greater the Cosine distance,the greater the similarity +// val calculateCosineDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => { +// require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length") +// trainEmbedding.zip(testEmbedding).map{case (train,test) => +// train * test}.sum / (sqrt(trainEmbedding.map{train => train * train}.sum) * sqrt(testEmbedding.map{test => test * test}.sum)) +// }) +// //the smaller the Jaccard distance,the greater the similarity +// val calculateJaccardDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => { +// require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length") +// val anb = testEmbedding.intersect(trainEmbedding).distinct +// val aub = testEmbedding.union(trainEmbedding).distinct +// val jaccardCoefficient = anb.length.toDouble / aub.length +// 1 - jaccardCoefficient +// }) +// spark.udf.register("calculateHammingDistance",calculateHammingDistanceUDF) +// spark.udf.register("calculateEuclideanDistance",calculateEuclideanDistanceUDF) +// spark.udf.register("calculateCosineDistance",calculateCosineDistanceUDF) +// spark.udf.register("calculateJaccardDistance",calculateJaccardDistanceUDF) +//// println(float2DDataNeighbor.length) +// trainData match { +// case float2DData:Array[Array[Float]] => +// val classIds = (1 to float2DData.length).toArray +// val schema = StructType(Array( +// StructField("IndexId",IntegerType,true), +// StructField("Embedding",ArrayType(FloatType),true), +// StructField("LSH",ArrayType(LongType),true) +// )) +// val rows = float2DData.zip(classIds).map { +// case (embedding,indexId)=> +// Row(indexId,embedding,null) +// } +// val df = spark.createDataFrame(spark.sparkContext.parallelize(rows),schema) +// df.write.format("lakesoul") +// .option("hashPartitions", "IndexId") +// .option("hashBucketNum", 4) +// .option(LakeSoulOptions.SHORT_TABLE_NAME,"trainData") +// .mode("Overwrite").save(trainPath) +//// val startTime1 = System.nanoTime() +// val lakeSoulTable = LakeSoulTable.forPath(trainPath) +// lakeSoulTable.compaction() +// +// testData match { +// case float2DTestData:Array[Array[Float]] => +// val classIdsTest = (1 to float2DTestData.length).toArray +// val schemaTest = StructType(Array( +// StructField("IndexId",IntegerType,true), +// StructField("Embedding",ArrayType(FloatType),true), +// StructField("LSH",ArrayType(LongType),true) +// )) +// val rowsTest = float2DTestData.zip(classIdsTest).map{ +// case (embedding,indexId) => +// Row(indexId,embedding,null) +// } +// +// val num = 50 +// val dfTest = spark.createDataFrame(spark.sparkContext.parallelize(rowsTest),schemaTest).limit(num) +// dfTest.write.format("lakesoul") +// .option("hashPartitions","IndexId") +// .option("hashBucketNum",4) +// .option(LakeSoulOptions.SHORT_TABLE_NAME,"testData") +// .mode("Overwrite").save(testPath) +// val lakeSoulTableTest = LakeSoulTable.forPath(testPath) +// +// lakeSoulTableTest.compaction() +//// val endTime1 = System.nanoTime() +//// val duration1 = (endTime1 - startTime1).nanos +//// println(s"time:${duration1.toMillis}") +//// val lshTrain = sql("select LSH from trainData") +//// val lshTest = sql("select LSH from testData") +// +//// val arr = Array(1,5,10,20,40,60,80,100,150,200,250,300) +//// for(n <- arr) { +// val n = 300 +// val topk = 100 +// val topkFirst = n * topk +// +// // val result = sql("select testData.IndexId as indexId,trainData.LSH as trainLSH,testData.LSH as testLSH," + +// // "calculateHammingDistance(testData.LSH,trainData.LSH) AS hamming_distance " + +// // "from testData " + +// // "cross join trainData " + +// // "order by indexId,hamming_distance") +// +// // val result = spark.sql(s""" +// // SELECT * +// // FROM ( +// // SELECT +// // testData.IndexId AS indexIdTest, +// // trainData.IndexId AS indexIdTrain, +// // trainData.LSH AS trainLSH, +// // testData.LSH AS testLSH, +// // calculateHammingDistance(testData.LSH, trainData.LSH) AS hamming_distance, +// // ROW_NUMBER() OVER (PARTITION BY testData.IndexId ORDER BY calculateHammingDistance(testData.LSH, trainData.LSH) asc) AS rank +// // FROM testData +// // CROSS JOIN trainData +// // ) ranked +// // WHERE rank <= $topk +// // """).groupBy("indexIdTest").agg(collect_list("indexIdTrain").alias("indexIdTrainList")) +// val startTime = System.nanoTime() +// val result = spark.sql( +// s""" +// SELECT * +// FROM ( +// SELECT +// testData.IndexId AS indexIdTest, +// trainData.IndexId AS indexIdTrain, +// testData.Embedding as EmbeddingTest, +// trainData.Embedding as EmbeddingTrain, +// ROW_NUMBER() OVER (PARTITION BY testData.IndexId ORDER BY calculateHammingDistance(testData.LSH, trainData.LSH) asc) AS rank +// FROM testData +// CROSS JOIN trainData +// ) ranked +// WHERE rank <= $topkFirst +// """) +// result.createOrReplaceTempView("rank") +// val reResult = spark.sql( +// s""" +// SELECT * +// FROM ( +// SELECT +// rank.indexIdTest, +// rank.indexIDTrain, +// ROW_NUMBER() OVER(PARTITION BY rank.indexIdTest ORDER BY calculateEuclideanDistance(rank.EmbeddingTest,rank.EmbeddingTrain) asc) AS reRank +// FROM rank +// ) reRanked +// WHERE reRank <= $topk +// """).groupBy("indexIdTest").agg(collect_list("indexIdTrain").alias("indexIdTrainList")) +// +// +// val endTime = System.nanoTime() +// val duration = (endTime - startTime).nanos +// println(s"time for query n4topk ${n} :${duration.toMillis} milliseconds") +// +// val startTime2 = System.nanoTime() +// +//// val (totalRecall, count) = reResult.map(row => { +//// val indexIdTest = row.getAs[Int]("indexIdTest") +//// val indexIdTrainList: Array[Int] = row.getAs[Seq[Int]]("indexIdTrainList").toArray +//// val updatedList = indexIdTrainList.map(_ - 1) +//// val count = float2DDataNeighbor(indexIdTest - 1).take(topk).count(updatedList.contains) +//// val recall = (count * 1.0 / topk) +//// (recall, 1) +//// }).reduce((acc1, acc2) => { +//// (acc1._1 + acc2._1, acc1._2 + acc2._2) +//// }) +//// println(totalRecall / count) +// val endTime2 = System.nanoTime() +// val duration2 = (endTime2 - startTime2).nanos +// println(s"time for sort:${duration2.toMillis} milliseconds") +//// } +// } +// case _ => +// println("unexpected data type") +// case _ => +// println("unexpected data type") +// } +// } +// finally { +// +// } +// +// } + test("merge into table with hash partition -- supported case") { initHashTable() withViewNamed(Seq((20201102, 4, 5)).toDF("range", "hash", "value"), "source_table") { diff --git a/rust/Cargo.lock b/rust/Cargo.lock index a2b4321d3..4e91a2ecc 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" @@ -818,6 +818,25 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" @@ -1836,12 +1855,14 @@ dependencies = [ "hex", "lazy_static", "log", + "ndarray", "object_store", "parking_lot", "parquet", "prost", "proto", "rand", + "rayon", "serde", "serde_json", "smallvec", @@ -2072,6 +2093,16 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "md-5" version = "0.10.6" @@ -2148,6 +2179,19 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -2712,6 +2756,32 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.3" diff --git a/rust/lakesoul-io/Cargo.toml b/rust/lakesoul-io/Cargo.toml index d1b12e5cd..f9558ef18 100644 --- a/rust/lakesoul-io/Cargo.toml +++ b/rust/lakesoul-io/Cargo.toml @@ -47,6 +47,10 @@ env_logger = "0.11" hex = "0.4" dhat = { version="0.3.3", optional = true } async-recursion = "1.1.1" +ndarray = "0.15.6" +#hdf5 = {version = "0.8.1"} +rayon = "1.10.0" + [features] hdfs = ["dep:hdrs", "dep:hdfs-sys"] diff --git a/rust/lakesoul-io/src/lakesoul_io_config.rs b/rust/lakesoul-io/src/lakesoul_io_config.rs index 4230e6cc7..eeeb3ed92 100644 --- a/rust/lakesoul-io/src/lakesoul_io_config.rs +++ b/rust/lakesoul-io/src/lakesoul_io_config.rs @@ -51,6 +51,9 @@ pub static OPTION_KEY_HASH_BUCKET_ID: &str = "hash_bucket_id"; pub static OPTION_KEY_CDC_COLUMN: &str = "cdc_column"; pub static OPTION_KEY_IS_COMPACTED: &str = "is_compacted"; pub static OPTION_KEY_MAX_FILE_SIZE: &str = "max_file_size"; +pub static OPTION_KEY_IS_LSH: &str = "is_lsh"; +pub static OPTION_KEY_NBITS: &str = "nbits"; +pub static OPTION_KEY_D: &str= "d"; #[derive(Debug, Derivative)] #[derivative(Default, Clone)] @@ -130,6 +133,10 @@ pub struct LakeSoulIOConfig { // max file size of bytes #[derivative(Default(value = "None"))] pub(crate) max_file_size: Option, + + // the seed for rng + #[derivative(Default(value = "1234"))] + pub(crate) seed: u64, } impl LakeSoulIOConfig { @@ -188,6 +195,18 @@ impl LakeSoulIOConfig { pub fn is_compacted(&self) -> bool { self.option(OPTION_KEY_IS_COMPACTED).map_or(false, |x| x.eq("true")) } + + pub fn is_lsh(&self) -> bool { + self.option(OPTION_KEY_IS_LSH).map_or(false,|x| x.eq("true")) + } + + pub fn nbits(&self) -> Option{ + self.option(OPTION_KEY_NBITS).map(|x| x.parse().unwrap()) + } + + pub fn d(&self) -> Option{ + self.option(OPTION_KEY_D).map(|x| x.parse().unwrap()) + } } #[derive(Derivative, Debug)] @@ -348,6 +367,11 @@ impl LakeSoulIOConfigBuilder { self } + pub fn with_seed(mut self,seed:u64) -> Self { + self.config.seed = seed; + self + } + pub fn build(self) -> LakeSoulIOConfig { self.config } @@ -517,7 +541,12 @@ fn register_object_store(path: &str, config: &mut LakeSoulIOConfig, runtime: &Ru Ok(joined_path) } } + // "file" => Ok(path.to_owned()), "file" => Ok(path.to_owned()), + // Support Windows drive letter paths like "c:" or "d:" + scheme if scheme.len() == 1 && scheme.chars().next().unwrap().is_ascii_alphabetic() => { + Ok(format!("file://{}", path)) + }, _ => Err(ObjectStore(object_store::Error::NotSupported { source: "FileSystem not supported".into(), })), @@ -652,5 +681,10 @@ mod tests { "file:///some/absolute/local/file2".to_string(), ] ); + let mut lakesoulconfigbuilder = LakeSoulIOConfigBuilder::from(conf.clone()); + let conf = lakesoulconfigbuilder.with_d(32 as u64).with_nbits(64 as u64).build(); + assert_eq!(conf.seed,1234 as u64); + assert_eq!(conf.d,Some(32)); + assert_eq!(conf.nbits,Some(64)); } } diff --git a/rust/lakesoul-io/src/lakesoul_writer.rs b/rust/lakesoul-io/src/lakesoul_writer.rs index 370a461d4..5d95ac711 100644 --- a/rust/lakesoul-io/src/lakesoul_writer.rs +++ b/rust/lakesoul-io/src/lakesoul_writer.rs @@ -4,8 +4,9 @@ use std::borrow::Borrow; use std::collections::HashMap; -use std::sync::Arc; +use std:: ptr; +use arrow::datatypes::{DataType, Field}; use arrow_array::RecordBatch; use arrow_schema::SchemaRef; use datafusion_common::{DataFusionError, Result}; @@ -13,6 +14,12 @@ use rand::distributions::DistString; use tokio::runtime::Runtime; use tokio::sync::Mutex; use tracing::debug; +use rand::{Rng,SeedableRng,rngs::StdRng}; +use ndarray::{concatenate, s, Array2, Axis,ArrayView2}; +use arrow::array::{Array as OtherArray, Float64Array, ListArray,Float32Array,Int64Array,GenericListArray}; +use arrow::buffer::OffsetBuffer; +use std::sync::Arc; + use crate::async_writer::{AsyncBatchWriter, MultiPartAsyncWriter, PartitioningAsyncWriter, SortAsyncWriter, WriterFlushResult}; use crate::helpers::{get_batch_memory_size, get_file_exist_col}; @@ -36,11 +43,10 @@ impl SyncSendableMutableLakeSoulWriter { pub fn try_new(config: LakeSoulIOConfig, runtime: Runtime) -> Result { let runtime = Arc::new(runtime); runtime.clone().block_on(async move { - let writer_config = config.clone(); let mut config = config.clone(); + let writer_config = config.clone(); let writer = Self::create_writer(writer_config).await?; let schema = writer.schema(); - if let Some(mem_limit) = config.mem_limit() { if config.use_dynamic_partition { config.max_file_size = Some((mem_limit as f64 * 0.15) as u64); @@ -123,7 +129,31 @@ impl SyncSendableMutableLakeSoulWriter { // for ffi callers pub fn write_batch(&mut self, record_batch: RecordBatch) -> Result<()> { let runtime = self.runtime.clone(); - runtime.block_on(async move { self.write_batch_async(record_batch, false).await }) + if record_batch.num_rows() == 0{ + runtime.block_on(async move { self.write_batch_async(record_batch, false).await }) + } + else{ + if self.config.is_lsh() { + let projection: ListArray= if let Some(array) = record_batch.column_by_name("Embedding") { + let embedding = array.as_any().downcast_ref::().unwrap(); + let projection_result:Result = self.lsh(&Some(embedding.clone())); + projection_result.unwrap().into() + + } else { + eprintln!("there is no column named Embedding"); + return Ok(()) ; + }; + + let mut new_columns = record_batch.columns().to_vec(); + new_columns[record_batch.schema().index_of("LSH").unwrap()] = Arc::new(projection.clone()); + let new_record_batch = RecordBatch::try_new(self.config.target_schema(),new_columns).unwrap(); + + runtime.block_on(async move { self.write_batch_async(new_record_batch, false).await }) + } + else{ + runtime.block_on(async move { self.write_batch_async(record_batch, false).await }) + } + } } #[async_recursion::async_recursion(?Send)] @@ -263,29 +293,215 @@ impl SyncSendableMutableLakeSoulWriter { pub fn get_schema(&self) -> SchemaRef { self.schema.clone() } + + // generate random digit with fixed seed + fn create_rng_with_seed(&self) -> StdRng { + StdRng::seed_from_u64(self.config.seed) + } + + // generate random planes + fn generate_random_array(&self) -> Result,String>{ + match self.config.nbits() { + Some(nbits) if nbits > 0 => { + match self.config.d() { + Some(d) if d > 0 => { + let mut rng = self.create_rng_with_seed(); +// assert!(d >= nbits,"the dimension of the embedding must be greater than nbits"); + let random_array = Array2::from_shape_fn((nbits as usize, d as usize), |_| rng.gen_range(-1.0..1.0)); + Ok(random_array) + } + Some(_) => Err("the dimension you input in the config must be greater than 0".to_string()), + None => Err("the dimension you input in the config is None".to_string()), + } + } + Some(_) => Err("the number of bits used for binary encoding must be greater than 0".to_string()), + None => Err("the number of bits used for binary encoding must be greater than 0".to_string()), + } + } + + // project the input data + fn project(&self,input_data:&ListArray,random_plans:&Result,String>) -> Result,String>{ + let list_len = input_data.len(); + assert!(list_len > 0,"the length of input data must be large than 0"); + let dimension_len = input_data.value(0).len(); + + let input_values = if let Some(values) = input_data.values().as_any().downcast_ref::(){ + let float64_values: Vec = values.iter().map(|x| x.unwrap() as f64).collect(); + Float64Array::from(float64_values) + } else if let Some(values) = input_data.values().as_any().downcast_ref::(){ + values.clone() + } + else { + return Err("Unsupported data type in ListArray.".to_string()); + }; + + let mut re_array2 = Array2::::zeros((list_len,dimension_len)); + + unsafe { + let data_ptr = input_values.values().as_ptr(); + let data_size = list_len * dimension_len; + ptr::copy_nonoverlapping(data_ptr,re_array2.as_mut_ptr(),data_size); + } + match random_plans { + Ok(random_array) => { + assert!(re_array2.shape()[1] == random_array.shape()[1],"the dimension corresponding to the matrix must be the same"); +// let final_result = re_array2.dot(&random_array.t()); + let batch_size = 1000; + let num_batches = re_array2.shape()[0] / batch_size; + let remaining_rows = re_array2.shape()[0] % batch_size; + let mut result = vec![]; + + for batch_idx in 0..num_batches{ + let batch_start = batch_idx * batch_size; + let batch_end = batch_start + batch_size; + + let current_batch = re_array2.slice(s![batch_start..batch_end,..]); + let random_projection = current_batch.dot(&random_array.t()); + + result.push(random_projection); + } + + if remaining_rows > 0{ + let batch_start = num_batches * batch_size; + let batch_end = batch_start + remaining_rows; + + let remaining_batch = re_array2.slice(s![batch_start..batch_end,..]); + + let random_projection = remaining_batch.dot(&random_array.t()); + + result.push(random_projection); + } + + let result_views: Vec> = result.iter().map(|arr| ArrayView2::from(arr)).collect(); + + + let final_result = concatenate(Axis(0),&result_views).expect("Failed to concatenate results"); + + // println!("{:}",end); + + Ok(final_result) + } + Err(e) => { + eprintln!("Error:{}",e); + Err(e.to_string()) + } + } + } + // add the input data with their projection + pub fn lsh(&self,input_embedding:&Option) -> Result + where + { + match input_embedding { + Some(data) => { + let random_plans = self.generate_random_array(); + let data_projection = self.project(data,&random_plans).unwrap(); + match Ok(data_projection) { + Ok(mut projection) => { + projection.mapv_inplace(|x| if x >= 0.0 {1.0} else {0.0}); + let convert:Vec> = Self::convert_array_to_u64_vec(&projection); + Ok(Self::convert_vec_to_byte_u64(convert)) + } + Err(e) => { + eprintln!("Error:{}",e); + Err(e) + } + } + } + None => { + Err("the input data is None".to_string()) + } + } + } + + fn convert_vec_to_byte_u64(array:Vec>) -> ListArray { + let field = Arc::new(Field::new("element", DataType::Int64,true)); + let values = Int64Array::from(array.iter().flatten().map(|&x| x as i64).collect::>()); + let mut offsets = vec![]; + for subarray in array{ + let current_offset = subarray.len() as usize; + offsets.push(current_offset); + } + let offsets_buffer = OffsetBuffer::from_lengths(offsets); + let list_array = GenericListArray::try_new(field,offsets_buffer,Arc::new(values),None).expect("can not list_array"); + list_array + + } + + fn convert_array_to_u64_vec(array:&Array2) -> Vec> + where + T: TryFrom + Copy, + >::Error: std::fmt::Debug, + { + let bianry_encode:Vec> = array + .axis_iter(ndarray::Axis(0)) + .map(|row|{ + let mut results = Vec::new(); + let mut acc = 0u64; + + for(i,&bit) in row.iter().enumerate(){ + acc = (acc << 1) | bit as u64; + if(i + 1) % 64 == 0{ + results.push(acc); + acc = 0; + } + } + if row.len() % 64 != 0{ + results.push(acc); + } + results + }) + .collect(); + + bianry_encode + .into_iter() + .map(|inner_vec|{ + inner_vec + .into_iter() + .map(|x| T::try_from(x).unwrap()) + .collect() + }).collect() + } + } #[cfg(test)] mod tests { + use arrow_array::builder; + use datafusion::catalog::schema; + use hdf5::File as OtherFile; + use hdf5::Group; + use parquet::arrow::ArrowWriter; + use parquet::column; + use parquet::file::properties::WriterProperties; use crate::{ lakesoul_io_config::{LakeSoulIOConfigBuilder, OPTION_KEY_MEM_LIMIT}, lakesoul_reader::LakeSoulReader, lakesoul_writer::{AsyncBatchWriter, MultiPartAsyncWriter, SyncSendableMutableLakeSoulWriter}, }; + use arrow::{ - array::{ArrayRef, Int64Array}, + array::{ArrayRef, Int64Array,FixedSizeListArray,ArrayData}, record_batch::RecordBatch, }; use arrow_array::{Array, StringArray}; use arrow_schema::{DataType, Field, Schema}; - use datafusion::error::Result; + use datafusion::{error::Result, physical_expr::math_expressions}; use parquet::arrow::arrow_reader::ParquetRecordBatchReader; use rand::{distributions::DistString, Rng}; use std::{fs::File, sync::Arc}; use tokio::{runtime::Builder, time::Instant}; use tracing_subscriber::layer::SubscriberExt; + use arrow::buffer::{Buffer, NullBuffer}; use super::SortAsyncWriter; + use std::env; + use std::path::Path; + use ndarray::{Array2,s,Dim}; + use crate::helpers::get_batch_memory_size; + use parquet::file::reader::{FileReader,SerializedFileReader}; + use arrow::compute::sort; + use arrow::compute::sort_to_indices; + use arrow::compute::SortOptions; #[test] fn test_parquet_async_write() -> Result<()> { @@ -730,4 +946,5 @@ mod tests { // assert_eq!(num_rows * num_batch, actual_batch.num_rows()); Ok(()) } + }