Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LSH implementation for vector embedding indexing #568

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 57 additions & 49 deletions lakesoul-spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ SPDX-License-Identifier: Apache-2.0
</exclusion>
</exclusions>
</dependency>
<!-- https://mvnrepository.com/artifact/io.jhdf/jhdf -->
<dependency>
<groupId>io.jhdf</groupId>
<artifactId>jhdf</artifactId>
<version>0.6.10</version>
</dependency>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test scope is required




<dependency>
<groupId>org.apache.spark</groupId>
Expand Down Expand Up @@ -322,55 +330,55 @@ SPDX-License-Identifier: Apache-2.0
</dependency>

<!-- for test only. we don't rely on gluten during package and runtime -->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>backends-velox</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- <scope>${local.scope}</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>gluten-core</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- <exclusions>-->
<!-- <exclusion>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>spark-sql-columnar-shims-spark32</artifactId>-->
<!-- </exclusion>-->
<!-- </exclusions>-->
<!-- <scope>${local.scope}</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>gluten-core</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- <exclusions>-->
<!-- <exclusion>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>spark-sql-columnar-shims-spark32</artifactId>-->
<!-- </exclusion>-->
<!-- </exclusions>-->
<!-- <classifier>tests</classifier>-->
<!-- <scope>test</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>gluten-data</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- <scope>${local.scope}</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>spark-sql-columnar-shims-common</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- <scope>${local.scope}</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>backends-velox</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- <scope>${local.scope}</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>gluten-core</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- <exclusions>-->
<!-- <exclusion>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>spark-sql-columnar-shims-spark32</artifactId>-->
<!-- </exclusion>-->
<!-- </exclusions>-->
<!-- <scope>${local.scope}</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>gluten-core</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- <exclusions>-->
<!-- <exclusion>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>spark-sql-columnar-shims-spark32</artifactId>-->
<!-- </exclusion>-->
<!-- </exclusions>-->
<!-- <classifier>tests</classifier>-->
<!-- <scope>test</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>gluten-data</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- <scope>${local.scope}</scope>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>spark-sql-columnar-shims-common</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- <scope>${local.scope}</scope>-->
<!-- </dependency>-->
<!-- currently this jar is missing on maven repo -->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>spark-sql-columnar-shims-spark33</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- </dependency>-->
<!-- <dependency>-->
<!-- <groupId>io.glutenproject</groupId>-->
<!-- <artifactId>spark-sql-columnar-shims-spark33</artifactId>-->
<!-- <version>${gluten.version}</version>-->
<!-- </dependency>-->
</dependencies>

<profiles>
Expand Down Expand Up @@ -646,4 +654,4 @@ SPDX-License-Identifier: Apache-2.0

