Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Handling of UUID and Other ID Formats in PgVectorStore #2111

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,8 @@

import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.util.JacksonUtils;
Expand Down Expand Up @@ -152,6 +150,7 @@
* @author Thomas Vitale
* @author Soby Chacko
* @author Sebastien Deleuze
* @author Jihoon Kim
* @since 1.0.0
*/
public class PgVectorStore extends AbstractObservationVectorStore implements InitializingBean {
Expand All @@ -162,6 +161,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini

public static final String DEFAULT_TABLE_NAME = "vector_store";

public static final PgIdType DEFAULT_ID_TYPE = PgIdType.UUID;

public static final String DEFAULT_VECTOR_INDEX_NAME = "spring_ai_vector_index";

public static final String DEFAULT_SCHEMA_NAME = "public";
Expand All @@ -187,6 +188,8 @@ public class PgVectorStore extends AbstractObservationVectorStore implements Ini

private final String schemaName;

private final PgIdType idType;

private final boolean schemaValidation;

private final boolean initializeSchema;
Expand Down Expand Up @@ -224,6 +227,7 @@ protected PgVectorStore(PgVectorStoreBuilder builder) {
: this.vectorTableName + "_index";

this.schemaName = builder.schemaName;
this.idType = builder.idType;
this.schemaValidation = builder.vectorTableValidationsEnabled;

this.jdbcTemplate = builder.jdbcTemplate;
Expand Down Expand Up @@ -272,13 +276,13 @@ private void insertOrUpdateBatch(List<Document> batch, List<Document> documents,
public void setValues(PreparedStatement ps, int i) throws SQLException {

var document = batch.get(i);
var id = convertIdToPgType(document.getId());
var content = document.getText();
var json = toJson(document.getMetadata());
var embedding = embeddings.get(documents.indexOf(document));
var pGvector = new PGvector(embedding);

StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN,
UUID.fromString(document.getId()));
StatementCreatorUtils.setParameterValue(ps, 1, SqlTypeValue.TYPE_UNKNOWN, id);
StatementCreatorUtils.setParameterValue(ps, 2, SqlTypeValue.TYPE_UNKNOWN, content);
StatementCreatorUtils.setParameterValue(ps, 3, SqlTypeValue.TYPE_UNKNOWN, json);
StatementCreatorUtils.setParameterValue(ps, 4, SqlTypeValue.TYPE_UNKNOWN, pGvector);
Expand All @@ -303,6 +307,19 @@ private String toJson(Map<String, Object> map) {
}
}

private Object convertIdToPgType(String id) {
if (this.initializeSchema) {
return UUID.fromString(id);
}

return switch (getIdType()) {
case UUID -> UUID.fromString(id);
case TEXT -> id;
case INTEGER, SERIAL -> Integer.valueOf(id);
case BIGSERIAL -> Long.valueOf(id);
};
}

@Override
public Optional<Boolean> doDelete(List<String> idList) {
int updateCount = 0;
Expand Down Expand Up @@ -412,6 +429,10 @@ private String getFullyQualifiedTableName() {
return this.schemaName + "." + this.vectorTableName;
}

private PgIdType getIdType() {
return this.idType;
}

private String getVectorTableName() {
return this.vectorTableName;
}
Expand Down Expand Up @@ -489,6 +510,12 @@ public enum PgIndexType {

}

public enum PgIdType {

UUID, TEXT, INTEGER, SERIAL, BIGSERIAL

}

/**
* Defaults to CosineDistance. But if vectors are normalized to length 1 (like OpenAI
* embeddings), use inner product (NegativeInnerProduct) for best performance.
Expand Down Expand Up @@ -584,6 +611,8 @@ public static final class PgVectorStoreBuilder extends AbstractVectorStoreBuilde

private String vectorTableName = PgVectorStore.DEFAULT_TABLE_NAME;

private PgIdType idType = PgVectorStore.DEFAULT_ID_TYPE;

private boolean vectorTableValidationsEnabled = PgVectorStore.DEFAULT_SCHEMA_VALIDATION;

private int dimensions = PgVectorStore.INVALID_EMBEDDING_DIMENSION;
Expand Down Expand Up @@ -614,6 +643,11 @@ public PgVectorStoreBuilder vectorTableName(String vectorTableName) {
return this;
}

public PgVectorStoreBuilder idType(PgIdType idType) {
this.idType = idType;
return this;
}

public PgVectorStoreBuilder vectorTableValidationsEnabled(boolean vectorTableValidationsEnabled) {
this.vectorTableValidationsEnabled = vectorTableValidationsEnabled;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand All @@ -29,6 +30,7 @@

import com.zaxxer.hikari.HikariDataSource;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
Expand All @@ -40,13 +42,15 @@

import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.document.id.RandomIdGenerator;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIdType;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser.FilterExpressionParseException;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore.PgIndexType;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
Expand All @@ -67,6 +71,7 @@
* @author Muthukumaran Navaneethakrishnan
* @author Christian Tzolov
* @author Thomas Vitale
* @author Jihoon Kim
*/
@Testcontainers
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
Expand Down Expand Up @@ -103,6 +108,27 @@ public static String getText(String uri) {
}
}

private static void initSchema(ApplicationContext context) {
PgVectorStore vectorStore = context.getBean(PgVectorStore.class);
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
// Enable the PGVector, JSONB and UUID support.
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector");
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS hstore");
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"");

jdbcTemplate.execute(String.format("CREATE SCHEMA IF NOT EXISTS %s", PgVectorStore.DEFAULT_SCHEMA_NAME));

jdbcTemplate.execute(String.format("""
CREATE TABLE IF NOT EXISTS %s.%s (
id text PRIMARY KEY,
content text,
metadata json,
embedding vector(%d)
)
""", PgVectorStore.DEFAULT_SCHEMA_NAME, PgVectorStore.DEFAULT_TABLE_NAME,
vectorStore.embeddingDimensions()));
}

private static void dropTable(ApplicationContext context) {
JdbcTemplate jdbcTemplate = context.getBean(JdbcTemplate.class);
jdbcTemplate.execute("DROP TABLE IF EXISTS vector_store");
Expand Down Expand Up @@ -166,6 +192,35 @@ public void addAndSearch(String distanceType) {
});
}

@Test
public void testToPgTypeWithUuidIdType() {
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
.run(context -> {

VectorStore vectorStore = context.getBean(VectorStore.class);

vectorStore.add(List.of(new Document(new RandomIdGenerator().generateId(), "TEXT", new HashMap<>())));

dropTable(context);
});
}

@Test
public void testToPgTypeWithNonUuidIdType() {
this.contextRunner.withPropertyValues("test.spring.ai.vectorstore.pgvector.distanceType=" + "COSINE_DISTANCE")
.withPropertyValues("test.spring.ai.vectorstore.pgvector.initializeSchema=" + false)
.withPropertyValues("test.spring.ai.vectorstore.pgvector.idType=" + "TEXT")
.run(context -> {

VectorStore vectorStore = context.getBean(VectorStore.class);
initSchema(context);

vectorStore.add(List.of(new Document("NOT_UUID", "TEXT", new HashMap<>())));

dropTable(context);
});
}

@ParameterizedTest(name = "Filter expression {0} should return {1} records ")
@MethodSource("provideFilters")
public void searchWithInFilter(String expression, Integer expectedRecords) {
Expand Down Expand Up @@ -371,12 +426,19 @@ public static class TestApplication {
@Value("${test.spring.ai.vectorstore.pgvector.distanceType}")
PgVectorStore.PgDistanceType distanceType;

@Value("${test.spring.ai.vectorstore.pgvector.initializeSchema:true}")
boolean initializeSchema;

@Value("${test.spring.ai.vectorstore.pgvector.idType:UUID}")
PgIdType idType;

@Bean
public VectorStore vectorStore(JdbcTemplate jdbcTemplate, EmbeddingModel embeddingModel) {
return PgVectorStore.builder(jdbcTemplate, embeddingModel)
.dimensions(PgVectorStore.INVALID_EMBEDDING_DIMENSION)
.idType(idType)
.distanceType(this.distanceType)
.initializeSchema(true)
.initializeSchema(initializeSchema)
.indexType(PgIndexType.HNSW)
.removeExistingVectorStoreTable(true)
.build();
Expand Down
Loading