Skip to content

Commit

Permalink
[ML] Add stream flag to inference providers (elastic#113424)
Browse files Browse the repository at this point in the history
Pass the stream flag from the REST request through to the inference
providers via the InferenceInputs.

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
prwhelan and elasticmachine authored Sep 26, 2024
1 parent d8a3215 commit 1e2c19f
Show file tree
Hide file tree
Showing 29 changed files with 91 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ void parseRequestConfig(
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param input Inference input
* @param stream Stream inference results
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
Expand All @@ -94,6 +95,7 @@ void infer(
Model model,
@Nullable String query,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ public void infer(
Model model,
@Nullable String query,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public void infer(
Model model,
@Nullable String query,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ public void infer(
Model model,
@Nullable String query,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ public void infer(
Model model,
String query,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ private void inferOnService(
model,
request.getQuery(),
request.getInput(),
request.isStreaming(),
request.getTaskSettings(),
request.getInputType(),
request.getInferenceTimeout(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ private static String getStatusCodeErrorMessage(Request request, HttpResult resu
}

public static void checkForEmptyBody(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) {
if (result.isBodyEmpty()) {
if (result.isBodyEmpty() && (request.isStreaming() == false)) {
String message = format("Response body was empty for request from inference entity id [%s]", request.getInferenceEntityId());
throttlerManager.warn(logger, message);
throw new IllegalStateException(message);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,23 @@ public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) {
}

private final List<String> input;
private final boolean stream;

public DocumentsOnlyInput(List<String> chunks) {
public DocumentsOnlyInput(List<String> input) {
this(input, false);
}

public DocumentsOnlyInput(List<String> input, boolean stream) {
super();
this.input = Objects.requireNonNull(chunks);
this.input = Objects.requireNonNull(input);
this.stream = stream;
}

public List<String> getInputs() {
return this.input;
}

public boolean stream() {
return stream;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
}

private final String query;
private final List<String> chunks;
private final boolean stream;

public QueryAndDocsInputs(String query, List<String> chunks) {
this(query, chunks, false);
}

public QueryAndDocsInputs(String query, List<String> chunks, boolean stream) {
super();
this.query = Objects.requireNonNull(query);
this.chunks = Objects.requireNonNull(chunks);
this.stream = stream;
}

public String getQuery() {
return query;
Expand All @@ -30,12 +43,8 @@ public List<String> getChunks() {
return chunks;
}

List<String> chunks;

public QueryAndDocsInputs(String query, List<String> chunks) {
super();
this.query = Objects.requireNonNull(query);
this.chunks = Objects.requireNonNull(chunks);
public boolean stream() {
return stream;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,17 @@ public void infer(
Model model,
@Nullable String query,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
init();
if (query != null) {
doInfer(model, new QueryAndDocsInputs(query, input), taskSettings, inputType, timeout, listener);
doInfer(model, new QueryAndDocsInputs(query, input, stream), taskSettings, inputType, timeout, listener);
} else {
doInfer(model, new DocumentsOnlyInput(input), taskSettings, inputType, timeout, listener);
doInfer(model, new DocumentsOnlyInput(input, stream), taskSettings, inputType, timeout, listener);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ public static void getEmbeddingSize(Model model, InferenceService service, Actio
model,
null,
List.of(TEST_EMBEDDING_INPUT),
false,
Map.of(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ private void checkAlibabaCloudSearchServiceConfig(Model model, InferenceService
model,
query,
List.of(input),
false,
Map.of(),
InputType.INGEST,
DEFAULT_TIMEOUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ public void infer(
Model model,
@Nullable String query,
List<String> input,
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ public void infer(
Model model,
@Nullable String query,
List<String> inputs,
boolean stream,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public void validate(InferenceService service, Model model, ActionListener<Infer
model,
model.getTaskType().equals(TaskType.RERANK) ? QUERY : null,
TEST_INPUT,
false,
Map.of(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -855,12 +856,11 @@ public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingResults_IsEmpty()
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
listener.onResponse(new InferenceTextEmbeddingFloatResults(List.of()));

return Void.TYPE;
}).when(service).infer(any(), any(), any(), any(), any(), any(), any());
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());

PlainActionFuture<Integer> listener = new PlainActionFuture<>();
getEmbeddingSize(model, service, listener);
Expand All @@ -878,12 +878,11 @@ public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingByteResults_IsEmp
when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
listener.onResponse(new InferenceTextEmbeddingByteResults(List.of()));

return Void.TYPE;
}).when(service).infer(any(), any(), any(), any(), any(), any(), any());
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());

PlainActionFuture<Integer> listener = new PlainActionFuture<>();
getEmbeddingSize(model, service, listener);
Expand All @@ -903,12 +902,11 @@ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingResults() {
var textEmbedding = TextEmbeddingResultsTests.createRandomResults();

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
listener.onResponse(textEmbedding);

return Void.TYPE;
}).when(service).infer(any(), any(), any(), any(), any(), any(), any());
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());

PlainActionFuture<Integer> listener = new PlainActionFuture<>();
getEmbeddingSize(model, service, listener);
Expand All @@ -927,12 +925,11 @@ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingByteResults() {
var textEmbedding = InferenceTextEmbeddingByteResultsTests.createRandomResults();

doAnswer(invocation -> {
@SuppressWarnings("unchecked")
ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
listener.onResponse(textEmbedding);

return Void.TYPE;
}).when(service).infer(any(), any(), any(), any(), any(), any(), any());
}).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());

PlainActionFuture<Integer> listener = new PlainActionFuture<>();
getEmbeddingSize(model, service, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc
mockModel,
null,
List.of(""),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -721,6 +722,7 @@ public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException {
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -762,6 +764,7 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException {
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -1025,6 +1028,7 @@ public void testInfer_UnauthorizedResponse() throws IOException {
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException
mockModel,
null,
List.of(""),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -506,6 +507,7 @@ public void testInfer_SendsCompletionRequest() throws IOException {
model,
null,
List.of("input"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc
mockModel,
null,
List.of(""),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -954,6 +955,7 @@ public void testInfer_WithChatCompletionModel() throws IOException {
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -1004,6 +1006,7 @@ public void testInfer_UnauthorisedResponse() throws IOException {
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep
mockModel,
null,
List.of(""),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -656,6 +657,7 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException {
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -1051,6 +1053,7 @@ public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxExcept
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException
mockModel,
null,
List.of(""),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -689,6 +690,7 @@ public void testInfer_SendsRequest() throws IOException {
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -932,6 +934,7 @@ public void testInfer_UnauthorisedResponse() throws IOException {
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -991,6 +994,7 @@ public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsA
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -1064,6 +1068,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs
model,
null,
List.of("abc"),
false,
CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null),
InputType.INGEST,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down Expand Up @@ -1135,6 +1140,7 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec
model,
null,
List.of("abc"),
false,
new HashMap<>(),
InputType.UNSPECIFIED,
InferenceAction.Request.DEFAULT_TIMEOUT,
Expand Down
Loading

0 comments on commit 1e2c19f

Please sign in to comment.