Skip to content

Commit

Permalink
Integrate watsonx for re-ranking task
Browse files Browse the repository at this point in the history
  • Loading branch information
saikatsarkar056 committed Nov 19, 2024
1 parent bcd690f commit 63def28
Show file tree
Hide file tree
Showing 10 changed files with 1,132 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,31 @@
*/

package org.elasticsearch.xpack.inference.external.action.ibmwatsonx;

import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxRerankRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;

import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;

/**
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the cohere model type.
*/
public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {

private static final String COMPLETION_ERROR_PREFIX = "Ibm Watsonx completion";
private final Sender sender;
private final ServiceComponents serviceComponents;

public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponents) {
// TODO Batching - accept a class that can handle batching
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
}
Expand All @@ -41,6 +45,17 @@ public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Obje
);
}

@Override
public ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings);
var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().uri(),
"Ibm Watsonx rerank"
);
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager(
IbmWatsonxEmbeddingsModel model,
Truncator truncator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.util.Map;

public interface IbmWatsonxActionVisitor {
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);

ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.ibmwatsonx.IbmWatsonxResponseHandler;
import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxRerankRequest;
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxRankedResponseEntity;

import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.util.Objects;
import java.util.function.Supplier;

public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager {
private static final Logger logger = LogManager.getLogger(IbmWatsonxRerankRequestManager.class);
private static final ResponseHandler HANDLER = createIbmWatsonxResponseHandler();

private static ResponseHandler createIbmWatsonxResponseHandler() {
return new IbmWatsonxResponseHandler("ibm watsonx rerank", (request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response), false);
}

public static IbmWatsonxRerankRequestManager of(IbmWatsonxRerankModel model, ThreadPool threadPool) {
return new IbmWatsonxRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final IbmWatsonxRerankModel model;

private IbmWatsonxRerankRequestManager(IbmWatsonxRerankModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = model;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);
IbmWatsonxRerankRequest request = new IbmWatsonxRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;

public class IbmWatsonxRerankRequest implements IbmWatsonxRequest {

private final Truncator truncator;
private final Truncator.TruncationResult truncationResult;
private final IbmWatsonxRerankModel model;

public IbmWatsonxRerankRequest(Truncator truncator, Truncator.TruncationResult input, IbmWatsonxRerankModel model) {
this.truncator = Objects.requireNonNull(truncator);
this.truncationResult = Objects.requireNonNull(input);
this.model = Objects.requireNonNull(model);
}

@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(
new IbmWatsonxRerankRequestEntity(
truncationResult.input(),
model.getServiceSettings().modelId(),
model.getServiceSettings().projectId()
)
).getBytes(StandardCharsets.UTF_8)
);

httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());

decorateWithAuth(httpPost);

return new HttpRequest(httpPost, getInferenceEntityId());
}

public void decorateWithAuth(HttpPost httpPost) {
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());
}

public Truncator truncator() {
return truncator;
}

public Truncator.TruncationResult truncationResult() {
return truncationResult;
}

public IbmWatsonxRerankModel model() {
return model;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
var truncatedInput = truncator.truncate(truncationResult.input());

return new IbmWatsonxRerankRequest(truncator, truncatedInput, model);
}

@Override
public boolean[] getTruncationInfo() {
return truncationResult.truncated().clone();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;

import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public record IbmWatsonxRerankRequestEntity(List<String> inputs, String modelId, String projectId) implements ToXContentObject {

private static final String INPUTS_FIELD = "inputs";
private static final String MODEL_ID_FIELD = "model_id";
private static final String PROJECT_ID_FIELD = "project_id";

public IbmWatsonxRerankRequestEntity {
Objects.requireNonNull(inputs);
Objects.requireNonNull(modelId);
Objects.requireNonNull(projectId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(INPUTS_FIELD, inputs);
builder.field(MODEL_ID_FIELD, modelId);
builder.field(PROJECT_ID_FIELD, projectId);
builder.endObject();

return builder;
}
}
Loading

0 comments on commit 63def28

Please sign in to comment.