Skip to content

Commit

Permalink
Merge 'main' into lucene_snapshot_9_9
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisHegarty committed Dec 4, 2023
2 parents 11185d0 + 8be0446 commit 4cf6f30
Show file tree
Hide file tree
Showing 24 changed files with 519 additions and 7 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/102877.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 102877
summary: Add basic telelemetry for the inference feature
area: Machine Learning
type: enhancement
issues: []
7 changes: 7 additions & 0 deletions docs/changelog/102891.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pr: 102891
summary: "[Query Rules] Fix bug where combining the same metadata with text/numeric\
\ values leads to error"
area: Application
type: bug
issues:
- 102827
5 changes: 5 additions & 0 deletions docs/reference/rest-api/usage.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ GET /_xpack/usage
},
"node_count" : 1
},
"inference": {
"available" : true,
"enabled" : true,
"models" : []
},
"logstash" : {
"available" : true,
"enabled" : true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_PROFILE = def(8_551_00_0);
public static final TransportVersion CLUSTER_STATS_RESCORER_USAGE_ADDED = def(8_552_00_0);
public static final TransportVersion ML_INFERENCE_HF_SERVICE_ADDED = def(8_553_00_0);
public static final TransportVersion UPGRADE_TO_LUCENE_9_9 = def(8_554_00_0);
public static final TransportVersion INFERENCE_USAGE_ADDED = def(8_554_00_0);
public static final TransportVersion UPGRADE_TO_LUCENE_9_9 = def(8_555_00_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@

public interface InferenceServiceResults extends NamedWriteable, ToXContentFragment {

/**
* Transform the result to match the format required for the TransportCoordinatedInferenceAction.
* For the inference plugin TextEmbeddingResults, the {@link #transformToLegacyFormat()} transforms the
* results into an intermediate format only used by the plugin's return value. It doesn't align with what the
* TransportCoordinatedInferenceAction expects. TransportCoordinatedInferenceAction expects an ml plugin
* TextEmbeddingResults.
*
* For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.
*/
List<? extends InferenceResults> transformToCoordinationFormat();

/**
* Transform the result to match the format required for versions prior to
* {@link org.elasticsearch.TransportVersions#INFERENCE_SERVICE_RESULTS_ADDED}
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugin/core/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
exports org.elasticsearch.xpack.core.indexing;
exports org.elasticsearch.xpack.core.inference.action;
exports org.elasticsearch.xpack.core.inference.results;
exports org.elasticsearch.xpack.core.inference;
exports org.elasticsearch.xpack.core.logstash;
exports org.elasticsearch.xpack.core.ml.action;
exports org.elasticsearch.xpack.core.ml.annotations;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.elasticsearch.xpack.core.ilm.TimeseriesLifecycleType;
import org.elasticsearch.xpack.core.ilm.UnfollowAction;
import org.elasticsearch.xpack.core.ilm.WaitForSnapshotAction;
import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage;
import org.elasticsearch.xpack.core.logstash.LogstashFeatureSetUsage;
import org.elasticsearch.xpack.core.ml.MachineLearningFeatureSetUsage;
import org.elasticsearch.xpack.core.ml.MlMetadata;
Expand Down Expand Up @@ -133,6 +134,8 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.LOGSTASH, LogstashFeatureSetUsage::new),
// ML
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MACHINE_LEARNING, MachineLearningFeatureSetUsage::new),
// inference
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.INFERENCE, InferenceFeatureSetUsage::new),
// monitoring
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),
// security
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public final class XPackField {
public static final String GRAPH = "graph";
/** Name constant for the machine learning feature. */
public static final String MACHINE_LEARNING = "ml";
/** Name constant for the inference feature. */
public static final String INFERENCE = "inference";
/** Name constant for the Logstash feature. */
public static final String LOGSTASH = "logstash";
/** Name constant for the Beats feature. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ public class XPackUsageFeatureAction extends ActionType<XPackUsageFeatureRespons
public static final XPackUsageFeatureAction WATCHER = new XPackUsageFeatureAction(XPackField.WATCHER);
public static final XPackUsageFeatureAction GRAPH = new XPackUsageFeatureAction(XPackField.GRAPH);
public static final XPackUsageFeatureAction MACHINE_LEARNING = new XPackUsageFeatureAction(XPackField.MACHINE_LEARNING);
public static final XPackUsageFeatureAction INFERENCE = new XPackUsageFeatureAction(XPackField.INFERENCE);
public static final XPackUsageFeatureAction LOGSTASH = new XPackUsageFeatureAction(XPackField.LOGSTASH);
public static final XPackUsageFeatureAction EQL = new XPackUsageFeatureAction(XPackField.EQL);
public static final XPackUsageFeatureAction ESQL = new XPackUsageFeatureAction(XPackField.ESQL);
Expand Down Expand Up @@ -64,6 +65,7 @@ public class XPackUsageFeatureAction extends ActionType<XPackUsageFeatureRespons
FROZEN_INDICES,
GRAPH,
INDEX_LIFECYCLE,
INFERENCE,
LOGSTASH,
MACHINE_LEARNING,
MONITORING,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.XPackFeatureSet;
import org.elasticsearch.xpack.core.XPackField;

import java.io.IOException;
import java.util.Collection;
import java.util.Objects;

public class InferenceFeatureSetUsage extends XPackFeatureSet.Usage {

public static class ModelStats implements ToXContentObject, Writeable {

private final String service;
private final TaskType taskType;
private long count;

public ModelStats(String service, TaskType taskType) {
this(service, taskType, 0L);
}

public ModelStats(String service, TaskType taskType, long count) {
this.service = service;
this.taskType = taskType;
this.count = count;
}

public ModelStats(ModelStats stats) {
this(stats.service, stats.taskType, stats.count);
}

public ModelStats(StreamInput in) throws IOException {
this.service = in.readString();
this.taskType = in.readEnum(TaskType.class);
this.count = in.readLong();
}

public void add() {
count++;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("service", service);
builder.field("task_type", taskType.name());
builder.field("count", count);
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(service);
out.writeEnum(taskType);
out.writeLong(count);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ModelStats that = (ModelStats) o;
return count == that.count && Objects.equals(service, that.service) && taskType == that.taskType;
}

@Override
public int hashCode() {
return Objects.hash(service, taskType, count);
}
}

private final Collection<ModelStats> modelStats;

public InferenceFeatureSetUsage(Collection<ModelStats> modelStats) {
super(XPackField.INFERENCE, true, true);
this.modelStats = modelStats;
}

public InferenceFeatureSetUsage(StreamInput in) throws IOException {
super(in);
this.modelStats = in.readCollectionAsList(ModelStats::new);
}

@Override
protected void innerXContent(XContentBuilder builder, Params params) throws IOException {
super.innerXContent(builder, params);
builder.xContentList("models", modelStats);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeCollection(modelStats);
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.INFERENCE_USAGE_ADDED;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ public Map<String, Object> asMap() {
return map;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return transformToLegacyFormat();
}

@Override
public List<? extends InferenceResults> transformToLegacyFormat() {
return embeddings.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ public String getWriteableName() {
return NAME;
}

@Override
public List<? extends InferenceResults> transformToCoordinationFormat() {
return embeddings.stream()
.map(embedding -> embedding.values.stream().mapToDouble(value -> value).toArray())
.map(values -> new org.elasticsearch.xpack.core.ml.inference.results.TextEmbeddingResults(TEXT_EMBEDDING, values, false))
.toList();
}

@Override
@SuppressWarnings("deprecation")
public List<? extends InferenceResults> transformToLegacyFormat() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference;

import com.carrotsearch.randomizedtesting.generators.RandomStrings;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.io.IOException;

public class InferenceFeatureSetUsageTests extends AbstractWireSerializingTestCase<InferenceFeatureSetUsage.ModelStats> {

@Override
protected Writeable.Reader<InferenceFeatureSetUsage.ModelStats> instanceReader() {
return InferenceFeatureSetUsage.ModelStats::new;
}

@Override
protected InferenceFeatureSetUsage.ModelStats createTestInstance() {
RandomStrings.randomAsciiLettersOfLength(random(), 10);
return new InferenceFeatureSetUsage.ModelStats(
randomIdentifier(),
TaskType.values()[randomInt(TaskType.values().length - 1)],
randomInt(10)
);
}

@Override
protected InferenceFeatureSetUsage.ModelStats mutateInstance(InferenceFeatureSetUsage.ModelStats modelStats) throws IOException {
InferenceFeatureSetUsage.ModelStats newModelStats = new InferenceFeatureSetUsage.ModelStats(modelStats);
newModelStats.add();
return newModelStats;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,46 @@ setup:
- match: { hits.hits.0._id: 'doc2' }
- match: { hits.hits.1._id: 'doc3' }

---
"Perform a rule query over a ruleset with combined numeric and text rule matching":

- do:
query_ruleset.put:
ruleset_id: combined-ruleset
body:
rules:
- rule_id: rule1
type: pinned
criteria:
- type: fuzzy
metadata: foo
values: [ bar ]
actions:
ids:
- 'doc1'
- rule_id: rule2
type: pinned
criteria:
- type: lte
metadata: foo
values: [ 100 ]
actions:
ids:
- 'doc2'
- do:
search:
body:
query:
rule_query:
organic:
query_string:
default_field: text
query: blah blah blah
match_criteria:
foo: baz
ruleset_id: combined-ruleset

- match: { hits.total.value: 1 }
- match: { hits.hits.0._id: 'doc1' }


Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ public AppliedQueryRules applyRule(AppliedQueryRules appliedRules, Map<String, O
final String criteriaMetadata = criterion.criteriaMetadata();

if (criteriaType == ALWAYS || (criteriaMetadata != null && criteriaMetadata.equals(match))) {
boolean singleCriterionMatches = criterion.isMatch(matchValue, criteriaType);
boolean singleCriterionMatches = criterion.isMatch(matchValue, criteriaType, false);
isRuleMatch = (isRuleMatch == null) ? singleCriterionMatches : isRuleMatch && singleCriterionMatches;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,19 @@ public String toString() {
}

public boolean isMatch(Object matchValue, QueryRuleCriteriaType matchType) {
return isMatch(matchValue, matchType, true);
}

public boolean isMatch(Object matchValue, QueryRuleCriteriaType matchType, boolean throwOnInvalidInput) {
if (matchType == ALWAYS) {
return true;
}
final String matchString = matchValue.toString();
for (Object criteriaValue : criteriaValues) {
matchType.validateInput(matchValue);
boolean isValid = matchType.validateInput(matchValue, throwOnInvalidInput);
if (isValid == false) {
return false;
}
boolean matchFound = matchType.isMatch(matchString, criteriaValue);
if (matchFound) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,16 @@ public boolean isMatch(Object input, Object criteriaValue) {
}
};

public void validateInput(Object input) {
public boolean validateInput(Object input, boolean throwOnInvalidInput) {
boolean isValid = isValidForInput(input);
if (isValid == false) {
if (isValid == false && throwOnInvalidInput) {
throw new IllegalArgumentException("Input [" + input + "] is not valid for CriteriaType [" + this + "]");
}
return isValid;
}

public boolean validateInput(Object input) {
return validateInput(input, true);
}

public abstract boolean isMatch(Object input, Object criteriaValue);
Expand Down
Loading

0 comments on commit 4cf6f30

Please sign in to comment.