Skip to content

Commit

Permalink
fix format violations
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Nov 20, 2023
1 parent 81c6691 commit ff8a160
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,7 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
)
);
}
}, e-> {
log.error("Failed to get model", e);
});
}, e -> { log.error("Failed to get model", e); });
client.get(getModelRequest, ActionListener.runBefore(listener, context::restore));
}
}
Expand All @@ -199,12 +197,19 @@ public MetricsCorrelationOutput execute(Input input) throws ExecuteException {
waitUntil(() -> {
if (modelId != null) {
MLModelState modelState = getModel(modelId).getModelState();
if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED){
if (modelState == MLModelState.DEPLOYED || modelState == MLModelState.PARTIALLY_DEPLOYED) {
log.info("Model deployed: " + modelState);
return true;
} else if (modelState == MLModelState.UNDEPLOYED || modelState == MLModelState.DEPLOY_FAILED) {
log.info("Model not deployed: " + modelState);
deployModel(modelId, ActionListener.wrap(deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(), e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e)));
deployModel(
modelId,
ActionListener
.wrap(
deployModelResponse -> modelId = getTask(deployModelResponse.getTaskId()).getModelId(),
e -> log.error("Metrics correlation model didn't get deployed to the index successfully", e)
)
);
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ public static ModelTensors processOutput(
// execute user defined painless script.
Optional<String> processedResponse = executePostProcessFunction(scriptService, postProcessFunction, modelResponse);
String response = processedResponse.orElse(modelResponse);
boolean scriptReturnModelTensor = postProcessFunction != null && processedResponse.isPresent() && org.opensearch.ml.common.utils.StringUtils.isJson(response);
boolean scriptReturnModelTensor = postProcessFunction != null
&& processedResponse.isPresent()
&& org.opensearch.ml.common.utils.StringUtils.isJson(response);
if (responseFilter == null) {
connector.parseResponse(response, modelTensors, scriptReturnModelTensor);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,17 @@ default ModelTensorOutput executePredict(MLInput mlInput) {
if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet) mlInput.getInputDataset();
int processedDocs = 0;
while(processedDocs < textDocsInputDataSet.getDocs().size()) {
while (processedDocs < textDocsInputDataSet.getDocs().size()) {
List<String> textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
List<ModelTensors> tempTensorOutputs = new ArrayList<>();
preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs);
preparePayloadAndInvokeRemoteModel(
MLInput
.builder()
.algorithm(FunctionName.TEXT_EMBEDDING)
.inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build())
.build(),
tempTensorOutputs
);
int tensorCount = 0;
if (tempTensorOutputs.size() > 0 && tempTensorOutputs.get(0).getMlModelTensors() != null) {
tensorCount = tempTensorOutputs.get(0).getMlModelTensors().size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import javax.crypto.spec.SecretKeySpec;

import org.opensearch.ResourceNotFoundException;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.exception.MLException;

import com.amazonaws.encryptionsdk.AwsCrypto;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,4 +327,4 @@ private MLModel trainLinearRegressionModel() {

return mlEngine.train(mlInput);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,41 @@

package org.opensearch.ml.engine.algorithms.metrics_correlation;

import com.google.common.collect.ImmutableMap;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE;
import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_HELPER;
import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_ZIP_FILE;
import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MCORR_ML_VERSION;
import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MODEL_CONTENT_HASH;

import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.junit.Ignore;
Expand Down Expand Up @@ -81,41 +115,7 @@
import org.opensearch.search.suggest.Suggest;
import org.opensearch.threadpool.ThreadPool;

import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.anyLong;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.engine.algorithms.DLModel.ML_ENGINE;
import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_HELPER;
import static org.opensearch.ml.engine.algorithms.DLModel.MODEL_ZIP_FILE;
import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MCORR_ML_VERSION;
import static org.opensearch.ml.engine.algorithms.metrics_correlation.MetricsCorrelation.MODEL_CONTENT_HASH;

import com.google.common.collect.ImmutableMap;

public class MetricsCorrelationTest {
@Rule
Expand Down Expand Up @@ -295,7 +295,6 @@ public void setUp() throws IOException, URISyntaxException {
extendedInput = MetricsCorrelationInput.builder().inputData(extendedInputData).build();
}


@Ignore
@Test
public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteException {
Expand All @@ -306,16 +305,17 @@ public void testWhenModelIdNotNullButModelIsNotDeployed() throws ExecuteExceptio

doAnswer(invocation -> {

MLModel smallModel = MLModel.builder()
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.name(FunctionName.METRICS_CORRELATION.name())
.modelId(modelId)
.modelGroupId(modelGroupId)
.algorithm(FunctionName.METRICS_CORRELATION)
.version(MCORR_ML_VERSION)
.modelConfig(modelConfig)
.modelState(MLModelState.UNDEPLOYED)
.build();
MLModel smallModel = MLModel
.builder()
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.name(FunctionName.METRICS_CORRELATION.name())
.modelId(modelId)
.modelGroupId(modelGroupId)
.algorithm(FunctionName.METRICS_CORRELATION)
.version(MCORR_ML_VERSION)
.modelConfig(modelConfig)
.modelState(MLModelState.UNDEPLOYED)
.build();
MLModelGetResponse responseTemp = new MLModelGetResponse(smallModel);
ActionFuture<MLModelGetResponse> mockedFutureTemp = mock(ActionFuture.class);
MLTaskGetResponse taskResponse = new MLTaskGetResponse(mlTask);
Expand Down Expand Up @@ -436,7 +436,6 @@ public void testExecuteWithNoModelIndexAndOneEvent() throws ExecuteException, UR
assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics());
}


@Ignore
@Test
public void testExecuteWithModelInIndexAndInvokeDeployAndOneEvent() throws ExecuteException, URISyntaxException {
Expand Down Expand Up @@ -526,8 +525,7 @@ public void testExecuteWithNoModelInIndexAndOneEvent() throws ExecuteException,
assertNotNull(mlModelOutputs.get(0).getMCorrModelTensors().get(0).getSuspected_metrics());
}


//working
// working
@Test
public void testGetModel() {
ActionFuture<MLModelGetResponse> mockedFuture = mock(ActionFuture.class);
Expand Down Expand Up @@ -560,7 +558,7 @@ public static XContentBuilder builder() throws IOException {
return XContentBuilder.builder(XContentType.JSON.xContent());
}

//working
// working
@Test
public void testSearchRequest() {
String expectedIndex = CommonValue.ML_MODEL_INDEX;
Expand Down Expand Up @@ -770,49 +768,56 @@ public static ClusterState setupTestClusterState() {
Set<DiscoveryNodeRole> roleSet = new HashSet<>();
roleSet.add(DiscoveryNodeRole.DATA_ROLE);
DiscoveryNode node = new DiscoveryNode(
"node",
new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()),
new HashMap<>(),
roleSet,
Version.CURRENT
"node",
new TransportAddress(TransportAddress.META_ADDRESS, new AtomicInteger().incrementAndGet()),
new HashMap<>(),
roleSet,
Version.CURRENT
);
Metadata metadata = new Metadata.Builder()
.indices(
ImmutableMap
.<String, IndexMetadata>builder()
.put(
ML_MODEL_INDEX,
IndexMetadata
.builder("test")
.settings(
Settings
.builder()
.put("index.number_of_shards", 1)
.put("index.number_of_replicas", 1)
.put("index.version.created", Version.CURRENT.id)
)
.build()
)
.put(ML_MODEL_GROUP_INDEX, IndexMetadata.builder(ML_MODEL_GROUP_INDEX)
.settings(Settings.builder()
.put("index.number_of_shards", 1)
.put("index.number_of_replicas", 1)
.put("index.version.created", Version.CURRENT.id))
.build())
.build()
)
.build();
.indices(
ImmutableMap
.<String, IndexMetadata>builder()
.put(
ML_MODEL_INDEX,
IndexMetadata
.builder("test")
.settings(
Settings
.builder()
.put("index.number_of_shards", 1)
.put("index.number_of_replicas", 1)
.put("index.version.created", Version.CURRENT.id)
)
.build()
)
.put(
ML_MODEL_GROUP_INDEX,
IndexMetadata
.builder(ML_MODEL_GROUP_INDEX)
.settings(
Settings
.builder()
.put("index.number_of_shards", 1)
.put("index.number_of_replicas", 1)
.put("index.version.created", Version.CURRENT.id)
)
.build()
)
.build()
)
.build();
return new ClusterState(
new ClusterName("test cluster"),
123l,
"111111",
metadata,
null,
DiscoveryNodes.builder().add(node).build(),
null,
Map.of(),
0,
false
new ClusterName("test cluster"),
123l,
"111111",
metadata,
null,
DiscoveryNodes.builder().add(node).build(),
null,
Map.of(),
0,
false
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,25 @@ public void executePredict_RemoteInferenceInput_NullResponse() throws IOExceptio
when(response.httpResponse()).thenReturn(httpResponse);
when(httpClient.prepareRequest(any())).thenReturn(httpRequest);

ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
ConnectorAction predictAction = ConnectorAction
.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
.method("POST")
.url("http://test.com/mock")
.requestBody("{\"input\": \"${parameters.input}\"}")
.build();
Map<String, String> credential = ImmutableMap
.of(ACCESS_KEY_FIELD, encryptor.encrypt("test_key"), SECRET_KEY_FIELD, encryptor.encrypt("test_secret_key"));
Map<String, String> parameters = ImmutableMap.of(REGION_FIELD, "us-west-2", SERVICE_NAME_FIELD, "sagemaker");
Connector connector = AwsConnector.awsConnectorBuilder().name("test connector").version("1").protocol("http").parameters(parameters).credential(credential).actions(Arrays.asList(predictAction)).build();
Connector connector = AwsConnector
.awsConnectorBuilder()
.name("test connector")
.version("1")
.protocol("http")
.parameters(parameters)
.credential(credential)
.actions(Arrays.asList(predictAction))
.build();
connector.decrypt((c) -> encryptor.decrypt(c));
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

Expand Down Expand Up @@ -244,7 +254,8 @@ public void executePredict_TextDocsInferenceInput() throws IOException {
AwsConnectorExecutor executor = spy(new AwsConnectorExecutor(connector, httpClient));

MLInputDataset inputDataSet = TextDocsInputDataSet.builder().docs(ImmutableList.of("input")).build();
ModelTensorOutput modelTensorOutput = executor.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
ModelTensorOutput modelTensorOutput = executor
.executePredict(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().size());
Assert.assertEquals(1, modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().size());
Assert.assertEquals("response", modelTensorOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getName());
Expand Down
Loading

0 comments on commit ff8a160

Please sign in to comment.