Skip to content

Commit

Permalink
Fix handling of explicit null values for semantic text fields
Browse files Browse the repository at this point in the history
Previously, setting a field explicitly to null in an update request did not work correctly with semantic text fields.
This change resolves the issue by adding an explicit null entry to the `_inference_fields` metadata when such cases occur.

The explicit null value ensures that any prior inference results are overwritten during the merge of the partial update with the latest document version.
  • Loading branch information
jimczi committed Dec 18, 2024
1 parent 9533c7b commit 77c361c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,16 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
for (var entry : response.responses.entrySet()) {
var fieldName = entry.getKey();
var responses = entry.getValue();
if (responses == null) {
if (item.request() instanceof UpdateRequest == false) {
// could be an assert
throw new IllegalArgumentException(
"Inference results can only be cleared for update requests where a field is explicitly set to null."
);
}
inferenceFieldsMap.put(fieldName, null);
continue;
}
var model = responses.get(0).model();
// ensure that the order in the original field is consistent in case of multiple inputs
Collections.sort(responses, Comparator.comparingInt(FieldInferenceResponse::inputOrder));
Expand Down Expand Up @@ -480,17 +490,19 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
}

final Map<String, Object> docMap = indexRequest.sourceAsMap();
Object explicitNull = new Object();
for (var entry : fieldInferenceMap.values()) {
String field = entry.getName();
String inferenceId = entry.getInferenceId();

if (useInferenceMetadataFieldsFormat) {
var inferenceMetadataFieldsValue = XContentMapValues.extractValue(
InferenceMetadataFieldsMapper.NAME + "." + field,
docMap
docMap,
explicitNull
);
if (inferenceMetadataFieldsValue != null) {
// Inference has already been computed
// Inference has already been computed for this source field
continue;
}
} else {
Expand All @@ -503,9 +515,20 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu

int order = 0;
for (var sourceField : entry.getSourceFields()) {
// TODO: Detect when the field is provided with an explicit null value
var valueObj = XContentMapValues.extractValue(sourceField, docMap);
if (valueObj == null) {
var valueObj = XContentMapValues.extractValue(sourceField, docMap, explicitNull);
if (useInferenceMetadataFieldsFormat && isUpdateRequest && valueObj == explicitNull) {
/**
* It's an update request, and the source field is explicitly set to null,
* so we need to propagate this information to the inference fields metadata
* to overwrite any inference previously computed on the field.
* This ensures that the field is treated as intentionally cleared,
* preventing any unintended carryover of prior inference results.
*/
var slot = ensureResponseAccumulatorSlot(itemIndex);
slot.responses.put(sourceField, null);
continue;
}
if (valueObj == null || valueObj == explicitNull) {
if (isUpdateRequest && (useInferenceMetadataFieldsFormat == false)) {
addInferenceResponseFailure(
item.id(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.ActionFilterChain;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexMetadata;
Expand Down Expand Up @@ -212,6 +213,11 @@ public void testItemFailures() throws Exception {
),
equalTo("I am a success")
);
if (useInferenceMetadataFieldsFormat) {
assertNotNull(
XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME + ".field1", actualRequest.sourceAsMap())
);
}

// item 2 is a failure
assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse());
Expand Down Expand Up @@ -239,6 +245,85 @@ public void testItemFailures() throws Exception {
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testExplicitNull() throws Exception {
StaticModel model = StaticModel.createRandomInstance();

ShardBulkInferenceActionFilter filter = createFilter(
threadPool,
Map.of(model.getInferenceEntityId(), model),
randomIntBetween(1, 10),
IndexVersion.current()
);
model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success")));
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertNull(bulkShardRequest.getInferenceFieldMap());
assertThat(bulkShardRequest.items().length, equalTo(4));

Object explicitNull = new Object();
// item 0
assertNull(bulkShardRequest.items()[0].getPrimaryResponse());
IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[0].request());
assertTrue(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull);
assertNull(XContentMapValues.extractValue(InferenceMetadataFieldsMapper.NAME, actualRequest.sourceAsMap(), explicitNull));

// item 1 is a success
assertNull(bulkShardRequest.items()[1].getPrimaryResponse());
actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request());
assertThat(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap()), equalTo("I am a success"));
assertNotNull(
XContentMapValues.extractValue(
InferenceMetadataFieldsMapper.NAME + ".field1",
actualRequest.sourceAsMap(),
explicitNull
)
);

// item 2 is a failure
assertNotNull(bulkShardRequest.items()[2].getPrimaryResponse());
assertTrue(bulkShardRequest.items()[2].getPrimaryResponse().isFailed());
var failure = bulkShardRequest.items()[2].getPrimaryResponse().getFailure();
assertThat(failure.getCause().getCause().getMessage(), containsString("boom"));

// item 3
assertNull(bulkShardRequest.items()[3].getPrimaryResponse());
actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[3].request());
assertTrue(XContentMapValues.extractValue("field1", actualRequest.sourceAsMap(), explicitNull) == explicitNull);
assertTrue(
XContentMapValues.extractValue(
InferenceMetadataFieldsMapper.NAME + ".field1",
actualRequest.sourceAsMap(),
explicitNull
) == explicitNull
);
} finally {
chainExecuted.countDown();
}
};
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);

Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
"field1",
new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" })
);
BulkItemRequest[] items = new BulkItemRequest[4];
Map<String, Object> sourceWithNull = new HashMap<>();
sourceWithNull.put("field1", null);
items[0] = new BulkItemRequest(0, new IndexRequest("index").source(sourceWithNull));
items[1] = new BulkItemRequest(1, new IndexRequest("index").source("field1", "I am a success"));
items[2] = new BulkItemRequest(2, new IndexRequest("index").source("field1", "I am a failure"));
items[3] = new BulkItemRequest(3, new UpdateRequest().doc(new IndexRequest("index").source(sourceWithNull)));
BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
request.setInferenceFieldMap(inferenceFieldMap);
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}

@SuppressWarnings({ "unchecked", "rawtypes" })
public void testManyRandomDocs() throws Exception {
IndexVersion indexVersion = getRandomIndexVersion();
Expand Down

0 comments on commit 77c361c

Please sign in to comment.