Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pass query string to model_config in ml inference search response processor #2899

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

mingshl
Copy link
Collaborator

@mingshl mingshl commented Sep 5, 2024

Description

Enable pass query string to model_config in ml inference search response processor

setting cluster

PUT _cluster/settings
{
  "persistent": {
    "plugins": {
      "ml_commons": {
        "only_run_on_ml_node": "false",
        "model_access_control_enabled": "true",
        "native_memory_threshold": "99"
      }
    }
  }
}

register cross-encoders local model

POST /_plugins/_ml/models/_register
{
  "name": "huggingface/cross-encoders/ms-marco-MiniLM-L-6-v2",
  "version": "1.0.2",
  "model_format": "TORCH_SCRIPT"
}

{
  "task_id": "tQ5p1ZEB4iWlnHsIf2Xw",
  "status": "CREATED"
}

get register task status

GET /_plugins/_ml/tasks/tQ5p1ZEB4iWlnHsIf2Xw

{
  "model_id": "tg5p1ZEB4iWlnHsIh2U9",
  "task_type": "REGISTER_MODEL",
  "function_name": "TEXT_SIMILARITY",
  "state": "COMPLETED",
  "worker_node": [
    "AahcbpI9R2OtId7Wnt0cYA"
  ],
  "create_time": 1725862346467,
  "last_update_time": 1725862356009,
  "is_async": true
}

deploy cross-encoders local model

`POST /_plugins/_ml/models/tg5p1ZEB4iWlnHsIh2U9/_deploy

`

{
  "task_id": "tw5q1ZEB4iWlnHsIo2WX",
  "task_type": "DEPLOY_MODEL",
  "status": "CREATED"
}

get deploy task status

GET /_plugins/_ml/tasks/tw5q1ZEB4iWlnHsIo2WX

 {
  "model_id": "tg5p1ZEB4iWlnHsIh2U9",
  "task_type": "DEPLOY_MODEL",
  "function_name": "TEXT_SIMILARITY",
  "state": "RUNNING",
  "worker_node": [
    "AahcbpI9R2OtId7Wnt0cYA"
  ],
  "create_time": 1725862421392,
  "last_update_time": 1725862421555,
  "is_async": true
}

wait until completed

{
  "model_id": "tg5p1ZEB4iWlnHsIh2U9",
  "task_type": "DEPLOY_MODEL",
  "function_name": "TEXT_SIMILARITY",
  "state": "COMPLETED",
  "worker_node": [
    "AahcbpI9R2OtId7Wnt0cYA"
  ],
  "create_time": 1725862421392,
  "last_update_time": 1725862466568,
  "is_async": true
}

test model predict

POST _plugins/_ml/models/tg5p1ZEB4iWlnHsIh2U9/_predict
{
    "query_text": "today is sunny",
    "text_docs": [
        "how are you",
        "today is sunny",
        "today is july fifth",
        "it is winter"
    ]
}

{
  "inference_results": [
    {
      "output": [
        {
          "name": "similarity",
          "data_type": "FLOAT32",
          "shape": [
            1
          ],
          "data": [
            -11.055183
          ],
          "byte_buffer": {
            "array": "COIwwQ==",
            "order": "LITTLE_ENDIAN"
          }
        }
      ]
    },
    {
      "output": [
        {
          "name": "similarity",
          "data_type": "FLOAT32",
          "shape": [
            1
          ],
          "data": [
            8.969885
          ],
          "byte_buffer": {
            "array": "poQPQQ==",
            "order": "LITTLE_ENDIAN"
          }
        }
      ]
    },
    {
      "output": [
        {
          "name": "similarity",
          "data_type": "FLOAT32",
          "shape": [
            1
          ],
          "data": [
            -5.736347
          ],
          "byte_buffer": {
            "array": "KJC3wA==",
            "order": "LITTLE_ENDIAN"
          }
        }
      ]
    },
    {
      "output": [
        {
          "name": "similarity",
          "data_type": "FLOAT32",
          "shape": [
            1
          ],
          "data": [
            -10.0452175
          ],
          "byte_buffer": {
            "array": "NrkgwQ==",
            "order": "LITTLE_ENDIAN"
          }
        }
      ]
    }
  ]
}

upload index

 PUT /demo-index-0/_doc/1
{
  "dairy": "how are you"
}

 PUT /demo-index-0/_doc/2
{
  "dairy": "today is sunny"
}
 PUT /demo-index-0/_doc/3
{
  "dairy": "today is july fifth"
}
{
  "_index": "demo-index-0",
  "_id": "4",
  "_version": 1,
  "result": "created",
  "_shards": {
    "total": 2,
    "successful": 1,
    "failed": 0
  },
  "_seq_no": 3,
  "_primary_term": 1
}

