Skip to content

Commit

Permalink
Add download + indexOuput#write implementation to RemoteIndexBuildStr…
Browse files Browse the repository at this point in the history
…ategy

Signed-off-by: Jay Deng <jayd0104@gmail.com>
  • Loading branch information
jed326 authored and Jay Deng committed Feb 24, 2025
1 parent c7ac05c commit 1a11805
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public NativeIndexBuildStrategy getBuildStrategy(final FieldInfo fieldInfo) {
&& indexSettings != null
&& knnEngine.supportsRemoteIndexBuild()
&& RemoteIndexBuildStrategy.shouldBuildIndexRemotely(indexSettings)) {
return new RemoteIndexBuildStrategy(repositoriesServiceSupplier, strategy);
return new RemoteIndexBuildStrategy(repositoriesServiceSupplier, strategy, indexSettings);
} else {
return strategy;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,30 @@

package org.opensearch.knn.index.codec.nativeindex.remote;

import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang.NotImplementedException;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IndexOutput;
import org.opensearch.common.StopWatch;
import org.opensearch.common.UUIDs;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.common.blobstore.BlobContainer;
import org.opensearch.common.blobstore.BlobPath;
import org.opensearch.index.IndexSettings;
import org.opensearch.knn.common.featureflags.KNNFeatureFlags;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.repositories.Repository;
import org.opensearch.repositories.RepositoryMissingException;
import org.opensearch.repositories.blobstore.BlobStoreRepository;

import java.io.IOException;
import java.io.InputStream;
import java.util.function.Supplier;

import static org.opensearch.knn.index.KNNSettings.KNN_INDEX_REMOTE_VECTOR_BUILD_SETTING;
Expand All @@ -37,17 +44,25 @@ public class RemoteIndexBuildStrategy implements NativeIndexBuildStrategy {

private final Supplier<RepositoriesService> repositoriesServiceSupplier;
private final NativeIndexBuildStrategy fallbackStrategy;
private final IndexSettings indexSettings;

private static final String VECTOR_BLOB_FILE_EXTENSION = ".knnvec";
private static final String DOC_ID_FILE_EXTENSION = ".knndid";
private static final String VECTORS_PATH = "_vectors";

/**
* Public constructor
*
* @param repositoriesServiceSupplier A supplier for {@link RepositoriesService} used for interacting with repository
*/
public RemoteIndexBuildStrategy(Supplier<RepositoriesService> repositoriesServiceSupplier, NativeIndexBuildStrategy fallbackStrategy) {
public RemoteIndexBuildStrategy(
Supplier<RepositoriesService> repositoriesServiceSupplier,
NativeIndexBuildStrategy fallbackStrategy,
IndexSettings indexSettings
) {
this.repositoriesServiceSupplier = repositoriesServiceSupplier;
this.fallbackStrategy = fallbackStrategy;
this.indexSettings = indexSettings;
}

/**
Expand Down Expand Up @@ -98,7 +113,9 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException {
log.debug("Await vector build took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());

stopWatch = new StopWatch().start();
readFromRepository();
// TODO: This blob will be retrieved from the remote vector build service status response
String blobName = UUIDs.base64UUID() + "_" + indexInfo.getFieldName() + "_" + indexInfo.getSegmentWriteState().segmentInfo.name;
readFromRepository(blobName + KNNEngine.FAISS.getExtension(), indexInfo.getIndexOutputWithBuffer().getIndexOutput());
time_in_millis = stopWatch.stop().totalTime().millis();
log.debug("Repository read took {} ms for vector field [{}]", time_in_millis, indexInfo.getFieldName());
} catch (Exception e) {
Expand Down Expand Up @@ -126,6 +143,14 @@ private BlobStoreRepository getRepository() throws RepositoryMissingException {
return (BlobStoreRepository) repository;
}

/**
* @return The blob container to read/write from, determined from the repository base path and index settings. This container is where all blobs will be written to.
*/
private BlobContainer getBlobContainer() {
BlobPath path = getRepository().basePath().add(indexSettings.getUUID() + VECTORS_PATH);
return getRepository().blobStore().blobContainer(path);
}

/**
* Write relevant vector data to repository
*
Expand Down Expand Up @@ -163,7 +188,27 @@ private void awaitVectorBuild() {
/**
* Read constructed vector file from remote repository and write to IndexOutput
*/
private void readFromRepository() {
throw new NotImplementedException();
@VisibleForTesting
void readFromRepository(String blobName, IndexOutput indexOutput) throws IOException {
BlobContainer blobContainer = getBlobContainer();
// TODO: We are using the sequential download API as multi-part parallel download is difficult for us to implement today and
// requires some changes in core. For more details, see: https://github.com/opensearch-project/k-NN/issues/2464
InputStream graphStream = blobContainer.readBlob(blobName);

// Allocate buffer of 64KB, same as used for CPU builds, see: IndexOutputWithBuffer
int CHUNK_SIZE = 64 * 1024;
byte[] buffer = new byte[CHUNK_SIZE];

int bytesRead = 0;
// InputStream uses -1 indicates there are no more bytes to be read
while (bytesRead != -1) {
// Try to read CHUNK_SIZE into the buffer. The actual amount read may be less.
bytesRead = graphStream.read(buffer, 0, CHUNK_SIZE);
assert bytesRead <= CHUNK_SIZE;
// However many bytes we read, write it to the IndexOutput if != -1
if (bytesRead != -1) {
indexOutput.writeBytes(buffer, 0, bytesRead);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

package org.opensearch.knn.index.store;

import lombok.Getter;
import org.apache.lucene.store.IndexOutput;

import java.io.IOException;

public class IndexOutputWithBuffer {
// Underlying `IndexOutput` obtained from Lucene's Directory.
@Getter
private IndexOutput indexOutput;
// Write buffer. Native engine will copy bytes into this buffer.
// Allocating 64KB here since it show better performance in NMSLIB with the size. (We had slightly improvement in FAISS than having 4KB)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,19 @@

package org.opensearch.knn.index.codec.nativeindex.remote;

import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.junit.Before;
import org.mockito.Mockito;
import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer;
import org.opensearch.common.blobstore.BlobPath;
import org.opensearch.common.blobstore.BlobStore;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.index.IndexSettings;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
Expand All @@ -16,17 +28,21 @@
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.repositories.RepositoryMissingException;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.repositories.blobstore.BlobStoreRepository;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import java.util.Random;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.index.KNNSettings.KNN_REMOTE_VECTOR_REPO_SETTING;

public class RemoteIndexBuildStrategyTests extends OpenSearchTestCase {
public class RemoteIndexBuildStrategyTests extends KNNTestCase {

static int fallbackCounter = 0;

Expand All @@ -38,6 +54,16 @@ public void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException {
}
}

@Before
@Override
public void setUp() throws Exception {
super.setUp();
ClusterSettings clusterSettings = mock(ClusterSettings.class);
when(clusterSettings.get(KNN_REMOTE_VECTOR_REPO_SETTING)).thenReturn("test-repo-name");
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
KNNSettings.state().setClusterService(clusterService);
}

public void testFallback() throws IOException {
List<float[]> vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 });
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
Expand All @@ -48,7 +74,11 @@ public void testFallback() throws IOException {
RepositoriesService repositoriesService = mock(RepositoriesService.class);
when(repositoriesService.repository(any())).thenThrow(new RepositoryMissingException("Fallback"));

RemoteIndexBuildStrategy objectUnderTest = new RemoteIndexBuildStrategy(() -> repositoriesService, new TestIndexBuildStrategy());
RemoteIndexBuildStrategy objectUnderTest = new RemoteIndexBuildStrategy(
() -> repositoriesService,
new TestIndexBuildStrategy(),
mock(IndexSettings.class)
);

IndexOutputWithBuffer indexOutputWithBuffer = Mockito.mock(IndexOutputWithBuffer.class);

Expand All @@ -64,4 +94,54 @@ public void testFallback() throws IOException {
objectUnderTest.buildAndWriteIndex(buildIndexParams);
assertEquals(1, fallbackCounter);
}

/**
* Verify the buffered read method in {@link RemoteIndexBuildStrategy#readFromRepository} produces the correct result
*/
public void testRepositoryRead() throws IOException {
// Create an InputStream with random values
int TEST_ARRAY_SIZE = 64 * 1024 * 10;
byte[] byteArray = new byte[TEST_ARRAY_SIZE];
Random random = new Random();
random.nextBytes(byteArray);
InputStream randomStream = new ByteArrayInputStream(byteArray);

// Create a test segment that we will read/write from
Directory directory;
directory = newFSDirectory(createTempDir());
String TEST_SEGMENT_NAME = "test-segment-name";
IndexOutput testIndexOutput = directory.createOutput(TEST_SEGMENT_NAME, IOContext.DEFAULT);

// Set up RemoteIndexBuildStrategy and write to IndexOutput
RepositoriesService repositoriesService = mock(RepositoriesService.class);
BlobStoreRepository mockRepository = mock(BlobStoreRepository.class);
BlobPath testBasePath = new BlobPath().add("testBasePath");
BlobStore mockBlobStore = mock(BlobStore.class);
AsyncMultiStreamBlobContainer mockBlobContainer = mock(AsyncMultiStreamBlobContainer.class);

when(repositoriesService.repository(any())).thenReturn(mockRepository);
when(mockRepository.basePath()).thenReturn(testBasePath);
when(mockRepository.blobStore()).thenReturn(mockBlobStore);
when(mockBlobStore.blobContainer(any())).thenReturn(mockBlobContainer);
when(mockBlobContainer.readBlob("test-blob")).thenReturn(randomStream);

RemoteIndexBuildStrategy objectUnderTest = new RemoteIndexBuildStrategy(
() -> repositoriesService,
mock(NativeIndexBuildStrategy.class),
mock(IndexSettings.class)
);
// This should read from randomStream into testIndexOutput
objectUnderTest.readFromRepository("test-blob", testIndexOutput);
testIndexOutput.close();

// Now try to read from the IndexOutput
IndexInput testIndexInput = directory.openInput(TEST_SEGMENT_NAME, IOContext.DEFAULT);
byte[] resultByteArray = new byte[TEST_ARRAY_SIZE];
testIndexInput.readBytes(resultByteArray, 0, TEST_ARRAY_SIZE);
assertArrayEquals(byteArray, resultByteArray);

// Test Cleanup
testIndexInput.close();
directory.close();
}
}

0 comments on commit 1a11805

Please sign in to comment.