Skip to content

Commit

Permalink
use non results fields
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Oct 3, 2023
1 parent d8803b5 commit 2a9395d
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,29 @@ static void writeResult(InferenceResults results, IngestDocument ingestDocument,
}
}

static void writeResultToField(InferenceResults results, IngestDocument ingestDocument, String outputField, String modelId, boolean includeModelId) {
Objects.requireNonNull(results, "results");
Objects.requireNonNull(ingestDocument, "ingestDocument");
Objects.requireNonNull(outputField, "outputField");
Map<String, Object> resultMap = results.asMap();
if (includeModelId) {
resultMap.put(MODEL_ID_RESULTS_FIELD, modelId);
}
Object predictedValue = results.predictedValue();
if (predictedValue != null) {
if (ingestDocument.hasField(outputField)) {
ingestDocument.appendFieldValue(outputField, results.predictedValue());
} else {
ingestDocument.setFieldValue(outputField, resultMap);
}
}
}

String getResultsField();

Map<String, Object> asMap();

Map<String, Object> nonResultFeatures();

Object predictedValue();
}
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,14 @@ public String getResultsField() {

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = new LinkedHashMap<>();
Map<String, Object> map = nonResultFeatures();
map.put(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
return map;
}

@Override
public Map<String, Object> nonResultFeatures() {
Map<String, Object> map = new LinkedHashMap<>();
if (topClasses.isEmpty() == false) {
map.put(topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ public String getResultsField() {

@Override
public Map<String, Object> asMap() {
return nonResultFeatures();
}

@Override
public Map<String, Object> nonResultFeatures() {
Map<String, Object> asMap = new LinkedHashMap<>();
asMap.put(NAME, exception.getMessage());
return asMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ void addMapFields(Map<String, Object> map) {
map.put(resultsField + "_sequence", predictedSequence);
}

@Override
public Map<String, Object> nonResultFeatures() {
var map = super.nonResultFeatures();
map.put(resultsField + "_sequence", predictedSequence); // TODO
return map;
}

@Override
public String getWriteableName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ void addMapFields(Map<String, Object> map) {
map.put(ENTITY_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList()));
}

@Override
public Map<String, Object> nonResultFeatures() {
var map = super.nonResultFeatures();
return Map.of(ENTITY_FIELD, entityGroups.stream().map(EntityGroup::toMap).collect(Collectors.toList()));
}

@Override
public Object predictedValue() {
return annotatedResult;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ public Object predictedValue() {
@Override
void addMapFields(Map<String, Object> map) {
map.put(resultsField, classificationLabel);
addNonResultFeaturesToMap(map);
}

@Override
public Map<String, Object> nonResultFeatures() {
var map = super.nonResultFeatures();
addNonResultFeaturesToMap(map);
return map;
}

private void addNonResultFeaturesToMap(Map<String, Object> map) {
if (topClasses.isEmpty() == false) {
map.put(
NlpConfig.DEFAULT_TOP_CLASSES_RESULTS_FIELD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -65,6 +66,13 @@ public final Map<String, Object> asMap() {
return map;
}

@Override
public Map<String, Object> nonResultFeatures() {
var map = new HashMap<String, Object>();
map.put("is_truncated", isTruncated);
return map;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,18 @@ public String predictedValue() {

@Override
void addMapFields(Map<String, Object> map) {
addNonResultFeaturesToMap(map);
map.put(resultsField, answer);
}

@Override
public Map<String, Object> nonResultFeatures() {
var map = super.nonResultFeatures();
addNonResultFeaturesToMap(map);
return map;
}

private void addNonResultFeaturesToMap(Map<String, Object> map) {
map.put(START_OFFSET.getPreferredName(), startOffset);
map.put(END_OFFSET.getPreferredName(), endOffset);
if (topClasses.isEmpty() == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ public Map<String, Object> asMap() {
throw new UnsupportedOperationException("[raw] does not support map conversion");
}

@Override
public Map<String, Object> nonResultFeatures() {
throw new UnsupportedOperationException("[raw] does not support map conversion of features");
}

@Override
public Object predictedValue() {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,14 @@ public String getResultsField() {

@Override
public Map<String, Object> asMap() {
Map<String, Object> map = nonResultFeatures();
map.put(resultsField, predictedValue());
return map;
}

@Override
public Map<String, Object> nonResultFeatures() {
Map<String, Object> map = new LinkedHashMap<>();
map.put(resultsField, value());
if (featureImportance.isEmpty() == false) {
map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(RegressionFeatureImportance::toMap).collect(Collectors.toList()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ public String getResultsField() {

@Override
public Map<String, Object> asMap() {
return nonResultFeatures();
}

@Override
public Map<String, Object> nonResultFeatures() {
Map<String, Object> asMap = new LinkedHashMap<>();
asMap.put(NAME, warning);
return asMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) {
if (configuredWithInputsFields) {
List<String> requestInputs = new ArrayList<>();
for (var inputFields : inputs) {
var lookup = (String)fields.get(inputFields.inputField);
var lookup = (String) fields.get(inputFields.inputField);
if (lookup == null) {
lookup = ""; // need to send a non-null request to the same number of results back
}
Expand Down Expand Up @@ -249,11 +249,15 @@ void mutateDocument(InferModelAction.Response response, IngestDocument ingestDoc

if (configuredWithInputsFields) {
if (response.getInferenceResults().size() != inputs.size()) {
throw new ElasticsearchStatusException("number of results [{}] does not match the number of inputs [{}]",
RestStatus.INTERNAL_SERVER_ERROR, response.getInferenceResults().size(), inputs.size());
throw new ElasticsearchStatusException(
"number of results [{}] does not match the number of inputs [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
response.getInferenceResults().size(),
inputs.size()
);
}

for (int i=0; i< inputs.size(); i++) {
for (int i = 0; i < inputs.size(); i++) {
InferenceResults.writeResult(
response.getInferenceResults().get(i),
ingestDocument,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.PipelineConfiguration;
import org.elasticsearch.test.ESTestCase;
Expand Down

0 comments on commit 2a9395d

Please sign in to comment.