create search pipeline with query text passing in model_config

PUT /_search/pipeline/my_pipeline
{
  "response_processors": [
    {
      "ml_inference": {
        "tag": "ml_inference",
        "description": "This processor is going to run ml inference during search response",
        "model_id": "tg5p1ZEB4iWlnHsIh2U9",
        "model_input":"{ \"text_docs\": ${input_map.text_docs}, \"query_text\": \"${model_config.query_text}\" }",
        "function_name": "TEXT_SIMILARITY",
        "input_map": [
          {
            "text_docs": "dairy"
          }
        ],
        "output_map": [
          {
            "rank_score": "$.inference_results[*].output[*].data"
          }
        ],
        "full_response_path":false,
        "model_config": {
          "query_text": "$.query.term.dairy.value"
        },
        "ignore_missing": false,
        "ignore_failure": false
      }
    }
  ]
}

search with search pipeline, scores are added in the response

GET /demo-index-0/_search?search_pipeline=my_pipeline
{
  "query": {
    "term": {
      "dairy": {
        "value": "today"
      }
    }
  }
}
{
  "took": 400,
  "timed_out": false,
  "_shards": {
    "total": 1,
    "successful": 1,
    "skipped": 0,
    "failed": 0
  },
  "hits": {
    "total": {
      "value": 2,
      "relation": "eq"
    },
    "max_score": 0.71566814,
    "hits": [
      {
        "_index": "demo-index-0",
        "_id": "2",
        "_score": 0.71566814,
        "_source": {
          "dairy": "today is sunny",
          "rank_score": [
            3.6144485
          ]
        }
      },
      {
        "_index": "demo-index-0",
        "_id": "3",
        "_score": 0.6333549,
        "_source": {
          "dairy": "today is july fifth",
          "rank_score": [
            3.6144485
          ]
        }
      }
    ]
  }
}

ToDo

currently ml inference processor only support single tensor for local model, need to support multiple tensor parsing as well.

Related Issues

#2897
#2878

Check List

  • New functionality includes testing.
  • New functionality has been documented.
  • API changes companion pull request created.
  • Commits are signed per the DCO using --signoff.
  • Public documentation issue/PR created.

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
For more information on following Developer Certificate of Origin and signing off your commits, please check here.

@ylwu-amzn
Copy link
Collaborator

Add more details/examples to description ?

@ylwu-amzn
Copy link
Collaborator

> Task :opensearch-ml-common:test

StringUtilsTest > testisValidJSONPath_InvalidInputs FAILED
    java.lang.AssertionError at StringUtilsTest.java:476

StringUtilsTest > testisValidJSONPath_EmptyInput FAILED
    java.lang.AssertionError at StringUtilsTest.java:490

1008 tests completed, 2 failed, 2 skipped

@mingshl
Copy link
Collaborator Author

mingshl commented Sep 5, 2024

#2897

there are more details in the RFC #2897

@mingshl
Copy link
Collaborator Author

mingshl commented Sep 6, 2024

CI failed

OpenSearchConversationalMemoryHandlerTests > classMethod FAILED
Error: Exception in thread "Thread-4" java.lang.NoClassDefFoundError: Could not initialize class org.opensearch.test.OpenSearchTestCase
	at java.base/java.lang.Thread.run(Thread.java:1583)
	Suppressed: java.lang.IllegalStateException: No context information for thread: Thread[id=29, name=Thread-4, state=RUNNABLE, group=TGRP-ConversationalMemoryHandlerITTests]. Is this thread running under a class com.carrotsearch.randomizedtesting.RandomizedRunner runner context? Add @RunWith(class com.carrotsearch.randomizedtesting.RandomizedRunner.class) to your test class. Make sure your code accesses random contexts within @BeforeClass and @AfterClass boundary (for example, static test class initializers are not permitted to access random contexts).
		at com.carrotsearch.randomizedtesting.RandomizedContext.context(RandomizedContext.java:249)
		at com.carrotsearch.randomizedtesting.RandomizedContext.current(RandomizedContext.java:134)
		at com.carrotsearch.randomizedtesting.RandomizedRunner.augmentStackTrace(RandomizedRunner.java:1885)
		at com.carrotsearch.randomizedtesting.RunnerThreadGroup.uncaughtException(RunnerThreadGroup.java:20)
		at java.base/java.lang.Thread.dispatchUncaughtException(Thread.java:2901)
