diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index 81a4ab75bb..3794b61dde 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -5,18 +5,9 @@ package org.opensearch.ml.action.register; -import static org.opensearch.ml.common.MLTask.STATE_FIELD; -import static org.opensearch.ml.common.MLTaskState.FAILED; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; -import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; -import static org.opensearch.ml.utils.MLExceptionUtils.logException; - -import java.time.Instant; -import java.util.Arrays; -import java.util.List; -import java.util.regex.Pattern; - +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import lombok.extern.log4j.Log4j2; import org.apache.logging.log4j.util.Strings; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.ActionRequest; @@ -63,10 +54,17 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; -import lombok.extern.log4j.Log4j2; +import static org.opensearch.ml.common.MLTask.STATE_FIELD; +import static org.opensearch.ml.common.MLTaskState.FAILED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; +import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT; +import static org.opensearch.ml.utils.MLExceptionUtils.logException; @Log4j2 public class TransportRegisterModelAction extends HandledTransportAction { @@ -176,10 +174,9 @@ private void checkUserAccess( if (isModelNameAlreadyExisting) { // This case handles when user is using the same pre-trained model already registered by another user on the cluster. // The only way here is for the user to first create model group and use its ID in the request - if (registerModelInput.getModelGroupId() != null - && (registerModelInput.getUrl() == null + if (registerModelInput.getUrl() == null && registerModelInput.getFunctionName() != FunctionName.REMOTE - && registerModelInput.getConnectorId() == null)) { + && registerModelInput.getConnectorId() == null) { listener .onFailure( new IllegalArgumentException( diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index 27c86987d7..26ebfce5b6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -5,23 +5,7 @@ package org.opensearch.ml.action.register; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; -import static org.opensearch.ml.utils.TestHelper.clusterSetting; - -import java.io.IOException; -import java.util.List; -import java.util.Map; - +import com.google.common.collect.ImmutableList; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.junit.Rule; @@ -72,7 +56,22 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import com.google.common.collect.ImmutableList; +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; +import static org.opensearch.ml.utils.TestHelper.clusterSetting; public class TransportRegisterModelActionTests extends OpenSearchTestCase { @Rule @@ -495,6 +494,36 @@ public void test_ModelNameAlreadyExists() throws IOException { verify(actionListener).onResponse(argumentCaptor.capture()); } + public void test_FailureWhenPreBuildModelNameAlreadyExists() throws IOException { + SearchResponse searchResponse = createModelGroupSearchResponse(1); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(searchResponse); + return null; + }).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onResponse(false); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + MLRegisterModelInput registerModelInput = MLRegisterModelInput + .builder() + .modelName("huggingface/sentence-transformers/all-MiniLM-L12-v2") + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .version("1") + .build(); + + transportRegisterModelAction.doExecute(task, new MLRegisterModelRequest(registerModelInput), actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Without a model group ID, the system will use the model name {huggingface/sentence-transformers/all-MiniLM-L12-v2} to create a new model group. However, this name is taken by another group {model_group_ID} you can't access. To register this pre-trained model, create a new model group and use its ID in your request.", + argumentCaptor.getValue().getMessage() + ); + } + public void test_FailureWhenSearchingModelGroupName() throws IOException { doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1);