</plugins>
</build>
</project>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import scala.collection.JavaConverters.{asScalaBufferConverter, seqAsJavaListCon
class NativeParquetCompactionColumnarOutputWriter(path: String, dataSchema: StructType, timeZoneId: String,
context: TaskAttemptContext)
extends NativeParquetOutputWriter(path, dataSchema, timeZoneId, context) {

override def write(row: InternalRow): Unit = {
if (!row.isInstanceOf[ArrowFakeRow]) {
throw new IllegalStateException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class NativeParquetOutputWriter(val path: String, dataSchema: StructType, timeZo

protected val nativeIOWriter: NativeIOWriter = new NativeIOWriter(arrowSchema)



GlutenUtils.setArrowAllocator(nativeIOWriter)
nativeIOWriter.setRowGroupRowNumber(NATIVE_IO_WRITE_MAX_ROW_GROUP_SIZE)
nativeIOWriter.addFile(path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

package org.apache.spark.sql.lakesoul.commands

import com.dmetasoul.lakesoul.meta.LakeSoulOptions
import com.dmetasoul.lakesoul.tables.LakeSoulTable
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SparkSession}
import org.apache.spark.sql.lakesoul.test.{LakeSoulSQLCommandTest, LakeSoulTestBeforeAndAfterEach, LakeSoulTestSparkSession, LakeSoulTestUtils}
import org.apache.spark.util.Utils
Expand All @@ -15,6 +17,19 @@ import org.apache.spark.sql.lakesoul.sources.LakeSoulSQLConf
import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession}
import org.junit.runner.RunWith
import org.scalatestplus.junit.JUnitRunner
import io.jhdf.HdfFile
import io.jhdf.api.Dataset
import org.apache.commons.lang3.ArrayUtils
import org.apache.spark.sql.types.{ArrayType, ByteType, FloatType, IntegerType, LongType, StructField, StructType}
import org.apache.commons.lang3.ArrayUtils
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.functions.{collect_list, sum, udf}

import java.nio.file.Paths
import scala.collection.mutable.ListBuffer
import scala.concurrent.duration.DurationLong
import scala.math.{pow, sqrt}

@RunWith(classOf[JUnitRunner])
class MergeIntoSQLSuite extends QueryTest
Expand Down Expand Up @@ -63,6 +78,208 @@ class MergeIntoSQLSuite extends QueryTest
}
}

test("test lsh"){
val filepath = "/Users/beidu/Documents/dataset/glove-200-angular.hdf5"
val trainPath = "/Users/beidu/Documents/LakeSoul/train"
val testPath = "/Users/beidu/Documents/LakeSoul/test"

println(filepath)
val spark = SparkSession.builder
.appName("Array to RDD")
.master("local[*]")
.getOrCreate()
try{
val hdfFile = new HdfFile(Paths.get(filepath))

val trainDataset = hdfFile.getDatasetByPath("train")
val testDataset = hdfFile.getDatasetByPath("test")
val neighborDataset = hdfFile.getDatasetByPath("neighbors")
val trainData = trainDataset.getData()
val testData = testDataset.getData()
val neighborData = neighborDataset.getData()
println(trainData)
var float2DDataNeighbor: Array[Array[Int]] = null
neighborData match {
case data:Array[Array[Int]] =>
float2DDataNeighbor = data
case _ =>
println("not")
}
// the smaller the Hamming distance,the greater the similarity
val calculateHammingDistanceUDF = udf((trainLSH: Seq[Long], testLSH: Seq[Long]) => {
require(trainLSH.length == testLSH.length, "The input sequences must have the same length")
trainLSH.zip(testLSH).map { case (train, test) =>
java.lang.Long.bitCount(train ^ test)
}.sum
})
// the smaller the Euclidean distance,the greater the similarity
val calculateEuclideanDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => {
require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length")
sqrt(trainEmbedding.zip(testEmbedding).map{case (train,test) =>
pow(train - test,2) }.sum)
})
//the greater the Cosine distance,the greater the similarity
val calculateCosineDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => {
require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length")
trainEmbedding.zip(testEmbedding).map{case (train,test) =>
train * test}.sum / (sqrt(trainEmbedding.map{train => train * train}.sum) * sqrt(testEmbedding.map{test => test * test}.sum))
})
//the smaller the Jaccard distance,the greater the similarity
val calculateJaccardDistanceUDF = udf((trainEmbedding:Seq[Float],testEmbedding:Seq[Float]) => {
require(testEmbedding.length == trainEmbedding.length,"The input sequences must have the same length")
val anb = testEmbedding.intersect(trainEmbedding).distinct
val aub = testEmbedding.union(trainEmbedding).distinct
val jaccardCoefficient = anb.length.toDouble / aub.length
1 - jaccardCoefficient
})
spark.udf.register("calculateHammingDistance",calculateHammingDistanceUDF)
spark.udf.register("calculateEuclideanDistance",calculateEuclideanDistanceUDF)
spark.udf.register("calculateCosineDistance",calculateCosineDistanceUDF)
spark.udf.register("calculateJaccardDistance",calculateJaccardDistanceUDF)
// println(float2DDataNeighbor.length)
trainData match {
case float2DData:Array[Array[Float]] =>
val classIds = (1 to float2DData.length).toArray
val schema = StructType(Array(
StructField("IndexId",IntegerType,true),
StructField("Embedding",ArrayType(FloatType),true),
StructField("LSH",ArrayType(LongType),true)
))
val rows = float2DData.zip(classIds).map {
case (embedding,indexId)=>
Row(indexId,embedding,null)
}
val df = spark.createDataFrame(spark.sparkContext.parallelize(rows),schema)
df.write.format("lakesoul")
.option("hashPartitions", "IndexId")
.option("hashBucketNum", 4)
.option(LakeSoulOptions.SHORT_TABLE_NAME,"trainData")
.mode("Overwrite").save(trainPath)
// val startTime1 = System.nanoTime()
val lakeSoulTable = LakeSoulTable.forPath(trainPath)
lakeSoulTable.compaction()

testData match {
case float2DTestData:Array[Array[Float]] =>
val classIdsTest = (1 to float2DTestData.length).toArray
val schemaTest = StructType(Array(
StructField("IndexId",IntegerType,true),
StructField("Embedding",ArrayType(FloatType),true),
StructField("LSH",ArrayType(LongType),true)
))
val rowsTest = float2DTestData.zip(classIdsTest).map{
case (embedding,indexId) =>
Row(indexId,embedding,null)
}

val num = 50
val dfTest = spark.createDataFrame(spark.sparkContext.parallelize(rowsTest),schemaTest).limit(num)
dfTest.write.format("lakesoul")
.option("hashPartitions","IndexId")
.option("hashBucketNum",4)
.option(LakeSoulOptions.SHORT_TABLE_NAME,"testData")
.mode("Overwrite").save(testPath)
val lakeSoulTableTest = LakeSoulTable.forPath(testPath)

lakeSoulTableTest.compaction()
// val endTime1 = System.nanoTime()
// val duration1 = (endTime1 - startTime1).nanos
// println(s"time:${duration1.toMillis}")
// val lshTrain = sql("select LSH from trainData")
// val lshTest = sql("select LSH from testData")

// val arr = Array(1,5,10,20,40,60,80,100,150,200,250,300)
// for(n <- arr) {
val n = 300
val topk = 100
val topkFirst = n * topk

// val result = sql("select testData.IndexId as indexId,trainData.LSH as trainLSH,testData.LSH as testLSH," +
// "calculateHammingDistance(testData.LSH,trainData.LSH) AS hamming_distance " +
// "from testData " +
// "cross join trainData " +
// "order by indexId,hamming_distance")

// val result = spark.sql(s"""
// SELECT *
// FROM (
// SELECT
// testData.IndexId AS indexIdTest,
// trainData.IndexId AS indexIdTrain,
// trainData.LSH AS trainLSH,
// testData.LSH AS testLSH,
// calculateHammingDistance(testData.LSH, trainData.LSH) AS hamming_distance,
// ROW_NUMBER() OVER (PARTITION BY testData.IndexId ORDER BY calculateHammingDistance(testData.LSH, trainData.LSH) asc) AS rank
// FROM testData
// CROSS JOIN trainData
// ) ranked
// WHERE rank <= $topk
// """).groupBy("indexIdTest").agg(collect_list("indexIdTrain").alias("indexIdTrainList"))
val startTime = System.nanoTime()
val result = spark.sql(
s"""
SELECT *
FROM (
SELECT
testData.IndexId AS indexIdTest,
trainData.IndexId AS indexIdTrain,
testData.Embedding as EmbeddingTest,
trainData.Embedding as EmbeddingTrain,
ROW_NUMBER() OVER (PARTITION BY testData.IndexId ORDER BY calculateHammingDistance(testData.LSH, trainData.LSH) asc) AS rank
FROM testData
CROSS JOIN trainData
) ranked
WHERE rank <= $topkFirst
""")
result.createOrReplaceTempView("rank")
val reResult = spark.sql(
s"""
SELECT *
FROM (
SELECT
rank.indexIdTest,
rank.indexIDTrain,
ROW_NUMBER() OVER(PARTITION BY rank.indexIdTest ORDER BY calculateEuclideanDistance(rank.EmbeddingTest,rank.EmbeddingTrain) asc) AS reRank
FROM rank
) reRanked
WHERE reRank <= $topk
""").groupBy("indexIdTest").agg(collect_list("indexIdTrain").alias("indexIdTrainList"))


val endTime = System.nanoTime()
val duration = (endTime - startTime).nanos
println(s"time for query n4topk ${n} :${duration.toMillis} milliseconds")

val startTime2 = System.nanoTime()

// val (totalRecall, count) = reResult.map(row => {
// val indexIdTest = row.getAs[Int]("indexIdTest")
// val indexIdTrainList: Array[Int] = row.getAs[Seq[Int]]("indexIdTrainList").toArray
// val updatedList = indexIdTrainList.map(_ - 1)
// val count = float2DDataNeighbor(indexIdTest - 1).take(topk).count(updatedList.contains)
// val recall = (count * 1.0 / topk)
// (recall, 1)
// }).reduce((acc1, acc2) => {
// (acc1._1 + acc2._1, acc1._2 + acc2._2)
// })
// println(totalRecall / count)
val endTime2 = System.nanoTime()
val duration2 = (endTime2 - startTime2).nanos
println(s"time for sort:${duration2.toMillis} milliseconds")
// }
}
case _ =>
println("unexpected data type")
case _ =>
println("unexpected data type")
}
}
finally {

}

}

test("merge into table with hash partition -- supported case") {
initHashTable()
withViewNamed(Seq((20201102, 4, 5)).toDF("range", "hash", "value"), "source_table") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ public void initializeWriter() throws IOException {
assert tokioRuntimeBuilder != null;
assert ioConfigBuilder != null;

setOption("is_lsh","true");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not set option by hard-coding, pass options by configuration

setOption("nbits","400");
setOption("d","200");

tokioRuntime = libLakeSoulIO.create_tokio_runtime_from_builder(tokioRuntimeBuilder);
config = libLakeSoulIO.create_lakesoul_io_config_from_builder(ioConfigBuilder);
writer = libLakeSoulIO.create_lakesoul_writer_from_config(config, tokioRuntime);
Expand Down
Loading