From 2fed70a5f4509dd31bae8c4ccaacd7042ab8468c Mon Sep 17 00:00:00 2001 From: Yaliang <49084640+ylwu-amzn@users.noreply.github.com> Date: Mon, 13 Dec 2021 15:12:44 -0800 Subject: [PATCH] support registering ML objects; refactor ML engine interface (#108) * support registering ML objects; refactor ML engine interface Signed-off-by: Yaliang Wu * update doc Signed-off-by: Yaliang Wu * bump to 1.2.1; update doc Signed-off-by: Yaliang Wu --- .github/workflows/CI-workflow.yml | 4 +- .../ml/client/MachineLearningClientTest.java | 13 +-- .../ml/common/parameter/MLInput.java | 27 +++++- .../MLPredictionTaskRequestTest.java | 7 +- .../training/MLTrainingTaskRequestTest.java | 9 +- docs/how-to-add-new-function.md | 41 ++++++++- gradle.properties | 4 +- .../org/opensearch/ml/engine/MLEngine.java | 69 ++++++++------- .../ml/engine/MLEngineClassLoader.java | 64 +++++++++++++- .../sample/LocalSampleCalculator.java | 21 ++++- .../ml/engine/MLEngineClassLoaderTests.java | 76 +++++++++++++++++ .../opensearch/ml/engine/MLEngineTest.java | 83 +++++++++++++++---- .../ml/plugin/MachineLearningPlugin.java | 8 ++ .../ml/task/MLExecuteTaskRunner.java | 6 -- .../ml/task/MLPredictTaskRunner.java | 3 +- .../org/opensearch/ml/task/MLTaskRunner.java | 6 +- .../ml/task/MLTrainingTaskRunner.java | 23 ++--- 17 files changed, 368 insertions(+), 96 deletions(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineClassLoaderTests.java diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index be1fb86361..0347da87c8 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -25,10 +25,10 @@ jobs: uses: actions/checkout@v2 - name: Build with Gradle - run: ./gradlew build -Dopensearch.version=1.2.0 + run: ./gradlew build -Dopensearch.version=1.2.1 - name: Publish to Maven Local - run: ./gradlew publishToMavenLocal -Dopensearch.version=1.2.0 + run: ./gradlew publishToMavenLocal -Dopensearch.version=1.2.1 - name: Multi Nodes Integration Testing run: ./gradlew integTest -PnumNodes=3 diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 3b8d1b9b49..2eddbf8918 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -21,6 +21,7 @@ import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.parameter.Input; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; @@ -93,7 +94,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public void predict_WithAlgoAndInputData() { MLInput mlInput = MLInput.builder() .algorithm(FunctionName.KMEANS) - .dataFrame(input) + .inputDataset(new DataFrameInputDataset(input)) .build(); assertEquals(output, machineLearningClient.predict(null, mlInput).actionGet()); } @@ -103,7 +104,7 @@ public void predict_WithAlgoAndParametersAndInputData() { MLInput mlInput = MLInput.builder() .algorithm(FunctionName.KMEANS) .parameters(mlParameters) - .dataFrame(input) + .inputDataset(new DataFrameInputDataset(input)) .build(); assertEquals(output, machineLearningClient.predict(null, mlInput).actionGet()); } @@ -113,7 +114,7 @@ public void predict_WithAlgoAndParametersAndInputDataAndModelId() { MLInput mlInput = MLInput.builder() .algorithm(FunctionName.KMEANS) .parameters(mlParameters) - .dataFrame(input) + .inputDataset(new DataFrameInputDataset(input)) .build(); assertEquals(output, machineLearningClient.predict("modelId", mlInput).actionGet()); } @@ -122,7 +123,7 @@ public void predict_WithAlgoAndParametersAndInputDataAndModelId() { public void predict_WithAlgoAndInputDataAndListener() { MLInput mlInput = MLInput.builder() .algorithm(FunctionName.KMEANS) - .dataFrame(input) + .inputDataset(new DataFrameInputDataset(input)) .build(); ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); machineLearningClient.predict(null, mlInput, dataFrameActionListener); @@ -135,7 +136,7 @@ public void predict_WithAlgoAndInputDataAndParametersAndListener() { MLInput mlInput = MLInput.builder() .algorithm(FunctionName.KMEANS) .parameters(mlParameters) - .dataFrame(input) + .inputDataset(new DataFrameInputDataset(input)) .build(); ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); machineLearningClient.predict(null, mlInput, dataFrameActionListener); @@ -148,7 +149,7 @@ public void train() { MLInput mlInput = MLInput.builder() .algorithm(FunctionName.KMEANS) .parameters(mlParameters) - .dataFrame(input) + .inputDataset(new DataFrameInputDataset(input)) .build(); assertEquals(modekId, ((MLTrainingOutput)machineLearningClient.train(mlInput).actionGet()).getModelId()); } diff --git a/common/src/main/java/org/opensearch/ml/common/parameter/MLInput.java b/common/src/main/java/org/opensearch/ml/common/parameter/MLInput.java index c471a353e6..fe7a15977d 100644 --- a/common/src/main/java/org/opensearch/ml/common/parameter/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/parameter/MLInput.java @@ -54,11 +54,16 @@ public class MLInput implements Input { private int version = 1; - @Builder + @Builder(toBuilder = true) + public MLInput(FunctionName algorithm, MLAlgoParams parameters, MLInputDataset inputDataset) { + validate(algorithm); + this.algorithm = algorithm; + this.parameters = parameters; + this.inputDataset = inputDataset; + } + public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder, List sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) { - if (algorithm == null) { - throw new IllegalArgumentException("algorithm can't be null"); - } + validate(algorithm); this.algorithm = algorithm; this.parameters = parameters; if (inputDataset != null) { @@ -68,6 +73,12 @@ public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuil } } + private void validate(FunctionName algorithm) { + if (algorithm == null) { + throw new IllegalArgumentException("algorithm can't be null"); + } + } + public MLInput(StreamInput in) throws IOException { this.algorithm = in.readEnum(FunctionName.class); if (in.readBoolean()) { @@ -176,4 +187,12 @@ private MLInputDataset createInputDataSet(SearchSourceBuilder searchSourceBuilde public FunctionName getFunctionName() { return this.algorithm; } + + public DataFrame getDataFrame() { + if (inputDataset == null || !(inputDataset instanceof DataFrameInputDataset)) { + return null; + } + return ((DataFrameInputDataset)inputDataset).getDataFrame(); + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java index 5804febb64..e57378220b 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/prediction/MLPredictionTaskRequestTest.java @@ -48,12 +48,13 @@ public class MLPredictionTaskRequestTest { @Before public void setUp() { + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ + put("key1", 2.0D); + }})); mlInput = MLInput.builder() .algorithm(FunctionName.KMEANS) .parameters(KMeansParams.builder().centroids(1).build()) - .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }}))) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) .build(); } diff --git a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java index 5196d9c6a1..9e01854adf 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/training/MLTrainingTaskRequestTest.java @@ -18,7 +18,9 @@ import org.opensearch.action.ActionRequestValidationException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.parameter.KMeansParams; import org.opensearch.ml.common.parameter.FunctionName; @@ -40,12 +42,13 @@ public class MLTrainingTaskRequestTest { @Before public void setUp() { + DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ + put("key1", 2.0D); + }})); mlInput = MLInput.builder() .algorithm(FunctionName.KMEANS) .parameters(KMeansParams.builder().centroids(1).build()) - .dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap() {{ - put("key1", 2.0D); - }}))) + .inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build()) .build(); } diff --git a/docs/how-to-add-new-function.md b/docs/how-to-add-new-function.md index f3f009d4c2..3c066b696b 100644 --- a/docs/how-to-add-new-function.md +++ b/docs/how-to-add-new-function.md @@ -85,8 +85,10 @@ public class SampleAlgo implements MLAlgo { } ``` +### (Optional)Step 5: register function object +You can register instance of thread-safe class in ML plugin. Refer to Example2#Step5. -### Step 5: Run and test +### Step 6: Run and test Run `./gradlew run` and test sample algorithm. Train with sample data @@ -203,7 +205,7 @@ public class LocalSampleCalculatorOutput implements Output{ Create new class `ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator` in `ml-algorithms` package by implementing interface `Executable`. Override `execute` method. -Must add `@Function` annotation with new function name. +If don't register new ML function class(refer to step 5), you must add `@Function` annotation with new function name. ``` @Function(FunctionName.LOCAL_SAMPLE_CALCULATOR) // must add this annotation public class LocalSampleCalculator implements Executable { @@ -215,7 +217,40 @@ public class LocalSampleCalculator implements Executable { } ``` -### Step 5: Run and test +### (Optional)Step 5: register function object +If the new function class is thread-safe, you can register it. If don't register, will create new object for each request. +For example, we can register instance of `LocalSampleCalculator` in `MachineLearningPlugin` like this +``` +public class MachineLearningPlugin extends Plugin implements ActionPlugin { + ... + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier + ) { + ... + Settings settings = environment.settings(); + ... + // Register thread-safe ML objects here. + LocalSampleCalculator localSampleCalculator = new LocalSampleCalculator(client, settings); + MLEngineClassLoader.register(FunctionName.LOCAL_SAMPLE_CALCULATOR, localSampleCalculator); + ... + } + ... +} +``` + + +### Step 6: Run and test Run `./gradlew run` and test this sample calculator. ``` diff --git a/gradle.properties b/gradle.properties index f421566939..71def4a23f 100644 --- a/gradle.properties +++ b/gradle.properties @@ -10,5 +10,5 @@ # # -opensearch_version = 1.2.0-SNAPSHOT -opensearchBaseVersion = 1.2.0 \ No newline at end of file +opensearch_version = 1.2.1-SNAPSHOT +opensearchBaseVersion = 1.2.1 \ No newline at end of file diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index 287f413360..d7f7131a27 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -14,56 +14,63 @@ import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.parameter.Input; -import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLOutput; import org.opensearch.ml.common.parameter.Output; -import org.opensearch.ml.engine.annotation.Function; -import org.opensearch.ml.engine.exceptions.MetaDataException; -import org.reflections.Reflections; - -import java.lang.reflect.InvocationTargetException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; -import java.util.Set; /** * This is the interface to all ml algorithms. */ public class MLEngine { - public static MLOutput predict(FunctionName algoName, MLAlgoParams parameters, DataFrame dataFrame, Model model) { - if (algoName == null) { - throw new IllegalArgumentException("Algo name should not be null"); + public static Model train(Input input) { + validateMLInput(input); + MLInput mlInput = (MLInput) input; + Trainable trainable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class); + if (trainable == null) { + throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm()); } - Predictable mlAlgo = MLEngineClassLoader.initInstance(algoName, parameters, MLAlgoParams.class); - if (mlAlgo == null) { - throw new IllegalArgumentException("Unsupported algorithm: " + algoName); + return trainable.train(mlInput.getDataFrame()); + } + + public static MLOutput predict(Input input, Model model) { + validateMLInput(input); + MLInput mlInput = (MLInput) input; + Predictable predictable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class); + if (predictable == null) { + throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm()); } - return mlAlgo.predict(dataFrame, model); + return predictable.predict(mlInput.getDataFrame(), model); } - public static Model train(FunctionName algoName, MLAlgoParams parameters, DataFrame dataFrame) { - if (algoName == null) { - throw new IllegalArgumentException("Algo name should not be null"); + public static Output execute(Input input) { + validateInput(input); + Executable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class); + if (executable == null) { + throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName()); } - Trainable mlAlgo = MLEngineClassLoader.initInstance(algoName, parameters, MLAlgoParams.class); - if (mlAlgo == null) { - throw new IllegalArgumentException("Unsupported algorithm: " + algoName); + return executable.execute(input); + } + + private static void validateMLInput(Input input) { + validateInput(input); + if (!(input instanceof MLInput)) { + throw new IllegalArgumentException("Input should be MLInput"); + } + MLInput mlInput = (MLInput) input; + DataFrame dataFrame = mlInput.getDataFrame(); + if (dataFrame == null || dataFrame.size() == 0) { + throw new IllegalArgumentException("Input data frame should not be null or empty"); } - return mlAlgo.train(dataFrame); } - public static Output execute(Input input) { + private static void validateInput(Input input) { if (input == null) { - throw new IllegalArgumentException("Algo name should not be null"); + throw new IllegalArgumentException("Input should not be null"); } - Executable function = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class); - if (function == null) { - throw new IllegalArgumentException("Unsupported algorithm: " + input.getFunctionName()); + if (input.getFunctionName() == null) { + throw new IllegalArgumentException("Function name should not be null"); } - return function.execute(input); } - } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngineClassLoader.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngineClassLoader.java index f432a78168..4e0b97751f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngineClassLoader.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngineClassLoader.java @@ -11,6 +11,7 @@ package org.opensearch.ml.engine; +import org.apache.commons.beanutils.BeanUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ml.common.parameter.FunctionName; @@ -29,8 +30,16 @@ public class MLEngineClassLoader { private static final Logger logger = LogManager.getLogger(MLEngineClassLoader.class); + /** + * This map contains class mapping of enum types like {@link FunctionName} + */ private static Map, Class> mlAlgoClassMap = new HashMap<>(); + /** + * This map contains pre-created thread-safe ML objects. + */ + private static Map, Object> mlObjects = new HashMap<>(); + static { try { AccessController.doPrivileged((PrivilegedExceptionAction) () -> { @@ -42,6 +51,27 @@ public class MLEngineClassLoader { } } + /** + * Register thread-safe ML objects. "initInstance" method will get thread-safe object from + * "mlObjects" map first. If not found, will try to create new instance. So if you are not + * sure if the object to be registered is thread-safe or not, you should NOT register it. + * @param functionName function name + * @param obj object + */ + public static void register(Enum functionName, Object obj) { + mlObjects.put(functionName, obj); + } + + /** + * If you are sure some ML objects will not be used anymore, you can deregister it to release + * memory. + * @param functionName function name + * @return removed object + */ + public static Object deregister(Enum functionName) { + return mlObjects.remove(functionName); + } + public static void loadClassMapping() { Reflections reflections = new Reflections("org.opensearch.ml.engine.algorithms"); @@ -58,13 +88,43 @@ public static void loadClassMapping() { @SuppressWarnings("unchecked") public static , S, I extends Object> S initInstance(T type, I in, Class constructorParamClass) { + return initInstance(type, in, constructorParamClass, null); + } + + /** + * Get instance from registered ML objects. If not registered, will create new instance. + * When create new instance, will try constructor with "constructorParamClass" first. If + * not found, will try default constructor without input parameter. + * @param type enum type + * @param in input parameter of constructor + * @param constructorParamClass constructor parameter class + * @param properties class properties + * @param Enum type + * @param return class + * @param input parameter of constructor + * @return + */ + @SuppressWarnings("unchecked") + public static , S, I extends Object> S initInstance(T type, I in, Class constructorParamClass, Map properties) { + if (mlObjects.containsKey(type)) { + return (S) mlObjects.get(type); + } Class clazz = mlAlgoClassMap.get(type); if (clazz == null) { throw new IllegalArgumentException("Can't find class for type " + type); } try { - Constructor constructor = clazz.getConstructor(constructorParamClass); - return (S) constructor.newInstance(in); + Constructor constructor; + S instance; + try { + constructor = clazz.getConstructor(constructorParamClass); + instance = (S) constructor.newInstance(in); + } catch (NoSuchMethodException e) { + constructor = clazz.getConstructor(); + instance = (S) constructor.newInstance(); + } + BeanUtils.populate(instance, properties); + return instance; } catch (Exception e) { logger.error("Failed to init instance for type " + type, e); return null; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator.java index 1e69e7fd02..404123818f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator.java @@ -11,7 +11,14 @@ package org.opensearch.ml.engine.algorithms.sample; +import lombok.Data; +import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.Setter; +import org.opensearch.client.Client; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.Input; import org.opensearch.ml.common.parameter.LocalSampleCalculatorInput; @@ -23,13 +30,18 @@ import java.util.Comparator; import java.util.List; +@Data @NoArgsConstructor @Function(FunctionName.LOCAL_SAMPLE_CALCULATOR) public class LocalSampleCalculator implements Executable { - private LocalSampleCalculatorInput sampleCalculatorInput; - public LocalSampleCalculator(Input input) { - sampleCalculatorInput = (LocalSampleCalculatorInput) input; + //TODO: support calculate sum/max/min value from index. + private Client client; + private Settings settings; + + public LocalSampleCalculator(Client client, Settings settings) { + this.client = client; + this.settings = settings; } @Override @@ -37,11 +49,12 @@ public Output execute(Input input) { if (input == null || !(input instanceof LocalSampleCalculatorInput)) { throw new IllegalArgumentException("wrong input"); } + LocalSampleCalculatorInput sampleCalculatorInput = (LocalSampleCalculatorInput) input; String operation = sampleCalculatorInput.getOperation(); List inputData = sampleCalculatorInput.getInputData(); switch (operation) { case "sum": - double sum = inputData.stream().mapToDouble(f -> f.doubleValue()).sum(); + double sum = inputData.stream().mapToDouble(f -> f.doubleValue()).sum() ; return new SampleAlgoOutput(sum); case "max": double max = inputData.stream().max(Comparator.naturalOrder()).get(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineClassLoaderTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineClassLoaderTests.java new file mode 100644 index 0000000000..b29095ae89 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineClassLoaderTests.java @@ -0,0 +1,76 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.ml.engine; + +import org.junit.Test; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.ml.common.parameter.FunctionName; +import org.opensearch.ml.common.parameter.Input; +import org.opensearch.ml.common.parameter.LocalSampleCalculatorInput; +import org.opensearch.ml.common.parameter.SampleAlgoOutput; +import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.mockito.Mockito.mock; + +public class MLEngineClassLoaderTests { + + @Test + public void initInstance_LocalSampleCalculator() { + List inputData = new ArrayList<>(); + double d1 = 10.0; + double d2 = 20.0; + inputData.add(d1); + inputData.add(d2); + LocalSampleCalculatorInput input = LocalSampleCalculatorInput.builder().operation("sum").inputData(inputData).build(); + + Map properties = new HashMap<>(); + properties.put("wrongField", "test"); + Client client = mock(Client.class); + properties.put("client", client); + Settings settings = Settings.EMPTY; + properties.put("settings", settings); + + // set properties + MLEngineClassLoader.deregister(FunctionName.LOCAL_SAMPLE_CALCULATOR); + LocalSampleCalculator instance = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class, properties); + SampleAlgoOutput output = (SampleAlgoOutput) instance.execute(input); + assertEquals(d1 + d2, output.getSampleResult(), 1e-6); + assertEquals(client, instance.getClient()); + assertEquals(settings, instance.getSettings()); + + // don't set properties + instance = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, input, Input.class); + output = (SampleAlgoOutput) instance.execute(input); + assertEquals(d1 + d2, output.getSampleResult(), 1e-6); + assertNull(instance.getClient()); + assertNull(instance.getSettings()); + } + + @Test + public void initInstance_LocalSampleCalculator_RegisterFirst() { + Client client = mock(Client.class); + Settings settings = Settings.EMPTY; + LocalSampleCalculator calculator = new LocalSampleCalculator(client, settings); + MLEngineClassLoader.register(FunctionName.LOCAL_SAMPLE_CALCULATOR, calculator); + + LocalSampleCalculator instance = MLEngineClassLoader.initInstance(FunctionName.LOCAL_SAMPLE_CALCULATOR, null, Input.class); + assertEquals(calculator, instance); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index ea50c165c8..4b604029ed 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -13,21 +13,22 @@ package org.opensearch.ml.engine; import org.junit.Assert; -import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.parameter.Input; import org.opensearch.ml.common.parameter.KMeansParams; import org.opensearch.ml.common.parameter.LinearRegressionParams; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.MLAlgoParams; +import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLPredictionOutput; -import java.util.HashSet; -import java.util.Set; import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame; import static org.opensearch.ml.engine.helper.LinearRegressionHelper.constructLinearRegressionPredictionDataFrame; @@ -41,7 +42,9 @@ public class MLEngineTest { public void predictKMeans() { Model model = trainKMeansModel(); DataFrame predictionDataFrame = constructKMeansDataFrame(10); - MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(FunctionName.KMEANS, null, predictionDataFrame, model); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build(); + Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(inputDataset).build(); + MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(mlInput, model); DataFrame predictions = output.getPredictionResult(); Assert.assertEquals(10, predictions.size()); predictions.forEach(row -> Assert.assertTrue(row.getValue(0).intValue() == 0 || row.getValue(0).intValue() == 1)); @@ -51,7 +54,9 @@ public void predictKMeans() { public void predictLinearRegression() { Model model = trainLinearRegressionModel(); DataFrame predictionDataFrame = constructLinearRegressionPredictionDataFrame(); - MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(FunctionName.LINEAR_REGRESSION, null, predictionDataFrame, model); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(predictionDataFrame).build(); + Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build(); + MLPredictionOutput output = (MLPredictionOutput)MLEngine.predict(mlInput, model); DataFrame predictions = output.getPredictionResult(); Assert.assertEquals(2, predictions.size()); } @@ -73,35 +78,74 @@ public void trainLinearRegression() { } @Test - public void trainWithoutAlgorithm() { + public void train_NullInput() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Algo name should not be null"); - MLEngine.train(null, null, null); + exceptionRule.expectMessage("Input should not be null"); + FunctionName algoName = FunctionName.LINEAR_REGRESSION; + try (MockedStatic loader = Mockito.mockStatic(MLEngineClassLoader.class)) { + loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null); + MLEngine.train(null); + } } @Test - public void trainUnsupportedAlgorithm() { + public void train_NullDataFrame() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Input data frame should not be null or empty"); + FunctionName algoName = FunctionName.LINEAR_REGRESSION; + try (MockedStatic loader = Mockito.mockStatic(MLEngineClassLoader.class)) { + loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null); + MLEngine.train(MLInput.builder().algorithm(algoName).build()); + } + } + + @Test + public void train_EmptyDataFrame() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Input data frame should not be null or empty"); + FunctionName algoName = FunctionName.LINEAR_REGRESSION; + try (MockedStatic loader = Mockito.mockStatic(MLEngineClassLoader.class)) { + loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(0)).build(); + MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build()); + } + } + + @Test + public void train_UnsupportedAlgorithm() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Unsupported algorithm: LINEAR_REGRESSION"); FunctionName algoName = FunctionName.LINEAR_REGRESSION; try (MockedStatic loader = Mockito.mockStatic(MLEngineClassLoader.class)) { loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null); - MLEngine.train(algoName, null, null); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build(); + MLEngine.train(MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build()); } } + @Test + public void predictNullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("Input should not be null"); + MLEngine.predict(null, null); + } + @Test public void predictWithoutAlgoName() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Algo name should not be null"); - MLEngine.predict(null, null, null, null); + exceptionRule.expectMessage("algorithm can't be null"); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructKMeansDataFrame(10)).build(); + Input mlInput = MLInput.builder().inputDataset(inputDataset).build(); + MLEngine.predict(mlInput, null); } @Test public void predictWithoutModel() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("No model found for linear regression prediction."); - MLEngine.predict(FunctionName.LINEAR_REGRESSION, null, null, null); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build(); + Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).inputDataset(inputDataset).build(); + MLEngine.predict(mlInput, null); } @Test @@ -111,7 +155,9 @@ public void predictUnsupportedAlgorithm() { FunctionName algoName = FunctionName.LINEAR_REGRESSION; try (MockedStatic loader = Mockito.mockStatic(MLEngineClassLoader.class)) { loader.when(() -> MLEngineClassLoader.initInstance(algoName, null, MLAlgoParams.class)).thenReturn(null); - MLEngine.predict(algoName, null, null, null); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(constructLinearRegressionPredictionDataFrame()).build(); + Input mlInput = MLInput.builder().algorithm(algoName).inputDataset(inputDataset).build(); + MLEngine.predict(mlInput, null); } } @@ -122,7 +168,9 @@ private Model trainKMeansModel() { .distanceType(KMeansParams.DistanceType.EUCLIDEAN) .build(); DataFrame trainDataFrame = constructKMeansDataFrame(100); - return MLEngine.train(FunctionName.KMEANS, parameters, trainDataFrame); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build(); + Input mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(parameters).inputDataset(inputDataset).build(); + return MLEngine.train(mlInput); } private Model trainLinearRegressionModel() { @@ -136,8 +184,9 @@ private Model trainLinearRegressionModel() { .target("price") .build(); DataFrame trainDataFrame = constructLinearRegressionTrainDataFrame(); + MLInputDataset inputDataset = DataFrameInputDataset.builder().dataFrame(trainDataFrame).build(); + Input mlInput = MLInput.builder().algorithm(FunctionName.LINEAR_REGRESSION).parameters(parameters).inputDataset(inputDataset).build(); - - return MLEngine.train(FunctionName.LINEAR_REGRESSION, parameters, trainDataFrame); + return MLEngine.train(mlInput); } } \ No newline at end of file diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 2bc245fc8e..dabcc9a763 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -43,6 +43,7 @@ import org.opensearch.ml.action.training.MLTrainingTaskExecutionAction; import org.opensearch.ml.action.training.MLTrainingTaskExecutionTransportAction; import org.opensearch.ml.action.training.TransportTrainingTaskAction; +import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.KMeansParams; import org.opensearch.ml.common.parameter.LinearRegressionParams; import org.opensearch.ml.common.parameter.LocalSampleCalculatorInput; @@ -50,6 +51,8 @@ import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.training.MLTrainingTaskAction; +import org.opensearch.ml.engine.MLEngineClassLoader; +import org.opensearch.ml.engine.algorithms.sample.LocalSampleCalculator; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.indices.MLInputDatasetHandler; import org.opensearch.ml.rest.RestMLExecuteAction; @@ -134,6 +137,7 @@ public Collection createComponents( this.client = client; this.threadPool = threadPool; this.clusterService = clusterService; + Settings settings = environment.settings(); Map> stats = ImmutableMap .>builder() @@ -175,6 +179,10 @@ public Collection createComponents( mlTaskDispatcher ); + // Register thread-safe ML objects here. + LocalSampleCalculator localSampleCalculator = new LocalSampleCalculator(client, settings); + MLEngineClassLoader.register(FunctionName.LOCAL_SAMPLE_CALCULATOR, localSampleCalculator); + return ImmutableList .of( mlStats, diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index 22e4c9b151..6db43b6a7f 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -60,12 +60,6 @@ public MLExecuteTaskRunner( * @param transportService transport service * @param listener Action listener */ - public void execute(MLExecuteTaskRequest request, TransportService transportService, ActionListener listener) { - Input input = request.getInput(); - Output output = MLEngine.execute(input); - listener.onResponse(MLExecuteTaskResponse.builder().output(output).build()); - } - @Override public void run(MLExecuteTaskRequest request, TransportService transportService, ActionListener listener) { Input input = request.getInput(); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 33b300d1cc..ce5b99642e 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -38,6 +38,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.action.prediction.MLPredictionTaskExecutionAction; import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.parameter.Input; import org.opensearch.ml.common.parameter.MLInput; @@ -208,7 +209,7 @@ private void predict( MLOutput output; try { mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING); - output = MLEngine.predict(mlInput.getAlgorithm(), mlInput.getParameters(), inputDataFrame, model); + output = MLEngine.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model); if (output instanceof MLPredictionOutput) { ((MLPredictionOutput) output).setTaskId(mlTask.getTaskId()); ((MLPredictionOutput) output).setStatus(mlTaskManager.get(mlTask.getTaskId()).getState().name()); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index 37ac190878..9449f8ec46 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -22,8 +22,10 @@ /** * MLTaskRunner has common code for dispatching and running predict/training tasks. + * @param ML task request + * @param ML task request */ -public abstract class MLTaskRunner { +public abstract class MLTaskRunner { protected final MLTaskManager mlTaskManager; protected final MLStats mlStats; protected final MLTaskDispatcher mlTaskDispatcher; @@ -57,5 +59,5 @@ protected void handleMLTaskComplete(MLTask mlTask) { mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.COMPLETED); } - public abstract void run(S request, TransportService transportService, ActionListener listener); + public abstract void run(Request request, TransportService transportService, ActionListener listener); } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java index 9331f3b08b..01f7d62ae5 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTrainingTaskRunner.java @@ -33,6 +33,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.ml.action.training.MLTrainingTaskExecutionAction; import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.parameter.MLInput; import org.opensearch.ml.common.parameter.MLTrainingOutput; @@ -119,31 +120,33 @@ public void startTrainingTask(MLTrainingTaskRequest request, ActionListener dataFrameActionListener = ActionListener - .wrap(dataFrame -> { train(mlTask, dataFrame, mlInput); }, e -> { - log.error("Failed to generate DataFrame from search query", e); - mlTaskManager.addIfAbsent(mlTask); - mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.FAILED); - mlTaskManager.updateTaskError(mlTask.getTaskId(), e.getMessage()); - }); + .wrap( + dataFrame -> { train(mlTask, mlInput.toBuilder().inputDataset(new DataFrameInputDataset(dataFrame)).build()); }, + e -> { + log.error("Failed to generate DataFrame from search query", e); + mlTaskManager.addIfAbsent(mlTask); + mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.FAILED); + mlTaskManager.updateTaskError(mlTask.getTaskId(), e.getMessage()); + } + ); mlInputDatasetHandler .parseSearchQueryInput( mlInput.getInputDataset(), new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false) ); } else { - DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(mlInput.getInputDataset()); - threadPool.executor(TASK_THREAD_POOL).execute(() -> { train(mlTask, inputDataFrame, mlInput); }); + threadPool.executor(TASK_THREAD_POOL).execute(() -> { train(mlTask, mlInput); }); } } - private void train(MLTask mlTask, DataFrame inputDataFrame, MLInput mlInput) { + private void train(MLTask mlTask, MLInput mlInput) { // track ML task count and add ML task into cache mlStats.getStat(ML_EXECUTING_TASK_COUNT.getName()).increment(); mlTaskManager.add(mlTask); // run training try { mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING); - Model model = MLEngine.train(mlInput.getAlgorithm(), mlInput.getParameters(), inputDataFrame); + Model model = MLEngine.train(mlInput); String encodedModelContent = Base64.getEncoder().encodeToString(model.getContent()); mlIndicesHandler.initModelIndexIfAbsent(); Map source = new HashMap<>();