From 5a1b2b9456e8ca480013f6553ad2039201c223cf Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Mon, 20 Nov 2023 22:33:18 +0530 Subject: [PATCH] fix format violations Signed-off-by: Bhavana Ramaram --- .../MetricsCorrelation.java | 49 ++--- .../algorithms/remote/ConnectorUtils.java | 4 +- .../remote/RemoteConnectorExecutor.java | 11 +- .../ml/engine/encryptor/EncryptorImpl.java | 2 +- .../opensearch/ml/engine/MLEngineTest.java | 2 +- .../MetricsCorrelationTest.java | 187 +++++++++--------- .../remote/AwsConnectorExecutorTest.java | 29 ++- .../remote/HttpJsonConnectorExecutorTest.java | 69 +++++-- .../GetConnectorTransportAction.java | 15 +- 9 files changed, 215 insertions(+), 153 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index 1c96a048fc..d41f0c41dc 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -5,20 +5,10 @@ package org.opensearch.ml.engine.algorithms.metrics_correlation; -import static org.opensearch.index.query.QueryBuilders.termQuery; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_MAPPING; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.common.MLModel.MODEL_STATE_FIELD; - -import java.io.IOException; -import java.time.Instant; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.function.BooleanSupplier; - +import ai.djl.modality.Output; +import ai.djl.translate.TranslateException; +import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.get.GetRequest; @@ -70,11 +60,19 @@ import org.opensearch.ml.engine.annotation.Function; import org.opensearch.search.builder.SearchSourceBuilder; -import com.google.common.annotations.VisibleForTesting; +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.function.BooleanSupplier; -import ai.djl.modality.Output; -import ai.djl.translate.TranslateException; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.index.query.QueryBuilders.termQuery; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.MLModel.MODEL_STATE_FIELD; @Log4j2 @Function(FunctionName.METRICS_CORRELATION) @@ -175,9 +173,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { ) ); } - }, e-> { - log.error("Failed to get model", e); - }); + }, e -> { log.error("Failed to get model", e); }); client.get(getModelRequest, ActionListener.runBefore(listener, context::restore)); } } @@ -199,12 +195,19 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException { waitUntil(() -> { if (modelId != null) { MLModelState modelState = getModel(modelId).getModelState(); - if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED){ + if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED) { log.info("Model deployed: " + modelState); return true; } else if (modelState == MLModelState.UNDEPLOYED || modelState == MLModelState.DEPLOY_FAILED) { log.info("Model not deployed: " + modelState); - deployModel(modelId, ActionListener.wrap(deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(), e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e))); + deployModel( + modelId, + ActionListener + .wrap( + deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(), + e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e) + ) + ); return false; } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 704d6e3e05..88c43a969e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -186,7 +186,9 @@ public static ModelTensors processOutput( // execute user defined painless script. Optional processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse); String response = processedResponse.orElse(modelResponse); - boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && org.opensearch.ml.common.utils.StringUtils.isJson(response); + boolean scriptReturnModelTensor = postProcessFunction != null + && processedResponse.isPresent() + && org.opensearch.ml.common.utils.StringUtils.isJson(response); if (responseFilter == null) { connector.parseResponse(response, modelTensors, scriptReturnModelTensor); } else { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index c26c79b452..aa471ba3fe 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -33,10 +33,17 @@ default ModelTensorOutput executePredict(MLInput mlInput) { if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) { TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset(); int processedDocs = 0; - while(processedDocs < textDocsInputDataSet.getDocs().size()) { + while (processedDocs < textDocsInputDataSet.getDocs().size()) { List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size()); List tempTensorOutputs = new ArrayList<>(); - preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs); + preparePayloadAndInvokeRemoteModel( + MLInput + .builder() + .algorithm(FunctionName.TEXT_EMBEDDING) + .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) + .build(), + tempTensorOutputs + ); int tensorCount = 0; if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) { tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index 1c23ae4043..b500709bc5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -18,13 +18,13 @@ import javax.crypto.spec.SecretKeySpec; import org.opensearch.ResourceNotFoundException; -import org.opensearch.core.action.ActionListener; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.get.GetRequest; import org.opensearch.action.get.GetResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.exception.MLException; import com.amazonaws.encryptionsdk.AwsCrypto; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 609777dcf7..11f0c207e6 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -327,4 +327,4 @@ private MLModel trainLinearRegressionModel() { return mlEngine.train(mlInput); } -} \ No newline at end of file +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index d3b8003756..223cb22289 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -5,7 +5,41 @@ package org.opensearch.ml.engine.algorithms.metrics_correlation; -import com.google.common.collect.ImmutableMap; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE; +import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_HELPER; +import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_ZIP_FILE; +import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MCORR_ML_VERSION; +import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MODEL_CONTENT_HASH; + +import java.io.File; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Ignore; @@ -81,41 +115,7 @@ import org.opensearch.search.suggest.Suggest; import org.opensearch.threadpool.ThreadPool; -import java.io.File; -import java.io.IOException; -import java.net.URISyntaxException; -import java.nio.file.Path; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.UUID; -import java.util.concurrent.atomic.AtomicInteger; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.anyLong; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX; -import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; -import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE; -import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_HELPER; -import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_ZIP_FILE; -import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MCORR_ML_VERSION; -import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MODEL_CONTENT_HASH; - +import com.google.common.collect.ImmutableMap; public class MetricsCorrelationTest { @Rule @@ -295,7 +295,6 @@ public void setUp() throws IOException, URISyntaxException { extendedInput = MetricsCorrelationInput.builder().inputData(extendedInputData).build(); } - @Ignore @Test public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteException { @@ -306,16 +305,17 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio doAnswer(invocation -> { - MLModel smallModel = MLModel.builder() - .modelFormat(MLModelFormat.TORCH_SCRIPT) - .name(FunctionName.METRICS_CORRELATION.name()) - .modelId(modelId) - .modelGroupId(modelGroupId) - .algorithm(FunctionName.METRICS_CORRELATION) - .version(MCORR_ML_VERSION) - .modelConfig(modelConfig) - .modelState(MLModelState.UNDEPLOYED) - .build(); + MLModel smallModel = MLModel + .builder() + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .name(FunctionName.METRICS_CORRELATION.name()) + .modelId(modelId) + .modelGroupId(modelGroupId) + .algorithm(FunctionName.METRICS_CORRELATION) + .version(MCORR_ML_VERSION) + .modelConfig(modelConfig) + .modelState(MLModelState.UNDEPLOYED) + .build(); MLModelGetResponse responseTemp = new MLModelGetResponse(smallModel); ActionFuture mockedFutureTemp = mock(ActionFuture.class); MLTaskGetResponse taskResponse = new MLTaskGetResponse(mlTask); @@ -436,7 +436,6 @@ public void testExecuteWithNoModelIndexAndOneEvent() throws ExecuteException, UR assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } - @Ignore @Test public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws ExecuteException, URISyntaxException { @@ -526,8 +525,7 @@ public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException, assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics()); } - - //working + // working @Test public void testGetModel() { ActionFuture mockedFuture = mock(ActionFuture.class); @@ -560,7 +558,7 @@ public static XContentBuilder builder() throws IOException { return XContentBuilder.builder(XContentType.JSON.xContent()); } - //working + // working @Test public void testSearchRequest() { String expectedIndex = CommonValue.ML_MODEL_INDEX; @@ -770,49 +768,56 @@ public static ClusterState setupTestClusterState() { Set roleSet = new HashSet<>(); roleSet.add(DiscoveryNodeRole.DATA_ROLE); DiscoveryNode node = new DiscoveryNode( - "node", - new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), - new HashMap<>(), - roleSet, - Version.CURRENT + "node", + new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()), + new HashMap<>(), + roleSet, + Version.CURRENT ); Metadata metadata = new Metadata.Builder() - .indices( - ImmutableMap - .builder() - .put( - ML_MODEL_INDEX, - IndexMetadata - .builder("test") - .settings( - Settings - .builder() - .put("index.number_of_shards", 1) - .put("index.number_of_replicas", 1) - .put("index.version.created", Version.CURRENT.id) - ) - .build() - ) - .put(ML_MODEL_GROUP_INDEX, IndexMetadata.builder(ML_MODEL_GROUP_INDEX) - .settings(Settings.builder() - .put("index.number_of_shards", 1) - .put("index.number_of_replicas", 1) - .put("index.version.created", Version.CURRENT.id)) - .build()) - .build() - ) - .build(); + .indices( + ImmutableMap + .builder() + .put( + ML_MODEL_INDEX, + IndexMetadata + .builder("test") + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ) + .put( + ML_MODEL_GROUP_INDEX, + IndexMetadata + .builder(ML_MODEL_GROUP_INDEX) + .settings( + Settings + .builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id) + ) + .build() + ) + .build() + ) + .build(); return new ClusterState( - new ClusterName("test cluster"), - 123l, - "111111", - metadata, - null, - DiscoveryNodes.builder().add(node).build(), - null, - Map.of(), - 0, - false + new ClusterName("test cluster"), + 123l, + "111111", + metadata, + null, + DiscoveryNodes.builder().add(node).build(), + null, + Map.of(), + 0, + false ); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 70b45fe309..b35f9b0eac 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -108,15 +108,25 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio when(response.httpResponse()).thenReturn(httpResponse); when(httpClient.prepareRequest(any())).thenReturn(httpRequest); - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": \"${parameters.input}\"}") - .build(); - Map credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": \"${parameters.input}\"}") + .build(); + Map credential = ImmutableMap + .of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key")); Map parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker"); - Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build(); + Connector connector = AwsConnector + .awsConnectorBuilder() + .name("test connector") + .version("1") + .protocol("http") + .parameters(parameters) + .credential(credential) + .actions(Arrays.asList(predictAction)) + .build(); connector.decrypt((c) -> encryptor.decrypt(c)); AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); @@ -244,7 +254,8 @@ public void executePredict_TextDocsInferenceInput() throws IOException { AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient)); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build(); - ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor + .executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size()); Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index d62a684ce3..e91110b74e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -27,7 +27,6 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchStatusException; -import org.opensearch.cluster.ClusterStateTaskConfig; import org.opensearch.ingest.TestTemplateService; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.connector.Connector; @@ -134,11 +133,18 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti when(response.getEntity()).thenReturn(entity); StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); when(response.getStatusLine()).thenReturn(statusLine); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); - ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); + ModelTensorOutput modelTensorOutput = executor + .executePredict(MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build()); Assert.assertEquals(2, modelTensorOutput.getMlModelOutputs().size()); Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName()); Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap().size()); @@ -153,18 +159,25 @@ public void executePredict_TextDocsInput_NoPreprocessFunction() throws IOExcepti public void executePredict_TextDocsInput_LimitExceed() throws IOException { exceptionRule.expect(OpenSearchStatusException.class); exceptionRule.expectMessage("{\"message\": \"Too many requests\"}"); - ConnectorAction predictAction = ConnectorAction.builder() - .actionType(ConnectorAction.ActionType.PREDICT) - .method("POST") - .url("http://test.com/mock") - .requestBody("{\"input\": ${parameters.input}}") - .build(); + ConnectorAction predictAction = ConnectorAction + .builder() + .actionType(ConnectorAction.ActionType.PREDICT) + .method("POST") + .url("http://test.com/mock") + .requestBody("{\"input\": ${parameters.input}}") + .build(); when(httpClient.execute(any())).thenReturn(response); HttpEntity entity = new StringEntity("{\"message\": \"Too many requests\"}"); when(response.getEntity()).thenReturn(entity); StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 429, "OK"); when(response.getStatusLine()).thenReturn(statusLine); - Connector connector = HttpConnector.builder().name("test connector").version("1").protocol("http").actions(Arrays.asList(predictAction)).build(); + Connector connector = HttpConnector + .builder() + .name("test connector") + .version("1") + .protocol("http") + .actions(Arrays.asList(predictAction)) + .build(); HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); when(executor.getHttpClient()).thenReturn(httpClient); MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(Arrays.asList("test doc1", "test doc2")).build(); @@ -198,14 +211,34 @@ public void executePredict_TextDocsInput() throws IOException { HttpJsonConnectorExecutor executor = spy(new HttpJsonConnectorExecutor(connector)); executor.setScriptService(scriptService); when(httpClient.execute(any())).thenReturn(response); - String modelResponse = "{\n" + " \"object\": \"list\",\n" + " \"data\": [\n" + " {\n" - + " \"object\": \"embedding\",\n" + " \"index\": 0,\n" + " \"embedding\": [\n" - + " -0.014555434,\n" + " -0.002135904,\n" + " 0.0035105038\n" + " ]\n" - + " },\n" + " {\n" + " \"object\": \"embedding\",\n" + " \"index\": 1,\n" - + " \"embedding\": [\n" + " -0.014555434,\n" + " -0.002135904,\n" - + " 0.0035105038\n" + " ]\n" + " }\n" + " ],\n" - + " \"model\": \"text-embedding-ada-002-v2\",\n" + " \"usage\": {\n" + " \"prompt_tokens\": 5,\n" - + " \"total_tokens\": 5\n" + " }\n" + "}"; + String modelResponse = "{\n" + + " \"object\": \"list\",\n" + + " \"data\": [\n" + + " {\n" + + " \"object\": \"embedding\",\n" + + " \"index\": 0,\n" + + " \"embedding\": [\n" + + " -0.014555434,\n" + + " -0.002135904,\n" + + " 0.0035105038\n" + + " ]\n" + + " },\n" + + " {\n" + + " \"object\": \"embedding\",\n" + + " \"index\": 1,\n" + + " \"embedding\": [\n" + + " -0.014555434,\n" + + " -0.002135904,\n" + + " 0.0035105038\n" + + " ]\n" + + " }\n" + + " ],\n" + + " \"model\": \"text-embedding-ada-002-v2\",\n" + + " \"usage\": {\n" + + " \"prompt_tokens\": 5,\n" + + " \"total_tokens\": 5\n" + + " }\n" + + "}"; StatusLine statusLine = new BasicStatusLine(new ProtocolVersion("HTTP", 1, 1), 200, "OK"); when(response.getStatusLine()).thenReturn(statusLine); HttpEntity entity = new StringEntity(modelResponse); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java index c78cf11831..7b953c341a 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/GetConnectorTransportAction.java @@ -5,9 +5,11 @@ package org.opensearch.ml.action.connector; -import lombok.AccessLevel; -import lombok.experimental.FieldDefaults; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; +import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; +import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; + import org.opensearch.OpenSearchStatusException; import org.opensearch.action.ActionRequest; import org.opensearch.action.get.GetRequest; @@ -32,10 +34,9 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; -import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; -import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; @Log4j2 @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)