Skip to content

Commit

Permalink
Update GetDataObjectResponse to include full GetResponse parser
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Jun 21, 2024
1 parent 1277830 commit 9044be1
Show file tree
Hide file tree
Showing 13 changed files with 325 additions and 180 deletions.
21 changes: 10 additions & 11 deletions common/src/main/java/org/opensearch/sdk/GetDataObjectResponse.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@

import java.util.Collections;
import java.util.Map;
import java.util.Optional;

public class GetDataObjectResponse {
private final String id;
private final Optional<XContentParser> parser;
private final XContentParser parser;
private final Map<String, Object> source;

/**
* Instantiate this request with an id and parser/map used to recreate the data object.
* <p>
* For data storage implementations other than OpenSearch, the id may be referred to as a primary key.
* @param id the document id
* @param parser an optional XContentParser that can be used to create the data object if present.
* @param parser a parser that can be used to create a GetResponse
* @param source the data object as a map
*/
public GetDataObjectResponse(String id, Optional<XContentParser> parser, Map<String, Object> source) {
public GetDataObjectResponse(String id, XContentParser parser, Map<String, Object> source) {
this.id = id;
this.parser = parser;
this.source = source;
Expand All @@ -42,10 +41,10 @@ public String id() {
}

/**
* Returns the parser optional. If present, is a representation of the data object that may be parsed.
* @return the parser optional
* Returns the parser that can be used to create a GetResponse
* @return the parser
*/
public Optional<XContentParser> parser() {
public XContentParser parser() {
return this.parser;
}

Expand All @@ -62,7 +61,7 @@ public Map<String, Object> source() {
*/
public static class Builder {
private String id = null;
private Optional<XContentParser> parser = Optional.empty();
private XContentParser parser = null;
private Map<String, Object> source = Collections.emptyMap();

/**
Expand All @@ -81,11 +80,11 @@ public Builder id(String id) {
}

/**
* Add an optional parser to this builder
* @param parser an {@link Optional} which may contain the data object parser
* Add a parser to this builder
* @param parser a parser that can be used to create a GetResponse
* @return the updated builder
*/
public Builder parser(Optional<XContentParser> parser) {
public Builder parser(XContentParser parser) {
this.parser = parser;
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import org.opensearch.core.xcontent.XContentParser;

import java.util.Map;
import java.util.Optional;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;

Expand All @@ -33,10 +31,10 @@ public void setUp() {

@Test
public void testGetDataObjectResponse() {
GetDataObjectResponse response = new GetDataObjectResponse.Builder().id(testId).parser(Optional.of(testParser)).source(testSource).build();
GetDataObjectResponse response = new GetDataObjectResponse.Builder().id(testId).parser(testParser).source(testSource).build();

assertEquals(testId, response.id());
assertEquals(testParser, response.parser().get());
assertEquals(testParser, response.parser());
assertEquals(testSource, response.source());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,21 @@

package org.opensearch.ml.action.model_group;

import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
Expand Down Expand Up @@ -134,35 +137,74 @@ private void handleThrowable(Throwable throwable, String modelGroupId, ActionLis
}
}

/*
if (r != null && r.isExists()) {
try (XContentParser parser = createXContentParserFromRegistry(xContentRegistry, r.getSourceAsBytesRef())) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
modelAccessControlHelper.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> {
if (!access) {
wrappedListener
.onFailure(
new OpenSearchStatusException(
"User doesn't have privilege to perform this operation on this model group",
RestStatus.FORBIDDEN
)
);
} else {
wrappedListener.onResponse(MLModelGroupGetResponse.builder().mlModelGroup(mlModelGroup).build());
}
}, e -> {
log.error("Failed to validate access for Model Group " + modelGroupId, e);
wrappedListener.onFailure(e);
}));
} catch (Exception e) {
log.error("Failed to parse ml model group" + r.getId(), e);
wrappedListener.onFailure(e);
}
} else {
*/

private void processResponse(
GetDataObjectResponse getDataObjectResponse,
String modelGroupId,
String tenantId,
User user,
ActionListener<MLModelGroupGetResponse> wrappedListener
) {
if (getDataObjectResponse != null && getDataObjectResponse.parser().isPresent()) {
try {
XContentParser parser = getDataObjectResponse.parser().get();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);

if (TenantAwareHelper
.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModelGroup.getTenantId(), wrappedListener)) {
validateModelGroupAccess(user, modelGroupId, mlModelGroup, wrappedListener);
try {
GetResponse r = GetResponse.fromXContent(getDataObjectResponse.parser());
if (r != null && r.isExists()) {
try (
XContentParser parser = jsonXContent
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, r.getSourceAsString())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);

if (TenantAwareHelper
.validateTenantResource(mlFeatureEnabledSetting, tenantId, mlModelGroup.getTenantId(), wrappedListener)) {
validateModelGroupAccess(user, modelGroupId, mlModelGroup, wrappedListener);
}
} catch (Exception e) {
log.error("Failed to parse ml connector {}", getDataObjectResponse.id(), e);
wrappedListener.onFailure(e);
}
} catch (Exception e) {
log.error("Failed to parse ml connector {}", getDataObjectResponse.id(), e);
wrappedListener.onFailure(e);
} else {
wrappedListener
.onFailure(
new OpenSearchStatusException(
"Failed to find model group with the provided model group id: " + modelGroupId,
RestStatus.NOT_FOUND
)
);
}
} else {
wrappedListener
.onFailure(
new OpenSearchStatusException(
"Failed to find model group with the provided model group id: " + modelGroupId,
RestStatus.NOT_FOUND
)
);
} catch (Exception e) {
wrappedListener.onFailure(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.action.model_group;

import static org.opensearch.common.xcontent.json.JsonXContent.jsonXContent;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.plugin.MachineLearningPlugin.GENERAL_THREAD_POOL;
Expand All @@ -16,13 +17,15 @@
import org.apache.commons.lang3.StringUtils;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
Expand Down Expand Up @@ -128,37 +131,44 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLUpda
wrappedListener.onFailure(cause);
}
} else {
if (r != null && r.parser().isPresent()) {
try {
XContentParser parser = r.parser().get();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
if (TenantAwareHelper
.validateTenantResource(
mlFeatureEnabledSetting,
tenantId,
mlModelGroup.getTenantId(),
wrappedListener
)) {
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup);
} else {
validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput);
try {
GetResponse gr = GetResponse.fromXContent(r.parser());
if (gr != null && gr.isExists()) {
try (
XContentParser parser = jsonXContent
.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, gr.getSourceAsString())
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLModelGroup mlModelGroup = MLModelGroup.parse(parser);
if (TenantAwareHelper
.validateTenantResource(
mlFeatureEnabledSetting,
tenantId,
mlModelGroup.getTenantId(),
wrappedListener
)) {
if (modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(user)) {
validateRequestForAccessControl(updateModelGroupInput, user, mlModelGroup);
} else {
validateSecurityDisabledOrModelAccessControlDisabled(updateModelGroupInput);
}
updateModelGroup(modelGroupId, r.source(), updateModelGroupInput, wrappedListener, user);
}
updateModelGroup(modelGroupId, r.source(), updateModelGroupInput, wrappedListener, user);
} catch (Exception e) {
log.error("Failed to parse ml connector {}", r.id(), e);
wrappedListener.onFailure(e);
}
} catch (Exception e) {
log.error("Failed to parse ml connector {}", r.id(), e);
wrappedListener.onFailure(e);
} else {
wrappedListener
.onFailure(
new OpenSearchStatusException(
"Failed to find model group with the provided model group id: " + modelGroupId,
RestStatus.NOT_FOUND
)
);
}
} else {
wrappedListener
.onFailure(
new OpenSearchStatusException(
"Failed to find model group with the provided model group id: " + modelGroupId,
RestStatus.NOT_FOUND
)
);
} catch (Exception e) {
wrappedListener.onFailure(e);
}
}
});
Expand Down
Loading

0 comments on commit 9044be1

Please sign in to comment.