diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index da4d64c33c..1b25a03353 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -141,16 +141,21 @@ jobs: else echo "imagePresent=false" >> $GITHUB_ENV fi + - name: Generate Password For Admin + id: genpass + run: | + PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-') + echo "password={$PASSWORD}" >> $GITHUB_OUTPUT - name: Run Docker Image if: env.imagePresent == 'true' run: | cd .. - docker run -p 9200:9200 -d -p 9600:9600 -e "discovery.type=single-node" opensearch-ml:test + docker run -p 9200:9200 -d -p 9600:9600 -e "discovery.type=single-node" -e OPENSEARCH_INITIAL_ADMIN_PASSWORD=${{ steps.genpass.outputs.password }} opensearch-ml:test sleep 90 - name: Run MLCommons Test if: env.imagePresent == 'true' run: | - security=`curl -XGET https://localhost:9200/_cat/plugins?v -u admin:admin --insecure |grep opensearch-security|wc -l` + security=`curl -XGET https://localhost:9200/_cat/plugins?v -u admin:${{ steps.genpass.outputs.password }} --insecure |grep opensearch-security|wc -l` export OPENAI_KEY=$(aws secretsmanager get-secret-value --secret-id github_openai_key --query SecretString --output text) export COHERE_KEY=$(aws secretsmanager get-secret-value --secret-id github_cohere_key --query SecretString --output text) echo "::add-mask::$OPENAI_KEY" @@ -158,7 +163,7 @@ jobs: if [ $security -gt 0 ] then echo "Security plugin is available" - ./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=admin + ./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=${{ steps.genpass.outputs.password }} else echo "Security plugin is NOT available" ./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" diff --git a/docs/remote_inference_blueprints/open_ai_connector_completion_blueprint.md b/docs/remote_inference_blueprints/open_ai_connector_completion_blueprint.md index 3c661dffd8..b7e6269a96 100644 --- a/docs/remote_inference_blueprints/open_ai_connector_completion_blueprint.md +++ b/docs/remote_inference_blueprints/open_ai_connector_completion_blueprint.md @@ -13,7 +13,7 @@ POST /_plugins/_ml/connectors/_create "endpoint": "api.openai.com", "max_tokens": 7, "temperature": 0, - "model": "text-davinci-003" + "model": "gpt-3.5-turbo-instruct" }, "credential": { "openAI_key": "" @@ -62,7 +62,7 @@ POST /_plugins/_ml/models//_predict "id": "cmpl-7g0NPOJd8IvXTdhecdlR0VGfrLMWE", "object": "text_completion", "created": 1690245579, - "model": "text-davinci-003", + "model": "gpt-3.5-turbo-instruct", "choices": [ { "text": """ diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index d262b816ec..f7e2a63138 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -198,6 +198,7 @@ public void createConversation(String name, ActionListener listener) { public void getConversations(int from, int maxResults, ActionListener> listener) { if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(List.of()); + return; } SearchRequest request = Requests.searchRequest(META_INDEX_NAME); String userstr = getUserStrFromThreadContext(); @@ -250,6 +251,7 @@ public void getConversations(int maxResults, ActionListener listener) { if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(true); + return; } DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId); String userstr = getUserStrFromThreadContext(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index b500709bc5..617c6871e5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -68,7 +68,8 @@ public String encrypt(String plainText) { initMasterKey(); final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); byte[] bytes = Base64.getDecoder().decode(masterKey); - JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding"); + // https://github.com/aws/aws-encryption-sdk-java/issues/1879 + JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING"); final CryptoResult encryptResult = crypto .encryptData(jceMasterKey, plainText.getBytes(StandardCharsets.UTF_8)); @@ -81,7 +82,7 @@ public String decrypt(String encryptedText) { final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build(); byte[] bytes = Base64.getDecoder().decode(masterKey); - JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding"); + JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING"); final CryptoResult decryptedResult = crypto .decryptData(jceMasterKey, Base64.getDecoder().decode(encryptedText)); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java b/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java index 3e5357b114..73698939e1 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLModelGroupRestIT.java @@ -1239,21 +1239,21 @@ public void test_get_modelGroup() throws IOException { getModelGroup( user2Client, modelGroupId1, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup1"); } ); // Admin successfully gets model group getModelGroup( client(), modelGroupId1, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup1"); } ); } catch (IOException e) { assertNull(e); } // User2 fails to get model group try { - getModelGroup(user3Client, modelGroupId, null); + getModelGroup(user3Client, modelGroupId1, null); } catch (Exception e) { assertEquals(ResponseException.class, e.getClass()); assertTrue( @@ -1273,21 +1273,21 @@ public void test_get_modelGroup() throws IOException { getModelGroup( user1Client, modelGroupId2, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); } ); // User3 successfully gets model group getModelGroup( user3Client, modelGroupId2, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); } ); // User4 successfully gets model group getModelGroup( user4Client, modelGroupId2, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); } ); } catch (IOException e) { assertNull(e); @@ -1303,14 +1303,14 @@ public void test_get_modelGroup() throws IOException { getModelGroup( user3Client, modelGroupId3, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup3"); } ); // Admin successfully gets model group getModelGroup( client(), modelGroupId3, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup3"); } ); } catch (IOException e) { assertNull(e); @@ -1337,7 +1337,7 @@ public void test_get_modelGroup() throws IOException { getModelGroup( client(), modelGroupId4, - getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); } + getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup4"); } ); } catch (IOException e) { assertNull(e); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 0c8a7c779c..3197d1c25e 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -39,7 +39,7 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { + " \"content_type\": \"application/json\",\n" + " \"max_tokens\": 7,\n" + " \"temperature\": 0,\n" - + " \"model\": \"text-davinci-003\"\n" + + " \"model\": \"gpt-3.5-turbo-instruct\"\n" + " },\n" + " \"credential\": {\n" + " \"openAI_key\": \"" @@ -265,7 +265,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { + " \"endpoint\": \"api.openai.com\",\n" + " \"auth\": \"API_Key\",\n" + " \"content_type\": \"application/json\",\n" - + " \"model\": \"text-davinci-edit-001\"\n" + + " \"model\": \"gpt-4\"\n" + " },\n" + " \"credential\": {\n" + " \"openAI_key\": \"" @@ -276,18 +276,18 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { + " {\n" + " \"action_type\": \"predict\",\n" + " \"method\": \"POST\",\n" - + " \"url\": \"https://api.openai.com/v1/edits\",\n" + + " \"url\": \"https://api.openai.com/v1/chat/completions\",\n" + " \"headers\": { \n" + " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n" + " },\n" - + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n" + + " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"${parameters.input}\\\"}]}\"\n" + " }\n" + " ]\n" + "}"; Response response = createConnector(entity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); - response = registerRemoteModel("openAI-GPT-3.5 edit model", connectorId); + response = registerRemoteModel("openAI-GPT-4 edit model", connectorId); responseMap = parseResponseToMap(response); String taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); @@ -298,12 +298,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { responseMap = parseResponseToMap(response); taskId = (String) responseMap.get("task_id"); waitForTask(taskId, MLTaskState.COMPLETED); - String predictInput = "{\n" - + " \"parameters\": {\n" - + " \"input\": \"What day of the wek is it?\",\n" - + " \"instruction\": \"Fix the spelling mistakes\"\n" - + " }\n" - + "}"; + String predictInput = "{\"parameters\":{\"input\":\"What day of the wek is it?\"}}"; response = predictRemoteModel(modelId, predictInput); responseMap = parseResponseToMap(response); List responseList = (List) responseMap.get("inference_results"); @@ -317,7 +312,9 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { return; } responseMap = (Map) responseList.get(0); - assertFalse(((String) responseMap.get("text")).isEmpty()); + responseMap = (Map) responseMap.get("message"); + + assertFalse(((String) responseMap.get("content")).isEmpty()); } public void testOpenAIModerationsModel() throws IOException, InterruptedException { diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java index 2112793166..2b2f409908 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java @@ -20,10 +20,12 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Map; +import java.util.concurrent.TimeUnit; import org.apache.hc.core5.http.HttpEntity; import org.apache.hc.core5.http.HttpHeaders; import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Assert; import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.core.rest.RestStatus; @@ -120,7 +122,7 @@ public void testConversations_MorePages() throws IOException { assert (((Double) map.get("next_token")).intValue() == 1); } - public void testGetConversations_nextPage() throws IOException { + public void testGetConversations_nextPage() throws IOException, InterruptedException { Response ccresponse1 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); assert (ccresponse1 != null); assert (TestHelper.restStatus(ccresponse1) == RestStatus.OK); @@ -128,8 +130,12 @@ public void testGetConversations_nextPage() throws IOException { String ccentityString1 = TestHelper.httpEntityToString(cchttpEntity1); Map ccmap1 = gson.fromJson(ccentityString1, Map.class); assert (ccmap1.containsKey("conversation_id")); + logger.info("ccentityString1={}", ccentityString1); String id1 = (String) ccmap1.get("conversation_id"); + // wait for 0.1s to make sure update time is different between conversation 1 and 2 + TimeUnit.MICROSECONDS.sleep(100); + Response ccresponse2 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); assert (ccresponse2 != null); assert (TestHelper.restStatus(ccresponse2) == RestStatus.OK); @@ -159,7 +165,7 @@ public void testGetConversations_nextPage() throws IOException { ArrayList conversations1 = (ArrayList) map1.get("conversations"); assert (conversations1.size() == 1); assert (conversations1.get(0).containsKey("conversation_id")); - assert (((String) conversations1.get(0).get("conversation_id")).equals(id2)); + Assert.assertEquals(conversations1.get(0).get("conversation_id"), id2); assert (((Double) map1.get("next_token")).intValue() == 1); Response response = TestHelper diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index ecd97c233a..db145dbf01 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -227,7 +227,7 @@ public static RestRequest getCreateConnectorRestRequest() { + " \"content_type\": \"application/json\",\n" + " \"max_tokens\": 7,\n" + " \"temperature\": 0,\n" - + " \"model\": \"text-davinci-003\"\n" + + " \"model\": \"gpt-3.5-turbo-instruct\"\n" + " },\n" + " \"credential\": {\n" + " \"openAI_key\": \"xxxxxxxx\"\n"