Caused by: java.lang.ExceptionInInitializerError: Exception java.lang.ExceptionInInitializerError [in thread "SUITE-ConversationalMemoryHandlerITTests-seed#[7B60246E9722A278]"]
	at org.opensearch.test.OpenSearchTestCase.<clinit>(OpenSearchTestCase.java:285)
	at java.base/java.lang.Class.forName0(Native Method)
	at java.base/java.lang.Class.forName(Class.java:534)
	at java.base/java.lang.Class.forName(Class.java:513)
	at com.carrotsearch.randomizedtesting.RandomizedRunner$2.run(RandomizedRunner.java:623)
    java.lang.NoClassDefFoundError at Class.java:-2
        Caused by: java.lang.ExceptionInInitializerError at OpenSearchTestCase.java:285


@mingshl mingshl force-pushed the main_add_query_text_to_response_processor branch from 350ab60 to c1d112a Compare September 6, 2024 21:09
@dhrubo-os
Copy link
Collaborator

I also left a comment in your RFC.

Signed-off-by: Mingshi Liu <[email protected]>
Signed-off-by: Mingshi Liu <[email protected]>
@mingshl
Copy link
Collaborator Author

mingshl commented Sep 9, 2024

flaky test

REPRODUCE WITH: ./gradlew ':opensearch-ml-plugin:test' --tests "org.opensearch.ml.rest.RestMLPredictionActionTests.testGetRequest_LocalModelInferenceDisabled" -Dtests.seed=9770CFDD00A92E2A -Dtests.security.manager=false -Dtests.locale=az-AZ -Dtests.timezone=Asia/Sakhalin -Druntime.java=21

RestMLPredictionActionTests > testGetRequest_LocalModelInferenceDisabled FAILED
    java.lang.AssertionError: 
    Expected: (an instance of java.lang.IllegalStateException and exception with message a string containing "Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.")
         but: an instance of java.lang.IllegalStateException <java.lang.IllegalArgumentException: Wrong Action Type of models> is a java.lang.IllegalArgumentException
    Stacktrace was: java.lang.IllegalArgumentException: Wrong Action Type of models
    	at org.opensearch.ml.common.connector.ConnectorAction$ActionType.from(ConnectorAction.java:199)
    	at org.opensearch.ml.rest.RestMLPredictionAction.getRequest(RestMLPredictionAction.java:129)
    	at org.opensearch.ml.rest.RestMLPredictionActionTests.testGetRequest_LocalModelInferenceDisabled(RestMLPredictionActionTests.java:146)
    	at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103)
    	at java.base/java.lang.reflect.Method.invoke(Method.java:580)
    	at com.carrotsearch.randomizedtesting.RandomizedRunner.invoke(RandomizedRunner.java:1750)
    	at com.carrotsearch.randomizedtesting.RandomizedRunner$8.evaluate(RandomizedRunner.java:938)
    	at com.carrotsearch.randomizedtesting.RandomizedRunner$9.evaluate(RandomizedRunner.java:974)
    	at com.carrotsearch.randomizedtesting.RandomizedRunner$10.evaluate(RandomizedRunner.java:988)
    	at org.junit.rules.ExpectedException$ExpectedExceptionStatement.evaluate(ExpectedException.java:258)
    	at com.carrotsearch.randomizedtesting.rules.StatementAdapter.evaluate(StatementAdapter.java:36)
    	at org.junit.rules.RunRules.evaluate(RunRules.java:20)
    	at org.apache.lucene.tests.util.TestRuleSetupTeardownChained$1.evaluate(TestRuleSetupTeardownChained.java:48)

@@ -316,22 +321,42 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
* @param inputMapIndex the index of the input mapping to process
* @param batchPredictionListener the listener to be notified when the predictions are processed
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @param queryString
Copy link
Collaborator

@ylwu-amzn ylwu-amzn Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor: add some explanation to this parameter

String modelConfigValue = entry.getValue();
if (StringUtils.isValidJSONPath(modelConfigValue)) {
Object queryJson = JsonPath.parse(queryString).read("$");
Configuration configuration = Configuration
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not efficient to construct same Configuration in a for loop. Move this out of for loop or create a static variable ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch, will move the reading the query string to line 341

String modelConfigKey = entry.getKey();
String modelConfigValue = entry.getValue();
if (StringUtils.isValidJSONPath(modelConfigValue)) {
Object queryJson = JsonPath.parse(queryString).read("$");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need to read "$" first ? , Can we just read the json path directly?

JsonPath.parse(queryString).read(modelConfigValue)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants