diff --git a/docs/changelog/102877.yaml b/docs/changelog/102877.yaml new file mode 100644 index 0000000000000..da2de19b19a90 --- /dev/null +++ b/docs/changelog/102877.yaml @@ -0,0 +1,5 @@ +pr: 102877 +summary: Add basic telelemetry for the inference feature +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/102891.yaml b/docs/changelog/102891.yaml new file mode 100644 index 0000000000000..c5d5ed8c6758e --- /dev/null +++ b/docs/changelog/102891.yaml @@ -0,0 +1,7 @@ +pr: 102891 +summary: "[Query Rules] Fix bug where combining the same metadata with text/numeric\ + \ values leads to error" +area: Application +type: bug +issues: + - 102827 diff --git a/docs/reference/rest-api/usage.asciidoc b/docs/reference/rest-api/usage.asciidoc index 959a798378fc6..e2529de75f0e7 100644 --- a/docs/reference/rest-api/usage.asciidoc +++ b/docs/reference/rest-api/usage.asciidoc @@ -197,6 +197,11 @@ GET /_xpack/usage }, "node_count" : 1 }, + "inference": { + "available" : true, + "enabled" : true, + "models" : [] + }, "logstash" : { "available" : true, "enabled" : true diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index c315af711576a..ca79be9453cfe 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -184,7 +184,8 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_PROFILE = def(8_551_00_0); public static final TransportVersion CLUSTER_STATS_RESCORER_USAGE_ADDED = def(8_552_00_0); public static final TransportVersion ML_INFERENCE_HF_SERVICE_ADDED = def(8_553_00_0); - public static final TransportVersion UPGRADE_TO_LUCENE_9_9 = def(8_554_00_0); + public static final TransportVersion INFERENCE_USAGE_ADDED = def(8_554_00_0); + public static final TransportVersion UPGRADE_TO_LUCENE_9_9 = def(8_555_00_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java index 37990caeec097..ab5b74faa6530 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java @@ -16,6 +16,17 @@ public interface InferenceServiceResults extends NamedWriteable, ToXContentFragment { + /** + * Transform the result to match the format required for the TransportCoordinatedInferenceAction. + * For the inference plugin TextEmbeddingResults, the {@link #transformToLegacyFormat()} transforms the + * results into an intermediate format only used by the plugin's return value. It doesn't align with what the + * TransportCoordinatedInferenceAction expects. TransportCoordinatedInferenceAction expects an ml plugin + * TextEmbeddingResults. + * + * For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat. + */ + List transformToCoordinationFormat(); + /** * Transform the result to match the format required for versions prior to * {@link org.elasticsearch.TransportVersions#INFERENCE_SERVICE_RESULTS_ADDED} diff --git a/x-pack/plugin/core/src/main/java/module-info.java b/x-pack/plugin/core/src/main/java/module-info.java index 4aa2e145228b8..f747d07224454 100644 --- a/x-pack/plugin/core/src/main/java/module-info.java +++ b/x-pack/plugin/core/src/main/java/module-info.java @@ -75,6 +75,7 @@ exports org.elasticsearch.xpack.core.indexing; exports org.elasticsearch.xpack.core.inference.action; exports org.elasticsearch.xpack.core.inference.results; + exports org.elasticsearch.xpack.core.inference; exports org.elasticsearch.xpack.core.logstash; exports org.elasticsearch.xpack.core.ml.action; exports org.elasticsearch.xpack.core.ml.annotations; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index ac16631bacb73..df19648307a0b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -55,6 +55,7 @@ import org.elasticsearch.xpack.core.ilm.TimeseriesLifecycleType; import org.elasticsearch.xpack.core.ilm.UnfollowAction; import org.elasticsearch.xpack.core.ilm.WaitForSnapshotAction; +import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage; import org.elasticsearch.xpack.core.logstash.LogstashFeatureSetUsage; import org.elasticsearch.xpack.core.ml.MachineLearningFeatureSetUsage; import org.elasticsearch.xpack.core.ml.MlMetadata; @@ -133,6 +134,8 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.LOGSTASH, LogstashFeatureSetUsage::new), // ML new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MACHINE_LEARNING, MachineLearningFeatureSetUsage::new), + // inference + new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.INFERENCE, InferenceFeatureSetUsage::new), // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), // security diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackField.java index c8a78af429592..801ef2c463e95 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackField.java @@ -18,6 +18,8 @@ public final class XPackField { public static final String GRAPH = "graph"; /** Name constant for the machine learning feature. */ public static final String MACHINE_LEARNING = "ml"; + /** Name constant for the inference feature. */ + public static final String INFERENCE = "inference"; /** Name constant for the Logstash feature. */ public static final String LOGSTASH = "logstash"; /** Name constant for the Beats feature. */ diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/XPackUsageFeatureAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/XPackUsageFeatureAction.java index d96fd91ed3f22..c0e6d96c1569a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/XPackUsageFeatureAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/action/XPackUsageFeatureAction.java @@ -27,6 +27,7 @@ public class XPackUsageFeatureAction extends ActionType modelStats; + + public InferenceFeatureSetUsage(Collection modelStats) { + super(XPackField.INFERENCE, true, true); + this.modelStats = modelStats; + } + + public InferenceFeatureSetUsage(StreamInput in) throws IOException { + super(in); + this.modelStats = in.readCollectionAsList(ModelStats::new); + } + + @Override + protected void innerXContent(XContentBuilder builder, Params params) throws IOException { + super.innerXContent(builder, params); + builder.xContentList("models", modelStats); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeCollection(modelStats); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.INFERENCE_USAGE_ADDED; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java index 20279e82d6c09..910ea5cab214d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/SparseEmbeddingResults.java @@ -81,6 +81,11 @@ public Map asMap() { return map; } + @Override + public List transformToCoordinationFormat() { + return transformToLegacyFormat(); + } + @Override public List transformToLegacyFormat() { return embeddings.stream() diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java index 7a7ccab2b4daa..ace5974866038 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/TextEmbeddingResults.java @@ -78,6 +78,14 @@ public String getWriteableName() { return NAME; } + @Override + public List transformToCoordinationFormat() { + return embeddings.stream() + .map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray()) + .map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false)) + .toList(); + } + @Override @SuppressWarnings("deprecation") public List transformToLegacyFormat() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceFeatureSetUsageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceFeatureSetUsageTests.java new file mode 100644 index 0000000000000..8f64b521c64c9 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceFeatureSetUsageTests.java @@ -0,0 +1,41 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference; + +import com.carrotsearch.randomizedtesting.generators.RandomStrings; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.io.IOException; + +public class InferenceFeatureSetUsageTests extends AbstractWireSerializingTestCase { + + @Override + protected Writeable.Reader instanceReader() { + return InferenceFeatureSetUsage.ModelStats::new; + } + + @Override + protected InferenceFeatureSetUsage.ModelStats createTestInstance() { + RandomStrings.randomAsciiLettersOfLength(random(), 10); + return new InferenceFeatureSetUsage.ModelStats( + randomIdentifier(), + TaskType.values()[randomInt(TaskType.values().length - 1)], + randomInt(10) + ); + } + + @Override + protected InferenceFeatureSetUsage.ModelStats mutateInstance(InferenceFeatureSetUsage.ModelStats modelStats) throws IOException { + InferenceFeatureSetUsage.ModelStats newModelStats = new InferenceFeatureSetUsage.ModelStats(modelStats); + newModelStats.add(); + return newModelStats; + } +} diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/260_rule_query_search.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/260_rule_query_search.yml index b41636e624674..c287209da5bed 100644 --- a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/260_rule_query_search.yml +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/260_rule_query_search.yml @@ -194,4 +194,46 @@ setup: - match: { hits.hits.0._id: 'doc2' } - match: { hits.hits.1._id: 'doc3' } +--- +"Perform a rule query over a ruleset with combined numeric and text rule matching": + + - do: + query_ruleset.put: + ruleset_id: combined-ruleset + body: + rules: + - rule_id: rule1 + type: pinned + criteria: + - type: fuzzy + metadata: foo + values: [ bar ] + actions: + ids: + - 'doc1' + - rule_id: rule2 + type: pinned + criteria: + - type: lte + metadata: foo + values: [ 100 ] + actions: + ids: + - 'doc2' + - do: + search: + body: + query: + rule_query: + organic: + query_string: + default_field: text + query: blah blah blah + match_criteria: + foo: baz + ruleset_id: combined-ruleset + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: 'doc1' } + diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/QueryRule.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/QueryRule.java index 9b2ce393e5b04..9cca42b0402bf 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/QueryRule.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/QueryRule.java @@ -294,7 +294,7 @@ public AppliedQueryRules applyRule(AppliedQueryRules appliedRules, Map(InferenceAction.INSTANCE, TransportInferenceAction.class), new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class), new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class), - new ActionHandler<>(DeleteInferenceModelAction.INSTANCE, TransportDeleteInferenceModelAction.class) + new ActionHandler<>(DeleteInferenceModelAction.INSTANCE, TransportDeleteInferenceModelAction.class), + new ActionHandler<>(XPackUsageFeatureAction.INFERENCE, TransportInferenceUsageAction.class) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java new file mode 100644 index 0000000000000..54452d8a7ed68 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.protocol.xpack.XPackUsageRequest; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.action.XPackUsageFeatureAction; +import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse; +import org.elasticsearch.xpack.core.action.XPackUsageFeatureTransportAction; +import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; + +import java.util.Map; +import java.util.TreeMap; + +import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; + +public class TransportInferenceUsageAction extends XPackUsageFeatureTransportAction { + + private final Client client; + + @Inject + public TransportInferenceUsageAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver, + Client client + ) { + super( + XPackUsageFeatureAction.INFERENCE.name(), + transportService, + clusterService, + threadPool, + actionFilters, + indexNameExpressionResolver + ); + this.client = new OriginSettingClient(client, ML_ORIGIN); + } + + @Override + protected void masterOperation( + Task task, + XPackUsageRequest request, + ClusterState state, + ActionListener listener + ) throws Exception { + GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY); + client.execute(GetInferenceModelAction.INSTANCE, getInferenceModelAction, ActionListener.wrap(response -> { + Map stats = new TreeMap<>(); + for (ModelConfigurations model : response.getModels()) { + String statKey = model.getService() + ":" + model.getTaskType().name(); + InferenceFeatureSetUsage.ModelStats stat = stats.computeIfAbsent( + statKey, + key -> new InferenceFeatureSetUsage.ModelStats(model.getService(), model.getTaskType()) + ); + stat.add(); + } + InferenceFeatureSetUsage usage = new InferenceFeatureSetUsage(stats.values()); + listener.onResponse(new XPackUsageFeatureResponse(usage)); + }, listener::onFailure)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java new file mode 100644 index 0000000000000..b0c59fe160be3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.protocol.xpack.XPackUsageRequest; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.MockUtils; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xpack.core.XPackFeatureSet; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse; +import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource; +import org.junit.After; +import org.junit.Before; + +import java.util.List; + +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.core.Is.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TransportInferenceUsageActionTests extends ESTestCase { + + private Client client; + private TransportInferenceUsageAction action; + + @Before + public void init() { + client = mock(Client.class); + ThreadPool threadPool = new TestThreadPool("test"); + when(client.threadPool()).thenReturn(threadPool); + + TransportService transportService = MockUtils.setupTransportServiceWithThreadpoolExecutor(mock(ThreadPool.class)); + + action = new TransportInferenceUsageAction( + transportService, + mock(ClusterService.class), + mock(ThreadPool.class), + mock(ActionFilters.class), + mock(IndexNameExpressionResolver.class), + client + ); + } + + @After + public void close() { + client.threadPool().shutdown(); + } + + public void test() throws Exception { + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onResponse( + new GetInferenceModelAction.Response( + List.of( + new ModelConfigurations("model-001", TaskType.TEXT_EMBEDDING, "openai", mock(ServiceSettings.class)), + new ModelConfigurations("model-002", TaskType.TEXT_EMBEDDING, "openai", mock(ServiceSettings.class)), + new ModelConfigurations("model-003", TaskType.SPARSE_EMBEDDING, "hugging_face_elser", mock(ServiceSettings.class)), + new ModelConfigurations("model-004", TaskType.TEXT_EMBEDDING, "openai", mock(ServiceSettings.class)), + new ModelConfigurations("model-005", TaskType.SPARSE_EMBEDDING, "openai", mock(ServiceSettings.class)), + new ModelConfigurations("model-006", TaskType.SPARSE_EMBEDDING, "hugging_face_elser", mock(ServiceSettings.class)) + ) + ) + ); + return Void.TYPE; + }).when(client).execute(any(GetInferenceModelAction.class), any(), any()); + + PlainActionFuture future = new PlainActionFuture<>(); + action.masterOperation(mock(Task.class), mock(XPackUsageRequest.class), mock(ClusterState.class), future); + + BytesStreamOutput out = new BytesStreamOutput(); + future.get().getUsage().writeTo(out); + XPackFeatureSet.Usage usage = new InferenceFeatureSetUsage(out.bytes().streamInput()); + + assertThat(usage.name(), is(XPackField.INFERENCE)); + assertTrue(usage.enabled()); + assertTrue(usage.available()); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + usage.toXContent(builder, ToXContent.EMPTY_PARAMS); + XContentSource source = new XContentSource(builder); + assertThat(source.getValue("models"), hasSize(3)); + assertThat(source.getValue("models.0.service"), is("hugging_face_elser")); + assertThat(source.getValue("models.0.task_type"), is("SPARSE_EMBEDDING")); + assertThat(source.getValue("models.0.count"), is(2)); + assertThat(source.getValue("models.1.service"), is("openai")); + assertThat(source.getValue("models.1.task_type"), is("SPARSE_EMBEDDING")); + assertThat(source.getValue("models.1.count"), is(1)); + assertThat(source.getValue("models.2.service"), is("openai")); + assertThat(source.getValue("models.2.task_type"), is("TEXT_EMBEDDING")); + assertThat(source.getValue("models.2.count"), is(3)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java index 0a8bfd20caaf1..6f8fa0c453d09 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/SparseEmbeddingResultsTests.java @@ -11,12 +11,14 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Map; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD; import static org.hamcrest.Matchers.is; public class SparseEmbeddingResultsTests extends AbstractWireSerializingTestCase { @@ -151,6 +153,25 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I }""")); } + public void testTransformToCoordinationFormat() { + var results = createSparseResult( + List.of( + createEmbedding(List.of(new SparseEmbeddingResults.WeightedToken("token", 0.1F)), false), + createEmbedding(List.of(new SparseEmbeddingResults.WeightedToken("token2", 0.2F)), true) + ) + ).transformToCoordinationFormat(); + + assertThat( + results, + is( + List.of( + new TextExpansionResults(DEFAULT_RESULTS_FIELD, List.of(new TextExpansionResults.WeightedToken("token", 0.1F)), false), + new TextExpansionResults(DEFAULT_RESULTS_FIELD, List.of(new TextExpansionResults.WeightedToken("token2", 0.2F)), true) + ) + ) + ); + } + public record EmbeddingExpectation(Map tokens, boolean isTruncated) {} public static Map buildExpectation(List embeddings) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java index 71d14e09872fd..09d9894d98853 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/TextEmbeddingResultsTests.java @@ -100,6 +100,30 @@ public void testToXContent_CreatesTheRightFormatForMultipleEmbeddings() throws I }""")); } + public void testTransformToCoordinationFormat() { + var results = new TextEmbeddingResults( + List.of(new TextEmbeddingResults.Embedding(List.of(0.1F, 0.2F)), new TextEmbeddingResults.Embedding(List.of(0.3F, 0.4F))) + ).transformToCoordinationFormat(); + + assertThat( + results, + is( + List.of( + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + TextEmbeddingResults.TEXT_EMBEDDING, + new double[] { 0.1F, 0.2F }, + false + ), + new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults( + TextEmbeddingResults.TEXT_EMBEDDING, + new double[] { 0.3F, 0.4F }, + false + ) + ) + ) + ); + } + @Override protected Writeable.Reader instanceReader() { return TextEmbeddingResults::new; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java index d90c9ec807495..13e04772683eb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCoordinatedInferenceAction.java @@ -182,7 +182,7 @@ private void replaceErrorOnMissing( } static InferModelAction.Response translateInferenceServiceResponse(InferenceServiceResults inferenceResults) { - var legacyResults = new ArrayList(inferenceResults.transformToLegacyFormat()); + var legacyResults = new ArrayList(inferenceResults.transformToCoordinationFormat()); return new InferModelAction.Response(legacyResults, null, false); } } diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 5412e7d05f27f..86640e2e1a784 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -416,6 +416,7 @@ public class Constants { "cluster:monitor/xpack/usage/graph", "cluster:monitor/xpack/usage/health_api", "cluster:monitor/xpack/usage/ilm", + "cluster:monitor/xpack/usage/inference", "cluster:monitor/xpack/usage/logstash", "cluster:monitor/xpack/usage/ml", "cluster:monitor/xpack/usage/monitoring",