diff --git a/docs/changelog/117939.yaml b/docs/changelog/117939.yaml new file mode 100644 index 000000000000..d41111f099f9 --- /dev/null +++ b/docs/changelog/117939.yaml @@ -0,0 +1,5 @@ +pr: 117939 +summary: Adding default endpoint for Elastic Rerank +area: Machine Learning +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java index ba3e48e11928..068b3e1f4ce0 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java @@ -57,6 +57,9 @@ public void testGet() throws IOException { var e5Model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID); assertDefaultE5Config(e5Model); + + var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID); + assertDefaultRerankConfig(rerankModel); } @SuppressWarnings("unchecked") @@ -125,6 +128,42 @@ private static void assertDefaultE5Config(Map modelConfig) { assertDefaultChunkingSettings(modelConfig); } + @SuppressWarnings("unchecked") + public void testInferDeploysDefaultRerank() throws IOException { + var model = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID); + assertDefaultRerankConfig(model); + + var inputs = List.of("Hello World", "Goodnight moon"); + var query = "but why"; + var queryParams = Map.of("timeout", "120s"); + var results = infer(ElasticsearchInternalService.DEFAULT_RERANK_ID, TaskType.RERANK, inputs, query, queryParams); + var embeddings = (List>) results.get("rerank"); + assertThat(results.toString(), embeddings, hasSize(2)); + } + + @SuppressWarnings("unchecked") + private static void assertDefaultRerankConfig(Map modelConfig) { + assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_RERANK_ID, modelConfig.get("inference_id")); + assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service")); + assertEquals(modelConfig.toString(), TaskType.RERANK.toString(), modelConfig.get("task_type")); + + var serviceSettings = (Map) modelConfig.get("service_settings"); + assertThat(modelConfig.toString(), serviceSettings.get("model_id"), is(".rerank-v1")); + assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads")); + + var adaptiveAllocations = (Map) serviceSettings.get("adaptive_allocations"); + assertThat( + modelConfig.toString(), + adaptiveAllocations, + Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32)) + ); + + var chunkingSettings = (Map) modelConfig.get("chunking_settings"); + assertNull(chunkingSettings); + var taskSettings = (Map) modelConfig.get("task_settings"); + assertThat(modelConfig.toString(), taskSettings, Matchers.is(Map.of("return_documents", true))); + } + @SuppressWarnings("unchecked") private static void assertDefaultChunkingSettings(Map modelConfig) { var chunkingSettings = (Map) modelConfig.get("chunking_settings"); @@ -159,6 +198,7 @@ public void onFailure(Exception exception) { var request = createInferenceRequest( Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID), inputs, + null, queryParams ); client().performRequestAsync(request, listener); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 4e32ef99d06d..86c0128a3e53 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -333,7 +333,7 @@ private List getInternalAsList(String endpoint) throws IOException { protected Map infer(String modelId, List input) throws IOException { var endpoint = Strings.format("_inference/%s", modelId); - return inferInternal(endpoint, input, Map.of()); + return inferInternal(endpoint, input, null, Map.of()); } protected Deque streamInferOnMockService(String modelId, TaskType taskType, List input) throws Exception { @@ -344,7 +344,7 @@ protected Deque streamInferOnMockService(String modelId, TaskTy private Deque callAsync(String endpoint, List input) throws Exception { var responseConsumer = new AsyncInferenceResponseConsumer(); var request = new Request("POST", endpoint); - request.setJsonEntity(jsonBody(input)); + request.setJsonEntity(jsonBody(input, null)); request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build()); var latch = new CountDownLatch(1); client().performRequestAsync(request, new ResponseListener() { @@ -364,33 +364,60 @@ public void onFailure(Exception exception) { protected Map infer(String modelId, TaskType taskType, List input) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); - return inferInternal(endpoint, input, Map.of()); + return inferInternal(endpoint, input, null, Map.of()); } protected Map infer(String modelId, TaskType taskType, List input, Map queryParameters) throws IOException { var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId); - return inferInternal(endpoint, input, queryParameters); + return inferInternal(endpoint, input, null, queryParameters); } - protected Request createInferenceRequest(String endpoint, List input, Map queryParameters) { + protected Map infer( + String modelId, + TaskType taskType, + List input, + String query, + Map queryParameters + ) throws IOException { + var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId); + return inferInternal(endpoint, input, query, queryParameters); + } + + protected Request createInferenceRequest( + String endpoint, + List input, + @Nullable String query, + Map queryParameters + ) { var request = new Request("POST", endpoint); - request.setJsonEntity(jsonBody(input)); + request.setJsonEntity(jsonBody(input, query)); if (queryParameters.isEmpty() == false) { request.addParameters(queryParameters); } return request; } - private Map inferInternal(String endpoint, List input, Map queryParameters) throws IOException { - var request = createInferenceRequest(endpoint, input, queryParameters); + private Map inferInternal( + String endpoint, + List input, + @Nullable String query, + Map queryParameters + ) throws IOException { + var request = createInferenceRequest(endpoint, input, query, queryParameters); var response = client().performRequest(request); assertOkOrCreated(response); return entityAsMap(response); } - private String jsonBody(List input) { - var bodyBuilder = new StringBuilder("{\"input\": ["); + private String jsonBody(List input, @Nullable String query) { + final StringBuilder bodyBuilder = new StringBuilder("{"); + + if (query != null) { + bodyBuilder.append("\"query\":\"").append(query).append("\","); + } + + bodyBuilder.append("\"input\": ["); for (var in : input) { bodyBuilder.append('"').append(in).append('"').append(','); } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index f5773e73f2b2..604e1d4f553b 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -44,7 +44,7 @@ public void testCRUD() throws IOException { } var getAllModels = getAllModels(); - int numModels = 11; + int numModels = 12; assertThat(getAllModels, hasSize(numModels)); var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING); @@ -482,7 +482,7 @@ public void testSupportedStream() throws Exception { } public void testGetZeroModels() throws IOException { - var models = getModels("_all", TaskType.RERANK); + var models = getModels("_all", TaskType.COMPLETION); assertThat(models, empty()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 2320cca8295d..673b841317a3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -62,12 +62,12 @@ import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings; -import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings; +import org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings; @@ -510,9 +510,7 @@ private static void addCustomElandWriteables(final List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankModel.java index f620b15680c8..6388bb33bb78 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankModel.java @@ -17,7 +17,7 @@ import java.util.HashMap; import java.util.Map; -import static org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings.RETURN_DOCUMENTS; +import static org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings.RETURN_DOCUMENTS; public class CustomElandRerankModel extends CustomElandModel { @@ -26,7 +26,7 @@ public CustomElandRerankModel( TaskType taskType, String service, CustomElandInternalServiceSettings serviceSettings, - CustomElandRerankTaskSettings taskSettings + RerankTaskSettings taskSettings ) { super(inferenceEntityId, taskType, service, serviceSettings, taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java index 115cc9f05599..276bce6dbe8f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java @@ -9,7 +9,6 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; @@ -22,9 +21,9 @@ public ElasticRerankerModel( TaskType taskType, String service, ElasticRerankerServiceSettings serviceSettings, - ChunkingSettings chunkingSettings + RerankTaskSettings taskSettings ) { - super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings); + super(inferenceEntityId, taskType, service, serviceSettings, taskSettings); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 2ec3a9d62943..0e64842f873d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -101,6 +101,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi public static final int EMBEDDING_MAX_BATCH_SIZE = 10; public static final String DEFAULT_ELSER_ID = ".elser-2-elasticsearch"; public static final String DEFAULT_E5_ID = ".multilingual-e5-small-elasticsearch"; + public static final String DEFAULT_RERANK_ID = ".rerank-v1-elasticsearch"; private static final EnumSet supportedTaskTypes = EnumSet.of( TaskType.RERANK, @@ -225,7 +226,7 @@ public void parseRequestConfig( ) ); } else if (RERANKER_ID.equals(modelId)) { - rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, chunkingSettings, modelListener); + rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, taskSettingsMap, modelListener); } else { customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, modelListener); } @@ -308,7 +309,7 @@ private static CustomElandModel createCustomElandModel( taskType, NAME, elandServiceSettings(serviceSettings, context), - CustomElandRerankTaskSettings.fromMap(taskSettings) + RerankTaskSettings.fromMap(taskSettings) ); default -> throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST); }; @@ -331,7 +332,7 @@ private void rerankerCase( TaskType taskType, Map config, Map serviceSettingsMap, - ChunkingSettings chunkingSettings, + Map taskSettingsMap, ActionListener modelListener ) { @@ -346,7 +347,7 @@ private void rerankerCase( taskType, NAME, new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()), - chunkingSettings + RerankTaskSettings.fromMap(taskSettingsMap) ) ); } @@ -512,6 +513,14 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M ElserMlNodeTaskSettings.DEFAULT, chunkingSettings ); + } else if (modelId.equals(RERANKER_ID)) { + return new ElasticRerankerModel( + inferenceEntityId, + taskType, + NAME, + new ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)), + RerankTaskSettings.fromMap(taskSettingsMap) + ); } else { return createCustomElandModel( inferenceEntityId, @@ -653,21 +662,23 @@ public void inferRerank( ) { var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout); - var modelSettings = (CustomElandRerankTaskSettings) model.getTaskSettings(); - var requestSettings = CustomElandRerankTaskSettings.fromMap(requestTaskSettings); - Boolean returnDocs = CustomElandRerankTaskSettings.of(modelSettings, requestSettings).returnDocuments(); + var returnDocs = Boolean.TRUE; + if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) { + var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings); + returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments(); + } Function inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null; - client.execute( - InferModelAction.INSTANCE, - request, - listener.delegateFailureAndWrap( - (l, inferenceResult) -> l.onResponse( - textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier) - ) - ) + ActionListener mlResultsListener = listener.delegateFailureAndWrap( + (l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier)) + ); + + var maybeDeployListener = mlResultsListener.delegateResponse( + (l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener) ); + + client.execute(InferModelAction.INSTANCE, request, maybeDeployListener); } public void chunkedInfer( @@ -811,7 +822,8 @@ private RankedDocsResults textSimilarityResultsToRankedDocs( public List defaultConfigIds() { return List.of( new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this), - new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this) + new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this), + new DefaultConfigId(DEFAULT_RERANK_ID, TaskType.RERANK, this) ); } @@ -904,12 +916,19 @@ private List defaultConfigs(boolean useLinuxOptimizedModel) { ), ChunkingSettingsBuilder.DEFAULT_SETTINGS ); - return List.of(defaultElser, defaultE5); + var defaultRerank = new ElasticRerankerModel( + DEFAULT_RERANK_ID, + TaskType.RERANK, + NAME, + new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)), + RerankTaskSettings.DEFAULT_SETTINGS + ); + return List.of(defaultElser, defaultE5, defaultRerank); } @Override boolean isDefaultId(String inferenceId) { - return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId); + return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId) || DEFAULT_RERANK_ID.equals(inferenceId); } static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/RerankTaskSettings.java similarity index 79% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/RerankTaskSettings.java index a0be1661b860..3c25f7a6a901 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/RerankTaskSettings.java @@ -26,14 +26,14 @@ /** * Defines the task settings for internal rerank service. */ -public class CustomElandRerankTaskSettings implements TaskSettings { +public class RerankTaskSettings implements TaskSettings { public static final String NAME = "custom_eland_rerank_task_settings"; public static final String RETURN_DOCUMENTS = "return_documents"; - static final CustomElandRerankTaskSettings DEFAULT_SETTINGS = new CustomElandRerankTaskSettings(Boolean.TRUE); + static final RerankTaskSettings DEFAULT_SETTINGS = new RerankTaskSettings(Boolean.TRUE); - public static CustomElandRerankTaskSettings defaultsFromMap(Map map) { + public static RerankTaskSettings defaultsFromMap(Map map) { ValidationException validationException = new ValidationException(); if (map == null || map.isEmpty()) { @@ -49,7 +49,7 @@ public static CustomElandRerankTaskSettings defaultsFromMap(Map returnDocuments = true; } - return new CustomElandRerankTaskSettings(returnDocuments); + return new RerankTaskSettings(returnDocuments); } /** @@ -57,13 +57,13 @@ public static CustomElandRerankTaskSettings defaultsFromMap(Map * @param map source map * @return Task settings */ - public static CustomElandRerankTaskSettings fromMap(Map map) { + public static RerankTaskSettings fromMap(Map map) { if (map == null || map.isEmpty()) { return DEFAULT_SETTINGS; } Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, new ValidationException()); - return new CustomElandRerankTaskSettings(returnDocuments); + return new RerankTaskSettings(returnDocuments); } /** @@ -74,20 +74,17 @@ public static CustomElandRerankTaskSettings fromMap(Map map) { * @param requestTaskSettings the settings passed in within the task_settings field of the request * @return Either {@code originalSettings} or {@code requestTaskSettings} */ - public static CustomElandRerankTaskSettings of( - CustomElandRerankTaskSettings originalSettings, - CustomElandRerankTaskSettings requestTaskSettings - ) { + public static RerankTaskSettings of(RerankTaskSettings originalSettings, RerankTaskSettings requestTaskSettings) { return requestTaskSettings.returnDocuments() != null ? requestTaskSettings : originalSettings; } private final Boolean returnDocuments; - public CustomElandRerankTaskSettings(StreamInput in) throws IOException { + public RerankTaskSettings(StreamInput in) throws IOException { this(in.readOptionalBoolean()); } - public CustomElandRerankTaskSettings(@Nullable Boolean doReturnDocuments) { + public RerankTaskSettings(@Nullable Boolean doReturnDocuments) { if (doReturnDocuments == null) { this.returnDocuments = true; } else { @@ -133,7 +130,7 @@ public Boolean returnDocuments() { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - CustomElandRerankTaskSettings that = (CustomElandRerankTaskSettings) o; + RerankTaskSettings that = (RerankTaskSettings) o; return Objects.equals(returnDocuments, that.returnDocuments); } @@ -144,7 +141,7 @@ public int hashCode() { @Override public TaskSettings updatedTaskSettings(Map newSettings) { - CustomElandRerankTaskSettings updatedSettings = CustomElandRerankTaskSettings.fromMap(new HashMap<>(newSettings)); + RerankTaskSettings updatedSettings = RerankTaskSettings.fromMap(new HashMap<>(newSettings)); return of(this, updatedSettings); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 306509ea60cf..17e6583f11c8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -534,16 +534,13 @@ public void testParseRequestConfig_Rerank() { ) ); var returnDocs = randomBoolean(); - settings.put( - ModelConfigurations.TASK_SETTINGS, - new HashMap<>(Map.of(CustomElandRerankTaskSettings.RETURN_DOCUMENTS, returnDocs)) - ); + settings.put(ModelConfigurations.TASK_SETTINGS, new HashMap<>(Map.of(RerankTaskSettings.RETURN_DOCUMENTS, returnDocs))); ActionListener modelListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(CustomElandRerankModel.class)); - assertThat(model.getTaskSettings(), instanceOf(CustomElandRerankTaskSettings.class)); + assertThat(model.getTaskSettings(), instanceOf(RerankTaskSettings.class)); assertThat(model.getServiceSettings(), instanceOf(CustomElandInternalServiceSettings.class)); - assertEquals(returnDocs, ((CustomElandRerankTaskSettings) model.getTaskSettings()).returnDocuments()); + assertEquals(returnDocs, ((RerankTaskSettings) model.getTaskSettings()).returnDocuments()); }, e -> { fail("Model parsing failed " + e.getMessage()); }); service.parseRequestConfig(randomInferenceEntityId, TaskType.RERANK, settings, modelListener); @@ -583,9 +580,9 @@ public void testParseRequestConfig_Rerank_DefaultTaskSettings() { ActionListener modelListener = ActionListener.wrap(model -> { assertThat(model, instanceOf(CustomElandRerankModel.class)); - assertThat(model.getTaskSettings(), instanceOf(CustomElandRerankTaskSettings.class)); + assertThat(model.getTaskSettings(), instanceOf(RerankTaskSettings.class)); assertThat(model.getServiceSettings(), instanceOf(CustomElandInternalServiceSettings.class)); - assertEquals(Boolean.TRUE, ((CustomElandRerankTaskSettings) model.getTaskSettings()).returnDocuments()); + assertEquals(Boolean.TRUE, ((RerankTaskSettings) model.getTaskSettings()).returnDocuments()); }, e -> { fail("Model parsing failed " + e.getMessage()); }); service.parseRequestConfig(randomInferenceEntityId, TaskType.RERANK, settings, modelListener); @@ -1249,14 +1246,11 @@ public void testParsePersistedConfig_Rerank() { ); settings.put(ElasticsearchInternalServiceSettings.MODEL_ID, "foo"); var returnDocs = randomBoolean(); - settings.put( - ModelConfigurations.TASK_SETTINGS, - new HashMap<>(Map.of(CustomElandRerankTaskSettings.RETURN_DOCUMENTS, returnDocs)) - ); + settings.put(ModelConfigurations.TASK_SETTINGS, new HashMap<>(Map.of(RerankTaskSettings.RETURN_DOCUMENTS, returnDocs))); var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.RERANK, settings); - assertThat(model.getTaskSettings(), instanceOf(CustomElandRerankTaskSettings.class)); - assertEquals(returnDocs, ((CustomElandRerankTaskSettings) model.getTaskSettings()).returnDocuments()); + assertThat(model.getTaskSettings(), instanceOf(RerankTaskSettings.class)); + assertEquals(returnDocs, ((RerankTaskSettings) model.getTaskSettings()).returnDocuments()); } // without task settings @@ -1279,8 +1273,8 @@ public void testParsePersistedConfig_Rerank() { settings.put(ElasticsearchInternalServiceSettings.MODEL_ID, "foo"); var model = service.parsePersistedConfig(randomInferenceEntityId, TaskType.RERANK, settings); - assertThat(model.getTaskSettings(), instanceOf(CustomElandRerankTaskSettings.class)); - assertTrue(((CustomElandRerankTaskSettings) model.getTaskSettings()).returnDocuments()); + assertThat(model.getTaskSettings(), instanceOf(RerankTaskSettings.class)); + assertTrue(((RerankTaskSettings) model.getTaskSettings()).returnDocuments()); } } @@ -1335,7 +1329,7 @@ private CustomElandModel getCustomElandModel(TaskType taskType) { taskType, ElasticsearchInternalService.NAME, new CustomElandInternalServiceSettings(1, 4, "custom-model", null), - CustomElandRerankTaskSettings.DEFAULT_SETTINGS + RerankTaskSettings.DEFAULT_SETTINGS ); } else if (taskType == TaskType.TEXT_EMBEDDING) { var serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "custom-model", null); @@ -1528,20 +1522,30 @@ public void testEmbeddingTypeFromTaskTypeAndSettings() { ) ); - var e = expectThrows( + var e1 = expectThrows( ElasticsearchStatusException.class, () -> ElasticsearchInternalService.embeddingTypeFromTaskTypeAndSettings( TaskType.COMPLETION, new ElasticsearchInternalServiceSettings(1, 1, "foo", null) ) ); - assertThat(e.getMessage(), containsString("Chunking is not supported for task type [completion]")); + assertThat(e1.getMessage(), containsString("Chunking is not supported for task type [completion]")); + + var e2 = expectThrows( + ElasticsearchStatusException.class, + () -> ElasticsearchInternalService.embeddingTypeFromTaskTypeAndSettings( + TaskType.RERANK, + new ElasticsearchInternalServiceSettings(1, 1, "foo", null) + ) + ); + assertThat(e2.getMessage(), containsString("Chunking is not supported for task type [rerank]")); } public void testIsDefaultId() { var service = createService(mock(Client.class)); assertTrue(service.isDefaultId(".elser-2-elasticsearch")); assertTrue(service.isDefaultId(".multilingual-e5-small-elasticsearch")); + assertTrue(service.isDefaultId(".rerank-v1-elasticsearch")); assertFalse(service.isDefaultId("foo")); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/RerankTaskSettingsTests.java similarity index 53% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/RerankTaskSettingsTests.java index 4207896fc54f..255454a1ed62 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankTaskSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/RerankTaskSettingsTests.java @@ -22,7 +22,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.sameInstance; -public class CustomElandRerankTaskSettingsTests extends AbstractWireSerializingTestCase { +public class RerankTaskSettingsTests extends AbstractWireSerializingTestCase { public void testIsEmpty() { var randomSettings = createRandom(); @@ -35,9 +35,9 @@ public void testUpdatedTaskSettings() { var newSettings = createRandom(); Map newSettingsMap = new HashMap<>(); if (newSettings.returnDocuments() != null) { - newSettingsMap.put(CustomElandRerankTaskSettings.RETURN_DOCUMENTS, newSettings.returnDocuments()); + newSettingsMap.put(RerankTaskSettings.RETURN_DOCUMENTS, newSettings.returnDocuments()); } - CustomElandRerankTaskSettings updatedSettings = (CustomElandRerankTaskSettings) initialSettings.updatedTaskSettings( + RerankTaskSettings updatedSettings = (RerankTaskSettings) initialSettings.updatedTaskSettings( Collections.unmodifiableMap(newSettingsMap) ); if (newSettings.returnDocuments() == null) { @@ -48,37 +48,37 @@ public void testUpdatedTaskSettings() { } public void testDefaultsFromMap_MapIsNull_ReturnsDefaultSettings() { - var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(null); + var rerankTaskSettings = RerankTaskSettings.defaultsFromMap(null); - assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + assertThat(rerankTaskSettings, sameInstance(RerankTaskSettings.DEFAULT_SETTINGS)); } public void testDefaultsFromMap_MapIsEmpty_ReturnsDefaultSettings() { - var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(new HashMap<>()); + var rerankTaskSettings = RerankTaskSettings.defaultsFromMap(new HashMap<>()); - assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + assertThat(rerankTaskSettings, sameInstance(RerankTaskSettings.DEFAULT_SETTINGS)); } public void testDefaultsFromMap_ExtractedReturnDocumentsNull_SetsReturnDocumentToTrue() { - var customElandRerankTaskSettings = CustomElandRerankTaskSettings.defaultsFromMap(new HashMap<>()); + var rerankTaskSettings = RerankTaskSettings.defaultsFromMap(new HashMap<>()); - assertThat(customElandRerankTaskSettings.returnDocuments(), is(Boolean.TRUE)); + assertThat(rerankTaskSettings.returnDocuments(), is(Boolean.TRUE)); } public void testFromMap_MapIsNull_ReturnsDefaultSettings() { - var customElandRerankTaskSettings = CustomElandRerankTaskSettings.fromMap(null); + var rerankTaskSettings = RerankTaskSettings.fromMap(null); - assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + assertThat(rerankTaskSettings, sameInstance(RerankTaskSettings.DEFAULT_SETTINGS)); } public void testFromMap_MapIsEmpty_ReturnsDefaultSettings() { - var customElandRerankTaskSettings = CustomElandRerankTaskSettings.fromMap(new HashMap<>()); + var rerankTaskSettings = RerankTaskSettings.fromMap(new HashMap<>()); - assertThat(customElandRerankTaskSettings, sameInstance(CustomElandRerankTaskSettings.DEFAULT_SETTINGS)); + assertThat(rerankTaskSettings, sameInstance(RerankTaskSettings.DEFAULT_SETTINGS)); } public void testToXContent_WritesAllValues() throws IOException { - var serviceSettings = new CustomElandRerankTaskSettings(Boolean.TRUE); + var serviceSettings = new RerankTaskSettings(Boolean.TRUE); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); @@ -89,30 +89,30 @@ public void testToXContent_WritesAllValues() throws IOException { } public void testOf_PrefersNonNullRequestTaskSettings() { - var originalSettings = new CustomElandRerankTaskSettings(Boolean.FALSE); - var requestTaskSettings = new CustomElandRerankTaskSettings(Boolean.TRUE); + var originalSettings = new RerankTaskSettings(Boolean.FALSE); + var requestTaskSettings = new RerankTaskSettings(Boolean.TRUE); - var taskSettings = CustomElandRerankTaskSettings.of(originalSettings, requestTaskSettings); + var taskSettings = RerankTaskSettings.of(originalSettings, requestTaskSettings); assertThat(taskSettings, sameInstance(requestTaskSettings)); } - private static CustomElandRerankTaskSettings createRandom() { - return new CustomElandRerankTaskSettings(randomOptionalBoolean()); + private static RerankTaskSettings createRandom() { + return new RerankTaskSettings(randomOptionalBoolean()); } @Override - protected Writeable.Reader instanceReader() { - return CustomElandRerankTaskSettings::new; + protected Writeable.Reader instanceReader() { + return RerankTaskSettings::new; } @Override - protected CustomElandRerankTaskSettings createTestInstance() { + protected RerankTaskSettings createTestInstance() { return createRandom(); } @Override - protected CustomElandRerankTaskSettings mutateInstance(CustomElandRerankTaskSettings instance) throws IOException { - return randomValueOtherThan(instance, CustomElandRerankTaskSettingsTests::createRandom); + protected RerankTaskSettings mutateInstance(RerankTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, RerankTaskSettingsTests::createRandom); } }