Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix post_process_function bug on sort results for rerank_pipeline_with_bge-rerank-m3-v2_model_deployed_on_Sagemaker.md #3277

Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,38 @@ result = predictor.predict(data={
]
})

print(json.dumps(sorted(result, key=lambda x: x['index']), indent=2))
print(json.dumps(result, indent=2))
```

The reranking results are as follows:
The reranking result is ordering by the highest score first:
```
[
{
"index": 2,
"score": 0.92879725
},
{
"index": 0,
"score": 0.013636836
},
{
"index": 1,
"score": 0.000593021
},
{
"index": 3,
"score": 0.00012148176
}
]
```

You can sort the result by index number.

```python
print(json.dumps(result, indent=2))
```

The results are as follows:

```
[
Expand Down Expand Up @@ -121,9 +149,51 @@ POST /_plugins/_ml/connectors/_create
"headers": {
"content-type": "application/json"
},
"request_body": "{ \"query\": \"${parameters.query}\", \"texts\": ${parameters.texts} }",
"pre_process_function": "\n def query_text = params.query_text;\n def text_docs = params.text_docs;\n def textDocsBuilder = new StringBuilder('[');\n for (int i=0; i<text_docs.length; i++) {\n textDocsBuilder.append('\"');\n textDocsBuilder.append(text_docs[i]);\n textDocsBuilder.append('\"');\n if (i<text_docs.length - 1) {\n textDocsBuilder.append(',');\n }\n }\n textDocsBuilder.append(']');\n def parameters = '{ \"query\": \"' + query_text + '\", \"texts\": ' + textDocsBuilder.toString() + ' }';\n return '{\"parameters\": ' + parameters + '}';\n",
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n def sorted_outputs = outputs;\n for (int i=0; i<outputs.length; i++) {\n def idx = new BigDecimal(outputs[i].index.toString()).intValue();\n sorted_outputs[idx] = outputs[i];\n }\n def resultBuilder = new StringBuilder('[');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
"pre_process_function": """
def query_text = params.query_text;
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ "query": "' + query_text + '", "texts": ' + textDocsBuilder.toString() + ' }';
return '{"parameters": ' + parameters + '}';
""",
"request_body": """
{
"query": "${parameters.query}",
"texts": ${parameters.texts}
}
""",
"post_process_function": """
if (params.result == null || params.result.length == 0) {
throw new IllegalArgumentException("Post process function input is empty.");
}
def outputs = params.result;
def scores = new Double[outputs.length];
for (int i=0; i<outputs.length; i++) {
def index = new BigDecimal(outputs[i].index.toString()).intValue();
scores[index] = outputs[i].score;
}
def resultBuilder = new StringBuilder('[');
for (int i=0; i<scores.length; i++) {
resultBuilder.append(' {"name": "similarity", "data_type": "FLOAT32", "shape": [1],');
resultBuilder.append('"data": [');
resultBuilder.append(scores[i]);
resultBuilder.append(']}');
if (i<outputs.length - 1) {
resultBuilder.append(',');
}
}
resultBuilder.append(']');
return resultBuilder.toString();
"""
}
]
}
Expand Down Expand Up @@ -152,9 +222,51 @@ POST /_plugins/_ml/connectors/_create
"headers": {
"content-type": "application/json"
},
"request_body": "{ \"query\": \"${parameters.query}\", \"texts\": ${parameters.texts} }",
"pre_process_function": "\n def query_text = params.query_text;\n def text_docs = params.text_docs;\n def textDocsBuilder = new StringBuilder('[');\n for (int i=0; i<text_docs.length; i++) {\n textDocsBuilder.append('\"');\n textDocsBuilder.append(text_docs[i]);\n textDocsBuilder.append('\"');\n if (i<text_docs.length - 1) {\n textDocsBuilder.append(',');\n }\n }\n textDocsBuilder.append(']');\n def parameters = '{ \"query\": \"' + query_text + '\", \"texts\": ' + textDocsBuilder.toString() + ' }';\n return '{\"parameters\": ' + parameters + '}';\n",
"post_process_function": "\n \n def dataType = \"FLOAT32\";\n \n \n if (params.result == null)\n {\n return 'no result generated';\n //return params.response;\n }\n def outputs = params.result;\n \n def sorted_outputs = outputs;\n for (int i=0; i<outputs.length; i++) {\n def idx = new BigDecimal(outputs[i].index.toString()).intValue();\n sorted_outputs[idx] = outputs[i];\n }\n def resultBuilder = new StringBuilder('[');\n for (int i=0; i<outputs.length; i++) {\n resultBuilder.append(' {\"name\": \"similarity\", \"data_type\": \"FLOAT32\", \"shape\": [1],');\n //resultBuilder.append('{\"name\": \"similarity\"}');\n \n resultBuilder.append('\"data\": [');\n resultBuilder.append(outputs[i].score);\n resultBuilder.append(']}');\n if (i<outputs.length - 1) {\n resultBuilder.append(',');\n }\n }\n resultBuilder.append(']');\n \n return resultBuilder.toString();\n "
"pre_process_function": """
def query_text = params.query_text;
def text_docs = params.text_docs;
def textDocsBuilder = new StringBuilder('[');
for (int i=0; i<text_docs.length; i++) {
textDocsBuilder.append('"');
textDocsBuilder.append(text_docs[i]);
textDocsBuilder.append('"');
if (i<text_docs.length - 1) {
textDocsBuilder.append(',');
}
}
textDocsBuilder.append(']');
def parameters = '{ "query": "' + query_text + '", "texts": ' + textDocsBuilder.toString() + ' }';
return '{"parameters": ' + parameters + '}';
""",
"request_body": """
{
"query": "${parameters.query}",
"texts": ${parameters.texts}
}
""",
"post_process_function": """
if (params.result == null || params.result.length == 0) {
throw new IllegalArgumentException("Post process function input is empty.");
}
def outputs = params.result;
def scores = new Double[outputs.length];
for (int i=0; i<outputs.length; i++) {
def index = new BigDecimal(outputs[i].index.toString()).intValue();
scores[index] = outputs[i].score;
}
def resultBuilder = new StringBuilder('[');
for (int i=0; i<scores.length; i++) {
resultBuilder.append(' {"name": "similarity", "data_type": "FLOAT32", "shape": [1],');
resultBuilder.append('"data": [');
resultBuilder.append(scores[i]);
resultBuilder.append(']}');
if (i<outputs.length - 1) {
resultBuilder.append(',');
}
}
resultBuilder.append(']');
return resultBuilder.toString();
"""
}
]
}
Expand Down Expand Up @@ -188,7 +300,7 @@ POST _plugins/_ml/models/your_model_id/_predict
}
```

Each item in the `inputs` array comprises a `query_text` and a `text_docs` string, separated by a ` . `
Each item in the array comprises a `query_text` and a `text_docs` string, separated by a ` . `

Alternatively, you can test the model as follows:
```json
Expand All @@ -209,6 +321,10 @@ The connector `pre_process_function` transforms the input into the format requir
By default, the SageMaker model output has the following format:
```json
[
{
"index": 2,
"score": 0.92879725
},
{
"index": 0,
"score": 0.013636836
Expand All @@ -217,18 +333,14 @@ By default, the SageMaker model output has the following format:
"index": 1,
"score": 0.000593021
},
{
"index": 2,
"score": 0.92879725
},
{
"index": 3,
"score": 0.00012148176
}
]
```

The connector `post_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpret. This adapted format is as follows:
The connector `post_process_function` transforms the model's output into a format that the [Reranker processor](https://opensearch.org/docs/latest/search-plugins/search-pipelines/rerank-processor/) can interpret, and order result by index. This adapted format is as follows:
```json
{
"inference_results": [
Expand Down
Loading