Skip to content

Commit

Permalink
Backport missing PR(#1443) to enable bwc in 2.10 (#2090)
Browse files Browse the repository at this point in the history
* add status code to model tensor (#1443)

* add status code to model tensor

Signed-off-by: Yaliang Wu <[email protected]>

* fix ut

Signed-off-by: Yaliang Wu <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>

* fix cherrypick conflict

Signed-off-by: Sicheng Song <[email protected]>

---------

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>
Co-authored-by: Yaliang Wu <[email protected]>
  • Loading branch information
b4sjoo and ylwu-amzn authored Feb 19, 2024
1 parent dedaefc commit 4013820
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
Expand All @@ -24,7 +25,10 @@
@Getter
public class ModelTensors implements Writeable, ToXContentObject {
public static final String OUTPUT_FIELD = "output";
public static final String STATUS_CODE_FIELD = "status_code";
private List<ModelTensor> mlModelTensors;
@Setter
private Integer statusCode;

@Builder
public ModelTensors(List<ModelTensor> mlModelTensors) {
Expand All @@ -41,6 +45,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
builder.endArray();
}
if (statusCode != null) {
builder.field(STATUS_CODE_FIELD, statusCode);
}
builder.endObject();
return builder;
}
Expand All @@ -53,6 +60,7 @@ public ModelTensors(StreamInput in) throws IOException {
mlModelTensors.add(new ModelTensor(in));
}
}
statusCode = in.readOptionalInt();
}

@Override
Expand All @@ -66,6 +74,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeOptionalInt(statusCode);
}

public void filter(ModelResultFilter resultFilter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
HttpExecuteResponse response = AccessController.doPrivileged((PrivilegedExceptionAction<HttpExecuteResponse>) () -> {
return httpClient.prepareRequest(executeRequest).call();
});
int statusCode = response.httpResponse().statusCode();

AbortableInputStream body = null;
if (response.responseBody().isPresent()) {
Expand All @@ -102,6 +103,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
String modelResponse = responseBuilder.toString();

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCode);
tensorOutputs.add(tensors);
} catch (RuntimeException exception) {
log.error("Failed to execute predict in aws connector: " + exception.getMessage(), exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public HttpJsonConnectorExecutor(Connector connector) {
public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
try {
AtomicReference<String> responseRef = new AtomicReference<>("");
AtomicReference<Integer> statusCodeRef = new AtomicReference<>();

HttpUriRequest request;
switch (connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
Expand Down Expand Up @@ -97,12 +98,14 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
String responseBody = EntityUtils.toString(responseEntity);
EntityUtils.consume(responseEntity);
responseRef.set(responseBody);
statusCodeRef.set(response.getStatusLine().getStatusCode());
}
return null;
});
String modelResponse = responseRef.get();

ModelTensors tensors = processOutput(modelResponse, connector, scriptService, parameters);
tensors.setStatusCode(statusCodeRef.get());
tensorOutputs.add(tensors);
} catch (RuntimeException e) {
log.error("Fail to execute http connector", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import software.amazon.awssdk.http.ExecutableHttpRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpResponse;

import java.io.ByteArrayInputStream;
import java.io.IOException;
Expand All @@ -38,6 +39,7 @@
import java.util.Optional;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD;
Expand Down Expand Up @@ -89,6 +91,9 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
exceptionRule.expectMessage("No response from model");
when(response.responseBody()).thenReturn(Optional.empty());
when(httpRequest.call()).thenReturn(response);
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction.builder()
Expand All @@ -113,6 +118,9 @@ public void executePredict_RemoteInferenceInput() throws IOException {
InputStream inputStream = new ByteArrayInputStream(jsonString.getBytes());
AbortableInputStream abortableInputStream = AbortableInputStream.create(inputStream);
when(response.responseBody()).thenReturn(Optional.of(abortableInputStream));
SdkHttpResponse httpResponse = mock(SdkHttpResponse.class);
when(httpResponse.statusCode()).thenReturn(200);
when(response.httpResponse()).thenReturn(httpResponse);
when(httpRequest.call()).thenReturn(response);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

Expand All @@ -136,4 +144,4 @@ public void executePredict_RemoteInferenceInput() throws IOException {
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size());
Assert.assertEquals("value", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().get("key"));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@

import com.google.common.collect.ImmutableMap;
import org.apache.http.HttpEntity;
import org.apache.http.ProtocolVersion;
import org.apache.http.StatusLine;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.message.BasicStatusLine;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.cluster.ClusterStateTaskConfig;
import org.opensearch.ingest.TestTemplateService;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
Expand All @@ -32,6 +36,7 @@
import org.opensearch.script.ScriptService;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;

import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -84,6 +89,8 @@ public void executePredict_RemoteInferenceInput() throws IOException {
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
when(executor.getHttpClient()).thenReturn(httpClient);
MLInputDataset inputDataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", "test input data")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build());
Expand All @@ -94,7 +101,7 @@ public void executePredict_RemoteInferenceInput() throws IOException {
}

@Test
public void executePredict_TextDocsInput_NoPreprocessFunction() {
public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOException {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("Must provide pre_process_function for predict action to process text docs input.");
ConnectorAction predictAction = ConnectorAction.builder()
Expand All @@ -103,6 +110,11 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() {
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
when(httpClient.execute(any())).thenReturn(response);
HttpEntity entity = new StringEntity("{\"response\": \"test result\"}");
when(response.getEntity()).thenReturn(entity);
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build();
HttpJsonConnectorExecutor executor = new HttpJsonConnectorExecutor(connector);
MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build();
Expand Down Expand Up @@ -133,7 +145,16 @@ public void executePredict_TextDocsInput() throws IOException {
HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector));
executor.setScriptService(scriptService);
when(httpClient.execute(any())).thenReturn(response);
String modelResponse = "{\"object\":\"list\",\"data\":[{\"object\":\"embedding\",\"index\":0,\"embedding\":[-0.014555434,-0.0002135904,0.0035105038]}],\"model\":\"text-embedding-ada-002-v2\",\"usage\":{\"prompt_tokens\":5,\"total_tokens\":5}}";
String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n"
+ " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n"
+ " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n"
+ " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n"
+ " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n"
+ " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n"
+ " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n"
+ " \"total_tokens\": 5\n" + " }\n" + "}";
StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK");
when(response.getStatusLine()).thenReturn(statusLine);
HttpEntity entity = new StringEntity(modelResponse);
when(response.getEntity()).thenReturn(entity);
when(executor.getHttpClient()).thenReturn(httpClient);
Expand Down

0 comments on commit 4013820

Please sign in to comment.