forked from elastic/elasticsearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate watsonx for re-ranking task
- Loading branch information
1 parent
bcd690f
commit 63def28
Showing
10 changed files
with
1,132 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
...rg/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxRerankRequestManager.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.http.sender; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.elasticsearch.action.ActionListener; | ||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.threadpool.ThreadPool; | ||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; | ||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; | ||
import org.elasticsearch.xpack.inference.external.ibmwatsonx.IbmWatsonxResponseHandler; | ||
import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxRerankRequest; | ||
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxRankedResponseEntity; | ||
|
||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; | ||
|
||
import java.util.Objects; | ||
import java.util.function.Supplier; | ||
|
||
public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager { | ||
private static final Logger logger = LogManager.getLogger(IbmWatsonxRerankRequestManager.class); | ||
private static final ResponseHandler HANDLER = createIbmWatsonxResponseHandler(); | ||
|
||
private static ResponseHandler createIbmWatsonxResponseHandler() { | ||
return new IbmWatsonxResponseHandler("ibm watsonx rerank", (request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response), false); | ||
} | ||
|
||
public static IbmWatsonxRerankRequestManager of(IbmWatsonxRerankModel model, ThreadPool threadPool) { | ||
return new IbmWatsonxRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); | ||
} | ||
|
||
private final IbmWatsonxRerankModel model; | ||
|
||
private IbmWatsonxRerankRequestManager(IbmWatsonxRerankModel model, ThreadPool threadPool) { | ||
super(threadPool, model); | ||
this.model = model; | ||
} | ||
|
||
@Override | ||
public void execute( | ||
InferenceInputs inferenceInputs, | ||
RequestSender requestSender, | ||
Supplier<Boolean> hasRequestCompletedFunction, | ||
ActionListener<InferenceServiceResults> listener | ||
) { | ||
var rerankInput = QueryAndDocsInputs.of(inferenceInputs); | ||
IbmWatsonxRerankRequest request = new IbmWatsonxRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); | ||
|
||
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); | ||
} | ||
} |
95 changes: 95 additions & 0 deletions
95
...rg/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.request.ibmwatsonx; | ||
|
||
import org.apache.http.HttpHeaders; | ||
import org.apache.http.client.methods.HttpPost; | ||
import org.apache.http.entity.ByteArrayEntity; | ||
import org.elasticsearch.common.Strings; | ||
import org.elasticsearch.xcontent.XContentType; | ||
import org.elasticsearch.xpack.inference.common.Truncator; | ||
import org.elasticsearch.xpack.inference.external.request.HttpRequest; | ||
import org.elasticsearch.xpack.inference.external.request.Request; | ||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; | ||
|
||
import java.net.URI; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.Objects; | ||
|
||
public class IbmWatsonxRerankRequest implements IbmWatsonxRequest { | ||
|
||
private final Truncator truncator; | ||
private final Truncator.TruncationResult truncationResult; | ||
private final IbmWatsonxRerankModel model; | ||
|
||
public IbmWatsonxRerankRequest(Truncator truncator, Truncator.TruncationResult input, IbmWatsonxRerankModel model) { | ||
this.truncator = Objects.requireNonNull(truncator); | ||
this.truncationResult = Objects.requireNonNull(input); | ||
this.model = Objects.requireNonNull(model); | ||
} | ||
|
||
@Override | ||
public HttpRequest createHttpRequest() { | ||
HttpPost httpPost = new HttpPost(model.uri()); | ||
|
||
ByteArrayEntity byteEntity = new ByteArrayEntity( | ||
Strings.toString( | ||
new IbmWatsonxRerankRequestEntity( | ||
truncationResult.input(), | ||
model.getServiceSettings().modelId(), | ||
model.getServiceSettings().projectId() | ||
) | ||
).getBytes(StandardCharsets.UTF_8) | ||
); | ||
|
||
httpPost.setEntity(byteEntity); | ||
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); | ||
|
||
decorateWithAuth(httpPost); | ||
|
||
return new HttpRequest(httpPost, getInferenceEntityId()); | ||
} | ||
|
||
public void decorateWithAuth(HttpPost httpPost) { | ||
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId()); | ||
} | ||
|
||
public Truncator truncator() { | ||
return truncator; | ||
} | ||
|
||
public Truncator.TruncationResult truncationResult() { | ||
return truncationResult; | ||
} | ||
|
||
public IbmWatsonxRerankModel model() { | ||
return model; | ||
} | ||
|
||
@Override | ||
public String getInferenceEntityId() { | ||
return model.getInferenceEntityId(); | ||
} | ||
|
||
@Override | ||
public URI getURI() { | ||
return model.uri(); | ||
} | ||
|
||
@Override | ||
public Request truncate() { | ||
var truncatedInput = truncator.truncate(truncationResult.input()); | ||
|
||
return new IbmWatsonxRerankRequest(truncator, truncatedInput, model); | ||
} | ||
|
||
@Override | ||
public boolean[] getTruncationInfo() { | ||
return truncationResult.truncated().clone(); | ||
} | ||
} |
39 changes: 39 additions & 0 deletions
39
...sticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequestEntity.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.request.ibmwatsonx; | ||
|
||
import org.elasticsearch.xcontent.ToXContentObject; | ||
import org.elasticsearch.xcontent.XContentBuilder; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
public record IbmWatsonxRerankRequestEntity(List<String> inputs, String modelId, String projectId) implements ToXContentObject { | ||
|
||
private static final String INPUTS_FIELD = "inputs"; | ||
private static final String MODEL_ID_FIELD = "model_id"; | ||
private static final String PROJECT_ID_FIELD = "project_id"; | ||
|
||
public IbmWatsonxRerankRequestEntity { | ||
Objects.requireNonNull(inputs); | ||
Objects.requireNonNull(modelId); | ||
Objects.requireNonNull(projectId); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
builder.field(INPUTS_FIELD, inputs); | ||
builder.field(MODEL_ID_FIELD, modelId); | ||
builder.field(PROJECT_ID_FIELD, projectId); | ||
builder.endObject(); | ||
|
||
return builder; | ||
} | ||
} |
Oops, something went wrong.