diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 2320cca8295d1..879f2b6318031 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -78,6 +78,8 @@ import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings; import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; @@ -355,6 +357,13 @@ private static void addIbmWatsonxNamedWritables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxRerankRequestManager.java index f809b17a0bfee..f6ab0af8be885 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxRerankRequestManager.java @@ -55,10 +55,9 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); - var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); - IbmWatsonxRerankRequest request = new IbmWatsonxRerankRequest(truncator, truncatedInput, model); + IbmWatsonxRerankRequest request = new IbmWatsonxRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java index 2ad1f388d34b2..716e10b9d44cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java @@ -9,28 +9,41 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils; import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings; import java.net.URI; +import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; +import java.util.List; import java.util.Objects; public class IbmWatsonxRerankRequest implements IbmWatsonxRequest { - private final Truncator truncator; - private final Truncator.TruncationResult truncationResult; + private final String query; + private final List input; + private final IbmWatsonxRerankTaskSettings taskSettings; + private final String modelId; private final IbmWatsonxRerankModel model; + private final String inferenceEntityId; - public IbmWatsonxRerankRequest(Truncator truncator, Truncator.TruncationResult input, IbmWatsonxRerankModel model) { - this.truncator = Objects.requireNonNull(truncator); - this.truncationResult = Objects.requireNonNull(input); - this.model = Objects.requireNonNull(model); + public IbmWatsonxRerankRequest(String query, List input, IbmWatsonxRerankModel model) { + Objects.requireNonNull(model); + + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + taskSettings = model.getTaskSettings(); + this.model = model; + this.modelId = model.getServiceSettings().modelId(); + inferenceEntityId = model.getInferenceEntityId(); } @Override @@ -39,12 +52,7 @@ public HttpRequest createHttpRequest() { ByteArrayEntity byteEntity = new ByteArrayEntity( Strings.toString( - new IbmWatsonxRerankRequestEntity( - truncationResult.input(), - model.getServiceSettings().modelId(), - model.getServiceSettings().projectId() - ) - ).getBytes(StandardCharsets.UTF_8) + new IbmWatsonxRerankRequestEntity(query, input, taskSettings, modelId)).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); @@ -59,21 +67,9 @@ public void decorateWithAuth(HttpPost httpPost) { IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId()); } - public Truncator truncator() { - return truncator; - } - - public Truncator.TruncationResult truncationResult() { - return truncationResult; - } - - public IbmWatsonxRerankModel model() { - return model; - } - @Override public String getInferenceEntityId() { - return model.getInferenceEntityId(); + return inferenceEntityId; } @Override @@ -83,13 +79,18 @@ public URI getURI() { @Override public Request truncate() { - var truncatedInput = truncator.truncate(truncationResult.input()); - - return new IbmWatsonxRerankRequest(truncator, truncatedInput, model); + return this; // TODO? } @Override public boolean[] getTruncationInfo() { - return truncationResult.truncated().clone(); + return null; + } + + public static URI buildDefaultUri() throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(CohereUtils.HOST) + .setPathSegments(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH) + .build(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequestEntity.java index a07904382ea3a..cbcdd1ab2c45a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequestEntity.java @@ -9,29 +9,50 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings; import java.io.IOException; import java.util.List; import java.util.Objects; -public record IbmWatsonxRerankRequestEntity(List inputs, String modelId, String projectId) implements ToXContentObject { +public record IbmWatsonxRerankRequestEntity(String model, String query, List documents, IbmWatsonxRerankTaskSettings taskSettings) + implements ToXContentObject { - private static final String INPUTS_FIELD = "inputs"; - private static final String MODEL_ID_FIELD = "model_id"; - private static final String PROJECT_ID_FIELD = "project_id"; + private static final String DOCUMENTS_FIELD = "documents"; + private static final String QUERY_FIELD = "query"; + private static final String MODEL_FIELD = "model"; public IbmWatsonxRerankRequestEntity { - Objects.requireNonNull(inputs); - Objects.requireNonNull(modelId); - Objects.requireNonNull(projectId); + Objects.requireNonNull(query); + Objects.requireNonNull(documents); + Objects.requireNonNull(taskSettings); + } + + public IbmWatsonxRerankRequestEntity(String query, List input, IbmWatsonxRerankTaskSettings taskSettings, String model) { + this(model, query, input, taskSettings); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(INPUTS_FIELD, inputs); - builder.field(MODEL_ID_FIELD, modelId); - builder.field(PROJECT_ID_FIELD, projectId); + + builder.field(MODEL_FIELD, model); + builder.field(QUERY_FIELD, query); + builder.field(DOCUMENTS_FIELD, documents); + + if (taskSettings.getDoesReturnDocuments() != null) { + builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments()); + } + + if (taskSettings.getTopNDocumentsOnly() != null) { + builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly()); + } + + if (taskSettings.getMaxChunksPerDoc() != null) { + builder.field(CohereRerankTaskSettings.MAX_CHUNKS_PER_DOC, taskSettings.getMaxChunksPerDoc()); + } + builder.endObject(); return builder; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index e960b0b777f2b..c56feb3578945 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -40,8 +40,10 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; import java.util.EnumSet; import java.util.HashMap; @@ -126,6 +128,7 @@ private static IbmWatsonxModel createModel( secretSettings, context ); + case RERANK -> new IbmWatsonxRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; }