Skip to content

Commit

Permalink
Get model group API (opensearch-project#1670)
Browse files Browse the repository at this point in the history
* Get model group API

Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Nov 27, 2023
1 parent 5be107b commit db719cf
Show file tree
Hide file tree
Showing 15 changed files with 1,113 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import org.opensearch.action.ActionType;

public class MLModelGroupGetAction extends ActionType<MLModelGroupGetResponse> {
public static final MLModelGroupGetAction INSTANCE = new MLModelGroupGetAction();
public static final String NAME = "cluster:admin/opensearch/ml/model_groups/get";

private MLModelGroupGetAction() { super(NAME, MLModelGroupGetResponse::new);}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLModelGroupGetRequest extends ActionRequest {

String modelGroupId;

@Builder
public MLModelGroupGetRequest(String modelGroupId) {
this.modelGroupId = modelGroupId;
}

public MLModelGroupGetRequest(StreamInput in) throws IOException {
super(in);
this.modelGroupId = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.modelGroupId);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.modelGroupId == null) {
exception = addValidationError("Model group id can't be null", exception);
}

return exception;
}

public static MLModelGroupGetRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLModelGroupGetRequest) {
return (MLModelGroupGetRequest)actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLModelGroupGetRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLModelGroupGetRequest", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLModelGroup;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

@Getter
@ToString
public class MLModelGroupGetResponse extends ActionResponse implements ToXContentObject {

MLModelGroup mlModelGroup;

@Builder
public MLModelGroupGetResponse(MLModelGroup mlModelGroup) {
this.mlModelGroup = mlModelGroup;
}


public MLModelGroupGetResponse(StreamInput in) throws IOException {
super(in);
mlModelGroup = mlModelGroup.fromStream(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException{
mlModelGroup.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
return mlModelGroup.toXContent(xContentBuilder, params);
}

public static MLModelGroupGetResponse fromActionResponse(ActionResponse actionResponse) {
if (actionResponse instanceof MLModelGroupGetResponse) {
return (MLModelGroupGetResponse) actionResponse;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionResponse.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLModelGroupGetResponse(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionResponse into MLModelGroupGetResponse", e);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.core.common.io.stream.StreamOutput;

import java.io.IOException;
import java.io.UncheckedIOException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;

public class MLModelGroupGetRequestTest {
private String modelGroupId;

@Before
public void setUp() {
modelGroupId = "test_id";
}

@Test
public void writeTo_Success() throws IOException {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder()
.modelGroupId(modelGroupId).build();
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
mlModelGroupGetRequest.writeTo(bytesStreamOutput);
MLModelGroupGetRequest parsedModel = new MLModelGroupGetRequest(bytesStreamOutput.bytes().streamInput());
assertEquals(parsedModel.getModelGroupId(), modelGroupId);
}

@Test
public void validate_Exception_NullmodelGroupId() {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().build();

ActionRequestValidationException exception = mlModelGroupGetRequest.validate();
assertEquals("Validation Failed: 1: Model group id can't be null;", exception.getMessage());
}

@Test
public void fromActionRequest_Success() {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder()
.modelGroupId(modelGroupId).build();
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
mlModelGroupGetRequest.writeTo(out);
}
};
MLModelGroupGetRequest result = MLModelGroupGetRequest.fromActionRequest(actionRequest);
assertNotSame(result, mlModelGroupGetRequest);
assertEquals(result.getModelGroupId(), mlModelGroupGetRequest.getModelGroupId());
}

@Test(expected = UncheckedIOException.class)
public void fromActionRequest_IOException() {
ActionRequest actionRequest = new ActionRequest() {
@Override
public ActionRequestValidationException validate() {
return null;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException("test");
}
};
MLModelGroupGetRequest.fromActionRequest(actionRequest);
}

@Test
public void validate_Success() {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().modelGroupId(modelGroupId).build();
ActionRequestValidationException actionRequestValidationException = mlModelGroupGetRequest.validate();
assertNull(actionRequestValidationException);
}

@Test
public void fromActionRequestWithMLModelGroupGetRequest_Success() {
MLModelGroupGetRequest mlModelGroupGetRequest = MLModelGroupGetRequest.builder().modelGroupId(modelGroupId).build();
MLModelGroupGetRequest mlModelGroupGetRequestFromActionRequest = MLModelGroupGetRequest.fromActionRequest(mlModelGroupGetRequest);
assertSame(mlModelGroupGetRequest, mlModelGroupGetRequestFromActionRequest);
assertEquals(mlModelGroupGetRequest.getModelGroupId(), mlModelGroupGetRequestFromActionRequest.getModelGroupId());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import org.junit.Before;
import org.junit.Test;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLModelGroup;

import java.io.IOException;
import java.io.UncheckedIOException;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertSame;

public class MLModelGroupGetResponseTest {

MLModelGroup mlModelGroup;

@Before
public void setUp() {
mlModelGroup = MLModelGroup.builder()
.name("modelGroup1")
.latestVersion(1)
.description("This is an example model group")
.access("public")
.build();
}

@Test
public void writeTo_Success() throws IOException {
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
MLModelGroupGetResponse response = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
response.writeTo(bytesStreamOutput);
MLModelGroupGetResponse parsedResponse = new MLModelGroupGetResponse(bytesStreamOutput.bytes().streamInput());
assertNotEquals(response.mlModelGroup, parsedResponse.mlModelGroup);
assertEquals(response.mlModelGroup.getName(), parsedResponse.mlModelGroup.getName());
assertEquals(response.mlModelGroup.getDescription(), parsedResponse.mlModelGroup.getDescription());
assertEquals(response.mlModelGroup.getLatestVersion(), parsedResponse.mlModelGroup.getLatestVersion());
}

@Test
public void toXContentTest() throws IOException {
MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
mlModelGroupGetResponse.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
assertEquals("{\"name\":\"modelGroup1\"," +
"\"latest_version\":1," +
"\"description\":\"This is an example model group\"," +
"\"access\":\"public\"}",
jsonStr);
}

@Test
public void fromActionResponseWithMLModelGroupGetResponse_Success() {
MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
MLModelGroupGetResponse mlModelGroupGetResponseFromActionResponse = MLModelGroupGetResponse.fromActionResponse(mlModelGroupGetResponse);
assertSame(mlModelGroupGetResponse, mlModelGroupGetResponseFromActionResponse);
assertEquals(mlModelGroupGetResponse.mlModelGroup, mlModelGroupGetResponseFromActionResponse.mlModelGroup);
}

@Test
public void fromActionResponse_Success() {
MLModelGroupGetResponse mlModelGroupGetResponse = MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build();
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
mlModelGroupGetResponse.writeTo(out);
}
};
MLModelGroupGetResponse mlModelGroupGetResponseFromActionResponse = MLModelGroupGetResponse.fromActionResponse(actionResponse);
assertNotSame(mlModelGroupGetResponse, mlModelGroupGetResponseFromActionResponse);
assertNotEquals(mlModelGroupGetResponse.mlModelGroup, mlModelGroupGetResponseFromActionResponse.mlModelGroup);
}

@Test(expected = UncheckedIOException.class)
public void fromActionResponse_IOException() {
ActionResponse actionResponse = new ActionResponse() {
@Override
public void writeTo(StreamOutput out) throws IOException {
throw new IOException();
}
};
MLModelGroupGetResponse.fromActionResponse(actionResponse);
}
}
Loading

0 comments on commit db719cf

Please sign in to comment.