Skip to content

Commit

Permalink
Integrate watsonx for re-ranking task
Browse files Browse the repository at this point in the history
  • Loading branch information
saikatsarkar056 committed Nov 21, 2024
1 parent a85540a commit 2bc243d
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings;
Expand Down Expand Up @@ -355,6 +357,13 @@ private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entr
IbmWatsonxEmbeddingsServiceSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(ServiceSettings.class, IbmWatsonxRerankServiceSettings.NAME, IbmWatsonxRerankServiceSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
);
}

private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,9 @@ public void execute(
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens());
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);

IbmWatsonxRerankRequest request = new IbmWatsonxRerankRequest(truncator, truncatedInput, model);
IbmWatsonxRerankRequest request = new IbmWatsonxRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model);

execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,41 @@

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;

import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

public class IbmWatsonxRerankRequest implements IbmWatsonxRequest {

private final Truncator truncator;
private final Truncator.TruncationResult truncationResult;
private final String query;
private final List<String> input;
private final IbmWatsonxRerankTaskSettings taskSettings;
private final String modelId;
private final IbmWatsonxRerankModel model;
private final String inferenceEntityId;

public IbmWatsonxRerankRequest(Truncator truncator, Truncator.TruncationResult input, IbmWatsonxRerankModel model) {
this.truncator = Objects.requireNonNull(truncator);
this.truncationResult = Objects.requireNonNull(input);
this.model = Objects.requireNonNull(model);
public IbmWatsonxRerankRequest(String query, List<String> input, IbmWatsonxRerankModel model) {
Objects.requireNonNull(model);

this.input = Objects.requireNonNull(input);
this.query = Objects.requireNonNull(query);
taskSettings = model.getTaskSettings();
this.model = model;
this.modelId = model.getServiceSettings().modelId();
inferenceEntityId = model.getInferenceEntityId();
}

@Override
Expand All @@ -39,12 +52,7 @@ public HttpRequest createHttpRequest() {

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(
new IbmWatsonxRerankRequestEntity(
truncationResult.input(),
model.getServiceSettings().modelId(),
model.getServiceSettings().projectId()
)
).getBytes(StandardCharsets.UTF_8)
new IbmWatsonxRerankRequestEntity(query, input, taskSettings, modelId)).getBytes(StandardCharsets.UTF_8)
);

httpPost.setEntity(byteEntity);
Expand All @@ -59,21 +67,9 @@ public void decorateWithAuth(HttpPost httpPost) {
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());
}

public Truncator truncator() {
return truncator;
}

public Truncator.TruncationResult truncationResult() {
return truncationResult;
}

public IbmWatsonxRerankModel model() {
return model;
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
return inferenceEntityId;
}

@Override
Expand All @@ -83,13 +79,18 @@ public URI getURI() {

@Override
public Request truncate() {
var truncatedInput = truncator.truncate(truncationResult.input());

return new IbmWatsonxRerankRequest(truncator, truncatedInput, model);
return this; // TODO?
}

@Override
public boolean[] getTruncationInfo() {
return truncationResult.truncated().clone();
return null;
}

public static URI buildDefaultUri() throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(CohereUtils.HOST)
.setPathSegments(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,50 @@

import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public record IbmWatsonxRerankRequestEntity(List<String> inputs, String modelId, String projectId) implements ToXContentObject {
public record IbmWatsonxRerankRequestEntity(String model, String query, List<String> documents, IbmWatsonxRerankTaskSettings taskSettings)
implements ToXContentObject {

private static final String INPUTS_FIELD = "inputs";
private static final String MODEL_ID_FIELD = "model_id";
private static final String PROJECT_ID_FIELD = "project_id";
private static final String DOCUMENTS_FIELD = "documents";
private static final String QUERY_FIELD = "query";
private static final String MODEL_FIELD = "model";

public IbmWatsonxRerankRequestEntity {
Objects.requireNonNull(inputs);
Objects.requireNonNull(modelId);
Objects.requireNonNull(projectId);
Objects.requireNonNull(query);
Objects.requireNonNull(documents);
Objects.requireNonNull(taskSettings);
}

public IbmWatsonxRerankRequestEntity(String query, List<String> input, IbmWatsonxRerankTaskSettings taskSettings, String model) {
this(model, query, input, taskSettings);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(INPUTS_FIELD, inputs);
builder.field(MODEL_ID_FIELD, modelId);
builder.field(PROJECT_ID_FIELD, projectId);

builder.field(MODEL_FIELD, model);
builder.field(QUERY_FIELD, query);
builder.field(DOCUMENTS_FIELD, documents);

if (taskSettings.getDoesReturnDocuments() != null) {
builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments());
}

if (taskSettings.getTopNDocumentsOnly() != null) {
builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly());
}

if (taskSettings.getMaxChunksPerDoc() != null) {
builder.field(CohereRerankTaskSettings.MAX_CHUNKS_PER_DOC, taskSettings.getMaxChunksPerDoc());
}

builder.endObject();

return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@
import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.util.EnumSet;
import java.util.HashMap;
Expand Down Expand Up @@ -126,6 +128,7 @@ private static IbmWatsonxModel createModel(
secretSettings,
context
);
case RERANK -> new IbmWatsonxRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
Expand Down

0 comments on commit 2bc243d

Please sign in to comment.