Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
141802: cspann: finish vecstore refactor r=drewkimball a=andy-kimball

cspann: finish vecstore refactor

Update all the places that reference vecstore types to now reference
their new package locations, either cspann or memstore. Common store
tests are in a new commontest package. Various shared utility funcs
have been moved to utils/testutils packages. Rename various functions
to reflect the new packaging.

Epic: CRDB-42943

Release note: None


Co-authored-by: Andrew Kimball <andyk@cockroachlabs.com>
  • Loading branch information
craig[bot] and andy-kimball committed Feb 21, 2025
2 parents fe860f6 + 9651b3f commit 6b6f886
Show file tree
Hide file tree
Showing 44 changed files with 1,868 additions and 1,752 deletions.
5 changes: 5 additions & 0 deletions pkg/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ ALL_TESTS = [
"//pkg/sql/ttl/ttljob:ttljob_test",
"//pkg/sql/types:types_disallowed_imports_test",
"//pkg/sql/types:types_test",
"//pkg/sql/vecindex/cspann/memstore:memstore_test",
"//pkg/sql/vecindex/cspann/quantize:quantize_test",
"//pkg/sql/vecindex/cspann:cspann_test",
"//pkg/sql/vecindex/veclib:veclib_test",
Expand Down Expand Up @@ -2335,9 +2336,13 @@ GO_TARGETS = [
"//pkg/sql/ttl/ttlschedule:ttlschedule",
"//pkg/sql/types:types",
"//pkg/sql/types:types_test",
"//pkg/sql/vecindex/cspann/commontest:commontest",
"//pkg/sql/vecindex/cspann/memstore:memstore",
"//pkg/sql/vecindex/cspann/memstore:memstore_test",
"//pkg/sql/vecindex/cspann/quantize:quantize",
"//pkg/sql/vecindex/cspann/quantize:quantize_test",
"//pkg/sql/vecindex/cspann/testutils:testutils",
"//pkg/sql/vecindex/cspann/utils:utils",
"//pkg/sql/vecindex/cspann:cspann",
"//pkg/sql/vecindex/cspann:cspann_test",
"//pkg/sql/vecindex/veclib:veclib",
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/vecbench/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ go_library(
visibility = ["//visibility:private"],
deps = [
"//pkg/sql/vecindex/cspann",
"//pkg/sql/vecindex/cspann/memstore",
"//pkg/sql/vecindex/cspann/quantize",
"//pkg/sql/vecindex/veclib",
"//pkg/sql/vecindex/vecstore",
"//pkg/util/stop",
"//pkg/util/syncutil",
"//pkg/util/timeutil",
Expand Down
54 changes: 26 additions & 28 deletions pkg/cmd/vecbench/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import (

"cloud.google.com/go/storage"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/memstore"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/quantize"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/veclib"
"github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecstore"
"github.com/cockroachdb/cockroach/pkg/util/stop"
"github.com/cockroachdb/cockroach/pkg/util/timeutil"
"github.com/cockroachdb/cockroach/pkg/util/vector"
Expand Down Expand Up @@ -176,19 +176,19 @@ func searchIndex(ctx context.Context, stopper *stop.Stopper, datasetName string)

// If index file has not been built, then do so now. Otherwise, load it from
// disk.
var inMemStore *vecstore.InMemoryStore
var index *cspann.VectorIndex
var memStore *memstore.Store
var index *cspann.Index
_, err := os.Stat(indexFileName)
if err != nil {
if !oserror.IsNotExist(err) {
panic(err)
}

inMemStore, index = buildIndex(ctx, stopper, datasetName)
saveStore(inMemStore, indexFileName)
memStore, index = buildIndex(ctx, stopper, datasetName)
saveStore(memStore, indexFileName)
} else {
inMemStore = loadStore(indexFileName)
index = createIndex(ctx, stopper, inMemStore)
memStore = loadStore(indexFileName)
index = createIndex(ctx, stopper, memStore)
}

// Load test data.
Expand All @@ -209,7 +209,7 @@ func searchIndex(ctx context.Context, stopper *stop.Stopper, datasetName string)
// Calculate truth set for the vector.
queryVector := data.Test.At(i)

searchSet := vecstore.SearchSet{MaxResults: *flagMaxResults}
searchSet := cspann.SearchSet{MaxResults: *flagMaxResults}
searchOptions := cspann.SearchOptions{BaseBeamSize: beamSize}

// Calculate prediction set for the vector.
Expand All @@ -219,13 +219,13 @@ func searchIndex(ctx context.Context, stopper *stop.Stopper, datasetName string)
}
results := searchSet.PopResults()

prediction := make([]vecstore.KeyBytes, searchSet.MaxResults)
prediction := make([]cspann.KeyBytes, searchSet.MaxResults)
for res := 0; res < len(results); res++ {
prediction[res] = results[res].ChildKey.KeyBytes
}

primaryKeys := make([]byte, searchSet.MaxResults*4)
truth := make([]vecstore.KeyBytes, searchSet.MaxResults)
truth := make([]cspann.KeyBytes, searchSet.MaxResults)
for neighbor := 0; neighbor < searchSet.MaxResults; neighbor++ {
primaryKey := primaryKeys[neighbor*4 : neighbor*4+4]
binary.BigEndian.PutUint32(primaryKey, uint32(data.Neighbors[i][neighbor]))
Expand All @@ -248,7 +248,7 @@ func searchIndex(ctx context.Context, stopper *stop.Stopper, datasetName string)
}

fmt.Printf(White+"%s\n"+Reset, datasetName)
trainVectors := inMemStore.GetAllVectors()
trainVectors := memStore.GetAllVectors()
fmt.Printf(
White+"%d train vectors, %d test vectors, %d dimensions, %d/%d min/max partitions, base beam size %d\n"+Reset,
len(trainVectors), data.Test.Count, data.Test.Dims,
Expand Down Expand Up @@ -380,7 +380,7 @@ func downloadDataset(ctx context.Context, datasetName string) dataset {
// built index to the tmp directory.
func buildIndex(
ctx context.Context, stopper *stop.Stopper, datasetName string,
) (*vecstore.InMemoryStore, *cspann.VectorIndex) {
) (*memstore.Store, *cspann.Index) {
// Ensure dataset file has been downloaded.
data := downloadDataset(ctx, datasetName)
if *flagBuildCount != 0 {
Expand All @@ -393,7 +393,7 @@ func buildIndex(
}

// Create index.
store := vecstore.NewInMemoryStore(data.Train.Dims, seed)
store := memstore.New(data.Train.Dims, seed)
index := createIndex(ctx, stopper, store)

// Create unique primary key for each vector in a single large byte buffer.
Expand Down Expand Up @@ -496,28 +496,26 @@ func buildIndex(
}

// createIndex returns a vector index created using the given store.
func createIndex(
ctx context.Context, stopper *stop.Stopper, store vecstore.Store,
) *cspann.VectorIndex {
inMemStore := store.(*vecstore.InMemoryStore)
quantizer := quantize.NewRaBitQuantizer(inMemStore.Dims(), seed)
options := cspann.VectorIndexOptions{
func createIndex(ctx context.Context, stopper *stop.Stopper, store cspann.Store) *cspann.Index {
memStore := store.(*memstore.Store)
quantizer := quantize.NewRaBitQuantizer(memStore.Dims(), seed)
options := cspann.IndexOptions{
MinPartitionSize: minPartitionSize,
MaxPartitionSize: maxPartitionSize,
BaseBeamSize: *flagBeamSize,
}
index, err := cspann.NewVectorIndex(ctx, store, quantizer, seed, &options, stopper)
index, err := cspann.NewIndex(ctx, store, quantizer, seed, &options, stopper)
if err != nil {
panic(err)
}
return index
}

// saveStore serializes the store as a protobuf and saves it to the given file.
func saveStore(inMemStore *vecstore.InMemoryStore, fileName string) {
func saveStore(memStore *memstore.Store, fileName string) {
startTime := timeutil.Now()

indexBytes, err := inMemStore.MarshalBinary()
indexBytes, err := memStore.MarshalBinary()
if err != nil {
panic(err)
}
Expand All @@ -538,21 +536,21 @@ func saveStore(inMemStore *vecstore.InMemoryStore, fileName string) {
}

// loadStore deserializes a previously saved protobuf of a vector store.
func loadStore(fileName string) *vecstore.InMemoryStore {
func loadStore(fileName string) *memstore.Store {
startTime := timeutil.Now()

data, err := os.ReadFile(fileName)
if err != nil {
panic(err)
}
inMemStore, err := vecstore.LoadInMemoryStore(data)
memStore, err := memstore.Load(data)
if err != nil {
panic(err)
}

elapsed := timeutil.Since(startTime)
fmt.Printf(Cyan+"Loaded %s index from disk in %v\n"+Reset, fileName, roundDuration(elapsed))
return inMemStore
return memStore
}

// loadDataset deserializes a dataset saved as a gob file.
Expand All @@ -577,15 +575,15 @@ func loadDataset(fileName string) dataset {
return data
}

func beginTransaction(ctx context.Context, w *veclib.Workspace, store vecstore.Store) vecstore.Txn {
func beginTransaction(ctx context.Context, w *veclib.Workspace, store cspann.Store) cspann.Txn {
txn, err := store.Begin(ctx, w)
if err != nil {
panic(err)
}
return txn
}

func commitTransaction(ctx context.Context, store vecstore.Store, txn vecstore.Txn) {
func commitTransaction(ctx context.Context, store cspann.Store, txn cspann.Txn) {
if err := store.Commit(ctx, txn); err != nil {
panic(err)
}
Expand All @@ -595,7 +593,7 @@ func commitTransaction(ctx context.Context, store vecstore.Store, txn vecstore.T
// results with the true set of results. Both sets are expected to be of equal
// length. It returns the percentage overlap of the predicted set with the truth
// set.
func findMAP(prediction, truth []vecstore.KeyBytes) float64 {
func findMAP(prediction, truth []cspann.KeyBytes) float64 {
if len(prediction) != len(truth) {
panic(errors.AssertionFailedf("prediction and truth sets are not same length"))
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/gen/protobuf.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,10 @@ PROTOBUF_SRCS = [
"//pkg/sql/sqlstats/persistedsqlstats:persistedsqlstats_go_proto",
"//pkg/sql/stats:stats_go_proto",
"//pkg/sql/types:types_go_proto",
"//pkg/sql/vecindex/cspann/memstore:memstore_go_proto",
"//pkg/sql/vecindex/cspann/quantize:quantize_go_proto",
"//pkg/sql/vecindex/cspann:cspann_go_proto",
"//pkg/sql/vecindex/vecpb:vecpb_go_proto",
"//pkg/sql/vecindex/vecstore:vecstore_go_proto",
"//pkg/storage/enginepb:enginepb_go_proto",
"//pkg/storage/storagepb:storagepb_go_proto",
"//pkg/testutils/grpcutils:grpcutils_go_proto",
Expand Down
39 changes: 35 additions & 4 deletions pkg/sql/vecindex/cspann/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
load("@rules_proto//proto:defs.bzl", "proto_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")

filegroup(
Expand All @@ -9,48 +11,60 @@ filegroup(
go_library(
name = "cspann",
srcs = [
"cspannpb.go",
"fixup_processor.go",
"fixup_worker.go",
"index.go",
"index_stats.go",
"kmeans.go",
"pacer.go",
"partition.go",
"search_set.go",
"split_data.go",
"vector_index.go",
"store.go",
],
embed = [":cspann_go_proto"],
importpath = "github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann",
visibility = ["//visibility:public"],
deps = [
"//pkg/sql/vecindex/cspann/quantize",
"//pkg/sql/vecindex/cspann/utils",
"//pkg/sql/vecindex/veclib",
"//pkg/sql/vecindex/vecstore",
"//pkg/util/container/heap",
"//pkg/util/log",
"//pkg/util/num32",
"//pkg/util/stop",
"//pkg/util/syncutil",
"//pkg/util/vector",
"@com_github_cockroachdb_crlib//crtime",
"@com_github_cockroachdb_errors//:errors",
"@com_github_gogo_protobuf//gogoproto",
"@org_gonum_v1_gonum//stat",
],
)

go_test(
name = "cspann_test",
srcs = [
"cspannpb_test.go",
"fixup_processor_test.go",
"fixup_worker_test.go",
"index_stats_test.go",
"index_test.go",
"kmeans_test.go",
"pacer_test.go",
"vector_index_test.go",
"partition_test.go",
"search_set_test.go",
],
data = ["//pkg/sql/vecindex/cspann:testdata"],
embed = [":cspann"],
deps = [
"//pkg/sql/vecindex/cspann/commontest",
"//pkg/sql/vecindex/cspann/memstore",
"//pkg/sql/vecindex/cspann/quantize",
"//pkg/sql/vecindex/cspann/testutils",
"//pkg/sql/vecindex/cspann/utils",
"//pkg/sql/vecindex/veclib",
"//pkg/sql/vecindex/vecstore",
"//pkg/util/leaktest",
"//pkg/util/log",
"//pkg/util/num32",
Expand All @@ -65,3 +79,20 @@ go_test(
"@org_gonum_v1_gonum//stat",
],
)

proto_library(
name = "cspann_proto",
srcs = ["cspann.proto"],
strip_import_prefix = "/pkg",
visibility = ["//visibility:public"],
deps = ["@com_github_gogo_protobuf//gogoproto:gogo_proto"],
)

go_proto_library(
name = "cspann_go_proto",
compilers = ["//pkg/cmd/protoc-gen-gogoroach:protoc-gen-gogoroach_compiler"],
importpath = "github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann",
proto = ":cspann_proto",
visibility = ["//visibility:public"],
deps = ["@com_github_gogo_protobuf//gogoproto"],
)
17 changes: 17 additions & 0 deletions pkg/sql/vecindex/cspann/commontest/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")

go_library(
name = "commontest",
srcs = ["storetests.go"],
importpath = "github.com/cockroachdb/cockroach/pkg/sql/vecindex/cspann/commontest",
visibility = ["//visibility:public"],
deps = [
"//pkg/sql/vecindex/cspann",
"//pkg/sql/vecindex/cspann/quantize",
"//pkg/sql/vecindex/cspann/testutils",
"//pkg/sql/vecindex/veclib",
"//pkg/util/vector",
"@com_github_stretchr_testify//require",
"@org_gonum_v1_gonum//floats/scalar",
],
)
Loading

0 comments on commit 6b6f886

Please sign in to comment.