Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Apr 4, 2024
1 parent cfe1457 commit 03ab6cc
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1211,10 +1211,6 @@ void addConflict(String parameter, String existing, String toMerge) {
conflicts.add("Cannot update parameter [" + parameter + "] from [" + existing + "] to [" + toMerge + "]");
}

public boolean hasConflicts() {
return conflicts.isEmpty() == false;
}

public void check() {
if (conflicts.isEmpty()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
Expand Down Expand Up @@ -133,24 +134,24 @@ private record InferenceProvider(InferenceService service, Model model) {}
* @param field The target field.
* @param input The input to run inference on.
* @param inputOrder The original order of the input.
* @param isRawInput Whether the input is part of the raw values of the original field.
* @param isOriginalFieldInput Whether the input is part of the original values of the field.
*/
private record FieldInferenceRequest(int id, String field, String input, int inputOrder, boolean isRawInput) {}
private record FieldInferenceRequest(int id, String field, String input, int inputOrder, boolean isOriginalFieldInput) {}

/**
* The field inference response.
* @param field The target field.
* @param input The input that was used to run inference.
* @param inputOrder The original order of the input.
* @param isRawInput Whether the input is part of the raw values of the original field.
* @param isOriginalFieldInput Whether the input is part of the original values of the field.
* @param model The model used to run inference.
* @param chunkedResults The actual results.
*/
private record FieldInferenceResponse(
String field,
String input,
int inputOrder,
boolean isRawInput,
boolean isOriginalFieldInput,
Model model,
ChunkedInferenceServiceResults chunkedResults
) {}
Expand Down Expand Up @@ -286,7 +287,7 @@ public void onResponse(List<ChunkedInferenceServiceResults> results) {
request.field(),
request.input(),
request.inputOrder(),
request.isRawInput(),
request.isOriginalFieldInput(),
inferenceProvider.model,
result
)
Expand Down Expand Up @@ -370,13 +371,10 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
var fieldName = entry.getKey();
var responses = entry.getValue();
var model = responses.get(0).model();
// ensure that the order in the raw field is consistent in case of multiple inputs
// ensure that the order in the original field is consistent in case of multiple inputs
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
List<String> inputs = responses.stream().filter(r -> r.isRawInput).map(r -> r.input).collect(Collectors.toList());
List<ChunkedInferenceServiceResults> results = entry.getValue()
.stream()
.map(r -> r.chunkedResults)
.collect(Collectors.toList());
List<String> inputs = responses.stream().filter(r -> r.isOriginalFieldInput).map(r -> r.input).collect(Collectors.toList());
List<ChunkedInferenceServiceResults> results = responses.stream().map(r -> r.chunkedResults).collect(Collectors.toList());
var result = new SemanticTextField(
fieldName,
inputs,
Expand All @@ -398,7 +396,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
* field is skipped, and the existing results remain unchanged.
* Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing,
* where an error will be thrown if they mismatch or if the content is malformed.
*
* <p>
* TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ?
*/
private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) {
Expand Down Expand Up @@ -434,13 +432,13 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
for (var entry : fieldInferenceMap.values()) {
String field = entry.getName();
String inferenceId = entry.getInferenceId();
var rawValue = XContentMapValues.extractValue(field, docMap);
if (rawValue instanceof Map) {
var originalFieldValue = XContentMapValues.extractValue(field, docMap);
if (originalFieldValue instanceof Map) {
continue;
}
int order = 0;
for (var sourceField : entry.getSourceFields()) {
boolean isRawField = sourceField.equals(field);
boolean isOriginalFieldInput = sourceField.equals(field);
var valueObj = XContentMapValues.extractValue(sourceField, docMap);
if (valueObj == null) {
if (isUpdateRequest) {
Expand All @@ -458,39 +456,55 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
continue;
}
ensureResponseAccumulatorSlot(item.id());
if (valueObj instanceof String valueStr) {
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(
inferenceId,
k -> new ArrayList<>()
);
fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr, order++, isRawField));
} else if (valueObj instanceof List<?> valueList) {
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(
inferenceId,
k -> new ArrayList<>()
);
for (var value : valueList) {
fieldRequests.add(new FieldInferenceRequest(item.id(), field, value.toString(), order++, isRawField));
}
} else {
addInferenceResponseFailure(
item.id(),
new ElasticsearchStatusException(
"Invalid format for field [{}], expected [String] got [{}]",
RestStatus.BAD_REQUEST,
field,
valueObj.getClass().getSimpleName()
)
);
final List<String> values;
try {
values = nodeStringValues(field, valueObj);
} catch (Exception exc) {
addInferenceResponseFailure(item.id(), exc);
break;
}
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
for (var v : values) {
fieldRequests.add(new FieldInferenceRequest(item.id(), field, v, order++, isOriginalFieldInput));
}
}
}
}
return fieldRequestsMap;
}
}

/**
* This method converts the given {@code valueObj} into a list of strings.
* If {@code valueObj} is not a string or a collection of strings, it throws an ElasticsearchStatusException.
*/
private static List<String> nodeStringValues(String field, Object valueObj) {
if (valueObj instanceof String value) {
return List.of(value);
} else if (valueObj instanceof Collection<?> values) {
List<String> valuesString = new ArrayList<>();
for (var v : values) {
if (v instanceof String value) {
valuesString.add(value);
} else {
throw new ElasticsearchStatusException(
"Invalid format for field [{}], expected [String] got [{}]",
RestStatus.BAD_REQUEST,
field,
valueObj.getClass().getSimpleName()
);
}
}
return valuesString;
}
throw new ElasticsearchStatusException(
"Invalid format for field [{}], expected [String] got [{}]",
RestStatus.BAD_REQUEST,
field,
valueObj.getClass().getSimpleName()
);
}

static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {
if (docWriteRequest instanceof IndexRequest indexRequest) {
return indexRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@

/**
* A {@link ToXContentObject} that is used to represent the transformation of the semantic text field's inputs.
* The resulting object preserves the original input under the {@link SemanticTextField#RAW_FIELD} and exposes
* The resulting object preserves the original input under the {@link SemanticTextField#TEXT_FIELD} and exposes
* the inference results under the {@link SemanticTextField#INFERENCE_FIELD}.
*
* @param fieldName The original field name.
* @param raw The raw values associated with the field name.
* @param originalValues The original values associated with the field name.
* @param inference The inference result.
* @param contentType The {@link XContentType} used to store the embeddings chunks.
*/
public record SemanticTextField(String fieldName, List<String> raw, InferenceResult inference, XContentType contentType)
public record SemanticTextField(String fieldName, List<String> originalValues, InferenceResult inference, XContentType contentType)
implements
ToXContentObject {

static final ParseField RAW_FIELD = new ParseField("raw");
static final ParseField TEXT_FIELD = new ParseField("text");
static final ParseField INFERENCE_FIELD = new ParseField("inference");
static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
static final ParseField CHUNKS_FIELD = new ParseField("chunks");
Expand Down Expand Up @@ -132,8 +132,8 @@ private void validate() {
}
}

public static String getRawFieldName(String fieldName) {
return fieldName + "." + RAW_FIELD.getPreferredName();
public static String getOriginalTextFieldName(String fieldName) {
return fieldName + "." + TEXT_FIELD.getPreferredName();
}

public static String getInferenceFieldName(String fieldName) {
Expand Down Expand Up @@ -177,8 +177,8 @@ static ModelSettings parseModelSettingsFromMap(Object node) {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (raw.isEmpty() == false) {
builder.field(RAW_FIELD.getPreferredName(), raw.size() == 1 ? raw.get(0) : raw);
if (originalValues.isEmpty() == false) {
builder.field(TEXT_FIELD.getPreferredName(), originalValues.size() == 1 ? originalValues.get(0) : originalValues);
}
builder.startObject(INFERENCE_FIELD.getPreferredName());
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inference.inferenceId);
Expand All @@ -204,7 +204,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<SemanticTextField, Tuple<String, XContentType>> SEMANTIC_TEXT_FIELD_PARSER =
new ConstructingObjectParser<>(
"semantic",
SemanticTextFieldMapper.CONTENT_TYPE,
true,
(args, context) -> new SemanticTextField(
context.v1(),
Expand All @@ -216,20 +216,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<InferenceResult, Void> INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>(
"inference",
INFERENCE_FIELD.getPreferredName(),
true,
args -> new InferenceResult((String) args[0], (ModelSettings) args[1], (List<Chunk>) args[2])
);

@SuppressWarnings("unchecked")
private static final ConstructingObjectParser<Chunk, Void> CHUNKS_PARSER = new ConstructingObjectParser<>(
"chunks",
CHUNKS_FIELD.getPreferredName(),
true,
args -> new Chunk((String) args[0], (BytesReference) args[1])
);

private static final ConstructingObjectParser<ModelSettings, Void> MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>(
"model_settings",
MODEL_SETTINGS_FIELD.getPreferredName(),
true,
args -> {
TaskType taskType = TaskType.fromString((String) args[0]);
Expand All @@ -240,7 +240,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
);

static {
SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), RAW_FIELD);
SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), TEXT_FIELD);
SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), (p, c) -> INFERENCE_RESULT_PARSER.parse(p, null), INFERENCE_FIELD);

INFERENCE_RESULT_PARSER.declareString(constructorArg(), INFERENCE_ID_FIELD);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MapperBuilderContext;
import org.elasticsearch.index.mapper.MapperMergeContext;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.NestedObjectMapper;
import org.elasticsearch.index.mapper.ObjectMapper;
import org.elasticsearch.index.mapper.SimpleMappedFieldType;
Expand Down Expand Up @@ -59,7 +58,7 @@
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getRawFieldName;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;

/**
* A {@link FieldMapper} for semantic text fields.
Expand Down Expand Up @@ -131,11 +130,7 @@ protected void merge(FieldMapper mergeWith, Conflicts conflicts, MapperMergeCont
var context = mapperMergeContext.createChildContext(mergeWith.simpleName(), ObjectMapper.Dynamic.FALSE);
var inferenceField = inferenceFieldBuilder.apply(context.getMapperBuilderContext());
var childContext = context.createChildContext(inferenceField.simpleName(), ObjectMapper.Dynamic.FALSE);
var mergedInferenceField = inferenceField.merge(
semanticMergeWith.fieldType().getInferenceField(),
MapperService.MergeReason.MAPPING_UPDATE,
childContext
);
var mergedInferenceField = inferenceField.merge(semanticMergeWith.fieldType().getInferenceField(), childContext);
inferenceFieldBuilder = c -> mergedInferenceField;
}

Expand Down Expand Up @@ -217,7 +212,7 @@ protected void parseCreateField(DocumentParserContext context) throws IOExceptio
context.path().add(simpleName());
}
} else {
SemanticTextFieldMapper.Conflicts conflicts = new Conflicts(fullFieldName);
Conflicts conflicts = new Conflicts(fullFieldName);
canMergeModelSettings(field.inference().modelSettings(), fieldType().getModelSettings(), conflicts);
try {
conflicts.check();
Expand Down Expand Up @@ -311,8 +306,8 @@ public Query termQuery(Object value, SearchExecutionContext context) {

@Override
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
// Redirect the fetcher to load the value from the raw field
return SourceValueFetcher.toString(getRawFieldName(name()), context, format);
// Redirect the fetcher to load the original values of the field
return SourceValueFetcher.toString(getOriginalTextFieldName(name()), context, format);
}

@Override
Expand Down Expand Up @@ -376,7 +371,7 @@ private static Mapper.Builder createEmbeddingsField(IndexVersion indexVersionCre
private static boolean canMergeModelSettings(
SemanticTextField.ModelSettings previous,
SemanticTextField.ModelSettings current,
FieldMapper.Conflicts conflicts
Conflicts conflicts
) {
if (Objects.equals(previous, current)) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE;
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.toChunkedResult;
import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -179,8 +180,8 @@ public void testManyRandomDocs() throws Exception {
BulkItemRequest[] items = bulkShardRequest.items();
assertThat(items.length, equalTo(originalRequests.length));
for (int id = 0; id < items.length; id++) {
IndexRequest actualRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(items[id].request());
IndexRequest expectedRequest = ShardBulkInferenceActionFilter.getIndexRequestOrNull(modifiedRequests[id].request());
IndexRequest actualRequest = getIndexRequestOrNull(items[id].request());
IndexRequest expectedRequest = getIndexRequestOrNull(modifiedRequests[id].request());
try {
assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), actualRequest.getContentType());
} catch (Exception exc) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ protected Predicate<String> getRandomFieldsExcludeFilter() {
@Override
protected void assertEqualInstances(SemanticTextField expectedInstance, SemanticTextField newInstance) {
assertThat(newInstance.fieldName(), equalTo(expectedInstance.fieldName()));
assertThat(newInstance.raw(), equalTo(expectedInstance.raw()));
assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues()));
assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings()));
assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size()));
SemanticTextField.ModelSettings modelSettings = newInstance.inference().modelSettings();
Expand Down
Loading

0 comments on commit 03ab6cc

Please sign in to comment.