From 63def2872fab0824c17db12fcf96a5fb7ace34eb Mon Sep 17 00:00:00 2001 From: Saikat Sarkar Date: Tue, 19 Nov 2024 15:52:07 -0700 Subject: [PATCH] Integrate watsonx for re-ranking task --- .../ibmwatsonx/IbmWatsonxActionCreator.java | 21 +- .../ibmwatsonx/IbmWatsonxActionVisitor.java | 3 + .../IbmWatsonxRerankRequestManager.java | 57 +++++ .../ibmwatsonx/IbmWatsonxRerankRequest.java | 95 +++++++ .../IbmWatsonxRerankRequestEntity.java | 39 +++ .../IbmWatsonxRankedResponseEntity.java | 165 ++++++++++++ .../ibmwatsonx/IbmWatsonxServiceSettings.java | 235 ++++++++++++++++++ .../rerank/IbmWatsonxRerankModel.java | 144 +++++++++++ .../IbmWatsonxRerankServiceSettings.java | 187 ++++++++++++++ .../rerank/IbmWatsonxRerankTaskSettings.java | 189 ++++++++++++++ 10 files changed, 1132 insertions(+), 3 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxRerankRequestManager.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ibmwatsonx/IbmWatsonxRankedResponseEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankTaskSettings.java diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionCreator.java index 7cad7c42bdcf1..e10ffc1aea0ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionCreator.java @@ -6,27 +6,31 @@ */ package org.elasticsearch.xpack.inference.external.action.ibmwatsonx; - import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.common.Truncator; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxRerankRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel; - +import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; import java.util.Map; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the cohere model type. + */ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor { - + private static final String COMPLETION_ERROR_PREFIX = "Ibm Watsonx completion"; private final Sender sender; private final ServiceComponents serviceComponents; public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponents) { + // TODO Batching - accept a class that can handle batching this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); } @@ -41,6 +45,17 @@ public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map taskSettings) { + var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings); + var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool()); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( + overriddenModel.getServiceSettings().uri(), + "Ibm Watsonx rerank" + ); + return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + } + protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager( IbmWatsonxEmbeddingsModel model, Truncator truncator, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionVisitor.java index 0a13ec2fb4645..474533040e0c3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxActionVisitor.java @@ -9,9 +9,12 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; import java.util.Map; public interface IbmWatsonxActionVisitor { ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map taskSettings); + + ExecutableAction create(IbmWatsonxRerankModel model, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxRerankRequestManager.java new file mode 100644 index 0000000000000..dd09ad6484dfc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxRerankRequestManager.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.ibmwatsonx.IbmWatsonxResponseHandler; +import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxRerankRequest; +import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxRankedResponseEntity; + +import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager { + private static final Logger logger = LogManager.getLogger(IbmWatsonxRerankRequestManager.class); + private static final ResponseHandler HANDLER = createIbmWatsonxResponseHandler(); + + private static ResponseHandler createIbmWatsonxResponseHandler() { + return new IbmWatsonxResponseHandler("ibm watsonx rerank", (request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response), false); + } + + public static IbmWatsonxRerankRequestManager of(IbmWatsonxRerankModel model, ThreadPool threadPool) { + return new IbmWatsonxRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final IbmWatsonxRerankModel model; + + private IbmWatsonxRerankRequestManager(IbmWatsonxRerankModel model, ThreadPool threadPool) { + super(threadPool, model); + this.model = model; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); + IbmWatsonxRerankRequest request = new IbmWatsonxRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java new file mode 100644 index 0000000000000..2ad1f388d34b2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.ibmwatsonx; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class IbmWatsonxRerankRequest implements IbmWatsonxRequest { + + private final Truncator truncator; + private final Truncator.TruncationResult truncationResult; + private final IbmWatsonxRerankModel model; + + public IbmWatsonxRerankRequest(Truncator truncator, Truncator.TruncationResult input, IbmWatsonxRerankModel model) { + this.truncator = Objects.requireNonNull(truncator); + this.truncationResult = Objects.requireNonNull(input); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString( + new IbmWatsonxRerankRequestEntity( + truncationResult.input(), + model.getServiceSettings().modelId(), + model.getServiceSettings().projectId() + ) + ).getBytes(StandardCharsets.UTF_8) + ); + + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + + decorateWithAuth(httpPost); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + public void decorateWithAuth(HttpPost httpPost) { + IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId()); + } + + public Truncator truncator() { + return truncator; + } + + public Truncator.TruncationResult truncationResult() { + return truncationResult; + } + + public IbmWatsonxRerankModel model() { + return model; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + + return new IbmWatsonxRerankRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequestEntity.java new file mode 100644 index 0000000000000..a07904382ea3a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequestEntity.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.ibmwatsonx; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record IbmWatsonxRerankRequestEntity(List inputs, String modelId, String projectId) implements ToXContentObject { + + private static final String INPUTS_FIELD = "inputs"; + private static final String MODEL_ID_FIELD = "model_id"; + private static final String PROJECT_ID_FIELD = "project_id"; + + public IbmWatsonxRerankRequestEntity { + Objects.requireNonNull(inputs); + Objects.requireNonNull(modelId); + Objects.requireNonNull(projectId); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUTS_FIELD, inputs); + builder.field(MODEL_ID_FIELD, modelId); + builder.field(PROJECT_ID_FIELD, projectId); + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ibmwatsonx/IbmWatsonxRankedResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ibmwatsonx/IbmWatsonxRankedResponseEntity.java new file mode 100644 index 0000000000000..33b3105992fa0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/ibmwatsonx/IbmWatsonxRankedResponseEntity.java @@ -0,0 +1,165 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.inference.external.response.ibmwatsonx; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.io.IOException; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownField; +import static org.elasticsearch.common.xcontent.XContentParserUtils.throwUnknownToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class IbmWatsonxRankedResponseEntity { + + private static final Logger logger = LogManager.getLogger(IbmWatsonxRankedResponseEntity.class); + + /** + * Parses the Ibm Watsonx ranked response. + * + * For a request like: + * "model": "rerank-english-v2.0", + * "query": "What is the capital of the United States?", + * "return_documents": true, + * "top_n": 3, + * "documents": ["Carson City is the capital city of the American state of Nevada.", + * "The Commonwealth of the Northern Mariana ... Its capital is Saipan.", + * "Washington, D.C. (also known as simply Washington or D.C., ... It is a federal district.", + * "Capital punishment (the death penalty) ... As of 2017, capital punishment is legal in 30 of the 50 states."] + *

+ * The response will look like (without whitespace): + * { + * "id": "1983d114-a6e8-4940-b121-eb4ac3f6f703", + * "results": [ + * { + * "document": { + * "text": "Washington, D.C. is the capital of the United States. It is a federal district." + * }, + * "index": 2, + * "relevance_score": 0.98005307 + * }, + * { + * "document": { + * "text": "Capital punishment (the death penalty) As of 2017, capital punishment is legal in 30 of the 50 states." + * }, + * "index": 3, + * "relevance_score": 0.27904198 + * }, + * { + * "document": { + * "text": "Carson City is the capital city of the American state of Nevada." + * }, + * "index": 0, + * "relevance_score": 0.10194652 + * } + * ], + * "meta": { + * "api_version": { + * "version": "1" + * }, + * "billed_units": { + * "search_units": 1 + * } + * } + * + * @param response the http response from ibm watsonx + * @return the parsed response + * @throws IOException if there is an error parsing the response + */ + public static InferenceServiceResults fromResponse(HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "results", FAILED_TO_FIND_FIELD_TEMPLATE); // TODO error message + + token = jsonParser.currentToken(); + if (token == XContentParser.Token.START_ARRAY) { + return new RankedDocsResults(parseList(jsonParser, IbmWatsonxRankedResponseEntity::parseRankedDocObject)); + } else { + throwUnknownToken(token, jsonParser); + } + + // This should never be reached. The above code should either return successfully or hit the throwUnknownToken + // or throw a parsing exception + throw new IllegalStateException("Reached an invalid state while parsing the Cohere response"); + } + } + + private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser parser) throws IOException { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + int index = -1; + float relevanceScore = -1; + String documentText = null; + parser.nextToken(); + while (parser.currentToken() != XContentParser.Token.END_OBJECT) { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME) { + switch (parser.currentName()) { + case "index": + parser.nextToken(); // move to VALUE_NUMBER + index = parser.intValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "relevance_score": + parser.nextToken(); // move to VALUE_NUMBER + relevanceScore = parser.floatValue(); + parser.nextToken(); // move to next FIELD_NAME or END_OBJECT + break; + case "document": + parser.nextToken(); // move to START_OBJECT; document text is wrapped in an object + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + do { + if (parser.currentToken() == XContentParser.Token.FIELD_NAME && parser.currentName().equals("text")) { + parser.nextToken(); // move to VALUE_STRING + documentText = parser.text(); + } + } while (parser.nextToken() != XContentParser.Token.END_OBJECT); + parser.nextToken();// move past END_OBJECT + // parser should now be at the next FIELD_NAME or END_OBJECT + break; + default: + throwUnknownField(parser.currentName(), parser); + } + } else { + parser.nextToken(); + } + } + + if (index == -1) { + logger.warn("Failed to find required field [index] in Cohere rerank response"); + } + if (relevanceScore == -1) { + logger.warn("Failed to find required field [relevance_score] in Cohere rerank response"); + } + // documentText may or may not be present depending on the request parameter + + return new RankedDocsResults.RankedDoc(index, relevanceScore, documentText); + } + + private IbmWatsonxRankedResponseEntity() {} + + static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere rerank response"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceSettings.java new file mode 100644 index 0000000000000..ff6ee2825531e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceSettings.java @@ -0,0 +1,235 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.*; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.*; + +public class IbmWatsonxServiceSettings extends FilteredXContentObject implements ServiceSettings, IbmWatsonxRateLimitServiceSettings { + + public static final String NAME = "ibm_watsonx_service_settings"; + public static final String OLD_MODEL_ID_FIELD = "model"; + public static final String MODEL_ID = "model_id"; + private static final Logger logger = LogManager.getLogger(IbmWatsonxServiceSettings.class); + public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); + + public static IbmWatsonxServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + String oldModelId = extractOptionalString(map, OLD_MODEL_ID_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + IbmWatsonxService.NAME, + context + ); + + String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + + if (context == ConfigurationParseContext.REQUEST && oldModelId != null) { + logger.info("The ibm watsonx [service_settings.model] field is deprecated. Please use [service_settings.model_id] instead."); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new IbmWatsonxServiceSettings(uri, similarity, dims, maxInputTokens, modelId(oldModelId, modelId), rateLimitSettings); + } + + private static String modelId(@Nullable String model, @Nullable String modelId) { + return modelId != null ? modelId : model; + } + + private final URI uri; + private final SimilarityMeasure similarity; + private final Integer dimensions; + private final Integer maxInputTokens; + private final String modelId; + private final RateLimitSettings rateLimitSettings; + + public IbmWatsonxServiceSettings( + @Nullable URI uri, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable String modelId, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.uri = uri; + this.similarity = similarity; + this.dimensions = dimensions; + this.maxInputTokens = maxInputTokens; + this.modelId = modelId; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public IbmWatsonxServiceSettings( + @Nullable String url, + @Nullable SimilarityMeasure similarity, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable String modelId, + @Nullable RateLimitSettings rateLimitSettings + ) { + this(createOptionalUri(url), similarity, dimensions, maxInputTokens, modelId, rateLimitSettings); + } + + public IbmWatsonxServiceSettings(StreamInput in) throws IOException { + uri = createOptionalUri(in.readOptionalString()); + similarity = in.readOptionalEnum(SimilarityMeasure.class); + dimensions = in.readOptionalVInt(); + maxInputTokens = in.readOptionalVInt(); + modelId = in.readOptionalString(); + + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { + rateLimitSettings = new RateLimitSettings(in); + } else { + rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; + } + } + + // should only be used for testing, public because it's accessed outside of the package + public IbmWatsonxServiceSettings() { + this((URI) null, null, null, null, null, null); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + public URI uri() { + return uri; + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragment(builder, params); + + builder.endObject(); + return builder; + } + + public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { + return toXContentFragmentOfExposedFields(builder, params); + } + + @Override + public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + if (uri != null) { + builder.field(URL, uri.toString()); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + if (modelId != null) { + builder.field(MODEL_ID, modelId); + } + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_13_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + var uriToWrite = uri != null ? uri.toString() : null; + out.writeOptionalString(uriToWrite); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(dimensions); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalString(modelId); + + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { + rateLimitSettings.writeTo(out); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IbmWatsonxServiceSettings that = (IbmWatsonxServiceSettings) o; + return Objects.equals(uri, that.uri) + && Objects.equals(similarity, that.similarity) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(modelId, that.modelId) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(uri, similarity, dimensions, maxInputTokens, modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankModel.java new file mode 100644 index 0000000000000..f51d501f6bed7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankModel.java @@ -0,0 +1,144 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank; + +import org.elasticsearch.common.util.LazyInitializable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SettingsConfiguration; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; +import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.ibmwatsonx.IbmWatsonxActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings.RETURN_DOCUMENTS; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings.TOP_N_DOCS_ONLY; + +public class IbmWatsonxRerankModel extends IbmWatsonxModel { + public static IbmWatsonxRerankModel of(IbmWatsonxRerankModel model, Map taskSettings) { + var requestTaskSettings = IbmWatsonxRerankTaskSettings.fromMap(taskSettings); + return new IbmWatsonxRerankModel(model, IbmWatsonxRerankTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public IbmWatsonxRerankModel( + String modelId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + modelId, + taskType, + service, + IbmWatsonxRerankServiceSettings.fromMap(serviceSettings, context), + IbmWatsonxRerankTaskSettings.fromMap(taskSettings), + DefaultSecretSettings.fromMap(secrets) + ); + } + + // should only be used for testing + IbmWatsonxRerankModel( + String modelId, + TaskType taskType, + String service, + IbmWatsonxRerankServiceSettings serviceSettings, + IbmWatsonxRerankTaskSettings taskSettings, + @Nullable DefaultSecretSettings secretSettings + ) { + super( + new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secretSettings), + serviceSettings + ); + } + + private IbmWatsonxRerankModel(IbmWatsonxRerankModel model, IbmWatsonxRerankTaskSettings taskSettings) { + super(model, taskSettings); + } + + public IbmWatsonxRerankModel(IbmWatsonxRerankModel model, IbmWatsonxRerankServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public IbmWatsonxRerankServiceSettings getServiceSettings() { + return (IbmWatsonxRerankServiceSettings) super.getServiceSettings(); + } + + @Override + public IbmWatsonxRerankTaskSettings getTaskSettings() { + return (IbmWatsonxRerankTaskSettings) super.getTaskSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + /** + * Accepts a visitor to create an executable action. The returned action will not return documents in the response. + * @param visitor _ + * @param taskSettings _ + * @param inputType ignored for rerank task + * @return the rerank action + */ + @Override + public ExecutableAction accept(IbmWatsonxActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings); + } + + public static class Configuration { + public static Map get() { + return configuration.getOrCompute(); + } + + private static final LazyInitializable, RuntimeException> configuration = + new LazyInitializable<>(() -> { + var configurationMap = new HashMap(); + + configurationMap.put( + RETURN_DOCUMENTS, + new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.TOGGLE) + .setLabel("Return Documents") + .setOrder(1) + .setRequired(false) + .setSensitive(false) + .setTooltip("Specify whether to return doc text within the results.") + .setType(SettingsConfigurationFieldType.BOOLEAN) + .setValue(false) + .build() + ); + configurationMap.put( + TOP_N_DOCS_ONLY, + new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.NUMERIC) + .setLabel("Top N") + .setOrder(2) + .setRequired(false) + .setSensitive(false) + .setTooltip("The number of most relevant documents to return, defaults to the number of the documents.") + .setType(SettingsConfigurationFieldType.INTEGER) + .build() + ); + + return Collections.unmodifiableMap(configurationMap); + }); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankServiceSettings.java new file mode 100644 index 0000000000000..792cd7ef02492 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankServiceSettings.java @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.*; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.*; +import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS; + +public class IbmWatsonxRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, IbmWatsonxRateLimitServiceSettings { + public static final String NAME = "ibm_watsonx_rerank_service_settings"; + + private static final Logger logger = LogManager.getLogger(IbmWatsonxRerankServiceSettings.class); + + public static IbmWatsonxRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + + // We need to extract/remove those fields to avoid unknown service settings errors + extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + removeAsType(map, DIMENSIONS, Integer.class); + removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + + URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); + String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + IbmWatsonxService.NAME, + context + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new IbmWatsonxRerankServiceSettings(uri, modelId, rateLimitSettings); + } + + private final URI uri; + + private final String modelId; + + private final RateLimitSettings rateLimitSettings; + + public IbmWatsonxRerankServiceSettings(@Nullable URI uri, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this.uri = uri; + this.modelId = modelId; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public IbmWatsonxRerankServiceSettings(@Nullable String url, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { + this(createOptionalUri(url), modelId, rateLimitSettings); + } + + public IbmWatsonxRerankServiceSettings(StreamInput in) throws IOException { + this.uri = createOptionalUri(in.readOptionalString()); + + if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED)) { + // An older node sends these fields, so we need to skip them to progress through the serialized data + in.readOptionalEnum(SimilarityMeasure.class); + in.readOptionalVInt(); + in.readOptionalVInt(); + } + + this.modelId = in.readOptionalString(); + + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { + this.rateLimitSettings = new RateLimitSettings(in); + } else { + this.rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; + } + } + + public URI uri() { + return uri; + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + if (uri != null) { + builder.field(URL, uri.toString()); + } + + if (modelId != null) { + builder.field(MODEL_ID, modelId); + } + + rateLimitSettings.toXContent(builder, params); + + return builder; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_14_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + var uriToWrite = uri != null ? uri.toString() : null; + out.writeOptionalString(uriToWrite); + + if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED)) { + // An old node expects this data to be present, so we need to send at least the booleans + // indicating that the fields are not set + out.writeOptionalEnum(null); + out.writeOptionalVInt(null); + out.writeOptionalVInt(null); + } + + out.writeOptionalString(modelId); + + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { + rateLimitSettings.writeTo(out); + } + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + IbmWatsonxRerankServiceSettings that = (IbmWatsonxRerankServiceSettings) object; + return Objects.equals(uri, that.uri) + && Objects.equals(modelId, that.modelId) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(uri, modelId, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankTaskSettings.java new file mode 100644 index 0000000000000..ba1df5a4382b1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/rerank/IbmWatsonxRerankTaskSettings.java @@ -0,0 +1,189 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; + + +public class IbmWatsonxRerankTaskSettings implements TaskSettings { + + public static final String NAME = "ibm_watsonx_rerank_task_settings"; + public static final String RETURN_DOCUMENTS = "return_documents"; + public static final String TOP_N_DOCS_ONLY = "top_n"; + public static final String MAX_CHUNKS_PER_DOC = "max_chunks_per_doc"; + + static final IbmWatsonxRerankTaskSettings EMPTY_SETTINGS = new IbmWatsonxRerankTaskSettings(null, null, null); + + public static IbmWatsonxRerankTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + if (map == null || map.isEmpty()) { + return EMPTY_SETTINGS; + } + + Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, validationException); + Integer topNDocumentsOnly = extractOptionalPositiveInteger( + map, + TOP_N_DOCS_ONLY, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + Integer maxChunksPerDoc = extractOptionalPositiveInteger( + map, + MAX_CHUNKS_PER_DOC, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return of(topNDocumentsOnly, returnDocuments, maxChunksPerDoc); + } + + /** + * Creates a new {@link IbmWatsonxRerankTaskSettings} by preferring non-null fields from the request settings over the original settings. + * + * @param originalSettings the settings stored as part of the inference entity configuration + * @param requestTaskSettings the settings passed in within the task_settings field of the request + * @return a constructed {@link IbmWatsonxRerankTaskSettings} + */ + public static IbmWatsonxRerankTaskSettings of(IbmWatsonxRerankTaskSettings originalSettings, IbmWatsonxRerankTaskSettings requestTaskSettings) { + return new IbmWatsonxRerankTaskSettings( + requestTaskSettings.getTopNDocumentsOnly() != null + ? requestTaskSettings.getTopNDocumentsOnly() + : originalSettings.getTopNDocumentsOnly(), + requestTaskSettings.getReturnDocuments() != null + ? requestTaskSettings.getReturnDocuments() + : originalSettings.getReturnDocuments(), + requestTaskSettings.getMaxChunksPerDoc() != null + ? requestTaskSettings.getMaxChunksPerDoc() + : originalSettings.getMaxChunksPerDoc() + ); + } + + public static IbmWatsonxRerankTaskSettings of(Integer topNDocumentsOnly, Boolean returnDocuments, Integer maxChunksPerDoc) { + return new IbmWatsonxRerankTaskSettings(topNDocumentsOnly, returnDocuments, maxChunksPerDoc); + } + + private final Integer topNDocumentsOnly; + private final Boolean returnDocuments; + private final Integer maxChunksPerDoc; + + public IbmWatsonxRerankTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalInt(), in.readOptionalBoolean(), in.readOptionalInt()); + } + + public IbmWatsonxRerankTaskSettings( + @Nullable Integer topNDocumentsOnly, + @Nullable Boolean doReturnDocuments, + @Nullable Integer maxChunksPerDoc + ) { + this.topNDocumentsOnly = topNDocumentsOnly; + this.returnDocuments = doReturnDocuments; + this.maxChunksPerDoc = maxChunksPerDoc; + } + + @Override + public boolean isEmpty() { + return topNDocumentsOnly == null && returnDocuments == null && maxChunksPerDoc == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (topNDocumentsOnly != null) { + builder.field(TOP_N_DOCS_ONLY, topNDocumentsOnly); + } + if (returnDocuments != null) { + builder.field(RETURN_DOCUMENTS, returnDocuments); + } + if (maxChunksPerDoc != null) { + builder.field(MAX_CHUNKS_PER_DOC, maxChunksPerDoc); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.V_8_14_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalInt(topNDocumentsOnly); + out.writeOptionalBoolean(returnDocuments); + out.writeOptionalInt(maxChunksPerDoc); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IbmWatsonxRerankTaskSettings that = (IbmWatsonxRerankTaskSettings) o; + return Objects.equals(returnDocuments, that.returnDocuments) + && Objects.equals(topNDocumentsOnly, that.topNDocumentsOnly) + && Objects.equals(maxChunksPerDoc, that.maxChunksPerDoc); + } + + @Override + public int hashCode() { + return Objects.hash(returnDocuments, topNDocumentsOnly, maxChunksPerDoc); + } + + public static String invalidInputTypeMessage(InputType inputType) { + return Strings.format("received invalid input type value [%s]", inputType.toString()); + } + + public Boolean getDoesReturnDocuments() { + return returnDocuments; + } + + public Integer getTopNDocumentsOnly() { + return topNDocumentsOnly; + } + + public Boolean getReturnDocuments() { + return returnDocuments; + } + + public Integer getMaxChunksPerDoc() { + return maxChunksPerDoc; + } + + @Override + public TaskSettings updatedTaskSettings(Map newSettings) { + IbmWatsonxRerankTaskSettings updatedSettings = IbmWatsonxRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + return IbmWatsonxRerankTaskSettings.of(this, updatedSettings); + } +}