From 1183230639ab8af4c3e4182bce4038bc6ec93337 Mon Sep 17 00:00:00 2001 From: ryjiang Date: Thu, 11 Apr 2024 16:26:33 +0800 Subject: [PATCH] Add nq > 1 tests for sparse vectors and upgrade protos (#297) * add more test Signed-off-by: ryjiang * add nq > 1 tests for sparse vectors Signed-off-by: ryjiang * update test version Signed-off-by: ryjiang --------- Signed-off-by: ryjiang --- milvus/utils/Bytes.ts | 32 ++++++++-- milvus/utils/Format.ts | 35 +++++++++-- package.json | 2 +- proto | 2 +- test/grpc/Basic.spec.ts | 14 +++++ test/grpc/Float16Vector.spec.ts | 15 ++++- test/grpc/MultipleVectors.spec.ts | 38 ++++++++---- test/grpc/SparseVector.array.spec.ts | 12 ++++ test/grpc/SparseVector.coo.spec.ts | 12 ++++ test/grpc/SparseVector.csr.spec.ts | 16 ++++- test/grpc/SparseVector.dict.spec.ts | 12 ++++ test/utils/Bytes.spec.ts | 11 ++++ test/utils/Format.spec.ts | 93 ++++++++++++++++++++++++++++ 13 files changed, 267 insertions(+), 27 deletions(-) diff --git a/milvus/utils/Bytes.ts b/milvus/utils/Bytes.ts index 95d09c65..3ba60b3f 100644 --- a/milvus/utils/Bytes.ts +++ b/milvus/utils/Bytes.ts @@ -58,17 +58,37 @@ export const parseBytesToFloat16Vector = (float16Bytes: Uint8Array) => { * * @returns string, 'array' | 'coo' | 'csr' | 'dict' */ -export const getSparseFloatVectorType = (vector: SparseFloatVector) => { +export const getSparseFloatVectorType = ( + vector: SparseFloatVector +): 'array' | 'coo' | 'csr' | 'dict' | 'unknown' => { if (Array.isArray(vector)) { + if (vector.length === 0) { + return 'array'; + } if (typeof vector[0] === 'number' || typeof vector[0] === 'undefined') { return 'array'; - } else { + } else if ( + (vector as SparseVectorCOO).every( + item => typeof item === 'object' && 'index' in item && 'value' in item + ) + ) { return 'coo'; + } else { + return 'unknown'; } - } else if ('indices' in vector && 'values' in vector) { + } else if ( + typeof vector === 'object' && + 'indices' in vector && + 'values' in vector + ) { return 'csr'; - } else { + } else if ( + typeof vector === 'object' && + Object.keys(vector).every(key => typeof vector[key] === 'number') + ) { return 'dict'; + } else { + return 'unknown'; } }; @@ -86,8 +106,8 @@ export const parseSparseVectorToBytes = ( // detect the format of the sparse vector const type = getSparseFloatVectorType(data); - let indices: number[]; - let values: number[]; + let indices: number[] = []; + let values: number[] = []; switch (type) { case 'array': diff --git a/milvus/utils/Format.ts b/milvus/utils/Format.ts index fece793d..ff02d0c0 100644 --- a/milvus/utils/Format.ts +++ b/milvus/utils/Format.ts @@ -34,6 +34,7 @@ import { parseBytesToFloat16Vector, parseFloat16VectorToBytes, Float16Vector, + getSparseFloatVectorType, } from '../'; /** @@ -691,10 +692,8 @@ export const buildSearchRequest = ( searchSimpleReq.vector || searchSimpleReq.data; - // make sure the vector format - if (!Array.isArray(searchingVector[0])) { - searchingVector = [searchingVector as unknown] as VectorTypes[]; - } + // format saerching vector + searchingVector = formatSearchVector(searchingVector, field.dataType!); // create search request requests.push({ @@ -842,3 +841,31 @@ export const formatSearchResult = ( return results; }; + +/** + * Formats the search vector to match a specific data type. + * @param {VectorTypes | VectorTypes[]} searchVector - The search vector or array of vectors to be formatted. + * @param {DataType} dataType - The specified data type. + * @returns {VectorTypes[]} The formatted search vector or array of vectors. + */ +export const formatSearchVector = ( + searchVector: VectorTypes | VectorTypes[], + dataType: DataType +): VectorTypes[] => { + switch (dataType) { + case DataType.FloatVector: + case DataType.BinaryVector: + case DataType.Float16Vector: + case DataType.BFloat16Vector: + if (!Array.isArray(searchVector)) { + return [searchVector] as VectorTypes[]; + } + case DataType.SparseFloatVector: + const type = getSparseFloatVectorType(searchVector as VectorTypes); + if (type !== 'unknown') { + return [searchVector] as VectorTypes[]; + } + default: + return searchVector as VectorTypes[]; + } +}; diff --git a/package.json b/package.json index 66d45a57..938f7abc 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "@zilliz/milvus2-sdk-node", "author": "ued@zilliz.com", "version": "2.3.6", - "milvusVersion": "2.4-20240327-96cec787-amd64", + "milvusVersion": "2.4-20240411-246ef454-amd64", "main": "dist/milvus", "files": [ "dist" diff --git a/proto b/proto index 55a0bcee..4c648ad1 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 55a0bceef3e286134ee67472f0fdb6e4baf84a52 +Subproject commit 4c648ad172269eed9a823d1a4d6a70d04a4486a4 diff --git a/test/grpc/Basic.spec.ts b/test/grpc/Basic.spec.ts index 05e2b13c..84b5b563 100644 --- a/test/grpc/Basic.spec.ts +++ b/test/grpc/Basic.spec.ts @@ -80,6 +80,20 @@ describe(`Basic API without database`, () => { expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); }); + it(`search nq > 1 should be successful`, async () => { + const search = await milvusClient.search({ + collection_name: COLLECTION_NAME, + data: [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ], + }); + expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); + expect(search.results.length).toEqual(2); + expect(search.results[0].length).toEqual(10); + expect(search.results[1].length).toEqual(10); + }); + it(`release and drop should be successful`, async () => { // releases const release = await milvusClient.releaseCollection({ diff --git a/test/grpc/Float16Vector.spec.ts b/test/grpc/Float16Vector.spec.ts index 79b4640d..910a362d 100644 --- a/test/grpc/Float16Vector.spec.ts +++ b/test/grpc/Float16Vector.spec.ts @@ -4,7 +4,6 @@ import { DataType, IndexType, MetricType, - parseBytesToFloat16Vector, } from '../../milvus'; import { IP, @@ -126,4 +125,18 @@ describe(`Float16 vector API testing`, () => { expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); expect(search.results.length).toBeGreaterThan(0); }); + + it(`search with float16 vector and nq > 0 should be successful`, async () => { + const search = await milvusClient.search({ + vector: [data[0].vector, data[1].vector], + collection_name: COLLECTION_NAME, + output_fields: ['id', 'vector'], + limit: 5, + }); + + // console.log('search', search); + + expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); + expect(search.results.length).toBeGreaterThan(0); + }); }); diff --git a/test/grpc/MultipleVectors.spec.ts b/test/grpc/MultipleVectors.spec.ts index 8525cd8e..b480c769 100644 --- a/test/grpc/MultipleVectors.spec.ts +++ b/test/grpc/MultipleVectors.spec.ts @@ -14,7 +14,7 @@ import { generateInsertData, } from '../tools'; -const milvusClient = new MilvusClient({ address: IP }); +const milvusClient = new MilvusClient({ address: IP, logLevel: 'info' }); const COLLECTION_NAME = GENERATE_NAME(); const dbParam = { @@ -23,7 +23,11 @@ const dbParam = { const p = { collectionName: COLLECTION_NAME, - vectorType: [DataType.FloatVector, DataType.FloatVector], + vectorType: [ + DataType.FloatVector, + DataType.FloatVector, + // DataType.Float16Vector, + ], dim: [8, 16], }; const collectionParams = genCollectionParams(p); @@ -85,6 +89,12 @@ describe(`Multiple vectors API testing`, () => { metric_type: MetricType.COSINE, index_type: IndexType.AUTOINDEX, }, + // { + // collection_name: COLLECTION_NAME, + // field_name: 'vector2', + // metric_type: MetricType.COSINE, + // index_type: IndexType.AUTOINDEX, + // }, ]); expect(indexes.error_code).toEqual(ErrorCode.SUCCESS); @@ -110,20 +120,21 @@ describe(`Multiple vectors API testing`, () => { const item = query.data[0]; expect(item.vector.length).toEqual(p.dim[0]); expect(item.vector1.length).toEqual(p.dim[1]); + // expect(item.vector2.length).toEqual(p.dim[2]); }); it(`search multiple vector collection with old search api should be successful`, async () => { // search default first vector field const search0 = await milvusClient.search({ collection_name: COLLECTION_NAME, - vector: [1, 2, 3, 4, 5, 6, 7, 8], + data: [1, 2, 3, 4, 5, 6, 7, 8], }); expect(search0.status.error_code).toEqual(ErrorCode.SUCCESS); // search specific vector field const search = await milvusClient.search({ collection_name: COLLECTION_NAME, - vector: [1, 2, 3, 4, 5, 6, 7, 8], + data: [1, 2, 3, 4, 5, 6, 7, 8], anns_field: 'vector', }); expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); @@ -131,7 +142,7 @@ describe(`Multiple vectors API testing`, () => { // search second vector field const search2 = await milvusClient.search({ collection_name: COLLECTION_NAME, - vector: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], anns_field: 'vector1', limit: 5, }); @@ -139,15 +150,16 @@ describe(`Multiple vectors API testing`, () => { expect(search2.status.error_code).toEqual(ErrorCode.SUCCESS); expect(search2.results.length).toEqual(5); - // search first vector field - const search3 = await milvusClient.search({ - collection_name: COLLECTION_NAME, - vector: [1, 2, 3, 4, 5, 6, 7, 8], - }); + // // search third vector field + // const search3 = await milvusClient.search({ + // collection_name: COLLECTION_NAME, + // data: [1, 2, 3, 4, 5, 6, 7, 8], + // anns_field: 'vector2', + // }); - expect(search3.status.error_code).toEqual(ErrorCode.SUCCESS); - expect(search3.results.length).toEqual(search.results.length); - expect(search3.results).toEqual(search.results); + // expect(search3.status.error_code).toEqual(ErrorCode.SUCCESS); + // expect(search3.results.length).toEqual(search.results.length); + // expect(search3.results).toEqual(search.results); }); it(`hybrid search with rrf ranker set should be successful`, async () => { diff --git a/test/grpc/SparseVector.array.spec.ts b/test/grpc/SparseVector.array.spec.ts index 2c7e8770..2f1364b5 100644 --- a/test/grpc/SparseVector.array.spec.ts +++ b/test/grpc/SparseVector.array.spec.ts @@ -128,4 +128,16 @@ describe(`Sparse vectors type:object API testing`, () => { expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); expect(search.results.length).toBeGreaterThan(0); }); + + it(`search with sparse vector with nq > 1 should be successful`, async () => { + const search = await milvusClient.search({ + vectors: [data[0].vector, data[1].vector], + collection_name: COLLECTION_NAME, + output_fields: ['id', 'vector'], + limit: 5, + }); + + expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); + expect(search.results.length).toEqual(2); + }); }); diff --git a/test/grpc/SparseVector.coo.spec.ts b/test/grpc/SparseVector.coo.spec.ts index 871a5d29..abb5a25e 100644 --- a/test/grpc/SparseVector.coo.spec.ts +++ b/test/grpc/SparseVector.coo.spec.ts @@ -128,4 +128,16 @@ describe(`Sparse vectors type:coo API testing`, () => { expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); expect(search.results.length).toBeGreaterThan(0); }); + + it(`search with sparse vector with nq > 1 should be successful`, async () => { + const search = await milvusClient.search({ + vectors: [data[0].vector, data[1].vector], + collection_name: COLLECTION_NAME, + output_fields: ['id', 'vector'], + limit: 5, + }); + + expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); + expect(search.results.length).toEqual(2); + }); }); diff --git a/test/grpc/SparseVector.csr.spec.ts b/test/grpc/SparseVector.csr.spec.ts index 2b0ed040..c87e501f 100644 --- a/test/grpc/SparseVector.csr.spec.ts +++ b/test/grpc/SparseVector.csr.spec.ts @@ -101,7 +101,9 @@ describe(`Sparse vectors type:CSR API testing`, () => { output_fields: ['vector', 'id'], }); - const originKeys = data[0].vector.indices.map((index: number) => index.toString()); + const originKeys = data[0].vector.indices.map((index: number) => + index.toString() + ); const originValues = data[0].vector.values; const outputKeys: string[] = Object.keys(query.data[0].vector); @@ -124,4 +126,16 @@ describe(`Sparse vectors type:CSR API testing`, () => { expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); expect(search.results.length).toBeGreaterThan(0); }); + + it(`search with sparse vector with nq > 1 should be successful`, async () => { + const search = await milvusClient.search({ + vectors: [data[0].vector, data[1].vector], + collection_name: COLLECTION_NAME, + output_fields: ['id', 'vector'], + limit: 5, + }); + + expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); + expect(search.results.length).toEqual(2); + }); }); diff --git a/test/grpc/SparseVector.dict.spec.ts b/test/grpc/SparseVector.dict.spec.ts index e01bd4f0..bdcaa569 100644 --- a/test/grpc/SparseVector.dict.spec.ts +++ b/test/grpc/SparseVector.dict.spec.ts @@ -120,4 +120,16 @@ describe(`Sparse vectors type:dict API testing`, () => { expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); expect(search.results.length).toBeGreaterThan(0); }); + + it(`search with sparse vector with nq > 1 should be successful`, async () => { + const search = await milvusClient.search({ + vectors: [data[0].vector, data[1].vector], + collection_name: COLLECTION_NAME, + output_fields: ['id', 'vector'], + limit: 5, + }); + + expect(search.status.error_code).toEqual(ErrorCode.SUCCESS); + expect(search.results.length).toEqual(2); + }); }); diff --git a/test/utils/Bytes.spec.ts b/test/utils/Bytes.spec.ts index f753e2dd..4f994c81 100644 --- a/test/utils/Bytes.spec.ts +++ b/test/utils/Bytes.spec.ts @@ -63,4 +63,15 @@ describe('Sparse rows <-> Bytes conversion', () => { ]; expect(getSparseFloatVectorType(data)).toEqual('coo'); }); + + it('should return "unknown" if the input is not recognized', () => { + const data: any = 'invalid'; + expect(getSparseFloatVectorType(data)).toEqual('unknown'); + + const data2: any = [ + [1, 2, 3], + [4, 5, 6], + ]; + expect(getSparseFloatVectorType(data2)).toEqual('unknown'); + }); }); diff --git a/test/utils/Format.spec.ts b/test/utils/Format.spec.ts index 614b888d..ba4ec45a 100644 --- a/test/utils/Format.spec.ts +++ b/test/utils/Format.spec.ts @@ -26,6 +26,7 @@ import { buildFieldData, formatSearchResult, Field, + formatSearchVector, } from '../../milvus'; describe('utils/format', () => { @@ -586,4 +587,96 @@ describe('utils/format', () => { expect(results).toEqual(expectedResults); }); + + it('should format search vector correctly', () => { + // float vector + const floatVector = [1, 2, 3]; + const formattedVector = formatSearchVector( + floatVector, + DataType.FloatVector + ); + expect(formattedVector).toEqual([floatVector]); + + const floatVectors = [ + [1, 2, 3], + [4, 5, 6], + ]; + expect(formatSearchVector(floatVectors, DataType.FloatVector)).toEqual( + floatVectors + ); + }); + + // sparse coo vector + const sparseCooVector = [ + { index: 1, value: 2 }, + { index: 3, value: 4 }, + ]; + const formattedSparseCooVector = formatSearchVector( + sparseCooVector, + DataType.SparseFloatVector + ); + expect(formattedSparseCooVector).toEqual([sparseCooVector]); + + // sparse csr vector + const sparseCsrVector = { + indices: [1, 3], + values: [2, 4], + }; + const formattedSparseCsrVector = formatSearchVector( + sparseCsrVector, + DataType.SparseFloatVector + ); + expect(formattedSparseCsrVector).toEqual([sparseCsrVector]); + + const sparseCsrVectors = [ + { + indices: [1, 3], + values: [2, 4], + }, + { + indices: [2, 4], + values: [3, 5], + }, + ]; + const formattedSparseCsrVectors = formatSearchVector( + sparseCsrVectors, + DataType.SparseFloatVector + ); + expect(formattedSparseCsrVectors).toEqual(sparseCsrVectors); + + // sparse array vector + const sparseArrayVector = [0.1, 0.2, 0.3]; + const formattedSparseArrayVector = formatSearchVector( + sparseArrayVector, + DataType.SparseFloatVector + ); + expect(formattedSparseArrayVector).toEqual([sparseArrayVector]); + + const sparseArrayVectors = [ + [0.1, 0.2, 0.3], + [0.4, 0.5, 0.6], + ]; + const formattedSparseArrayVectors = formatSearchVector( + sparseArrayVectors, + DataType.SparseFloatVector + ); + expect(formattedSparseArrayVectors).toEqual(sparseArrayVectors); + + // sparse dict vector + const sparseDictVector = { 1: 2, 3: 4 }; + const formattedSparseDictVector = formatSearchVector( + sparseDictVector, + DataType.SparseFloatVector + ); + expect(formattedSparseDictVector).toEqual([sparseDictVector]); + + const sparseDictVectors = [ + { 1: 2, 3: 4 }, + { 1: 2, 3: 4 }, + ]; + const formattedSparseDictVectors = formatSearchVector( + sparseDictVectors, + DataType.SparseFloatVector + ); + expect(formattedSparseDictVectors).toEqual(sparseDictVectors); });