Skip to content

Commit

Permalink
support registering ML objects; refactor ML engine interface (opensea…
Browse files Browse the repository at this point in the history
…rch-project#108)

* support registering ML objects; refactor ML engine interface

Signed-off-by: Yaliang Wu <[email protected]>

* update doc

Signed-off-by: Yaliang Wu <[email protected]>

* bump to 1.2.1; update doc

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Dec 13, 2021
1 parent 88c373a commit 2fed70a
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 96 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/CI-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ jobs:
uses: actions/checkout@v2

- name: Build with Gradle
run: ./gradlew build -Dopensearch.version=1.2.0
run: ./gradlew build -Dopensearch.version=1.2.1

- name: Publish to Maven Local
run: ./gradlew publishToMavenLocal -Dopensearch.version=1.2.0
run: ./gradlew publishToMavenLocal -Dopensearch.version=1.2.1

- name: Multi Nodes Integration Testing
run: ./gradlew integTest -PnumNodes=3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.parameter.Input;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLAlgoParams;
Expand Down Expand Up @@ -93,7 +94,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
public void predict_WithAlgoAndInputData() {
MLInput mlInput = MLInput.builder()
.algorithm(FunctionName.KMEANS)
.dataFrame(input)
.inputDataset(new DataFrameInputDataset(input))
.build();
assertEquals(output, machineLearningClient.predict(null, mlInput).actionGet());
}
Expand All @@ -103,7 +104,7 @@ public void predict_WithAlgoAndParametersAndInputData() {
MLInput mlInput = MLInput.builder()
.algorithm(FunctionName.KMEANS)
.parameters(mlParameters)
.dataFrame(input)
.inputDataset(new DataFrameInputDataset(input))
.build();
assertEquals(output, machineLearningClient.predict(null, mlInput).actionGet());
}
Expand All @@ -113,7 +114,7 @@ public void predict_WithAlgoAndParametersAndInputDataAndModelId() {
MLInput mlInput = MLInput.builder()
.algorithm(FunctionName.KMEANS)
.parameters(mlParameters)
.dataFrame(input)
.inputDataset(new DataFrameInputDataset(input))
.build();
assertEquals(output, machineLearningClient.predict("modelId", mlInput).actionGet());
}
Expand All @@ -122,7 +123,7 @@ public void predict_WithAlgoAndParametersAndInputDataAndModelId() {
public void predict_WithAlgoAndInputDataAndListener() {
MLInput mlInput = MLInput.builder()
.algorithm(FunctionName.KMEANS)
.dataFrame(input)
.inputDataset(new DataFrameInputDataset(input))
.build();
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
machineLearningClient.predict(null, mlInput, dataFrameActionListener);
Expand All @@ -135,7 +136,7 @@ public void predict_WithAlgoAndInputDataAndParametersAndListener() {
MLInput mlInput = MLInput.builder()
.algorithm(FunctionName.KMEANS)
.parameters(mlParameters)
.dataFrame(input)
.inputDataset(new DataFrameInputDataset(input))
.build();
ArgumentCaptor<MLOutput> dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class);
machineLearningClient.predict(null, mlInput, dataFrameActionListener);
Expand All @@ -148,7 +149,7 @@ public void train() {
MLInput mlInput = MLInput.builder()
.algorithm(FunctionName.KMEANS)
.parameters(mlParameters)
.dataFrame(input)
.inputDataset(new DataFrameInputDataset(input))
.build();
assertEquals(modekId, ((MLTrainingOutput)machineLearningClient.train(mlInput).actionGet()).getModelId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,16 @@ public class MLInput implements Input {

private int version = 1;

@Builder
@Builder(toBuilder = true)
public MLInput(FunctionName algorithm, MLAlgoParams parameters, MLInputDataset inputDataset) {
validate(algorithm);
this.algorithm = algorithm;
this.parameters = parameters;
this.inputDataset = inputDataset;
}

public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder, List<String> sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) {
if (algorithm == null) {
throw new IllegalArgumentException("algorithm can't be null");
}
validate(algorithm);
this.algorithm = algorithm;
this.parameters = parameters;
if (inputDataset != null) {
Expand All @@ -68,6 +73,12 @@ public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuil
}
}

private void validate(FunctionName algorithm) {
if (algorithm == null) {
throw new IllegalArgumentException("algorithm can't be null");
}
}

public MLInput(StreamInput in) throws IOException {
this.algorithm = in.readEnum(FunctionName.class);
if (in.readBoolean()) {
Expand Down Expand Up @@ -176,4 +187,12 @@ private MLInputDataset createInputDataSet(SearchSourceBuilder searchSourceBuilde
public FunctionName getFunctionName() {
return this.algorithm;
}

public DataFrame getDataFrame() {
if (inputDataset == null || !(inputDataset instanceof DataFrameInputDataset)) {
return null;
}
return ((DataFrameInputDataset)inputDataset).getDataFrame();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ public class MLPredictionTaskRequestTest {

@Before
public void setUp() {
DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {{
put("key1", 2.0D);
}}));
mlInput = MLInput.builder()
.algorithm(FunctionName.KMEANS)
.parameters(KMeansParams.builder().centroids(1).build())
.dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {{
put("key1", 2.0D);
}})))
.inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.parameter.KMeansParams;
import org.opensearch.ml.common.parameter.FunctionName;
Expand All @@ -40,12 +42,13 @@ public class MLTrainingTaskRequestTest {

@Before
public void setUp() {
DataFrame dataFrame = DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {{
put("key1", 2.0D);
}}));
mlInput = MLInput.builder()
.algorithm(FunctionName.KMEANS)
.parameters(KMeansParams.builder().centroids(1).build())
.dataFrame(DataFrameBuilder.load(Collections.singletonList(new HashMap<String, Object>() {{
put("key1", 2.0D);
}})))
.inputDataset(DataFrameInputDataset.builder().dataFrame(dataFrame).build())
.build();
}

Expand Down
41 changes: 38 additions & 3 deletions docs/how-to-add-new-function.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ public class SampleAlgo implements MLAlgo {
}
```

### (Optional)Step 5: register function object
You can register instance of thread-safe class in ML plugin. Refer to Example2#Step5.

### Step 5: Run and test
### Step 6: Run and test
Run `./gradlew run` and test sample algorithm.

Train with sample data
Expand Down Expand Up @@ -203,7 +205,7 @@ public class LocalSampleCalculatorOutput implements Output{
Create new class `ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/sample/LocalSampleCalculator` in `ml-algorithms` package
by implementing interface `Executable`. Override `execute` method.

Must add `@Function` annotation with new function name.
If don't register new ML function class(refer to step 5), you must add `@Function` annotation with new function name.
```
@Function(FunctionName.LOCAL_SAMPLE_CALCULATOR) // must add this annotation
public class LocalSampleCalculator implements Executable {
Expand All @@ -215,7 +217,40 @@ public class LocalSampleCalculator implements Executable {
}
```

### Step 5: Run and test
### (Optional)Step 5: register function object
If the new function class is thread-safe, you can register it. If don't register, will create new object for each request.
For example, we can register instance of `LocalSampleCalculator` in `MachineLearningPlugin` like this
```
public class MachineLearningPlugin extends Plugin implements ActionPlugin {
...
@Override
public Collection<Object> createComponents(
Client client,
ClusterService clusterService,
ThreadPool threadPool,
ResourceWatcherService resourceWatcherService,
ScriptService scriptService,
NamedXContentRegistry xContentRegistry,
Environment environment,
NodeEnvironment nodeEnvironment,
NamedWriteableRegistry namedWriteableRegistry,
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
...
Settings settings = environment.settings();
...
// Register thread-safe ML objects here.
LocalSampleCalculator localSampleCalculator = new LocalSampleCalculator(client, settings);
MLEngineClassLoader.register(FunctionName.LOCAL_SAMPLE_CALCULATOR, localSampleCalculator);
...
}
...
}
```


### Step 6: Run and test
Run `./gradlew run` and test this sample calculator.

```
Expand Down
4 changes: 2 additions & 2 deletions gradle.properties
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
#
#

opensearch_version = 1.2.0-SNAPSHOT
opensearchBaseVersion = 1.2.0
opensearch_version = 1.2.1-SNAPSHOT
opensearchBaseVersion = 1.2.1
69 changes: 38 additions & 31 deletions ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,63 @@

import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.parameter.Input;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.MLAlgoParams;
import org.opensearch.ml.common.parameter.MLInput;
import org.opensearch.ml.common.parameter.MLOutput;
import org.opensearch.ml.common.parameter.Output;
import org.opensearch.ml.engine.annotation.Function;
import org.opensearch.ml.engine.exceptions.MetaDataException;
import org.reflections.Reflections;

import java.lang.reflect.InvocationTargetException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Set;

/**
* This is the interface to all ml algorithms.
*/
public class MLEngine {

public static MLOutput predict(FunctionName algoName, MLAlgoParams parameters, DataFrame dataFrame, Model model) {
if (algoName == null) {
throw new IllegalArgumentException("Algo name should not be null");
public static Model train(Input input) {
validateMLInput(input);
MLInput mlInput = (MLInput) input;
Trainable trainable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
if (trainable == null) {
throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
}
Predictable mlAlgo = MLEngineClassLoader.initInstance(algoName, parameters, MLAlgoParams.class);
if (mlAlgo == null) {
throw new IllegalArgumentException("Unsupported algorithm: " + algoName);
return trainable.train(mlInput.getDataFrame());
}

public static MLOutput predict(Input input, Model model) {
validateMLInput(input);
MLInput mlInput = (MLInput) input;
Predictable predictable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
if (predictable == null) {
throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
}
return mlAlgo.predict(dataFrame, model);
return predictable.predict(mlInput.getDataFrame(), model);
}

public static Model train(FunctionName algoName, MLAlgoParams parameters, DataFrame dataFrame) {
if (algoName == null) {
throw new IllegalArgumentException("Algo name should not be null");
public static Output execute(Input input) {
validateInput(input);
Executable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
if (executable == null) {
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
}
Trainable mlAlgo = MLEngineClassLoader.initInstance(algoName, parameters, MLAlgoParams.class);
if (mlAlgo == null) {
throw new IllegalArgumentException("Unsupported algorithm: " + algoName);
return executable.execute(input);
}

private static void validateMLInput(Input input) {
validateInput(input);
if (!(input instanceof MLInput)) {
throw new IllegalArgumentException("Input should be MLInput");
}
MLInput mlInput = (MLInput) input;
DataFrame dataFrame = mlInput.getDataFrame();
if (dataFrame == null || dataFrame.size() == 0) {
throw new IllegalArgumentException("Input data frame should not be null or empty");
}
return mlAlgo.train(dataFrame);
}

public static Output execute(Input input) {
private static void validateInput(Input input) {
if (input == null) {
throw new IllegalArgumentException("Algo name should not be null");
throw new IllegalArgumentException("Input should not be null");
}
Executable function = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
if (function == null) {
throw new IllegalArgumentException("Unsupported algorithm: " + input.getFunctionName());
if (input.getFunctionName() == null) {
throw new IllegalArgumentException("Function name should not be null");
}
return function.execute(input);
}

}
Loading

0 comments on commit 2fed70a

Please sign in to comment.