From e2f667a73f71c2389a7a2e072b8048ec12d4732d Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Thu, 16 Nov 2023 21:41:14 +0530 Subject: [PATCH] fix format violations Signed-off-by: Bhavana Ramaram --- .../opensearch/ml/common/MLModelGroup.java | 1 - .../MetricsCorrelation.java | 15 +- .../remote/AwsConnectorExecutor.java | 10 +- .../algorithms/remote/ConnectorUtils.java | 4 +- .../remote/HttpJsonConnectorExecutor.java | 8 +- .../remote/RemoteConnectorExecutor.java | 11 +- .../MetricsCorrelationTest.java | 204 +++++++++--------- .../remote/AwsConnectorExecutorTest.java | 65 +++--- .../remote/HttpJsonConnectorExecutorTest.java | 72 +++++-- 9 files changed, 222 insertions(+), 168 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java index 718f180636..0b9143f8cd 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModelGroup.java @@ -41,7 +41,6 @@ public class MLModelGroup implements ToXContentObject { @Setter private String name; private String description; - @Setter private int latestVersion; private List backendRoles; private User owner; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index 1c96a048fc..ec2fc1d141 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -175,9 +175,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { ) ); } - }, e-> { - log.error("Failed to get model", e); - }); + }, e -> { log.error("Failed to get model", e); }); client.get(getModelRequest, ActionListener.runBefore(listener, context::restore)); } } @@ -199,12 +197,19 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { waitUntil(() -> { if (modelId != null) { MLModelState modelState = getModel(modelId).getModelState(); - if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED){ + if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED) { log.info("Model deployed: " + modelState); return true; } else if (modelState == MLModelState.UNDEPLOYED || modelState == MLModelState.DEPLOY_FAILED) { log.info("Model not deployed: " + modelState); - deployModel(modelId, ActionListener.wrap(deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(), e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e))); + deployModel( + modelId, + ActionListener + .wrap( + deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(), + e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e) + ) + ); return false; } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index d53db7caca..1472e9bbc9 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -5,6 +5,11 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; +import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; +import static software.amazon.awssdk.http.SdkHttpMethod.POST; + import java.io.BufferedReader; import java.io.InputStreamReader; import java.net.URI; @@ -35,11 +40,6 @@ import software.amazon.awssdk.http.SdkHttpClient; import software.amazon.awssdk.http.SdkHttpFullRequest; -import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; -import static org.opensearch.ml.common.connector.ConnectorProtocols.AWS_SIGV4; -import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; -import static software.amazon.awssdk.http.SdkHttpMethod.POST; - @Log4j2 @ConnectorExecutor(AWS_SIGV4) public class AwsConnectorExecutor implements RemoteConnectorExecutor { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 704d6e3e05..88c43a969e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -186,7 +186,9 @@ public static ModelTensors processOutput( // execute user defined painless script. Optional processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); String response = processedResponse.orElse(modelResponse); - boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && org.opensearch.ml.common.utils.StringUtils.isJson(response); + boolean scriptReturnModelTensor = postProcessFunction != null + && processedResponse.isPresent() + && org.opensearch.ml.common.utils.StringUtils.isJson(response); if (responseFilter == null) { connector.parseResponse(response, modelTensors, scriptReturnModelTensor); } else { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index b92015bd3a..d08b2186ae 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -5,6 +5,10 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; +import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; +import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; + import java.security.AccessController; import java.security.PrivilegedExceptionAction; import java.util.List; @@ -35,10 +39,6 @@ import lombok.Setter; import lombok.extern.log4j.Log4j2; -import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; -import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; -import static org.opensearch.ml.engine.algorithms.remote.ConnectorUtils.processOutput; - @Log4j2 @ConnectorExecutor(HTTP) public class HttpJsonConnectorExecutor implements RemoteConnectorExecutor { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index c26c79b452..aa471ba3fe 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -33,10 +33,17 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); int processedDocs = 0; - while(processedDocs < textDocsInputDataSet.getDocs().size()) { + while (processedDocs < textDocsInputDataSet.getDocs().size()) { List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); List tempTensorOutputs = new ArrayList<>(); - preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs); + preparePayloadAndInvokeRemoteModel( + MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) + .build(), + tempTensorOutputs + ); int tensorCount = 0; if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index 36a1fe9998..223cb22289 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -5,7 +5,41 @@ package org.opensearch.ml.engine.algorithms.metrics_correlation; -import com.google.common.collect.ImmutableMap; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE; +import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_HELPER; +import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_ZIP_FILE; +import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MCORR_ML_VERSION; +import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MODEL_CONTENT_HASH; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Ignore; @@ -15,6 +49,10 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.Version; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; @@ -22,17 +60,13 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.action.ActionFuture; +import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.ShardSearchFailure; -import org.opensearch.client.Client; -import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.xcontent.XContentBuilder; @@ -81,41 +115,7 @@ import org.opensearch.search.suggest.Suggest; import org.opensearch.threadpool.ThreadPool; -import java.io.File; -import java.io.IOException; -import java.net.URISyntaxException; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.UUID; -import java.util.concurrent.atomic.AtomicInteger; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.anyLong; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE; -import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_HELPER; -import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_ZIP_FILE; -import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MCORR_ML_VERSION; -import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MODEL_CONTENT_HASH; - +import com.google.common.collect.ImmutableMap; public class MetricsCorrelationTest { @Rule @@ -295,7 +295,6 @@ public void setUp() throws IOException, URISyntaxException { extendedInput = MetricsCorrelationInput.builder().inputData(extendedInputData).build(); } - @Ignore @Test public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteException { @@ -306,16 +305,17 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio doAnswer(invocation -> { - MLModel smallModel = MLModel.builder() - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .name(FunctionName.METRICS_CORRELATION.name()) - .modelId(modelId) - .modelGroupId(modelGroupId) - .algorithm(FunctionName.METRICS_CORRELATION) - .version(MCORR_ML_VERSION) - .modelConfig(modelConfig) - .modelState(MLModelState.UNDEPLOYED) - .build(); + MLModel smallModel = MLModel + .builder() + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .name(FunctionName.METRICS_CORRELATION.name()) + .modelId(modelId) + .modelGroupId(modelGroupId) + .algorithm(FunctionName.METRICS_CORRELATION) + .version(MCORR_ML_VERSION) + .modelConfig(modelConfig) + .modelState(MLModelState.UNDEPLOYED) + .build(); MLModelGetResponse responseTemp = new MLModelGetResponse(smallModel); ActionFuture mockedFutureTemp = mock(ActionFuture.class); MLTaskGetResponse taskResponse = new MLTaskGetResponse(mlTask); @@ -334,7 +334,6 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio assertNull(mlModelOutputs.get(0).getMCorrModelTensors()); } - @Ignore @Test public void testExecuteWithModelInIndexAndEmptyOutput() throws ExecuteException, URISyntaxException { @@ -437,7 +436,6 @@ public void testExecuteWithNoModelIndexAndOneEvent() throws ExecuteException, UR assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } - @Ignore @Test public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws ExecuteException, URISyntaxException { @@ -485,7 +483,6 @@ public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws Execu assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } - @Ignore @Test public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException, URISyntaxException { @@ -528,8 +525,7 @@ public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException, assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } - - //working + // working @Test public void testGetModel() { ActionFuture mockedFuture = mock(ActionFuture.class); @@ -562,7 +558,7 @@ public static XContentBuilder builder() throws IOException { return XContentBuilder.builder(XContentType.JSON.xContent()); } - //working + // working @Test public void testSearchRequest() { String expectedIndex = CommonValue.ML_MODEL_INDEX; @@ -601,7 +597,6 @@ public void testSearchRequest() { assertEquals(MLModel.MODEL_VERSION_FIELD, versionQueryBuilder.fieldName()); } - @Ignore @Test public void testRegisterModel() throws InterruptedException { @@ -773,49 +768,56 @@ public static ClusterState setupTestClusterState() { Set roleSet = new HashSet<>(); roleSet.add(DiscoveryNodeRole.DATA_ROLE); DiscoveryNode node = new DiscoveryNode( - "node", - new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), - new HashMap<>(), - roleSet, - Version.CURRENT + "node", + new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), + new HashMap<>(), + roleSet, + Version.CURRENT ); Metadata metadata = new Metadata.Builder() - .indices( - ImmutableMap - .builder() - .put( - ML_MODEL_INDEX, - IndexMetadata - .builder("test") - .settings( - Settings - .builder() - .put("index.number_of_shards", 1) - .put("index.number_of_replicas", 1) - .put("index.version.created", Version.CURRENT.id) - ) - .build() - ) - .put(ML_MODEL_GROUP_INDEX, IndexMetadata.builder(ML_MODEL_GROUP_INDEX) - .settings(Settings.builder() - .put("index.number_of_shards", 1) - .put("index.number_of_replicas", 1) - .put("index.version.created", Version.CURRENT.id)) - .build()) - .build() - ) - .build(); + .indices( + ImmutableMap + .builder() + .put( + ML_MODEL_INDEX, + IndexMetadata + .builder("test") + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ) + .put( + ML_MODEL_GROUP_INDEX, + IndexMetadata + .builder(ML_MODEL_GROUP_INDEX) + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ) + .build() + ) + .build(); return new ClusterState( - new ClusterName("test cluster"), - 123l, - "111111", - metadata, - null, - DiscoveryNodes.builder().add(node).build(), - null, - Map.of(), - 0, - false + new ClusterName("test cluster"), + 123l, + "111111", + metadata, + null, + DiscoveryNodes.builder().add(node).build(), + null, + Map.of(), + 0, + false ); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index f3bdbf0644..b35f9b0eac 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -5,12 +5,22 @@ package org.opensearch.ml.engine.algorithms.remote; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; +import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; +import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.Arrays; +import java.util.Map; +import java.util.Optional; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import org.apache.http.ProtocolVersion; -import org.apache.http.StatusLine; -import org.apache.http.message.BasicStatusLine; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -42,22 +52,6 @@ import software.amazon.awssdk.http.SdkHttpClient; import software.amazon.awssdk.http.SdkHttpResponse; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.util.Arrays; -import java.util.Map; -import java.util.Optional; - -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.when; -import static org.opensearch.ml.common.connector.AbstractConnector.ACCESS_KEY_FIELD; -import static org.opensearch.ml.common.connector.AbstractConnector.SECRET_KEY_FIELD; -import static org.opensearch.ml.common.connector.HttpConnector.REGION_FIELD; -import static org.opensearch.ml.common.connector.HttpConnector.SERVICE_NAME_FIELD; - public class AwsConnectorExecutorTest { @Rule @@ -154,15 +148,25 @@ public void executePredict_RemoteInferenceInput_InvalidToken() throws IOExceptio when(response.httpResponse()).thenReturn(httpResponse); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); - Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); @@ -250,7 +254,8 @@ public void executePredict_TextDocsInferenceInput() throws IOException { AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); - ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 6666628d19..e91110b74e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -27,7 +27,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; -import org.opensearch.cluster.ClusterStateTaskConfig; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; @@ -134,11 +133,18 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti when(response.getEntity()).thenReturn(entity); StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); when(response.getStatusLine()).thenReturn(statusLine); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor + .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); @@ -153,18 +159,25 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti public void executePredict_TextDocsInput_LimitExceed() throws IOException { exceptionRule.expect(OpenSearchStatusException.class); exceptionRule.expectMessage("{\"message\": \"Too many requests\"}"); - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": ${parameters.input}}") - .build(); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}"); when(response.getEntity()).thenReturn(entity); StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK"); when(response.getStatusLine()).thenReturn(statusLine); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); @@ -198,21 +211,42 @@ public void executePredict_TextDocsInput() throws IOException { HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); - String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" - + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" - + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" - + " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n" - + " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n" - + " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n" - + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" - + " \"total_tokens\": 5\n" + " }\n" + "}"; + String modelResponse = "{\n" + + " \"object\": \"list\",\n" + + " \"data\": [\n" + + " {\n" + + " \"object\": \"embedding\",\n" + + " \"index\": 0,\n" + + " \"embedding\": [\n" + + " -0.014555434,\n" + + " -0.002135904,\n" + + " 0.0035105038\n" + + " ]\n" + + " },\n" + + " {\n" + + " \"object\": \"embedding\",\n" + + " \"index\": 1,\n" + + " \"embedding\": [\n" + + " -0.014555434,\n" + + " -0.002135904,\n" + + " 0.0035105038\n" + + " ]\n" + + " }\n" + + " ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + + " \"usage\": {\n" + + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + + " }\n" + + "}"; StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); when(response.getEntity()).thenReturn(entity); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor + .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); Assert.assertEquals("sentence_embedding", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());