Skip to content

Commit

Permalink
fix api change in inference service
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Apr 4, 2024
1 parent 7d03ac0 commit fea9138
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ private void onFinish() {
inferenceProvider.service()
.chunkedInfer(
inferenceProvider.model(),
null,
inputs,
Map.of(),
InputType.INGEST,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public void testManyRandomDocs() throws Exception {
IndexRequest actualRequest = getIndexRequestOrNull(items[id].request());
IndexRequest expectedRequest = getIndexRequestOrNull(modifiedRequests[id].request());
try {
assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), actualRequest.getContentType());
assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), expectedRequest.getContentType());
} catch (Exception exc) {
throw new IllegalStateException(exc);
}
Expand Down Expand Up @@ -228,9 +228,9 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool
InferenceService inferenceService = mock(InferenceService.class);
Answer<?> chunkedInferAnswer = invocationOnMock -> {
StaticModel model = (StaticModel) invocationOnMock.getArguments()[0];
List<String> inputs = (List<String>) invocationOnMock.getArguments()[1];
List<String> inputs = (List<String>) invocationOnMock.getArguments()[2];
ActionListener<List<ChunkedInferenceServiceResults>> listener = (ActionListener<
List<ChunkedInferenceServiceResults>>) invocationOnMock.getArguments()[5];
List<ChunkedInferenceServiceResults>>) invocationOnMock.getArguments()[6];
Runnable runnable = () -> {
List<ChunkedInferenceServiceResults> results = new ArrayList<>();
for (String input : inputs) {
Expand All @@ -249,7 +249,7 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool
}
return null;
};
doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any());
doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any());

Answer<Model> modelAnswer = invocationOnMock -> {
String inferenceId = (String) invocationOnMock.getArguments()[0];
Expand All @@ -270,6 +270,7 @@ private static BulkItemRequest[] randomBulkItemRequest(
) {
Map<String, Object> docMap = new LinkedHashMap<>();
Map<String, Object> expectedDocMap = new LinkedHashMap<>();
XContentType requestContentType = randomFrom(XContentType.values());
for (var entry : fieldInferenceMap.values()) {
String field = entry.getName();
var model = modelMap.get(entry.getInferenceId());
Expand All @@ -280,11 +281,10 @@ private static BulkItemRequest[] randomBulkItemRequest(
// ignore results, the doc should fail with a resource not found exception
continue;
}
var result = randomSemanticText(field, model, List.of(text), randomFrom(XContentType.values()));
var result = randomSemanticText(field, model, List.of(text), requestContentType);
model.putResult(text, result);
expectedDocMap.put(field, result);
}
XContentType requestContentType = randomFrom(XContentType.values());
return new BulkItemRequest[] {
new BulkItemRequest(id, new IndexRequest("index").source(docMap, requestContentType)),
new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap, requestContentType)) };
Expand Down

0 comments on commit fea9138

Please sign in to comment.