diff --git a/lakesoul-spark/pom.xml b/lakesoul-spark/pom.xml
index 241300a67..f880dd64d 100644
--- a/lakesoul-spark/pom.xml
+++ b/lakesoul-spark/pom.xml
@@ -142,6 +142,14 @@ SPDX-License-Identifier: Apache-2.0
+
+
+ io.jhdf
+ jhdf
+ 0.6.10
+
+
+
org.apache.spark
@@ -322,55 +330,55 @@ SPDX-License-Identifier: Apache-2.0
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
-
-
-
-
+
+
+
+
+
@@ -646,4 +654,4 @@ SPDX-License-Identifier: Apache-2.0
-
+
\ No newline at end of file
diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetColumnarOutputWriter.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetColumnarOutputWriter.scala
index 5c3de0663..c3b7e2841 100644
--- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetColumnarOutputWriter.scala
+++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetColumnarOutputWriter.scala
@@ -20,7 +20,6 @@ import scala.collection.JavaConverters.{asScalaBufferConverter, seqAsJavaListCon
class NativeParquetCompactionColumnarOutputWriter(path: String, dataSchema: StructType, timeZoneId: String,
context: TaskAttemptContext)
extends NativeParquetOutputWriter(path, dataSchema, timeZoneId, context) {
-
override def write(row: InternalRow): Unit = {
if (!row.isInstanceOf[ArrowFakeRow]) {
throw new IllegalStateException(
diff --git a/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetOutputWriter.scala b/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetOutputWriter.scala
index c4d599ef8..c55275424 100644
--- a/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetOutputWriter.scala
+++ b/lakesoul-spark/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/NativeParquetOutputWriter.scala
@@ -28,6 +28,8 @@ class NativeParquetOutputWriter(val path: String, dataSchema: StructType, timeZo
protected val nativeIOWriter: NativeIOWriter = new NativeIOWriter(arrowSchema)
+
+
GlutenUtils.setArrowAllocator(nativeIOWriter)
nativeIOWriter.setRowGroupRowNumber(NATIVE_IO_WRITE_MAX_ROW_GROUP_SIZE)
nativeIOWriter.addFile(path)
diff --git a/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala b/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala
index acf9d6282..50002ad5a 100644
--- a/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala
+++ b/lakesoul-spark/src/test/scala/org/apache/spark/sql/lakesoul/commands/MergeIntoSQLSuite.scala
@@ -4,6 +4,8 @@
package org.apache.spark.sql.lakesoul.commands
+import com.dmetasoul.lakesoul.meta.LakeSoulOptions
+import com.dmetasoul.lakesoul.tables.LakeSoulTable
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SparkSession}
import org.apache.spark.sql.lakesoul.test.{LakeSoulSQLCommandTest, LakeSoulTestBeforeAndAfterEach, LakeSoulTestSparkSession, LakeSoulTestUtils}
import org.apache.spark.util.Utils
@@ -15,6 +17,19 @@ import org.apache.spark.sql.lakesoul.sources.LakeSoulSQLConf
import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession}
import org.junit.runner.RunWith
import org.scalatestplus.junit.JUnitRunner
+import io.jhdf.HdfFile
+import io.jhdf.api.Dataset
+import org.apache.commons.lang3.ArrayUtils
+import org.apache.spark.sql.types.{ArrayType, ByteType, FloatType, IntegerType, LongType, StructField, StructType}
+import org.apache.commons.lang3.ArrayUtils
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.functions.{collect_list, sum, udf}
+
+import java.nio.file.Paths
+import scala.collection.mutable.ListBuffer
+import scala.concurrent.duration.DurationLong
+import scala.math.{pow, sqrt}
@RunWith(classOf[JUnitRunner])
class MergeIntoSQLSuite extends QueryTest
@@ -63,6 +78,208 @@ class MergeIntoSQLSuite extends QueryTest
}
}
+// test("test lsh"){
+// val filepath = "/Users/beidu/Documents/dataset/glove-200-angular.hdf5"
+// val trainPath = "/Users/beidu/Documents/LakeSoul/train"
+// val testPath = "/Users/beidu/Documents/LakeSoul/test"
+//
+// println(filepath)
+// val spark = SparkSession.builder
+// .appName("Array to RDD")
+// .master("local[*]")
+// .getOrCreate()
+// try{
+// val hdfFile = new HdfFile(Paths.get(filepath))
+//
+// val trainDataset = hdfFile.getDatasetByPath("train")
+// val testDataset = hdfFile.getDatasetByPath("test")
+// val neighborDataset = hdfFile.getDatasetByPath("neighbors")
+// val trainData = trainDataset.getData()
+// val testData = testDataset.getData()
+// val neighborData = neighborDataset.getData()
+// println(trainData)
+// var float2DDataNeighbor: Array[Array[Int]] = null
+// neighborData match {
+// case data:Array[Array[Int]] =>
+// float2DDataNeighbor = data
+// case _ =>
+// println("not")
+// }
+// // the smaller the Hamming distance,the greater the similarity
+// val calculateHammingDistanceUDF = udf((trainLSH: Seq[Long], testLSH: Seq[Long]) => {
+// require(trainLSH.length == testLSH.length, "The input sequences must have the same length")
+// trainLSH.zip(testLSH).map { case (train, test) =>
+// java.lang.Long.bitCount(train ^ test)
+// }.sum
+// })
+// // the smaller the Euclidean distance,the greater the similarity
+// val calculateEuclideanDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => {
+// require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length")
+// sqrt(trainEmbedding.zip(testEmbedding).map{case (train,test) =>
+// pow(train - test,2) }.sum)
+// })
+// //the greater the Cosine distance,the greater the similarity
+// val calculateCosineDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => {
+// require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length")
+// trainEmbedding.zip(testEmbedding).map{case (train,test) =>
+// train * test}.sum / (sqrt(trainEmbedding.map{train => train * train}.sum) * sqrt(testEmbedding.map{test => test * test}.sum))
+// })
+// //the smaller the Jaccard distance,the greater the similarity
+// val calculateJaccardDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => {
+// require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length")
+// val anb = testEmbedding.intersect(trainEmbedding).distinct
+// val aub = testEmbedding.union(trainEmbedding).distinct
+// val jaccardCoefficient = anb.length.toDouble / aub.length
+// 1 - jaccardCoefficient
+// })
+// spark.udf.register("calculateHammingDistance",calculateHammingDistanceUDF)
+// spark.udf.register("calculateEuclideanDistance",calculateEuclideanDistanceUDF)
+// spark.udf.register("calculateCosineDistance",calculateCosineDistanceUDF)
+// spark.udf.register("calculateJaccardDistance",calculateJaccardDistanceUDF)
+//// println(float2DDataNeighbor.length)
+// trainData match {
+// case float2DData:Array[Array[Float]] =>
+// val classIds = (1 to float2DData.length).toArray
+// val schema = StructType(Array(
+// StructField("IndexId",IntegerType,true),
+// StructField("Embedding",ArrayType(FloatType),true),
+// StructField("LSH",ArrayType(LongType),true)
+// ))
+// val rows = float2DData.zip(classIds).map {
+// case (embedding,indexId)=>
+// Row(indexId,embedding,null)
+// }
+// val df = spark.createDataFrame(spark.sparkContext.parallelize(rows),schema)
+// df.write.format("lakesoul")
+// .option("hashPartitions", "IndexId")
+// .option("hashBucketNum", 4)
+// .option(LakeSoulOptions.SHORT_TABLE_NAME,"trainData")
+// .mode("Overwrite").save(trainPath)
+//// val startTime1 = System.nanoTime()
+// val lakeSoulTable = LakeSoulTable.forPath(trainPath)
+// lakeSoulTable.compaction()
+//
+// testData match {
+// case float2DTestData:Array[Array[Float]] =>
+// val classIdsTest = (1 to float2DTestData.length).toArray
+// val schemaTest = StructType(Array(
+// StructField("IndexId",IntegerType,true),
+// StructField("Embedding",ArrayType(FloatType),true),
+// StructField("LSH",ArrayType(LongType),true)
+// ))
+// val rowsTest = float2DTestData.zip(classIdsTest).map{
+// case (embedding,indexId) =>
+// Row(indexId,embedding,null)
+// }
+//
+// val num = 50
+// val dfTest = spark.createDataFrame(spark.sparkContext.parallelize(rowsTest),schemaTest).limit(num)
+// dfTest.write.format("lakesoul")
+// .option("hashPartitions","IndexId")
+// .option("hashBucketNum",4)
+// .option(LakeSoulOptions.SHORT_TABLE_NAME,"testData")
+// .mode("Overwrite").save(testPath)
+// val lakeSoulTableTest = LakeSoulTable.forPath(testPath)
+//
+// lakeSoulTableTest.compaction()
+//// val endTime1 = System.nanoTime()
+//// val duration1 = (endTime1 - startTime1).nanos
+//// println(s"time:${duration1.toMillis}")
+//// val lshTrain = sql("select LSH from trainData")
+//// val lshTest = sql("select LSH from testData")
+//
+//// val arr = Array(1,5,10,20,40,60,80,100,150,200,250,300)
+//// for(n <- arr) {
+// val n = 300
+// val topk = 100
+// val topkFirst = n * topk
+//
+// // val result = sql("select testData.IndexId as indexId,trainData.LSH as trainLSH,testData.LSH as testLSH," +
+// // "calculateHammingDistance(testData.LSH,trainData.LSH) AS hamming_distance " +
+// // "from testData " +
+// // "cross join trainData " +
+// // "order by indexId,hamming_distance")
+//
+// // val result = spark.sql(s"""
+// // SELECT *
+// // FROM (
+// // SELECT
+// // testData.IndexId AS indexIdTest,
+// // trainData.IndexId AS indexIdTrain,
+// // trainData.LSH AS trainLSH,
+// // testData.LSH AS testLSH,
+// // calculateHammingDistance(testData.LSH, trainData.LSH) AS hamming_distance,
+// // ROW_NUMBER() OVER (PARTITION BY testData.IndexId ORDER BY calculateHammingDistance(testData.LSH, trainData.LSH) asc) AS rank
+// // FROM testData
+// // CROSS JOIN trainData
+// // ) ranked
+// // WHERE rank <= $topk
+// // """).groupBy("indexIdTest").agg(collect_list("indexIdTrain").alias("indexIdTrainList"))
+// val startTime = System.nanoTime()
+// val result = spark.sql(
+// s"""
+// SELECT *
+// FROM (
+// SELECT
+// testData.IndexId AS indexIdTest,
+// trainData.IndexId AS indexIdTrain,
+// testData.Embedding as EmbeddingTest,
+// trainData.Embedding as EmbeddingTrain,
+// ROW_NUMBER() OVER (PARTITION BY testData.IndexId ORDER BY calculateHammingDistance(testData.LSH, trainData.LSH) asc) AS rank
+// FROM testData
+// CROSS JOIN trainData
+// ) ranked
+// WHERE rank <= $topkFirst
+// """)
+// result.createOrReplaceTempView("rank")
+// val reResult = spark.sql(
+// s"""
+// SELECT *
+// FROM (
+// SELECT
+// rank.indexIdTest,
+// rank.indexIDTrain,
+// ROW_NUMBER() OVER(PARTITION BY rank.indexIdTest ORDER BY calculateEuclideanDistance(rank.EmbeddingTest,rank.EmbeddingTrain) asc) AS reRank
+// FROM rank
+// ) reRanked
+// WHERE reRank <= $topk
+// """).groupBy("indexIdTest").agg(collect_list("indexIdTrain").alias("indexIdTrainList"))
+//
+//
+// val endTime = System.nanoTime()
+// val duration = (endTime - startTime).nanos
+// println(s"time for query n4topk ${n} :${duration.toMillis} milliseconds")
+//
+// val startTime2 = System.nanoTime()
+//
+//// val (totalRecall, count) = reResult.map(row => {
+//// val indexIdTest = row.getAs[Int]("indexIdTest")
+//// val indexIdTrainList: Array[Int] = row.getAs[Seq[Int]]("indexIdTrainList").toArray
+//// val updatedList = indexIdTrainList.map(_ - 1)
+//// val count = float2DDataNeighbor(indexIdTest - 1).take(topk).count(updatedList.contains)
+//// val recall = (count * 1.0 / topk)
+//// (recall, 1)
+//// }).reduce((acc1, acc2) => {
+//// (acc1._1 + acc2._1, acc1._2 + acc2._2)
+//// })
+//// println(totalRecall / count)
+// val endTime2 = System.nanoTime()
+// val duration2 = (endTime2 - startTime2).nanos
+// println(s"time for sort:${duration2.toMillis} milliseconds")
+//// }
+// }
+// case _ =>
+// println("unexpected data type")
+// case _ =>
+// println("unexpected data type")
+// }
+// }
+// finally {
+//
+// }
+//
+// }
+
test("merge into table with hash partition -- supported case") {
initHashTable()
withViewNamed(Seq((20201102, 4, 5)).toDF("range", "hash", "value"), "source_table") {
diff --git a/rust/Cargo.lock b/rust/Cargo.lock
index a2b4321d3..4e91a2ecc 100644
--- a/rust/Cargo.lock
+++ b/rust/Cargo.lock
@@ -1,6 +1,6 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
-version = 3
+version = 4
[[package]]
name = "addr2line"
@@ -818,6 +818,25 @@ dependencies = [
"cfg-if",
]
+[[package]]
+name = "crossbeam-deque"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
+dependencies = [
+ "crossbeam-epoch",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-epoch"
+version = "0.9.18"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
+dependencies = [
+ "crossbeam-utils",
+]
+
[[package]]
name = "crossbeam-utils"
version = "0.8.20"
@@ -1836,12 +1855,14 @@ dependencies = [
"hex",
"lazy_static",
"log",
+ "ndarray",
"object_store",
"parking_lot",
"parquet",
"prost",
"proto",
"rand",
+ "rayon",
"serde",
"serde_json",
"smallvec",
@@ -2072,6 +2093,16 @@ dependencies = [
"regex-automata 0.1.10",
]
+[[package]]
+name = "matrixmultiply"
+version = "0.3.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
+dependencies = [
+ "autocfg",
+ "rawpointer",
+]
+
[[package]]
name = "md-5"
version = "0.10.6"
@@ -2148,6 +2179,19 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
+[[package]]
+name = "ndarray"
+version = "0.15.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
+dependencies = [
+ "matrixmultiply",
+ "num-complex",
+ "num-integer",
+ "num-traits",
+ "rawpointer",
+]
+
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
@@ -2712,6 +2756,32 @@ dependencies = [
"getrandom",
]
+[[package]]
+name = "rawpointer"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
+
+[[package]]
+name = "rayon"
+version = "1.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
+dependencies = [
+ "either",
+ "rayon-core",
+]
+
+[[package]]
+name = "rayon-core"
+version = "1.12.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2"
+dependencies = [
+ "crossbeam-deque",
+ "crossbeam-utils",
+]
+
[[package]]
name = "redox_syscall"
version = "0.5.3"
diff --git a/rust/lakesoul-io/Cargo.toml b/rust/lakesoul-io/Cargo.toml
index d1b12e5cd..f9558ef18 100644
--- a/rust/lakesoul-io/Cargo.toml
+++ b/rust/lakesoul-io/Cargo.toml
@@ -47,6 +47,10 @@ env_logger = "0.11"
hex = "0.4"
dhat = { version="0.3.3", optional = true }
async-recursion = "1.1.1"
+ndarray = "0.15.6"
+#hdf5 = {version = "0.8.1"}
+rayon = "1.10.0"
+
[features]
hdfs = ["dep:hdrs", "dep:hdfs-sys"]
diff --git a/rust/lakesoul-io/src/lakesoul_io_config.rs b/rust/lakesoul-io/src/lakesoul_io_config.rs
index 4230e6cc7..eeeb3ed92 100644
--- a/rust/lakesoul-io/src/lakesoul_io_config.rs
+++ b/rust/lakesoul-io/src/lakesoul_io_config.rs
@@ -51,6 +51,9 @@ pub static OPTION_KEY_HASH_BUCKET_ID: &str = "hash_bucket_id";
pub static OPTION_KEY_CDC_COLUMN: &str = "cdc_column";
pub static OPTION_KEY_IS_COMPACTED: &str = "is_compacted";
pub static OPTION_KEY_MAX_FILE_SIZE: &str = "max_file_size";
+pub static OPTION_KEY_IS_LSH: &str = "is_lsh";
+pub static OPTION_KEY_NBITS: &str = "nbits";
+pub static OPTION_KEY_D: &str= "d";
#[derive(Debug, Derivative)]
#[derivative(Default, Clone)]
@@ -130,6 +133,10 @@ pub struct LakeSoulIOConfig {
// max file size of bytes
#[derivative(Default(value = "None"))]
pub(crate) max_file_size: Option,
+
+ // the seed for rng
+ #[derivative(Default(value = "1234"))]
+ pub(crate) seed: u64,
}
impl LakeSoulIOConfig {
@@ -188,6 +195,18 @@ impl LakeSoulIOConfig {
pub fn is_compacted(&self) -> bool {
self.option(OPTION_KEY_IS_COMPACTED).map_or(false, |x| x.eq("true"))
}
+
+ pub fn is_lsh(&self) -> bool {
+ self.option(OPTION_KEY_IS_LSH).map_or(false,|x| x.eq("true"))
+ }
+
+ pub fn nbits(&self) -> Option{
+ self.option(OPTION_KEY_NBITS).map(|x| x.parse().unwrap())
+ }
+
+ pub fn d(&self) -> Option{
+ self.option(OPTION_KEY_D).map(|x| x.parse().unwrap())
+ }
}
#[derive(Derivative, Debug)]
@@ -348,6 +367,11 @@ impl LakeSoulIOConfigBuilder {
self
}
+ pub fn with_seed(mut self,seed:u64) -> Self {
+ self.config.seed = seed;
+ self
+ }
+
pub fn build(self) -> LakeSoulIOConfig {
self.config
}
@@ -517,7 +541,12 @@ fn register_object_store(path: &str, config: &mut LakeSoulIOConfig, runtime: &Ru
Ok(joined_path)
}
}
+ // "file" => Ok(path.to_owned()),
"file" => Ok(path.to_owned()),
+ // Support Windows drive letter paths like "c:" or "d:"
+ scheme if scheme.len() == 1 && scheme.chars().next().unwrap().is_ascii_alphabetic() => {
+ Ok(format!("file://{}", path))
+ },
_ => Err(ObjectStore(object_store::Error::NotSupported {
source: "FileSystem not supported".into(),
})),
@@ -652,5 +681,10 @@ mod tests {
"file:///some/absolute/local/file2".to_string(),
]
);
+ let mut lakesoulconfigbuilder = LakeSoulIOConfigBuilder::from(conf.clone());
+ let conf = lakesoulconfigbuilder.with_d(32 as u64).with_nbits(64 as u64).build();
+ assert_eq!(conf.seed,1234 as u64);
+ assert_eq!(conf.d,Some(32));
+ assert_eq!(conf.nbits,Some(64));
}
}
diff --git a/rust/lakesoul-io/src/lakesoul_writer.rs b/rust/lakesoul-io/src/lakesoul_writer.rs
index 370a461d4..5d95ac711 100644
--- a/rust/lakesoul-io/src/lakesoul_writer.rs
+++ b/rust/lakesoul-io/src/lakesoul_writer.rs
@@ -4,8 +4,9 @@
use std::borrow::Borrow;
use std::collections::HashMap;
-use std::sync::Arc;
+use std:: ptr;
+use arrow::datatypes::{DataType, Field};
use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion_common::{DataFusionError, Result};
@@ -13,6 +14,12 @@ use rand::distributions::DistString;
use tokio::runtime::Runtime;
use tokio::sync::Mutex;
use tracing::debug;
+use rand::{Rng,SeedableRng,rngs::StdRng};
+use ndarray::{concatenate, s, Array2, Axis,ArrayView2};
+use arrow::array::{Array as OtherArray, Float64Array, ListArray,Float32Array,Int64Array,GenericListArray};
+use arrow::buffer::OffsetBuffer;
+use std::sync::Arc;
+
use crate::async_writer::{AsyncBatchWriter, MultiPartAsyncWriter, PartitioningAsyncWriter, SortAsyncWriter, WriterFlushResult};
use crate::helpers::{get_batch_memory_size, get_file_exist_col};
@@ -36,11 +43,10 @@ impl SyncSendableMutableLakeSoulWriter {
pub fn try_new(config: LakeSoulIOConfig, runtime: Runtime) -> Result {
let runtime = Arc::new(runtime);
runtime.clone().block_on(async move {
- let writer_config = config.clone();
let mut config = config.clone();
+ let writer_config = config.clone();
let writer = Self::create_writer(writer_config).await?;
let schema = writer.schema();
-
if let Some(mem_limit) = config.mem_limit() {
if config.use_dynamic_partition {
config.max_file_size = Some((mem_limit as f64 * 0.15) as u64);
@@ -123,7 +129,31 @@ impl SyncSendableMutableLakeSoulWriter {
// for ffi callers
pub fn write_batch(&mut self, record_batch: RecordBatch) -> Result<()> {
let runtime = self.runtime.clone();
- runtime.block_on(async move { self.write_batch_async(record_batch, false).await })
+ if record_batch.num_rows() == 0{
+ runtime.block_on(async move { self.write_batch_async(record_batch, false).await })
+ }
+ else{
+ if self.config.is_lsh() {
+ let projection: ListArray= if let Some(array) = record_batch.column_by_name("Embedding") {
+ let embedding = array.as_any().downcast_ref::().unwrap();
+ let projection_result:Result = self.lsh(&Some(embedding.clone()));
+ projection_result.unwrap().into()
+
+ } else {
+ eprintln!("there is no column named Embedding");
+ return Ok(()) ;
+ };
+
+ let mut new_columns = record_batch.columns().to_vec();
+ new_columns[record_batch.schema().index_of("LSH").unwrap()] = Arc::new(projection.clone());
+ let new_record_batch = RecordBatch::try_new(self.config.target_schema(),new_columns).unwrap();
+
+ runtime.block_on(async move { self.write_batch_async(new_record_batch, false).await })
+ }
+ else{
+ runtime.block_on(async move { self.write_batch_async(record_batch, false).await })
+ }
+ }
}
#[async_recursion::async_recursion(?Send)]
@@ -263,29 +293,215 @@ impl SyncSendableMutableLakeSoulWriter {
pub fn get_schema(&self) -> SchemaRef {
self.schema.clone()
}
+
+ // generate random digit with fixed seed
+ fn create_rng_with_seed(&self) -> StdRng {
+ StdRng::seed_from_u64(self.config.seed)
+ }
+
+ // generate random planes
+ fn generate_random_array(&self) -> Result,String>{
+ match self.config.nbits() {
+ Some(nbits) if nbits > 0 => {
+ match self.config.d() {
+ Some(d) if d > 0 => {
+ let mut rng = self.create_rng_with_seed();
+// assert!(d >= nbits,"the dimension of the embedding must be greater than nbits");
+ let random_array = Array2::from_shape_fn((nbits as usize, d as usize), |_| rng.gen_range(-1.0..1.0));
+ Ok(random_array)
+ }
+ Some(_) => Err("the dimension you input in the config must be greater than 0".to_string()),
+ None => Err("the dimension you input in the config is None".to_string()),
+ }
+ }
+ Some(_) => Err("the number of bits used for binary encoding must be greater than 0".to_string()),
+ None => Err("the number of bits used for binary encoding must be greater than 0".to_string()),
+ }
+ }
+
+ // project the input data
+ fn project(&self,input_data:&ListArray,random_plans:&Result,String>) -> Result,String>{
+ let list_len = input_data.len();
+ assert!(list_len > 0,"the length of input data must be large than 0");
+ let dimension_len = input_data.value(0).len();
+
+ let input_values = if let Some(values) = input_data.values().as_any().downcast_ref::(){
+ let float64_values: Vec = values.iter().map(|x| x.unwrap() as f64).collect();
+ Float64Array::from(float64_values)
+ } else if let Some(values) = input_data.values().as_any().downcast_ref::(){
+ values.clone()
+ }
+ else {
+ return Err("Unsupported data type in ListArray.".to_string());
+ };
+
+ let mut re_array2 = Array2::::zeros((list_len,dimension_len));
+
+ unsafe {
+ let data_ptr = input_values.values().as_ptr();
+ let data_size = list_len * dimension_len;
+ ptr::copy_nonoverlapping(data_ptr,re_array2.as_mut_ptr(),data_size);
+ }
+ match random_plans {
+ Ok(random_array) => {
+ assert!(re_array2.shape()[1] == random_array.shape()[1],"the dimension corresponding to the matrix must be the same");
+// let final_result = re_array2.dot(&random_array.t());
+ let batch_size = 1000;
+ let num_batches = re_array2.shape()[0] / batch_size;
+ let remaining_rows = re_array2.shape()[0] % batch_size;
+ let mut result = vec![];
+
+ for batch_idx in 0..num_batches{
+ let batch_start = batch_idx * batch_size;
+ let batch_end = batch_start + batch_size;
+
+ let current_batch = re_array2.slice(s![batch_start..batch_end,..]);
+ let random_projection = current_batch.dot(&random_array.t());
+
+ result.push(random_projection);
+ }
+
+ if remaining_rows > 0{
+ let batch_start = num_batches * batch_size;
+ let batch_end = batch_start + remaining_rows;
+
+ let remaining_batch = re_array2.slice(s![batch_start..batch_end,..]);
+
+ let random_projection = remaining_batch.dot(&random_array.t());
+
+ result.push(random_projection);
+ }
+
+ let result_views: Vec> = result.iter().map(|arr| ArrayView2::from(arr)).collect();
+
+
+ let final_result = concatenate(Axis(0),&result_views).expect("Failed to concatenate results");
+
+ // println!("{:}",end);
+
+ Ok(final_result)
+ }
+ Err(e) => {
+ eprintln!("Error:{}",e);
+ Err(e.to_string())
+ }
+ }
+ }
+ // add the input data with their projection
+ pub fn lsh(&self,input_embedding:&Option) -> Result
+ where
+ {
+ match input_embedding {
+ Some(data) => {
+ let random_plans = self.generate_random_array();
+ let data_projection = self.project(data,&random_plans).unwrap();
+ match Ok(data_projection) {
+ Ok(mut projection) => {
+ projection.mapv_inplace(|x| if x >= 0.0 {1.0} else {0.0});
+ let convert:Vec> = Self::convert_array_to_u64_vec(&projection);
+ Ok(Self::convert_vec_to_byte_u64(convert))
+ }
+ Err(e) => {
+ eprintln!("Error:{}",e);
+ Err(e)
+ }
+ }
+ }
+ None => {
+ Err("the input data is None".to_string())
+ }
+ }
+ }
+
+ fn convert_vec_to_byte_u64(array:Vec>) -> ListArray {
+ let field = Arc::new(Field::new("element", DataType::Int64,true));
+ let values = Int64Array::from(array.iter().flatten().map(|&x| x as i64).collect::>());
+ let mut offsets = vec![];
+ for subarray in array{
+ let current_offset = subarray.len() as usize;
+ offsets.push(current_offset);
+ }
+ let offsets_buffer = OffsetBuffer::from_lengths(offsets);
+ let list_array = GenericListArray::try_new(field,offsets_buffer,Arc::new(values),None).expect("can not list_array");
+ list_array
+
+ }
+
+ fn convert_array_to_u64_vec(array:&Array2) -> Vec>
+ where
+ T: TryFrom + Copy,
+ >::Error: std::fmt::Debug,
+ {
+ let bianry_encode:Vec> = array
+ .axis_iter(ndarray::Axis(0))
+ .map(|row|{
+ let mut results = Vec::new();
+ let mut acc = 0u64;
+
+ for(i,&bit) in row.iter().enumerate(){
+ acc = (acc << 1) | bit as u64;
+ if(i + 1) % 64 == 0{
+ results.push(acc);
+ acc = 0;
+ }
+ }
+ if row.len() % 64 != 0{
+ results.push(acc);
+ }
+ results
+ })
+ .collect();
+
+ bianry_encode
+ .into_iter()
+ .map(|inner_vec|{
+ inner_vec
+ .into_iter()
+ .map(|x| T::try_from(x).unwrap())
+ .collect()
+ }).collect()
+ }
+
}
#[cfg(test)]
mod tests {
+ use arrow_array::builder;
+ use datafusion::catalog::schema;
+ use hdf5::File as OtherFile;
+ use hdf5::Group;
+ use parquet::arrow::ArrowWriter;
+ use parquet::column;
+ use parquet::file::properties::WriterProperties;
use crate::{
lakesoul_io_config::{LakeSoulIOConfigBuilder, OPTION_KEY_MEM_LIMIT},
lakesoul_reader::LakeSoulReader,
lakesoul_writer::{AsyncBatchWriter, MultiPartAsyncWriter, SyncSendableMutableLakeSoulWriter},
};
+
use arrow::{
- array::{ArrayRef, Int64Array},
+ array::{ArrayRef, Int64Array,FixedSizeListArray,ArrayData},
record_batch::RecordBatch,
};
use arrow_array::{Array, StringArray};
use arrow_schema::{DataType, Field, Schema};
- use datafusion::error::Result;
+ use datafusion::{error::Result, physical_expr::math_expressions};
use parquet::arrow::arrow_reader::ParquetRecordBatchReader;
use rand::{distributions::DistString, Rng};
use std::{fs::File, sync::Arc};
use tokio::{runtime::Builder, time::Instant};
use tracing_subscriber::layer::SubscriberExt;
+ use arrow::buffer::{Buffer, NullBuffer};
use super::SortAsyncWriter;
+ use std::env;
+ use std::path::Path;
+ use ndarray::{Array2,s,Dim};
+ use crate::helpers::get_batch_memory_size;
+ use parquet::file::reader::{FileReader,SerializedFileReader};
+ use arrow::compute::sort;
+ use arrow::compute::sort_to_indices;
+ use arrow::compute::SortOptions;
#[test]
fn test_parquet_async_write() -> Result<()> {
@@ -730,4 +946,5 @@ mod tests {
// assert_eq!(num_rows * num_batch, actual_batch.num_rows());
Ok(())
}
+
}