Skip to content

Commit

Permalink
Fixing the conflicts while merging from main branch
Browse files Browse the repository at this point in the history
Signed-off-by: Navneet Verma <navneev@amazon.com>
  • Loading branch information
navneet1v committed Jun 4, 2024
1 parent e727e79 commit 85a844c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ public ScriptDocValues<float[]> getScriptValues() {
default:
throw new IllegalStateException("Unsupported Lucene vector encoding: " + fieldInfo.getVectorEncoding());
}
} else if (fieldInfo.getDocValuesType() == DocValuesType.BINARY) {
values = DocValues.getBinary(reader, fieldName);
} else {
return KNNVectorScriptDocValues.emptyValues(fieldName, vectorDataType);
values = DocValues.getBinary(reader, fieldName);
}
return KNNVectorScriptDocValues.create(values, fieldName, vectorDataType);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,9 @@ private MethodComponentContext getMethodComponentContext(KNNMethodContext knnMet
* @param fieldType {@link FieldType}
* @return {@link List} of {@link Field}
*/
protected List<Field> getFieldsForFloatVector(final float[] array, final FieldType fieldType) {
protected List<Field> getFieldsForFloatVector(final float[] array, final FieldType fieldType, final SpaceType spaceType) {
final List<Field> fields = new ArrayList<>();
fields.add(new VectorField(name(), array, fieldType));
fields.add(createVectorField(array, dimension, spaceType, fieldType));
if (this.stored) {
fields.add(createStoredFieldForFloatVector(name(), array));
}
Expand All @@ -567,20 +567,20 @@ protected List<Field> getFieldsForFloatVector(final float[] array, final FieldTy
* @param fieldType {@link FieldType}
* @return {@link List} of {@link Field}
*/
protected List<Field> getFieldsForByteVector(final byte[] array, final FieldType fieldType) {
protected List<Field> getFieldsForByteVector(final byte[] array, final FieldType fieldType, final SpaceType spaceType) {
final List<Field> fields = new ArrayList<>();
fields.add(new VectorField(name(), array, fieldType));
fields.add(createVectorField(array, dimension, spaceType, fieldType));
if (this.stored) {
fields.add(createStoredFieldForByteVector(name(), array));
}
return fields;
}

protected Field createVectorField(float[] vectorValue, int dimension, SpaceType spaceType) {
protected Field createVectorField(float[] vectorValue, int dimension, final SpaceType spaceType, final FieldType fieldType) {
// Because we will come to this function only in case when Native engines are getting used. So I am avoiding the
// check of use Native engines here.
// Also dimension field is only accessible here hence we have to use this function to create fieldType too
if (this.indexCreatedVersion.onOrAfter(Version.V_2_15_0) && SpaceType.VECTOR_FIELD_SUPPORTED_SPACE_TYPES.contains(spaceType)) {
if (this.indexCreatedVersion.onOrAfter(Version.V_3_0_0) && SpaceType.VECTOR_FIELD_SUPPORTED_SPACE_TYPES.contains(spaceType)) {
FieldType tempFieldType = new FieldType(fieldType);
tempFieldType.setVectorAttributes(dimension, VectorEncoding.FLOAT32, spaceType.getVectorSimilarityFunction());
tempFieldType.freeze();
Expand All @@ -589,11 +589,11 @@ protected Field createVectorField(float[] vectorValue, int dimension, SpaceType
return new VectorField(name(), vectorValue, fieldType);
}

protected Field createVectorField(byte[] vectorValue, int dimension, SpaceType spaceType) {
protected Field createVectorField(byte[] vectorValue, int dimension, final SpaceType spaceType, final FieldType fieldType) {
// Because we will come to this function only in case when Native engines are getting used. So I am avoiding the
// check of use Native engines here.
// Also dimension field is only accessible here hence we have to use this function to create fieldType too
if (this.indexCreatedVersion.onOrAfter(Version.V_2_15_0) && SpaceType.VECTOR_FIELD_SUPPORTED_SPACE_TYPES.contains(spaceType)) {
if (this.indexCreatedVersion.onOrAfter(Version.V_3_0_0) && SpaceType.VECTOR_FIELD_SUPPORTED_SPACE_TYPES.contains(spaceType)) {
FieldType tempFieldType = new FieldType(fieldType);
tempFieldType.setVectorAttributes(dimension, VectorEncoding.BYTE, spaceType.getVectorSimilarityFunction());
tempFieldType.freeze();
Expand All @@ -616,7 +616,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
}
final byte[] array = bytesArrayOptional.get();
spaceType.validateVector(array);
context.doc().addAll(getFieldsForByteVector(array, fieldType));
context.doc().addAll(getFieldsForByteVector(array, fieldType, spaceType));
} else if (VectorDataType.FLOAT == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

Expand All @@ -625,7 +625,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
}
final float[] array = floatsArrayOptional.get();
spaceType.validateVector(array);
context.doc().addAll(getFieldsForFloatVector(array, fieldType));
context.doc().addAll(getFieldsForFloatVector(array, fieldType, spaceType));
} else {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "Cannot parse context for unsupported values provided for field [%s]", VECTOR_DATA_TYPE_FIELD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.opensearch.common.Explicit;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorField;
import org.opensearch.knn.index.util.KNNEngine;
Expand Down Expand Up @@ -74,7 +75,7 @@ public class LuceneFieldMapper extends KNNVectorFieldMapper {
}

@Override
protected List<Field> getFieldsForFloatVector(final float[] array, final FieldType fieldType) {
protected List<Field> getFieldsForFloatVector(final float[] array, final FieldType fieldType, final SpaceType spaceType) {
final List<Field> fieldsToBeAdded = new ArrayList<>();
fieldsToBeAdded.add(new KnnVectorField(name(), array, fieldType));

Expand All @@ -89,7 +90,7 @@ protected List<Field> getFieldsForFloatVector(final float[] array, final FieldTy
}

@Override
protected List<Field> getFieldsForByteVector(final byte[] array, final FieldType fieldType) {
protected List<Field> getFieldsForByteVector(final byte[] array, final FieldType fieldType, final SpaceType spaceType) {
final List<Field> fieldsToBeAdded = new ArrayList<>();
fieldsToBeAdded.add(new KnnByteVectorField(name(), array, fieldType));

Expand Down

0 comments on commit 85a844c

Please sign in to comment.