diff --git a/.buildkite/pipelines/intake.yml b/.buildkite/pipelines/intake.yml index 6c8b8edfcbac1..4bc72aec20972 100644 --- a/.buildkite/pipelines/intake.yml +++ b/.buildkite/pipelines/intake.yml @@ -56,7 +56,7 @@ steps: timeout_in_minutes: 300 matrix: setup: - BWC_VERSION: ["8.16.2", "8.17.0", "8.18.0", "9.0.0"] + BWC_VERSION: ["8.15.6", "8.16.2", "8.17.0", "8.18.0", "9.0.0"] agents: provider: gcp image: family/elasticsearch-ubuntu-2004 diff --git a/.buildkite/pipelines/periodic.yml b/.buildkite/pipelines/periodic.yml index 69d11ef1dabb6..3d6095d0b9e63 100644 --- a/.buildkite/pipelines/periodic.yml +++ b/.buildkite/pipelines/periodic.yml @@ -448,7 +448,7 @@ steps: setup: ES_RUNTIME_JAVA: - openjdk21 - BWC_VERSION: ["8.16.2", "8.17.0", "8.18.0", "9.0.0"] + BWC_VERSION: ["8.15.6", "8.16.2", "8.17.0", "8.18.0", "9.0.0"] agents: provider: gcp image: family/elasticsearch-ubuntu-2004 @@ -490,7 +490,7 @@ steps: ES_RUNTIME_JAVA: - openjdk21 - openjdk23 - BWC_VERSION: ["8.16.2", "8.17.0", "8.18.0", "9.0.0"] + BWC_VERSION: ["8.15.6", "8.16.2", "8.17.0", "8.18.0", "9.0.0"] agents: provider: gcp image: family/elasticsearch-ubuntu-2004 diff --git a/.ci/snapshotBwcVersions b/.ci/snapshotBwcVersions index 5514fc376a285..f92881da7fea4 100644 --- a/.ci/snapshotBwcVersions +++ b/.ci/snapshotBwcVersions @@ -1,4 +1,5 @@ BWC_VERSION: + - "8.15.6" - "8.16.2" - "8.17.0" - "8.18.0" diff --git a/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/InternalDistributionBwcSetupPluginFuncTest.groovy b/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/InternalDistributionBwcSetupPluginFuncTest.groovy index 6d080e1c80763..bb100b6b23882 100644 --- a/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/InternalDistributionBwcSetupPluginFuncTest.groovy +++ b/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/InternalDistributionBwcSetupPluginFuncTest.groovy @@ -9,9 +9,10 @@ package org.elasticsearch.gradle.internal +import spock.lang.Unroll + import org.elasticsearch.gradle.fixtures.AbstractGitAwareGradleFuncTest import org.gradle.testkit.runner.TaskOutcome -import spock.lang.Unroll class InternalDistributionBwcSetupPluginFuncTest extends AbstractGitAwareGradleFuncTest { @@ -23,8 +24,10 @@ class InternalDistributionBwcSetupPluginFuncTest extends AbstractGitAwareGradleF apply plugin: 'elasticsearch.internal-distribution-bwc-setup' """ execute("git branch origin/8.x", file("cloned")) + execute("git branch origin/8.3", file("cloned")) + execute("git branch origin/8.2", file("cloned")) + execute("git branch origin/8.1", file("cloned")) execute("git branch origin/7.16", file("cloned")) - execute("git branch origin/7.15", file("cloned")) } def "builds distribution from branches via archives extractedAssemble"() { @@ -48,10 +51,11 @@ class InternalDistributionBwcSetupPluginFuncTest extends AbstractGitAwareGradleF assertOutputContains(result.output, "[$bwcDistVersion] > Task :distribution:archives:darwin-tar:${expectedAssembleTaskName}") where: - bwcDistVersion | bwcProject | expectedAssembleTaskName - "8.0.0" | "minor" | "extractedAssemble" - "7.16.0" | "staged" | "extractedAssemble" - "7.15.2" | "bugfix" | "extractedAssemble" + bwcDistVersion | bwcProject | expectedAssembleTaskName + "8.4.0" | "minor" | "extractedAssemble" + "8.3.0" | "staged" | "extractedAssemble" + "8.2.1" | "bugfix" | "extractedAssemble" + "8.1.3" | "bugfix2" | "extractedAssemble" } @Unroll @@ -70,8 +74,8 @@ class InternalDistributionBwcSetupPluginFuncTest extends AbstractGitAwareGradleF where: bwcDistVersion | platform - "8.0.0" | "darwin" - "8.0.0" | "linux" + "8.4.0" | "darwin" + "8.4.0" | "linux" } def "bwc expanded distribution folder can be resolved as bwc project artifact"() { @@ -107,11 +111,11 @@ class InternalDistributionBwcSetupPluginFuncTest extends AbstractGitAwareGradleF result.task(":resolveExpandedDistribution").outcome == TaskOutcome.SUCCESS result.task(":distribution:bwc:minor:buildBwcDarwinTar").outcome == TaskOutcome.SUCCESS and: "assemble task triggered" - result.output.contains("[8.0.0] > Task :distribution:archives:darwin-tar:extractedAssemble") + result.output.contains("[8.4.0] > Task :distribution:archives:darwin-tar:extractedAssemble") result.output.contains("expandedRootPath /distribution/bwc/minor/build/bwc/checkout-8.x/" + "distribution/archives/darwin-tar/build/install") result.output.contains("nested folder /distribution/bwc/minor/build/bwc/checkout-8.x/" + - "distribution/archives/darwin-tar/build/install/elasticsearch-8.0.0-SNAPSHOT") + "distribution/archives/darwin-tar/build/install/elasticsearch-8.4.0-SNAPSHOT") } } diff --git a/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/InternalDistributionDownloadPluginFuncTest.groovy b/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/InternalDistributionDownloadPluginFuncTest.groovy index eb6185e5aed57..fc5d432a9ef9a 100644 --- a/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/InternalDistributionDownloadPluginFuncTest.groovy +++ b/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/InternalDistributionDownloadPluginFuncTest.groovy @@ -57,7 +57,7 @@ class InternalDistributionDownloadPluginFuncTest extends AbstractGradleFuncTest elasticsearch_distributions { test_distro { - version = "8.0.0" + version = "8.4.0" type = "archive" platform = "linux" architecture = Architecture.current(); @@ -87,7 +87,7 @@ class InternalDistributionDownloadPluginFuncTest extends AbstractGradleFuncTest elasticsearch_distributions { test_distro { - version = "8.0.0" + version = "8.4.0" type = "archive" platform = "linux" architecture = Architecture.current(); diff --git a/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/test/rest/LegacyYamlRestCompatTestPluginFuncTest.groovy b/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/test/rest/LegacyYamlRestCompatTestPluginFuncTest.groovy index e3efe3d7ffbf7..15b057a05e039 100644 --- a/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/test/rest/LegacyYamlRestCompatTestPluginFuncTest.groovy +++ b/build-tools-internal/src/integTest/groovy/org/elasticsearch/gradle/internal/test/rest/LegacyYamlRestCompatTestPluginFuncTest.groovy @@ -40,7 +40,7 @@ class LegacyYamlRestCompatTestPluginFuncTest extends AbstractRestResourcesFuncTe given: internalBuild() - subProject(":distribution:bwc:staged") << """ + subProject(":distribution:bwc:minor") << """ configurations { checkout } artifacts { checkout(new File(projectDir, "checkoutDir")) @@ -61,11 +61,11 @@ class LegacyYamlRestCompatTestPluginFuncTest extends AbstractRestResourcesFuncTe result.task(transformTask).outcome == TaskOutcome.NO_SOURCE } - def "yamlRestCompatTest executes and copies api and transforms tests from :bwc:staged"() { + def "yamlRestCompatTest executes and copies api and transforms tests from :bwc:minor"() { given: internalBuild() - subProject(":distribution:bwc:staged") << """ + subProject(":distribution:bwc:minor") << """ configurations { checkout } artifacts { checkout(new File(projectDir, "checkoutDir")) @@ -98,8 +98,8 @@ class LegacyYamlRestCompatTestPluginFuncTest extends AbstractRestResourcesFuncTe String api = "foo.json" String test = "10_basic.yml" //add the compatible test and api files, these are the prior version's normal yaml rest tests - file("distribution/bwc/staged/checkoutDir/rest-api-spec/src/main/resources/rest-api-spec/api/" + api) << "" - file("distribution/bwc/staged/checkoutDir/src/yamlRestTest/resources/rest-api-spec/test/" + test) << "" + file("distribution/bwc/minor/checkoutDir/rest-api-spec/src/main/resources/rest-api-spec/api/" + api) << "" + file("distribution/bwc/minor/checkoutDir/src/yamlRestTest/resources/rest-api-spec/test/" + test) << "" when: def result = gradleRunner("yamlRestCompatTest").build() @@ -145,7 +145,7 @@ class LegacyYamlRestCompatTestPluginFuncTest extends AbstractRestResourcesFuncTe given: internalBuild() withVersionCatalogue() - subProject(":distribution:bwc:staged") << """ + subProject(":distribution:bwc:minor") << """ configurations { checkout } artifacts { checkout(new File(projectDir, "checkoutDir")) @@ -186,7 +186,7 @@ class LegacyYamlRestCompatTestPluginFuncTest extends AbstractRestResourcesFuncTe given: internalBuild() - subProject(":distribution:bwc:staged") << """ + subProject(":distribution:bwc:minor") << """ configurations { checkout } artifacts { checkout(new File(projectDir, "checkoutDir")) @@ -230,7 +230,7 @@ class LegacyYamlRestCompatTestPluginFuncTest extends AbstractRestResourcesFuncTe setupRestResources([], []) - file("distribution/bwc/staged/checkoutDir/src/yamlRestTest/resources/rest-api-spec/test/test.yml" ) << """ + file("distribution/bwc/minor/checkoutDir/src/yamlRestTest/resources/rest-api-spec/test/test.yml" ) << """ "one": - do: do_.some.key_to_replace: diff --git a/build-tools-internal/src/integTest/resources/org/elasticsearch/gradle/internal/fake_git/remote/distribution/bwc/bugfix2/build.gradle b/build-tools-internal/src/integTest/resources/org/elasticsearch/gradle/internal/fake_git/remote/distribution/bwc/bugfix2/build.gradle new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/build-tools-internal/src/integTest/resources/org/elasticsearch/gradle/internal/fake_git/remote/distribution/bwc/maintenance/build.gradle b/build-tools-internal/src/integTest/resources/org/elasticsearch/gradle/internal/fake_git/remote/distribution/bwc/maintenance/build.gradle new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/build-tools-internal/src/integTest/resources/org/elasticsearch/gradle/internal/fake_git/remote/settings.gradle b/build-tools-internal/src/integTest/resources/org/elasticsearch/gradle/internal/fake_git/remote/settings.gradle index 8c321294b585f..e931537fcd6e9 100644 --- a/build-tools-internal/src/integTest/resources/org/elasticsearch/gradle/internal/fake_git/remote/settings.gradle +++ b/build-tools-internal/src/integTest/resources/org/elasticsearch/gradle/internal/fake_git/remote/settings.gradle @@ -10,9 +10,11 @@ rootProject.name = "root" include ":distribution:bwc:bugfix" +include ":distribution:bwc:bugfix2" include ":distribution:bwc:minor" include ":distribution:bwc:major" include ":distribution:bwc:staged" +include ":distribution:bwc:maintenance" include ":distribution:archives:darwin-tar" include ":distribution:archives:oss-darwin-tar" include ":distribution:archives:darwin-aarch64-tar" diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcVersions.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcVersions.java index 93c2623a23d31..37b28389ad97b 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcVersions.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/BwcVersions.java @@ -21,14 +21,15 @@ import java.util.Optional; import java.util.Set; import java.util.TreeMap; -import java.util.TreeSet; import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.Predicate; import java.util.regex.Matcher; import java.util.regex.Pattern; +import static java.util.Collections.reverseOrder; import static java.util.Collections.unmodifiableList; +import static java.util.Comparator.comparing; /** * A container for elasticsearch supported version information used in BWC testing. @@ -73,11 +74,11 @@ public class BwcVersions implements Serializable { private final transient List versions; private final Map unreleased; - public BwcVersions(List versionLines) { - this(versionLines, Version.fromString(VersionProperties.getElasticsearch())); + public BwcVersions(List versionLines, List developmentBranches) { + this(versionLines, Version.fromString(VersionProperties.getElasticsearch()), developmentBranches); } - public BwcVersions(Version currentVersionProperty, List allVersions) { + public BwcVersions(Version currentVersionProperty, List allVersions, List developmentBranches) { if (allVersions.isEmpty()) { throw new IllegalArgumentException("Could not parse any versions"); } @@ -86,12 +87,12 @@ public BwcVersions(Version currentVersionProperty, List allVersions) { this.currentVersion = allVersions.get(allVersions.size() - 1); assertCurrentVersionMatchesParsed(currentVersionProperty); - this.unreleased = computeUnreleased(); + this.unreleased = computeUnreleased(developmentBranches); } // Visible for testing - BwcVersions(List versionLines, Version currentVersionProperty) { - this(currentVersionProperty, parseVersionLines(versionLines)); + BwcVersions(List versionLines, Version currentVersionProperty, List developmentBranches) { + this(currentVersionProperty, parseVersionLines(versionLines), developmentBranches); } private static List parseVersionLines(List versionLines) { @@ -126,58 +127,77 @@ public void forPreviousUnreleased(Consumer consumer) { getUnreleased().stream().filter(version -> version.equals(currentVersion) == false).map(unreleased::get).forEach(consumer); } - private String getBranchFor(Version version) { - if (version.equals(currentVersion)) { - // Just assume the current branch is 'main'. It's actually not important, we never check out the current branch. - return "main"; - } else { + private String getBranchFor(Version version, List developmentBranches) { + // If the current version matches a specific feature freeze branch, use that + if (developmentBranches.contains(version.getMajor() + "." + version.getMinor())) { return version.getMajor() + "." + version.getMinor(); + } else if (developmentBranches.contains(version.getMajor() + ".x")) { // Otherwise if an n.x branch exists and we are that major + return version.getMajor() + ".x"; + } else { // otherwise we're the main branch + return "main"; } } - private Map computeUnreleased() { - Set unreleased = new TreeSet<>(); - // The current version is being worked, is always unreleased - unreleased.add(currentVersion); - // Recurse for all unreleased versions starting from the current version - addUnreleased(unreleased, currentVersion, 0); + private Map computeUnreleased(List developmentBranches) { + Map result = new TreeMap<>(); - // Grab the latest version from the previous major if necessary as well, this is going to be a maintenance release - Version maintenance = versions.stream() - .filter(v -> v.getMajor() == currentVersion.getMajor() - 1) - .max(Comparator.naturalOrder()) - .orElseThrow(); - // This is considered the maintenance release only if we haven't yet encountered it - boolean hasMaintenanceRelease = unreleased.add(maintenance); + // The current version is always in development + String currentBranch = getBranchFor(currentVersion, developmentBranches); + result.put(currentVersion, new UnreleasedVersionInfo(currentVersion, currentBranch, ":distribution")); + + // Check for an n.x branch as well + if (currentBranch.equals("main") && developmentBranches.stream().anyMatch(s -> s.endsWith(".x"))) { + // This should correspond to the latest new minor + Version version = versions.stream() + .sorted(Comparator.reverseOrder()) + .filter(v -> v.getMajor() == (currentVersion.getMajor() - 1) && v.getRevision() == 0) + .findFirst() + .orElseThrow(() -> new IllegalStateException("Unable to determine development version for branch")); + String branch = getBranchFor(version, developmentBranches); + assert branch.equals(currentVersion.getMajor() - 1 + ".x") : "Expected branch does not match development branch"; + + result.put(version, new UnreleasedVersionInfo(version, branch, ":distribution:bwc:minor")); + } - List unreleasedList = unreleased.stream().sorted(Comparator.reverseOrder()).toList(); - Map result = new TreeMap<>(); - boolean newMinor = false; - for (int i = 0; i < unreleasedList.size(); i++) { - Version esVersion = unreleasedList.get(i); - // This is either a new minor or staged release - if (currentVersion.equals(esVersion)) { - result.put(esVersion, new UnreleasedVersionInfo(esVersion, getBranchFor(esVersion), ":distribution")); - } else if (esVersion.getRevision() == 0) { - // If there are two upcoming unreleased minors then this one is the new minor - if (newMinor == false && unreleasedList.get(i + 1).getRevision() == 0) { - result.put(esVersion, new UnreleasedVersionInfo(esVersion, esVersion.getMajor() + ".x", ":distribution:bwc:minor")); - newMinor = true; - } else if (newMinor == false - && unreleasedList.stream().filter(v -> v.getMajor() == esVersion.getMajor() && v.getRevision() == 0).count() == 1) { - // This is the only unreleased new minor which means we've not yet staged it for release - result.put(esVersion, new UnreleasedVersionInfo(esVersion, esVersion.getMajor() + ".x", ":distribution:bwc:minor")); - newMinor = true; - } else { - result.put(esVersion, new UnreleasedVersionInfo(esVersion, getBranchFor(esVersion), ":distribution:bwc:staged")); - } - } else { - // If this is the oldest unreleased version and we have a maintenance release - if (i == unreleasedList.size() - 1 && hasMaintenanceRelease) { - result.put(esVersion, new UnreleasedVersionInfo(esVersion, getBranchFor(esVersion), ":distribution:bwc:maintenance")); - } else { - result.put(esVersion, new UnreleasedVersionInfo(esVersion, getBranchFor(esVersion), ":distribution:bwc:bugfix")); - } + // Now handle all the feature freeze branches + List featureFreezeBranches = developmentBranches.stream() + .filter(b -> Pattern.matches("[0-9]+\\.[0-9]+", b)) + .sorted(reverseOrder(comparing(s -> Version.fromString(s, Version.Mode.RELAXED)))) + .toList(); + + boolean existingBugfix = false; + for (int i = 0; i < featureFreezeBranches.size(); i++) { + String branch = featureFreezeBranches.get(i); + Version version = versions.stream() + .sorted(Comparator.reverseOrder()) + .filter(v -> v.toString().startsWith(branch)) + .findFirst() + .orElse(null); + + // If we don't know about this version we can ignore it + if (version == null) { + continue; + } + + // If this is the current version we can ignore as we've already handled it + if (version.equals(currentVersion)) { + continue; + } + + // We only maintain compatibility back one major so ignore anything older + if (currentVersion.getMajor() - version.getMajor() > 1) { + continue; + } + + // This is the maintenance version + if (i == featureFreezeBranches.size() - 1) { + result.put(version, new UnreleasedVersionInfo(version, branch, ":distribution:bwc:maintenance")); + } else if (version.getRevision() == 0) { // This is the next staged minor + result.put(version, new UnreleasedVersionInfo(version, branch, ":distribution:bwc:staged")); + } else { // This is a bugfix + String project = existingBugfix ? "bugfix2" : "bugfix"; + result.put(version, new UnreleasedVersionInfo(version, branch, ":distribution:bwc:" + project)); + existingBugfix = true; } } diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/info/GlobalBuildInfoPlugin.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/info/GlobalBuildInfoPlugin.java index 0535026b2594e..27d2a66feb206 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/info/GlobalBuildInfoPlugin.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/info/GlobalBuildInfoPlugin.java @@ -8,6 +8,9 @@ */ package org.elasticsearch.gradle.internal.info; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + import org.apache.commons.io.IOUtils; import org.elasticsearch.gradle.VersionProperties; import org.elasticsearch.gradle.internal.BwcVersions; @@ -44,11 +47,13 @@ import java.io.File; import java.io.FileInputStream; import java.io.IOException; +import java.io.InputStream; import java.io.InputStreamReader; import java.io.UncheckedIOException; import java.nio.file.Files; import java.time.ZoneOffset; import java.time.ZonedDateTime; +import java.util.ArrayList; import java.util.List; import java.util.Locale; import java.util.Random; @@ -68,6 +73,7 @@ public class GlobalBuildInfoPlugin implements Plugin { private final JavaInstallationRegistry javaInstallationRegistry; private final JvmMetadataDetector metadataDetector; private final ProviderFactory providers; + private final ObjectMapper objectMapper; private JavaToolchainService toolChainService; private Project project; @@ -82,7 +88,7 @@ public GlobalBuildInfoPlugin( this.javaInstallationRegistry = javaInstallationRegistry; this.metadataDetector = new ErrorTraceMetadataDetector(metadataDetector); this.providers = providers; - + this.objectMapper = new ObjectMapper(); } @Override @@ -190,12 +196,27 @@ private BwcVersions resolveBwcVersions() { ); try (var is = new FileInputStream(versionsFilePath)) { List versionLines = IOUtils.readLines(is, "UTF-8"); - return new BwcVersions(versionLines); + return new BwcVersions(versionLines, getDevelopmentBranches()); } catch (IOException e) { throw new IllegalStateException("Unable to resolve to resolve bwc versions from versionsFile.", e); } } + private List getDevelopmentBranches() { + List branches = new ArrayList<>(); + File branchesFile = new File(Util.locateElasticsearchWorkspace(project.getGradle()), "branches.json"); + try (InputStream is = new FileInputStream(branchesFile)) { + JsonNode json = objectMapper.readTree(is); + for (JsonNode node : json.get("branches")) { + branches.add(node.get("branch").asText()); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + return branches; + } + private void logGlobalBuildInfo(BuildParameterExtension buildParams) { final String osName = System.getProperty("os.name"); final String osVersion = System.getProperty("os.version"); diff --git a/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/BwcVersionsSpec.groovy b/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/BwcVersionsSpec.groovy index 9c7d20d84a670..4d033564a42b4 100644 --- a/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/BwcVersionsSpec.groovy +++ b/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/BwcVersionsSpec.groovy @@ -17,8 +17,9 @@ import org.elasticsearch.gradle.internal.BwcVersions.UnreleasedVersionInfo class BwcVersionsSpec extends Specification { List versionLines = [] - def "current version is next minor with next major and last minor both staged"() { + def "current version is next major"() { given: + addVersion('7.17.10', '8.9.0') addVersion('8.14.0', '9.9.0') addVersion('8.14.1', '9.9.0') addVersion('8.14.2', '9.9.0') @@ -29,25 +30,25 @@ class BwcVersionsSpec extends Specification { addVersion('8.16.1', '9.10.0') addVersion('8.17.0', '9.10.0') addVersion('9.0.0', '10.0.0') - addVersion('9.1.0', '10.1.0') when: - def bwc = new BwcVersions(versionLines, v('9.1.0')) + def bwc = new BwcVersions(versionLines, v('9.0.0'), ['main', '8.x', '8.16', '8.15', '7.17']) def unreleased = bwc.unreleased.collectEntries { [it, bwc.unreleasedInfo(it)] } then: unreleased == [ + (v('8.15.2')): new UnreleasedVersionInfo(v('8.15.2'), '8.15', ':distribution:bwc:bugfix2'), (v('8.16.1')): new UnreleasedVersionInfo(v('8.16.1'), '8.16', ':distribution:bwc:bugfix'), - (v('8.17.0')): new UnreleasedVersionInfo(v('8.17.0'), '8.17', ':distribution:bwc:staged'), - (v('9.0.0')): new UnreleasedVersionInfo(v('9.0.0'), '9.x', ':distribution:bwc:minor'), - (v('9.1.0')): new UnreleasedVersionInfo(v('9.1.0'), 'main', ':distribution') + (v('8.17.0')): new UnreleasedVersionInfo(v('8.17.0'), '8.x', ':distribution:bwc:minor'), + (v('9.0.0')): new UnreleasedVersionInfo(v('9.0.0'), 'main', ':distribution'), ] - bwc.wireCompatible == [v('8.17.0'), v('9.0.0'), v('9.1.0')] - bwc.indexCompatible == [v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('9.0.0'), v('9.1.0')] + bwc.wireCompatible == [v('8.17.0'), v('9.0.0')] + bwc.indexCompatible == [v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('9.0.0')] } - def "current is next minor with upcoming minor staged"() { + def "current version is next major with staged minor"() { given: + addVersion('7.17.10', '8.9.0') addVersion('8.14.0', '9.9.0') addVersion('8.14.1', '9.9.0') addVersion('8.14.2', '9.9.0') @@ -57,53 +58,106 @@ class BwcVersionsSpec extends Specification { addVersion('8.16.0', '9.10.0') addVersion('8.16.1', '9.10.0') addVersion('8.17.0', '9.10.0') - addVersion('8.17.1', '9.10.0') + addVersion('8.18.0', '9.10.0') addVersion('9.0.0', '10.0.0') - addVersion('9.1.0', '10.1.0') when: - def bwc = new BwcVersions(versionLines, v('9.1.0')) + def bwc = new BwcVersions(versionLines, v('9.0.0'), ['main', '8.x', '8.17', '8.16', '8.15', '7.17']) def unreleased = bwc.unreleased.collectEntries { [it, bwc.unreleasedInfo(it)] } then: unreleased == [ - (v('8.17.1')): new UnreleasedVersionInfo(v('8.17.1'), '8.17', ':distribution:bwc:bugfix'), + (v('8.15.2')): new UnreleasedVersionInfo(v('8.15.2'), '8.15', ':distribution:bwc:bugfix2'), + (v('8.16.1')): new UnreleasedVersionInfo(v('8.16.1'), '8.16', ':distribution:bwc:bugfix'), + (v('8.17.0')): new UnreleasedVersionInfo(v('8.17.0'), '8.17', ':distribution:bwc:staged'), + (v('8.18.0')): new UnreleasedVersionInfo(v('8.18.0'), '8.x', ':distribution:bwc:minor'), + (v('9.0.0')): new UnreleasedVersionInfo(v('9.0.0'), 'main', ':distribution'), + ] + bwc.wireCompatible == [v('8.18.0'), v('9.0.0')] + bwc.indexCompatible == [v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.18.0'), v('9.0.0')] + } + + def "current version is first new minor in major series"() { + given: + addVersion('7.17.10', '8.9.0') + addVersion('8.16.0', '9.10.0') + addVersion('8.16.1', '9.10.0') + addVersion('8.17.0', '9.10.0') + addVersion('8.18.0', '9.10.0') + addVersion('9.0.0', '10.0.0') + addVersion('9.1.0', '10.0.0') + + when: + def bwc = new BwcVersions(versionLines, v('9.1.0'), ['main', '9.0', '8.18']) + def unreleased = bwc.unreleased.collectEntries { [it, bwc.unreleasedInfo(it)] } + + then: + unreleased == [ + (v('8.18.0')): new UnreleasedVersionInfo(v('8.18.0'), '8.18', ':distribution:bwc:maintenance'), (v('9.0.0')): new UnreleasedVersionInfo(v('9.0.0'), '9.0', ':distribution:bwc:staged'), - (v('9.1.0')): new UnreleasedVersionInfo(v('9.1.0'), 'main', ':distribution') + (v('9.1.0')): new UnreleasedVersionInfo(v('9.1.0'), 'main', ':distribution'), ] - bwc.wireCompatible == [v('8.17.0'), v('8.17.1'), v('9.0.0'), v('9.1.0')] - bwc.indexCompatible == [v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.17.1'), v('9.0.0'), v('9.1.0')] + bwc.wireCompatible == [v('8.18.0'), v('9.0.0'), v('9.1.0')] + bwc.indexCompatible == [v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.18.0'), v('9.0.0'), v('9.1.0')] } - def "current version is staged major"() { + def "current version is new minor with single bugfix"() { given: - addVersion('8.14.0', '9.9.0') - addVersion('8.14.1', '9.9.0') - addVersion('8.14.2', '9.9.0') - addVersion('8.15.0', '9.9.0') - addVersion('8.15.1', '9.9.0') - addVersion('8.15.2', '9.9.0') + addVersion('7.17.10', '8.9.0') addVersion('8.16.0', '9.10.0') addVersion('8.16.1', '9.10.0') addVersion('8.17.0', '9.10.0') - addVersion('8.17.1', '9.10.0') + addVersion('8.18.0', '9.10.0') addVersion('9.0.0', '10.0.0') + addVersion('9.0.1', '10.0.0') + addVersion('9.1.0', '10.0.0') when: - def bwc = new BwcVersions(versionLines, v('9.0.0')) + def bwc = new BwcVersions(versionLines, v('9.1.0'), ['main', '9.0', '8.18']) def unreleased = bwc.unreleased.collectEntries { [it, bwc.unreleasedInfo(it)] } then: unreleased == [ - (v('8.17.1')): new UnreleasedVersionInfo(v('8.17.1'), '8.17', ':distribution:bwc:bugfix'), - (v('9.0.0')): new UnreleasedVersionInfo(v('9.0.0'), 'main', ':distribution'), + (v('8.18.0')): new UnreleasedVersionInfo(v('8.18.0'), '8.18', ':distribution:bwc:maintenance'), + (v('9.0.1')): new UnreleasedVersionInfo(v('9.0.1'), '9.0', ':distribution:bwc:bugfix'), + (v('9.1.0')): new UnreleasedVersionInfo(v('9.1.0'), 'main', ':distribution'), ] - bwc.wireCompatible == [v('8.17.0'), v('8.17.1'), v('9.0.0')] - bwc.indexCompatible == [v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.17.1'), v('9.0.0')] + bwc.wireCompatible == [v('8.18.0'), v('9.0.0'), v('9.0.1'), v('9.1.0')] + bwc.indexCompatible == [v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.18.0'), v('9.0.0'), v('9.0.1'), v('9.1.0')] } - def "current version is major with unreleased next minor"() { + def "current version is new minor with single bugfix and staged minor"() { given: + addVersion('7.17.10', '8.9.0') + addVersion('8.16.0', '9.10.0') + addVersion('8.16.1', '9.10.0') + addVersion('8.17.0', '9.10.0') + addVersion('8.18.0', '9.10.0') + addVersion('9.0.0', '10.0.0') + addVersion('9.0.1', '10.0.0') + addVersion('9.1.0', '10.0.0') + addVersion('9.2.0', '10.0.0') + + when: + def bwc = new BwcVersions(versionLines, v('9.2.0'), ['main', '9.1', '9.0', '8.18']) + def unreleased = bwc.unreleased.collectEntries { [it, bwc.unreleasedInfo(it)] } + + then: + unreleased == [ + (v('8.18.0')): new UnreleasedVersionInfo(v('8.18.0'), '8.18', ':distribution:bwc:maintenance'), + (v('9.0.1')): new UnreleasedVersionInfo(v('9.0.1'), '9.0', ':distribution:bwc:bugfix'), + (v('9.1.0')): new UnreleasedVersionInfo(v('9.1.0'), '9.1', ':distribution:bwc:staged'), + (v('9.2.0')): new UnreleasedVersionInfo(v('9.2.0'), 'main', ':distribution'), + ] + bwc.wireCompatible == [v('8.18.0'), v('9.0.0'), v('9.0.1'), v('9.1.0'), v('9.2.0')] + bwc.indexCompatible == [v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.18.0'), v('9.0.0'), v('9.0.1'), v('9.1.0'), v('9.2.0')] + } + + def "current version is next minor"() { + given: + addVersion('7.16.3', '8.9.0') + addVersion('7.17.0', '8.9.0') + addVersion('7.17.1', '8.9.0') addVersion('8.14.0', '9.9.0') addVersion('8.14.1', '9.9.0') addVersion('8.14.2', '9.9.0') @@ -113,24 +167,29 @@ class BwcVersionsSpec extends Specification { addVersion('8.16.0', '9.10.0') addVersion('8.16.1', '9.10.0') addVersion('8.17.0', '9.10.0') - addVersion('9.0.0', '10.0.0') + addVersion('8.17.1', '9.10.0') + addVersion('8.18.0', '9.10.0') when: - def bwc = new BwcVersions(versionLines, v('9.0.0')) + def bwc = new BwcVersions(versionLines, v('8.18.0'), ['main', '8.x', '8.17', '8.16', '7.17']) def unreleased = bwc.unreleased.collectEntries { [it, bwc.unreleasedInfo(it)] } then: unreleased == [ - (v('8.16.1')): new UnreleasedVersionInfo(v('8.16.1'), '8.16', ':distribution:bwc:bugfix'), - (v('8.17.0')): new UnreleasedVersionInfo(v('8.17.0'), '8.x', ':distribution:bwc:minor'), - (v('9.0.0')): new UnreleasedVersionInfo(v('9.0.0'), 'main', ':distribution'), + (v('7.17.1')): new UnreleasedVersionInfo(v('7.17.1'), '7.17', ':distribution:bwc:maintenance'), + (v('8.16.1')): new UnreleasedVersionInfo(v('8.16.1'), '8.16', ':distribution:bwc:bugfix2'), + (v('8.17.1')): new UnreleasedVersionInfo(v('8.17.1'), '8.17', ':distribution:bwc:bugfix'), + (v('8.18.0')): new UnreleasedVersionInfo(v('8.18.0'), '8.x', ':distribution'), ] - bwc.wireCompatible == [v('8.17.0'), v('9.0.0')] - bwc.indexCompatible == [v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('9.0.0')] + bwc.wireCompatible == [v('7.17.0'), v('7.17.1'), v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.17.1'), v('8.18.0')] + bwc.indexCompatible == [v('7.16.3'), v('7.17.0'), v('7.17.1'), v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.17.1'), v('8.18.0')] } - def "current version is major with staged next minor"() { + def "current version is new minor with staged minor"() { given: + addVersion('7.16.3', '8.9.0') + addVersion('7.17.0', '8.9.0') + addVersion('7.17.1', '8.9.0') addVersion('8.14.0', '9.9.0') addVersion('8.14.1', '9.9.0') addVersion('8.14.2', '9.9.0') @@ -138,26 +197,31 @@ class BwcVersionsSpec extends Specification { addVersion('8.15.1', '9.9.0') addVersion('8.15.2', '9.9.0') addVersion('8.16.0', '9.10.0') + addVersion('8.16.1', '9.10.0') addVersion('8.17.0', '9.10.0') - addVersion('9.0.0', '10.0.0') + addVersion('8.18.0', '9.10.0') when: - def bwc = new BwcVersions(versionLines, v('9.0.0')) + def bwc = new BwcVersions(versionLines, v('8.18.0'), ['main', '8.x', '8.17', '8.16', '8.15', '7.17']) def unreleased = bwc.unreleased.collectEntries { [it, bwc.unreleasedInfo(it)] } then: unreleased == [ - (v('8.15.2')): new UnreleasedVersionInfo(v('8.15.2'), '8.15', ':distribution:bwc:bugfix'), - (v('8.16.0')): new UnreleasedVersionInfo(v('8.16.0'), '8.16', ':distribution:bwc:staged'), - (v('8.17.0')): new UnreleasedVersionInfo(v('8.17.0'), '8.x', ':distribution:bwc:minor'), - (v('9.0.0')): new UnreleasedVersionInfo(v('9.0.0'), 'main', ':distribution'), + (v('7.17.1')): new UnreleasedVersionInfo(v('7.17.1'), '7.17', ':distribution:bwc:maintenance'), + (v('8.15.2')): new UnreleasedVersionInfo(v('8.15.2'), '8.15', ':distribution:bwc:bugfix2'), + (v('8.16.1')): new UnreleasedVersionInfo(v('8.16.1'), '8.16', ':distribution:bwc:bugfix'), + (v('8.17.0')): new UnreleasedVersionInfo(v('8.17.0'), '8.17', ':distribution:bwc:staged'), + (v('8.18.0')): new UnreleasedVersionInfo(v('8.18.0'), '8.x', ':distribution'), ] - bwc.wireCompatible == [v('8.17.0'), v('9.0.0')] - bwc.indexCompatible == [v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.17.0'), v('9.0.0')] + bwc.wireCompatible == [v('7.17.0'), v('7.17.1'), v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.18.0')] + bwc.indexCompatible == [v('7.16.3'), v('7.17.0'), v('7.17.1'), v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.18.0')] } - def "current version is next bugfix"() { + def "current version is first bugfix"() { given: + addVersion('7.16.3', '8.9.0') + addVersion('7.17.0', '8.9.0') + addVersion('7.17.1', '8.9.0') addVersion('8.14.0', '9.9.0') addVersion('8.14.1', '9.9.0') addVersion('8.14.2', '9.9.0') @@ -166,52 +230,44 @@ class BwcVersionsSpec extends Specification { addVersion('8.15.2', '9.9.0') addVersion('8.16.0', '9.10.0') addVersion('8.16.1', '9.10.0') - addVersion('8.17.0', '9.10.0') - addVersion('8.17.1', '9.10.0') - addVersion('9.0.0', '10.0.0') - addVersion('9.0.1', '10.0.0') when: - def bwc = new BwcVersions(versionLines, v('9.0.1')) + def bwc = new BwcVersions(versionLines, v('8.16.1'), ['main', '8.x', '8.17', '8.16', '8.15', '7.17']) def unreleased = bwc.unreleased.collectEntries { [it, bwc.unreleasedInfo(it)] } then: unreleased == [ - (v('8.17.1')): new UnreleasedVersionInfo(v('8.17.1'), '8.17', ':distribution:bwc:maintenance'), - (v('9.0.1')): new UnreleasedVersionInfo(v('9.0.1'), 'main', ':distribution'), + (v('7.17.1')): new UnreleasedVersionInfo(v('7.17.1'), '7.17', ':distribution:bwc:maintenance'), + (v('8.15.2')): new UnreleasedVersionInfo(v('8.15.2'), '8.15', ':distribution:bwc:bugfix'), + (v('8.16.1')): new UnreleasedVersionInfo(v('8.16.1'), '8.16', ':distribution'), ] - bwc.wireCompatible == [v('8.17.0'), v('8.17.1'), v('9.0.0'), v('9.0.1')] - bwc.indexCompatible == [v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.17.1'), v('9.0.0'), v('9.0.1')] + bwc.wireCompatible == [v('7.17.0'), v('7.17.1'), v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1')] + bwc.indexCompatible == [v('7.16.3'), v('7.17.0'), v('7.17.1'), v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1')] } - def "current version is next minor with no staged releases"() { + def "current version is second bugfix"() { given: + addVersion('7.16.3', '8.9.0') + addVersion('7.17.0', '8.9.0') + addVersion('7.17.1', '8.9.0') addVersion('8.14.0', '9.9.0') addVersion('8.14.1', '9.9.0') addVersion('8.14.2', '9.9.0') addVersion('8.15.0', '9.9.0') addVersion('8.15.1', '9.9.0') addVersion('8.15.2', '9.9.0') - addVersion('8.16.0', '9.10.0') - addVersion('8.16.1', '9.10.0') - addVersion('8.17.0', '9.10.0') - addVersion('8.17.1', '9.10.0') - addVersion('9.0.0', '10.0.0') - addVersion('9.0.1', '10.0.0') - addVersion('9.1.0', '10.1.0') when: - def bwc = new BwcVersions(versionLines, v('9.1.0')) + def bwc = new BwcVersions(versionLines, v('8.15.2'), ['main', '8.x', '8.17', '8.16', '8.15', '7.17']) def unreleased = bwc.unreleased.collectEntries { [it, bwc.unreleasedInfo(it)] } then: unreleased == [ - (v('8.17.1')): new UnreleasedVersionInfo(v('8.17.1'), '8.17', ':distribution:bwc:maintenance'), - (v('9.0.1')): new UnreleasedVersionInfo(v('9.0.1'), '9.0', ':distribution:bwc:bugfix'), - (v('9.1.0')): new UnreleasedVersionInfo(v('9.1.0'), 'main', ':distribution') + (v('7.17.1')): new UnreleasedVersionInfo(v('7.17.1'), '7.17', ':distribution:bwc:maintenance'), + (v('8.15.2')): new UnreleasedVersionInfo(v('8.15.2'), '8.15', ':distribution'), ] - bwc.wireCompatible == [v('8.17.0'), v('8.17.1'), v('9.0.0'), v('9.0.1'), v('9.1.0')] - bwc.indexCompatible == [v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2'), v('8.16.0'), v('8.16.1'), v('8.17.0'), v('8.17.1'), v('9.0.0'), v('9.0.1'), v('9.1.0')] + bwc.wireCompatible == [v('7.17.0'), v('7.17.1'), v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2')] + bwc.indexCompatible == [v('7.16.3'), v('7.17.0'), v('7.17.1'), v('8.14.0'), v('8.14.1'), v('8.14.2'), v('8.15.0'), v('8.15.1'), v('8.15.2')] } private void addVersion(String elasticsearch, String lucene) { diff --git a/build-tools-internal/src/test/java/org/elasticsearch/gradle/AbstractDistributionDownloadPluginTests.java b/build-tools-internal/src/test/java/org/elasticsearch/gradle/AbstractDistributionDownloadPluginTests.java index 639dec280ae9a..7512fa20814c6 100644 --- a/build-tools-internal/src/test/java/org/elasticsearch/gradle/AbstractDistributionDownloadPluginTests.java +++ b/build-tools-internal/src/test/java/org/elasticsearch/gradle/AbstractDistributionDownloadPluginTests.java @@ -16,6 +16,7 @@ import java.io.File; import java.util.Arrays; +import java.util.List; public class AbstractDistributionDownloadPluginTests { protected static Project rootProject; @@ -28,22 +29,27 @@ public class AbstractDistributionDownloadPluginTests { protected static final Version BWC_STAGED_VERSION = Version.fromString("1.0.0"); protected static final Version BWC_BUGFIX_VERSION = Version.fromString("1.0.1"); protected static final Version BWC_MAINTENANCE_VERSION = Version.fromString("0.90.1"); + protected static final List DEVELOPMENT_BRANCHES = Arrays.asList("main", "1.1", "1.0", "0.90"); protected static final BwcVersions BWC_MINOR = new BwcVersions( BWC_MAJOR_VERSION, - Arrays.asList(BWC_BUGFIX_VERSION, BWC_MINOR_VERSION, BWC_MAJOR_VERSION) + Arrays.asList(BWC_BUGFIX_VERSION, BWC_MINOR_VERSION, BWC_MAJOR_VERSION), + DEVELOPMENT_BRANCHES ); protected static final BwcVersions BWC_STAGED = new BwcVersions( BWC_MAJOR_VERSION, - Arrays.asList(BWC_MAINTENANCE_VERSION, BWC_STAGED_VERSION, BWC_MINOR_VERSION, BWC_MAJOR_VERSION) + Arrays.asList(BWC_MAINTENANCE_VERSION, BWC_STAGED_VERSION, BWC_MINOR_VERSION, BWC_MAJOR_VERSION), + DEVELOPMENT_BRANCHES ); protected static final BwcVersions BWC_BUGFIX = new BwcVersions( BWC_MAJOR_VERSION, - Arrays.asList(BWC_BUGFIX_VERSION, BWC_MINOR_VERSION, BWC_MAJOR_VERSION) + Arrays.asList(BWC_BUGFIX_VERSION, BWC_MINOR_VERSION, BWC_MAJOR_VERSION), + DEVELOPMENT_BRANCHES ); protected static final BwcVersions BWC_MAINTENANCE = new BwcVersions( BWC_MINOR_VERSION, - Arrays.asList(BWC_MAINTENANCE_VERSION, BWC_BUGFIX_VERSION, BWC_MINOR_VERSION) + Arrays.asList(BWC_MAINTENANCE_VERSION, BWC_BUGFIX_VERSION, BWC_MINOR_VERSION), + DEVELOPMENT_BRANCHES ); protected static String projectName(String base, boolean bundledJdk) { diff --git a/build-tools-internal/version.properties b/build-tools-internal/version.properties index 29c5bc16a8c4a..aaf654a37dd22 100644 --- a/build-tools-internal/version.properties +++ b/build-tools-internal/version.properties @@ -17,6 +17,8 @@ jna = 5.12.1 netty = 4.1.115.Final commons_lang3 = 3.9 google_oauth_client = 1.34.1 +awsv1sdk = 1.12.270 +awsv2sdk = 2.28.13 antlr4 = 4.13.1 # bouncy castle version for non-fips. fips jars use a different version diff --git a/build-tools/src/testFixtures/groovy/org/elasticsearch/gradle/fixtures/AbstractGradleFuncTest.groovy b/build-tools/src/testFixtures/groovy/org/elasticsearch/gradle/fixtures/AbstractGradleFuncTest.groovy index f3f8e4703eba2..07214b5fbf845 100644 --- a/build-tools/src/testFixtures/groovy/org/elasticsearch/gradle/fixtures/AbstractGradleFuncTest.groovy +++ b/build-tools/src/testFixtures/groovy/org/elasticsearch/gradle/fixtures/AbstractGradleFuncTest.groovy @@ -156,12 +156,12 @@ abstract class AbstractGradleFuncTest extends Specification { File internalBuild( List extraPlugins = [], - String bugfix = "7.15.2", - String bugfixLucene = "8.9.0", - String staged = "7.16.0", - String stagedLucene = "8.10.0", - String minor = "8.0.0", - String minorLucene = "9.0.0" + String maintenance = "7.16.10", + String bugfix2 = "8.1.3", + String bugfix = "8.2.1", + String staged = "8.3.0", + String minor = "8.4.0", + String current = "9.0.0" ) { buildFile << """plugins { id 'elasticsearch.global-build-info' @@ -172,15 +172,17 @@ abstract class AbstractGradleFuncTest extends Specification { import org.elasticsearch.gradle.internal.BwcVersions import org.elasticsearch.gradle.Version - Version currentVersion = Version.fromString("8.1.0") + Version currentVersion = Version.fromString("${current}") def versionList = [ + Version.fromString("$maintenance"), + Version.fromString("$bugfix2"), Version.fromString("$bugfix"), Version.fromString("$staged"), Version.fromString("$minor"), currentVersion ] - BwcVersions versions = new BwcVersions(currentVersion, versionList) + BwcVersions versions = new BwcVersions(currentVersion, versionList, ['main', '8.x', '8.3', '8.2', '8.1', '7.16']) buildParams.getBwcVersionsProperty().set(versions) """ } diff --git a/distribution/bwc/bugfix2/build.gradle b/distribution/bwc/bugfix2/build.gradle new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/distribution/src/config/log4j2.properties b/distribution/src/config/log4j2.properties index 36b5b03d9a110..bde4d9d17fc17 100644 --- a/distribution/src/config/log4j2.properties +++ b/distribution/src/config/log4j2.properties @@ -63,7 +63,7 @@ appender.deprecation_rolling.name = deprecation_rolling appender.deprecation_rolling.fileName = ${sys:es.logs.base_path}${sys:file.separator}${sys:es.logs.cluster_name}_deprecation.json appender.deprecation_rolling.layout.type = ECSJsonLayout # Intentionally follows a different pattern to above -appender.deprecation_rolling.layout.dataset = deprecation.elasticsearch +appender.deprecation_rolling.layout.dataset = elasticsearch.deprecation appender.deprecation_rolling.filter.rate_limit.type = RateLimitingFilter appender.deprecation_rolling.filePattern = ${sys:es.logs.base_path}${sys:file.separator}${sys:es.logs.cluster_name}_deprecation-%i.json.gz diff --git a/distribution/tools/plugin-cli/build.gradle b/distribution/tools/plugin-cli/build.gradle index 57750f2162a71..dc2bcd96b8d9f 100644 --- a/distribution/tools/plugin-cli/build.gradle +++ b/distribution/tools/plugin-cli/build.gradle @@ -25,8 +25,8 @@ dependencies { implementation project(":libs:plugin-api") implementation project(":libs:plugin-scanner") // TODO: asm is picked up from the plugin scanner, we should consolidate so it is not defined twice - implementation 'org.ow2.asm:asm:9.7' - implementation 'org.ow2.asm:asm-tree:9.7' + implementation 'org.ow2.asm:asm:9.7.1' + implementation 'org.ow2.asm:asm-tree:9.7.1' api "org.bouncycastle:bcpg-fips:1.0.7.1" api "org.bouncycastle:bc-fips:1.0.2.5" diff --git a/docs/changelog/111104.yaml b/docs/changelog/111104.yaml new file mode 100644 index 0000000000000..a7dffdd0be221 --- /dev/null +++ b/docs/changelog/111104.yaml @@ -0,0 +1,6 @@ +pr: 111104 +summary: "ESQL: Enable async get to support formatting" +area: ES|QL +type: feature +issues: + - 110926 diff --git a/docs/changelog/114445.yaml b/docs/changelog/114445.yaml new file mode 100644 index 0000000000000..afbc080d1e0b9 --- /dev/null +++ b/docs/changelog/114445.yaml @@ -0,0 +1,6 @@ +pr: 114445 +summary: Wrap jackson exception on malformed json string +area: Infra/Core +type: bug +issues: + - 114142 diff --git a/docs/changelog/117359.yaml b/docs/changelog/117359.yaml new file mode 100644 index 0000000000000..87d2d828ace54 --- /dev/null +++ b/docs/changelog/117359.yaml @@ -0,0 +1,5 @@ +pr: 117359 +summary: Term query for ES|QL +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/changelog/117589.yaml b/docs/changelog/117589.yaml new file mode 100644 index 0000000000000..e6880fd9477b5 --- /dev/null +++ b/docs/changelog/117589.yaml @@ -0,0 +1,5 @@ +pr: 117589 +summary: "Add Inference Unified API for chat completions for OpenAI" +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/117657.yaml b/docs/changelog/117657.yaml new file mode 100644 index 0000000000000..0a72e9dabe9e8 --- /dev/null +++ b/docs/changelog/117657.yaml @@ -0,0 +1,5 @@ +pr: 117657 +summary: Ignore cancellation exceptions +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/117701.yaml b/docs/changelog/117701.yaml new file mode 100644 index 0000000000000..5a72bdeb143e6 --- /dev/null +++ b/docs/changelog/117701.yaml @@ -0,0 +1,6 @@ +pr: 117701 +summary: Watcher history index has too many indexed fields - +area: Watcher +type: bug +issues: + - 71479 diff --git a/docs/changelog/117792.yaml b/docs/changelog/117792.yaml new file mode 100644 index 0000000000000..2d7ddda1ace40 --- /dev/null +++ b/docs/changelog/117792.yaml @@ -0,0 +1,6 @@ +pr: 117792 +summary: Address mapping and compute engine runtime field issues +area: Mapping +type: bug +issues: + - 117644 diff --git a/docs/changelog/117898.yaml b/docs/changelog/117898.yaml new file mode 100644 index 0000000000000..c60061abc49ff --- /dev/null +++ b/docs/changelog/117898.yaml @@ -0,0 +1,5 @@ +pr: 117898 +summary: Limit size of query +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/117914.yaml b/docs/changelog/117914.yaml new file mode 100644 index 0000000000000..da58ed7bb04b7 --- /dev/null +++ b/docs/changelog/117914.yaml @@ -0,0 +1,5 @@ +pr: 117914 +summary: Fix for propagating filters from compound to inner retrievers +area: Ranking +type: bug +issues: [] diff --git a/docs/changelog/117917.yaml b/docs/changelog/117917.yaml new file mode 100644 index 0000000000000..b6dc90f6b903d --- /dev/null +++ b/docs/changelog/117917.yaml @@ -0,0 +1,5 @@ +pr: 117917 +summary: Add option to store `sparse_vector` outside `_source` +area: Mapping +type: feature +issues: [] diff --git a/docs/changelog/117920.yaml b/docs/changelog/117920.yaml new file mode 100644 index 0000000000000..1bfddabd4462d --- /dev/null +++ b/docs/changelog/117920.yaml @@ -0,0 +1,6 @@ +pr: 117920 +summary: Wait for the worker service to shutdown before closing task processor +area: Machine Learning +type: bug +issues: + - 117563 diff --git a/docs/changelog/117933.yaml b/docs/changelog/117933.yaml new file mode 100644 index 0000000000000..92ae31afa30dd --- /dev/null +++ b/docs/changelog/117933.yaml @@ -0,0 +1,18 @@ +pr: 117933 +summary: Change `deprecation.elasticsearch` keyword to `elasticsearch.deprecation` +area: Infra/Logging +type: bug +issues: + - 83251 +breaking: + title: Deprecation logging value change for "data_stream.dataset" and "event.dataset" + area: Logging + details: |- + This change modifies the "data_stream.dataset" and "event.dataset" value for deprecation logging + to use the value `elasticsearch.deprecation` instead of `deprecation.elasticsearch`. This is now + consistent with other values where the name of the service is the first part of the key. + impact: |- + If you are directly consuming deprecation logs for "data_stream.dataset" and "event.dataset" and filtering on + this value, you will need to update your filters to use `elasticsearch.deprecation` instead of + `deprecation.elasticsearch`. + notable: false diff --git a/docs/changelog/117953.yaml b/docs/changelog/117953.yaml new file mode 100644 index 0000000000000..62f0218b1cdc7 --- /dev/null +++ b/docs/changelog/117953.yaml @@ -0,0 +1,5 @@ +pr: 117953 +summary: Acquire stats searcher for data stream stats +area: Data streams +type: bug +issues: [] diff --git a/docs/changelog/117963.yaml b/docs/changelog/117963.yaml new file mode 100644 index 0000000000000..4a50dc175786b --- /dev/null +++ b/docs/changelog/117963.yaml @@ -0,0 +1,5 @@ +pr: 117963 +summary: '`SearchServiceTests.testParseSourceValidation` failure' +area: Search +type: bug +issues: [] diff --git a/docs/changelog/118027.yaml b/docs/changelog/118027.yaml new file mode 100644 index 0000000000000..161c156b56a65 --- /dev/null +++ b/docs/changelog/118027.yaml @@ -0,0 +1,6 @@ +pr: 118027 +summary: Esql compare nanos and millis +area: ES|QL +type: enhancement +issues: + - 116281 diff --git a/docs/changelog/118064.yaml b/docs/changelog/118064.yaml new file mode 100644 index 0000000000000..7d12f365bf142 --- /dev/null +++ b/docs/changelog/118064.yaml @@ -0,0 +1,5 @@ +pr: 118064 +summary: Add Highlighter for Semantic Text Fields +area: Highlighting +type: feature +issues: [] diff --git a/docs/changelog/118094.yaml b/docs/changelog/118094.yaml new file mode 100644 index 0000000000000..a8866543fa7d2 --- /dev/null +++ b/docs/changelog/118094.yaml @@ -0,0 +1,5 @@ +pr: 118094 +summary: Update ASM 9.7 -> 9.7.1 to support JDK 24 +area: Infra/Core +type: upgrade +issues: [] diff --git a/docs/internal/DistributedArchitectureGuide.md b/docs/internal/DistributedArchitectureGuide.md index 793d38e3d73b3..11a2c860eb326 100644 --- a/docs/internal/DistributedArchitectureGuide.md +++ b/docs/internal/DistributedArchitectureGuide.md @@ -386,6 +386,9 @@ The tasks infrastructure is used to track currently executing operations in the Each individual task is local to a node, but can be related to other tasks, on the same node or other nodes, via a parent-child relationship. +> [!NOTE] +> The Task management API is experimental/beta, its status and outstanding issues can be tracked [here](https://github.com/elastic/elasticsearch/issues/51628). + ### Task tracking and registration Tasks are tracked in-memory on each node in the node's [TaskManager], new tasks are registered via one of the [TaskManager#register] methods. diff --git a/docs/plugins/analysis-kuromoji.asciidoc b/docs/plugins/analysis-kuromoji.asciidoc index 0a167bf3f0240..217d88f361223 100644 --- a/docs/plugins/analysis-kuromoji.asciidoc +++ b/docs/plugins/analysis-kuromoji.asciidoc @@ -750,3 +750,39 @@ Which results in: ] } -------------------------------------------------- + +[[analysis-kuromoji-completion]] +==== `kuromoji_completion` token filter + +The `kuromoji_completion` token filter adds Japanese romanized tokens to the term attributes along with the original tokens (surface forms). + +[source,console] +-------------------------------------------------- +GET _analyze +{ + "analyzer": "kuromoji_completion", + "text": "寿司" <1> +} +-------------------------------------------------- + +<1> Returns `寿司`, `susi` (Kunrei-shiki) and `sushi` (Hepburn-shiki). + +The `kuromoji_completion` token filter accepts the following settings: + +`mode`:: ++ +-- + +The tokenization mode determines how the tokenizer handles compound and +unknown words. It can be set to: + +`index`:: + + Simple romanization. Expected to be used when indexing. + +`query`:: + + Input Method aware romanization. Expected to be used when querying. + +Defaults to `index`. +-- diff --git a/docs/reference/connector/docs/connectors-box.asciidoc b/docs/reference/connector/docs/connectors-box.asciidoc index 07e4308d67c20..3e95f15d16ccd 100644 --- a/docs/reference/connector/docs/connectors-box.asciidoc +++ b/docs/reference/connector/docs/connectors-box.asciidoc @@ -54,7 +54,7 @@ For additional operations, see <>. ====== Box Free Account [discrete#es-connectors-box-create-oauth-custom-app] -======= Create Box User Authentication (OAuth 2.0) Custom App +*Create Box User Authentication (OAuth 2.0) Custom App* You'll need to create an OAuth app in the Box developer console by following these steps: @@ -64,7 +64,7 @@ You'll need to create an OAuth app in the Box developer console by following the 4. Once the app is created, *Client ID* and *Client secret* values are available in the configuration tab. Keep these handy. [discrete#es-connectors-box-connector-generate-a-refresh-token] -======= Generate a refresh Token +*Generate a refresh Token* To generate a refresh token, follow these steps: @@ -97,7 +97,7 @@ Save the refresh token from the response. You'll need this for the connector con ====== Box Enterprise Account [discrete#es-connectors-box-connector-create-box-server-authentication-client-credentials-grant-custom-app] -======= Create Box Server Authentication (Client Credentials Grant) Custom App +*Create Box Server Authentication (Client Credentials Grant) Custom App* 1. Register a new app in the https://app.box.com/developers/console[Box dev console] with custom App and select Server Authentication (Client Credentials Grant). 2. Check following permissions: @@ -224,7 +224,7 @@ For additional operations, see <>. ====== Box Free Account [discrete#es-connectors-box-client-create-oauth-custom-app] -======= Create Box User Authentication (OAuth 2.0) Custom App +*Create Box User Authentication (OAuth 2.0) Custom App* You'll need to create an OAuth app in the Box developer console by following these steps: @@ -234,7 +234,7 @@ You'll need to create an OAuth app in the Box developer console by following the 4. Once the app is created, *Client ID* and *Client secret* values are available in the configuration tab. Keep these handy. [discrete#es-connectors-box-client-connector-generate-a-refresh-token] -======= Generate a refresh Token +*Generate a refresh Token* To generate a refresh token, follow these steps: @@ -267,7 +267,7 @@ Save the refresh token from the response. You'll need this for the connector con ====== Box Enterprise Account [discrete#es-connectors-box-client-connector-create-box-server-authentication-client-credentials-grant-custom-app] -======= Create Box Server Authentication (Client Credentials Grant) Custom App +*Create Box Server Authentication (Client Credentials Grant) Custom App* 1. Register a new app in the https://app.box.com/developers/console[Box dev console] with custom App and select Server Authentication (Client Credentials Grant). 2. Check following permissions: diff --git a/docs/reference/connector/docs/connectors-content-extraction.asciidoc b/docs/reference/connector/docs/connectors-content-extraction.asciidoc index 5d2a9550a7c3c..a87d38c9bf531 100644 --- a/docs/reference/connector/docs/connectors-content-extraction.asciidoc +++ b/docs/reference/connector/docs/connectors-content-extraction.asciidoc @@ -183,7 +183,7 @@ Be aware that the self-managed connector will download files with randomized fil For that reason, we recommend using a dedicated directory for self-hosted extraction. [discrete#es-connectors-content-extraction-data-extraction-service-file-pointers-configuration-example] -======= Example +*Example* 1. For this example, we will be using `/app/files` as both our local directory and our container directory. When you run the extraction service docker container, you can mount the directory as a volume using the command-line option `-v /app/files:/app/files`. @@ -228,7 +228,7 @@ When using self-hosted extraction from a dockerized self-managed connector, ther * The self-managed connector and the extraction service will also need to share a volume. You can decide what directory inside these docker containers the volume will be mounted onto, but the directory must be the same for both docker containers. [discrete#es-connectors-content-extraction-data-extraction-service-file-pointers-configuration-dockerized-example] -======= Example +*Example* 1. First, set up a volume for the two docker containers to share. This will be where files are downloaded into and then extracted from. diff --git a/docs/reference/connector/docs/connectors-dropbox.asciidoc b/docs/reference/connector/docs/connectors-dropbox.asciidoc index 1f80a0ab4e952..295b7e2936625 100644 --- a/docs/reference/connector/docs/connectors-dropbox.asciidoc +++ b/docs/reference/connector/docs/connectors-dropbox.asciidoc @@ -190,7 +190,7 @@ When both are provided, priority is given to `file_categories`. We have some examples below for illustration. [discrete#es-connectors-dropbox-sync-rules-advanced-example-1] -======= Example: Query only +*Example: Query only* [source,js] ---- @@ -206,7 +206,7 @@ We have some examples below for illustration. // NOTCONSOLE [discrete#es-connectors-dropbox-sync-rules-advanced-example-2] -======= Example: Query with file extension filter +*Example: Query with file extension filter* [source,js] ---- @@ -225,7 +225,7 @@ We have some examples below for illustration. // NOTCONSOLE [discrete#es-connectors-dropbox-sync-rules-advanced-example-3] -======= Example: Query with file category filter +*Example: Query with file category filter* [source,js] ---- @@ -248,7 +248,7 @@ We have some examples below for illustration. // NOTCONSOLE [discrete#es-connectors-dropbox-sync-rules-advanced-limitations] -======= Limitations +*Limitations* * Content extraction is not supported for Dropbox *Paper* files when advanced sync rules are enabled. @@ -474,7 +474,7 @@ When both are provided, priority is given to `file_categories`. We have some examples below for illustration. [discrete#es-connectors-dropbox-client-sync-rules-advanced-example-1] -======= Example: Query only +*Example: Query only* [source,js] ---- @@ -490,7 +490,7 @@ We have some examples below for illustration. // NOTCONSOLE [discrete#es-connectors-dropbox-client-sync-rules-advanced-example-2] -======= Example: Query with file extension filter +*Example: Query with file extension filter* [source,js] ---- @@ -509,7 +509,7 @@ We have some examples below for illustration. // NOTCONSOLE [discrete#es-connectors-dropbox-client-sync-rules-advanced-example-3] -======= Example: Query with file category filter +*Example: Query with file category filter* [source,js] ---- @@ -532,7 +532,7 @@ We have some examples below for illustration. // NOTCONSOLE [discrete#es-connectors-dropbox-client-sync-rules-advanced-limitations] -======= Limitations +*Limitations* * Content extraction is not supported for Dropbox *Paper* files when advanced sync rules are enabled. diff --git a/docs/reference/connector/docs/connectors-github.asciidoc b/docs/reference/connector/docs/connectors-github.asciidoc index aa683e4bb0829..df577d83e8121 100644 --- a/docs/reference/connector/docs/connectors-github.asciidoc +++ b/docs/reference/connector/docs/connectors-github.asciidoc @@ -210,7 +210,7 @@ Advanced sync rules are defined through a source-specific DSL JSON snippet. The following sections provide examples of advanced sync rules for this connector. [discrete#es-connectors-github-sync-rules-advanced-branch] -======= Indexing document and files based on branch name configured via branch key +*Indexing document and files based on branch name configured via branch key* [source,js] ---- @@ -226,7 +226,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-github-sync-rules-advanced-issue-key] -======= Indexing document based on issue query related to bugs via issue key +*Indexing document based on issue query related to bugs via issue key* [source,js] ---- @@ -242,7 +242,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-github-sync-rules-advanced-pr-key] -======= Indexing document based on PR query related to open PR's via PR key +*Indexing document based on PR query related to open PR's via PR key* [source,js] ---- @@ -258,7 +258,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-github-sync-rules-advanced-issue-query-branch-name] -======= Indexing document and files based on queries and branch name +*Indexing document and files based on queries and branch name* [source,js] ---- @@ -283,7 +283,7 @@ Check the Elasticsearch index for the actual document count. ==== [discrete#es-connectors-github-sync-rules-advanced-overlapping] -======= Advanced rules for overlapping +*Advanced rules for overlapping* [source,js] ---- @@ -550,7 +550,7 @@ Advanced sync rules are defined through a source-specific DSL JSON snippet. The following sections provide examples of advanced sync rules for this connector. [discrete#es-connectors-github-client-sync-rules-advanced-branch] -======= Indexing document and files based on branch name configured via branch key +*Indexing document and files based on branch name configured via branch key* [source,js] ---- @@ -566,7 +566,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-github-client-sync-rules-advanced-issue-key] -======= Indexing document based on issue query related to bugs via issue key +*Indexing document based on issue query related to bugs via issue key* [source,js] ---- @@ -582,7 +582,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-github-client-sync-rules-advanced-pr-key] -======= Indexing document based on PR query related to open PR's via PR key +*Indexing document based on PR query related to open PR's via PR key* [source,js] ---- @@ -598,7 +598,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-github-client-sync-rules-advanced-issue-query-branch-name] -======= Indexing document and files based on queries and branch name +*Indexing document and files based on queries and branch name* [source,js] ---- @@ -623,7 +623,7 @@ Check the Elasticsearch index for the actual document count. ==== [discrete#es-connectors-github-client-sync-rules-advanced-overlapping] -======= Advanced rules for overlapping +*Advanced rules for overlapping* [source,js] ---- diff --git a/docs/reference/connector/docs/connectors-ms-sql.asciidoc b/docs/reference/connector/docs/connectors-ms-sql.asciidoc index 47fb282b16877..d706af8ca8043 100644 --- a/docs/reference/connector/docs/connectors-ms-sql.asciidoc +++ b/docs/reference/connector/docs/connectors-ms-sql.asciidoc @@ -196,7 +196,7 @@ Here are a few examples of advanced sync rules for this connector. ==== [discrete#es-connectors-ms-sql-sync-rules-advanced-queries] -======= Example: Two queries +*Example: Two queries* These rules fetch all records from both the `employee` and `customer` tables. The data from these tables will be synced separately to Elasticsearch. @@ -220,7 +220,7 @@ These rules fetch all records from both the `employee` and `customer` tables. Th // NOTCONSOLE [discrete#es-connectors-ms-sql-sync-rules-example-one-where] -======= Example: One WHERE query +*Example: One WHERE query* This rule fetches only the records from the `employee` table where the `emp_id` is greater than 5. Only these filtered records will be synced to Elasticsearch. @@ -236,7 +236,7 @@ This rule fetches only the records from the `employee` table where the `emp_id` // NOTCONSOLE [discrete#es-connectors-ms-sql-sync-rules-example-one-join] -======= Example: One JOIN query +*Example: One JOIN query* This rule fetches records by performing an INNER JOIN between the `employee` and `customer` tables on the condition that the `emp_id` in `employee` matches the `c_id` in `customer`. The result of this combined data will be synced to Elasticsearch. @@ -484,7 +484,7 @@ Here are a few examples of advanced sync rules for this connector. ==== [discrete#es-connectors-ms-sql-client-sync-rules-advanced-queries] -======= Example: Two queries +*Example: Two queries* These rules fetch all records from both the `employee` and `customer` tables. The data from these tables will be synced separately to Elasticsearch. @@ -508,7 +508,7 @@ These rules fetch all records from both the `employee` and `customer` tables. Th // NOTCONSOLE [discrete#es-connectors-ms-sql-client-sync-rules-example-one-where] -======= Example: One WHERE query +*Example: One WHERE query* This rule fetches only the records from the `employee` table where the `emp_id` is greater than 5. Only these filtered records will be synced to Elasticsearch. @@ -524,7 +524,7 @@ This rule fetches only the records from the `employee` table where the `emp_id` // NOTCONSOLE [discrete#es-connectors-ms-sql-client-sync-rules-example-one-join] -======= Example: One JOIN query +*Example: One JOIN query* This rule fetches records by performing an INNER JOIN between the `employee` and `customer` tables on the condition that the `emp_id` in `employee` matches the `c_id` in `customer`. The result of this combined data will be synced to Elasticsearch. diff --git a/docs/reference/connector/docs/connectors-network-drive.asciidoc b/docs/reference/connector/docs/connectors-network-drive.asciidoc index 91c9d3b28c385..909e3440c9f02 100644 --- a/docs/reference/connector/docs/connectors-network-drive.asciidoc +++ b/docs/reference/connector/docs/connectors-network-drive.asciidoc @@ -174,7 +174,7 @@ Advanced sync rules for this connector use *glob patterns*. The following sections provide examples of advanced sync rules for this connector. [discrete#es-connectors-network-drive-indexing-files-and-folders-recursively-within-folders] -======= Indexing files and folders recursively within folders +*Indexing files and folders recursively within folders* [source,js] ---- @@ -190,7 +190,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-network-drive-indexing-files-and-folders-directly-inside-folder] -======= Indexing files and folders directly inside folder +*Indexing files and folders directly inside folder* [source,js] ---- @@ -203,7 +203,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-network-drive-indexing-files-and-folders-directly-inside-a-set-of-folders] -======= Indexing files and folders directly inside a set of folders +*Indexing files and folders directly inside a set of folders* [source,js] ---- @@ -216,7 +216,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-network-drive-excluding-files-and-folders-that-match-a-pattern] -======= Excluding files and folders that match a pattern +*Excluding files and folders that match a pattern* [source,js] ---- @@ -432,7 +432,7 @@ Advanced sync rules for this connector use *glob patterns*. The following sections provide examples of advanced sync rules for this connector. [discrete#es-connectors-network-drive-client-indexing-files-and-folders-recursively-within-folders] -======= Indexing files and folders recursively within folders +*Indexing files and folders recursively within folders* [source,js] ---- @@ -448,7 +448,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-network-drive-client-indexing-files-and-folders-directly-inside-folder] -======= Indexing files and folders directly inside folder +*Indexing files and folders directly inside folder* [source,js] ---- @@ -461,7 +461,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-network-drive-client-indexing-files-and-folders-directly-inside-a-set-of-folders] -======= Indexing files and folders directly inside a set of folders +*Indexing files and folders directly inside a set of folders* [source,js] ---- @@ -474,7 +474,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-network-drive-client-excluding-files-and-folders-that-match-a-pattern] -======= Excluding files and folders that match a pattern +*Excluding files and folders that match a pattern* [source,js] ---- diff --git a/docs/reference/connector/docs/connectors-notion.asciidoc b/docs/reference/connector/docs/connectors-notion.asciidoc index 2d7a71bff20de..7c08c5d81e032 100644 --- a/docs/reference/connector/docs/connectors-notion.asciidoc +++ b/docs/reference/connector/docs/connectors-notion.asciidoc @@ -140,7 +140,7 @@ Advanced sync rules for Notion take the following parameters: ====== Examples [discrete] -======= Example 1 +*Example 1* Indexing every page where the title contains `Demo Page`: @@ -160,7 +160,7 @@ Indexing every page where the title contains `Demo Page`: // NOTCONSOLE [discrete] -======= Example 2 +*Example 2* Indexing every database where the title contains `Demo Database`: @@ -180,7 +180,7 @@ Indexing every database where the title contains `Demo Database`: // NOTCONSOLE [discrete] -======= Example 3 +*Example 3* Indexing every database where the title contains `Demo Database` and every page where the title contains `Demo Page`: @@ -206,7 +206,7 @@ Indexing every database where the title contains `Demo Database` and every page // NOTCONSOLE [discrete] -======= Example 4 +*Example 4* Indexing all pages in the workspace: @@ -226,7 +226,7 @@ Indexing all pages in the workspace: // NOTCONSOLE [discrete] -======= Example 5 +*Example 5* Indexing all the pages and databases connected to the workspace: @@ -243,7 +243,7 @@ Indexing all the pages and databases connected to the workspace: // NOTCONSOLE [discrete] -======= Example 6 +*Example 6* Indexing all the rows of a database where the record is `true` for the column `Task completed` and its property(datatype) is a checkbox: @@ -266,7 +266,7 @@ Indexing all the rows of a database where the record is `true` for the column `T // NOTCONSOLE [discrete] -======= Example 7 +*Example 7* Indexing all rows of a specific database: @@ -283,7 +283,7 @@ Indexing all rows of a specific database: // NOTCONSOLE [discrete] -======= Example 8 +*Example 8* Indexing all blocks defined in `searches` and `database_query_filters`: @@ -498,7 +498,7 @@ Advanced sync rules for Notion take the following parameters: ====== Examples [discrete] -======= Example 1 +*Example 1* Indexing every page where the title contains `Demo Page`: @@ -518,7 +518,7 @@ Indexing every page where the title contains `Demo Page`: // NOTCONSOLE [discrete] -======= Example 2 +*Example 2* Indexing every database where the title contains `Demo Database`: @@ -538,7 +538,7 @@ Indexing every database where the title contains `Demo Database`: // NOTCONSOLE [discrete] -======= Example 3 +*Example 3* Indexing every database where the title contains `Demo Database` and every page where the title contains `Demo Page`: @@ -564,7 +564,7 @@ Indexing every database where the title contains `Demo Database` and every page // NOTCONSOLE [discrete] -======= Example 4 +*Example 4* Indexing all pages in the workspace: @@ -584,7 +584,7 @@ Indexing all pages in the workspace: // NOTCONSOLE [discrete] -======= Example 5 +*Example 5* Indexing all the pages and databases connected to the workspace: @@ -601,7 +601,7 @@ Indexing all the pages and databases connected to the workspace: // NOTCONSOLE [discrete] -======= Example 6 +*Example 6* Indexing all the rows of a database where the record is `true` for the column `Task completed` and its property(datatype) is a checkbox: @@ -624,7 +624,7 @@ Indexing all the rows of a database where the record is `true` for the column `T // NOTCONSOLE [discrete] -======= Example 7 +*Example 7* Indexing all rows of a specific database: @@ -641,7 +641,7 @@ Indexing all rows of a specific database: // NOTCONSOLE [discrete] -======= Example 8 +*Example 8* Indexing all blocks defined in `searches` and `database_query_filters`: diff --git a/docs/reference/connector/docs/connectors-onedrive.asciidoc b/docs/reference/connector/docs/connectors-onedrive.asciidoc index 7d1a21aeb78db..44ac96e2ad99d 100644 --- a/docs/reference/connector/docs/connectors-onedrive.asciidoc +++ b/docs/reference/connector/docs/connectors-onedrive.asciidoc @@ -160,7 +160,7 @@ A <> is required for advanced sync rul Here are a few examples of advanced sync rules for this connector. [discrete#es-connectors-onedrive-sync-rules-advanced-examples-1] -======= Example 1 +*Example 1* This rule skips indexing for files with `.xlsx` and `.docx` extensions. All other files and folders will be indexed. @@ -176,7 +176,7 @@ All other files and folders will be indexed. // NOTCONSOLE [discrete#es-connectors-onedrive-sync-rules-advanced-examples-2] -======= Example 2 +*Example 2* This rule focuses on indexing files and folders owned by `user1-domain@onmicrosoft.com` and `user2-domain@onmicrosoft.com` but excludes files with `.py` extension. @@ -192,7 +192,7 @@ This rule focuses on indexing files and folders owned by `user1-domain@onmicroso // NOTCONSOLE [discrete#es-connectors-onedrive-sync-rules-advanced-examples-3] -======= Example 3 +*Example 3* This rule indexes only the files and folders directly inside the root folder, excluding any `.md` files. @@ -208,7 +208,7 @@ This rule indexes only the files and folders directly inside the root folder, ex // NOTCONSOLE [discrete#es-connectors-onedrive-sync-rules-advanced-examples-4] -======= Example 4 +*Example 4* This rule indexes files and folders owned by `user1-domain@onmicrosoft.com` and `user3-domain@onmicrosoft.com` that are directly inside the `abc` folder, which is a subfolder of any folder under the `hello` directory in the root. Files with extensions `.pdf` and `.py` are excluded. @@ -225,7 +225,7 @@ This rule indexes files and folders owned by `user1-domain@onmicrosoft.com` and // NOTCONSOLE [discrete#es-connectors-onedrive-sync-rules-advanced-examples-5] -======= Example 5 +*Example 5* This example contains two rules. The first rule indexes all files and folders owned by `user1-domain@onmicrosoft.com` and `user2-domain@onmicrosoft.com`. @@ -245,7 +245,7 @@ The second rule indexes files for all other users, but skips files with a `.py` // NOTCONSOLE [discrete#es-connectors-onedrive-sync-rules-advanced-examples-6] -======= Example 6 +*Example 6* This example contains two rules. The first rule indexes all files owned by `user1-domain@onmicrosoft.com` and `user2-domain@onmicrosoft.com`, excluding `.md` files. @@ -449,7 +449,7 @@ A <> is required for advanced sync rul Here are a few examples of advanced sync rules for this connector. [discrete#es-connectors-onedrive-client-sync-rules-advanced-examples-1] -======= Example 1 +*Example 1* This rule skips indexing for files with `.xlsx` and `.docx` extensions. All other files and folders will be indexed. @@ -465,7 +465,7 @@ All other files and folders will be indexed. // NOTCONSOLE [discrete#es-connectors-onedrive-client-sync-rules-advanced-examples-2] -======= Example 2 +*Example 2* This rule focuses on indexing files and folders owned by `user1-domain@onmicrosoft.com` and `user2-domain@onmicrosoft.com` but excludes files with `.py` extension. @@ -481,7 +481,7 @@ This rule focuses on indexing files and folders owned by `user1-domain@onmicroso // NOTCONSOLE [discrete#es-connectors-onedrive-client-sync-rules-advanced-examples-3] -======= Example 3 +*Example 3* This rule indexes only the files and folders directly inside the root folder, excluding any `.md` files. @@ -497,7 +497,7 @@ This rule indexes only the files and folders directly inside the root folder, ex // NOTCONSOLE [discrete#es-connectors-onedrive-client-sync-rules-advanced-examples-4] -======= Example 4 +*Example 4* This rule indexes files and folders owned by `user1-domain@onmicrosoft.com` and `user3-domain@onmicrosoft.com` that are directly inside the `abc` folder, which is a subfolder of any folder under the `hello` directory in the root. Files with extensions `.pdf` and `.py` are excluded. @@ -514,7 +514,7 @@ This rule indexes files and folders owned by `user1-domain@onmicrosoft.com` and // NOTCONSOLE [discrete#es-connectors-onedrive-client-sync-rules-advanced-examples-5] -======= Example 5 +*Example 5* This example contains two rules. The first rule indexes all files and folders owned by `user1-domain@onmicrosoft.com` and `user2-domain@onmicrosoft.com`. @@ -534,7 +534,7 @@ The second rule indexes files for all other users, but skips files with a `.py` // NOTCONSOLE [discrete#es-connectors-onedrive-client-sync-rules-advanced-examples-6] -======= Example 6 +*Example 6* This example contains two rules. The first rule indexes all files owned by `user1-domain@onmicrosoft.com` and `user2-domain@onmicrosoft.com`, excluding `.md` files. diff --git a/docs/reference/connector/docs/connectors-postgresql.asciidoc b/docs/reference/connector/docs/connectors-postgresql.asciidoc index 1fe28f867337c..aa6cb7f29e633 100644 --- a/docs/reference/connector/docs/connectors-postgresql.asciidoc +++ b/docs/reference/connector/docs/connectors-postgresql.asciidoc @@ -188,7 +188,7 @@ Advanced sync rules are defined through a source-specific DSL JSON snippet. Here is some example data that will be used in the following examples. [discrete#connectors-postgresql-sync-rules-advanced-example-data-1] -======= `employee` table +*`employee` table* [cols="3*", options="header"] |=== @@ -199,7 +199,7 @@ Here is some example data that will be used in the following examples. |=== [discrete#connectors-postgresql-sync-rules-advanced-example-2] -======= `customer` table +*`customer` table* [cols="3*", options="header"] |=== @@ -213,7 +213,7 @@ Here is some example data that will be used in the following examples. ====== Advanced sync rules examples [discrete#connectors-postgresql-sync-rules-advanced-examples-1] -======= Multiple table queries +*Multiple table queries* [source,js] ---- @@ -235,7 +235,7 @@ Here is some example data that will be used in the following examples. // NOTCONSOLE [discrete#connectors-postgresql-sync-rules-advanced-examples-1-id-columns] -======= Multiple table queries with `id_columns` +*Multiple table queries with `id_columns`* In 8.15.0, we added a new optional `id_columns` field in our advanced sync rules for the PostgreSQL connector. Use the `id_columns` field to ingest tables which do not have a primary key. Include the names of unique fields so that the connector can use them to generate unique IDs for documents. @@ -264,7 +264,7 @@ Use the `id_columns` field to ingest tables which do not have a primary key. Inc This example uses the `id_columns` field to specify the unique fields `emp_id` and `c_id` for the `employee` and `customer` tables, respectively. [discrete#connectors-postgresql-sync-rules-advanced-examples-2] -======= Filtering data with `WHERE` clause +*Filtering data with `WHERE` clause* [source,js] ---- @@ -278,7 +278,7 @@ This example uses the `id_columns` field to specify the unique fields `emp_id` a // NOTCONSOLE [discrete#connectors-postgresql-sync-rules-advanced-examples-3] -======= `JOIN` operations +*`JOIN` operations* [source,js] ---- @@ -494,7 +494,7 @@ Advanced sync rules are defined through a source-specific DSL JSON snippet. Here is some example data that will be used in the following examples. [discrete#es-connectors-postgresql-client-sync-rules-advanced-example-data-1] -======= `employee` table +*`employee` table* [cols="3*", options="header"] |=== @@ -505,7 +505,7 @@ Here is some example data that will be used in the following examples. |=== [discrete#es-connectors-postgresql-client-sync-rules-advanced-example-2] -======= `customer` table +*`customer` table* [cols="3*", options="header"] |=== @@ -519,7 +519,7 @@ Here is some example data that will be used in the following examples. ====== Advanced sync rules examples [discrete#es-connectors-postgresql-client-sync-rules-advanced-examples-1] -======== Multiple table queries +*Multiple table queries* [source,js] ---- @@ -541,7 +541,7 @@ Here is some example data that will be used in the following examples. // NOTCONSOLE [discrete#es-connectors-postgresql-client-sync-rules-advanced-examples-1-id-columns] -======== Multiple table queries with `id_columns` +*Multiple table queries with `id_columns`* In 8.15.0, we added a new optional `id_columns` field in our advanced sync rules for the PostgreSQL connector. Use the `id_columns` field to ingest tables which do not have a primary key. Include the names of unique fields so that the connector can use them to generate unique IDs for documents. @@ -570,7 +570,7 @@ Use the `id_columns` field to ingest tables which do not have a primary key. Inc This example uses the `id_columns` field to specify the unique fields `emp_id` and `c_id` for the `employee` and `customer` tables, respectively. [discrete#es-connectors-postgresql-client-sync-rules-advanced-examples-2] -======== Filtering data with `WHERE` clause +*Filtering data with `WHERE` clause* [source,js] ---- @@ -584,7 +584,7 @@ This example uses the `id_columns` field to specify the unique fields `emp_id` a // NOTCONSOLE [discrete#es-connectors-postgresql-client-sync-rules-advanced-examples-3] -======== `JOIN` operations +*`JOIN` operations* [source,js] ---- diff --git a/docs/reference/connector/docs/connectors-s3.asciidoc b/docs/reference/connector/docs/connectors-s3.asciidoc index b4d08d3884631..90c070f7b8044 100644 --- a/docs/reference/connector/docs/connectors-s3.asciidoc +++ b/docs/reference/connector/docs/connectors-s3.asciidoc @@ -118,7 +118,7 @@ The connector will fetch file and folder data that matches the string. Defaults to `""` (syncs all bucket objects). [discrete#es-connectors-s3-sync-rules-advanced-examples] -======= Advanced sync rules examples +*Advanced sync rules examples* *Fetching files and folders recursively by prefix* @@ -336,7 +336,7 @@ The connector will fetch file and folder data that matches the string. Defaults to `""` (syncs all bucket objects). [discrete#es-connectors-s3-client-sync-rules-advanced-examples] -======= Advanced sync rules examples +*Advanced sync rules examples* *Fetching files and folders recursively by prefix* diff --git a/docs/reference/connector/docs/connectors-salesforce.asciidoc b/docs/reference/connector/docs/connectors-salesforce.asciidoc index 3676f7663089c..c640751de92c0 100644 --- a/docs/reference/connector/docs/connectors-salesforce.asciidoc +++ b/docs/reference/connector/docs/connectors-salesforce.asciidoc @@ -227,7 +227,7 @@ They take the following parameters: Allowed values are *SOQL* and *SOSL*. [discrete#es-connectors-salesforce-sync-rules-advanced-fetch-query-language] -======= Fetch documents based on the query and language specified +*Fetch documents based on the query and language specified* **Example**: Fetch documents using SOQL query @@ -256,7 +256,7 @@ Allowed values are *SOQL* and *SOSL*. // NOTCONSOLE [discrete#es-connectors-salesforce-sync-rules-advanced-fetch-objects] -======= Fetch standard and custom objects using SOQL and SOSL queries +*Fetch standard and custom objects using SOQL and SOSL queries* **Example**: Fetch documents for standard objects via SOQL and SOSL query. @@ -293,7 +293,7 @@ Allowed values are *SOQL* and *SOSL*. // NOTCONSOLE [discrete#es-connectors-salesforce-sync-rules-advanced-fetch-standard-custom-fields] -======= Fetch documents with standard and custom fields +*Fetch documents with standard and custom fields* **Example**: Fetch documents with all standard and custom fields for Account object. @@ -626,7 +626,7 @@ They take the following parameters: Allowed values are *SOQL* and *SOSL*. [discrete#es-connectors-salesforce-client-sync-rules-advanced-fetch-query-language] -======= Fetch documents based on the query and language specified +*Fetch documents based on the query and language specified* **Example**: Fetch documents using SOQL query @@ -655,7 +655,7 @@ Allowed values are *SOQL* and *SOSL*. // NOTCONSOLE [discrete#es-connectors-salesforce-client-sync-rules-advanced-fetch-objects] -======= Fetch standard and custom objects using SOQL and SOSL queries +*Fetch standard and custom objects using SOQL and SOSL queries* **Example**: Fetch documents for standard objects via SOQL and SOSL query. @@ -692,7 +692,7 @@ Allowed values are *SOQL* and *SOSL*. // NOTCONSOLE [discrete#es-connectors-salesforce-client-sync-rules-advanced-fetch-standard-custom-fields] -======= Fetch documents with standard and custom fields +*Fetch documents with standard and custom fields* **Example**: Fetch documents with all standard and custom fields for Account object. diff --git a/docs/reference/connector/docs/connectors-servicenow.asciidoc b/docs/reference/connector/docs/connectors-servicenow.asciidoc index a02c418f11d74..3dc98ed9a44c9 100644 --- a/docs/reference/connector/docs/connectors-servicenow.asciidoc +++ b/docs/reference/connector/docs/connectors-servicenow.asciidoc @@ -167,7 +167,7 @@ Advanced sync rules are defined through a source-specific DSL JSON snippet. The following sections provide examples of advanced sync rules for this connector. [discrete#es-connectors-servicenow-sync-rules-number-incident-service] -======= Indexing document based on incident number for Incident service +*Indexing document based on incident number for Incident service* [source,js] ---- @@ -181,7 +181,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-servicenow-sync-rules-active-false-user-service] -======= Indexing document based on user activity state for User service +*Indexing document based on user activity state for User service* [source,js] ---- @@ -195,7 +195,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-servicenow-sync-rules-author-administrator-knowledge-service] -======= Indexing document based on author name for Knowledge service +*Indexing document based on author name for Knowledge service* [source,js] ---- @@ -407,7 +407,7 @@ Advanced sync rules are defined through a source-specific DSL JSON snippet. The following sections provide examples of advanced sync rules for this connector. [discrete#es-connectors-servicenow-client-sync-rules-number-incident-service] -======= Indexing document based on incident number for Incident service +*Indexing document based on incident number for Incident service* [source,js] ---- @@ -421,7 +421,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-servicenow-client-sync-rules-active-false-user-service] -======= Indexing document based on user activity state for User service +*Indexing document based on user activity state for User service* [source,js] ---- @@ -435,7 +435,7 @@ The following sections provide examples of advanced sync rules for this connecto // NOTCONSOLE [discrete#es-connectors-servicenow-client-sync-rules-author-administrator-knowledge-service] -======= Indexing document based on author name for Knowledge service +*Indexing document based on author name for Knowledge service* [source,js] ---- diff --git a/docs/reference/connector/docs/connectors-sharepoint-online.asciidoc b/docs/reference/connector/docs/connectors-sharepoint-online.asciidoc index 21d0890e436c5..02f598c16f63c 100644 --- a/docs/reference/connector/docs/connectors-sharepoint-online.asciidoc +++ b/docs/reference/connector/docs/connectors-sharepoint-online.asciidoc @@ -277,7 +277,7 @@ Example: This rule will not extract content of any drive items (files in document libraries) that haven't been modified for 60 days or more. [discrete#es-connectors-sharepoint-online-sync-rules-limitations] -======= Limitations of sync rules with incremental syncs +*Limitations of sync rules with incremental syncs* Changing sync rules after Sharepoint Online content has already been indexed can bring unexpected results, when using <>. @@ -288,7 +288,7 @@ Incremental syncs ensure _updates_ from 3rd-party system, but do not modify exis Let's take a look at several examples where incremental syncs might lead to inconsistent data on your index. [discrete#es-connectors-sharepoint-online-sync-rules-limitations-restrictive-added] -======== Example: Restrictive basic sync rule added after a full sync +*Example: Restrictive basic sync rule added after a full sync* Imagine your Sharepoint Online drive contains the following drive items: @@ -322,7 +322,7 @@ If no files were changed, incremental sync will not receive information about ch After a *full sync*, the index will be updated and files that are excluded by sync rules will be removed. [discrete#es-connectors-sharepoint-online-sync-rules-limitations-restrictive-removed] -======== Example: Restrictive basic sync rules removed after a full sync +*Example: Restrictive basic sync rules removed after a full sync* Imagine that Sharepoint Online drive has the following drive items: @@ -354,7 +354,7 @@ Afterwards, we can remove the filtering rule and run an incremental sync. If no Only a *full sync* will include the items previously ignored by the sync rule. [discrete#es-connectors-sharepoint-online-sync-rules-limitations-restrictive-changed] -======== Example: Advanced sync rules edge case +*Example: Advanced sync rules edge case* Advanced sync rules can be applied to limit which documents will have content extracted. For example, it's possible to set a rule so that documents older than 180 days won't have content extracted. @@ -763,7 +763,7 @@ Example: This rule will not extract content of any drive items (files in document libraries) that haven't been modified for 60 days or more. [discrete#es-connectors-sharepoint-online-client-sync-rules-limitations] -======= Limitations of sync rules with incremental syncs +*Limitations of sync rules with incremental syncs* Changing sync rules after Sharepoint Online content has already been indexed can bring unexpected results, when using <>. @@ -774,7 +774,7 @@ Incremental syncs ensure _updates_ from 3rd-party system, but do not modify exis Let's take a look at several examples where incremental syncs might lead to inconsistent data on your index. [discrete#es-connectors-sharepoint-online-client-sync-rules-limitations-restrictive-added] -======== Example: Restrictive basic sync rule added after a full sync +*Example: Restrictive basic sync rule added after a full sync* Imagine your Sharepoint Online drive contains the following drive items: @@ -808,7 +808,7 @@ If no files were changed, incremental sync will not receive information about ch After a *full sync*, the index will be updated and files that are excluded by sync rules will be removed. [discrete#es-connectors-sharepoint-online-client-sync-rules-limitations-restrictive-removed] -======== Example: Restrictive basic sync rules removed after a full sync +*Example: Restrictive basic sync rules removed after a full sync* Imagine that Sharepoint Online drive has the following drive items: @@ -840,7 +840,7 @@ Afterwards, we can remove the filtering rule and run an incremental sync. If no Only a *full sync* will include the items previously ignored by the sync rule. [discrete#es-connectors-sharepoint-online-client-sync-rules-limitations-restrictive-changed] -======== Example: Advanced sync rules edge case +*Example: Advanced sync rules edge case* Advanced sync rules can be applied to limit which documents will have content extracted. For example, it's possible to set a rule so that documents older than 180 days won't have content extracted. diff --git a/docs/reference/esql/esql-async-query-get-api.asciidoc b/docs/reference/esql/esql-async-query-get-api.asciidoc index ec68313b2c490..82a6ae5b28b51 100644 --- a/docs/reference/esql/esql-async-query-get-api.asciidoc +++ b/docs/reference/esql/esql-async-query-get-api.asciidoc @@ -39,6 +39,10 @@ parameter is `true`. [[esql-async-query-get-api-query-params]] ==== {api-query-parms-title} +The API accepts the same parameters as the synchronous +<>, along with the following +parameters: + `wait_for_completion_timeout`:: (Optional, <>) Timeout duration to wait for the request to finish. Defaults to no timeout, diff --git a/docs/reference/esql/functions/description/term.asciidoc b/docs/reference/esql/functions/description/term.asciidoc new file mode 100644 index 0000000000000..c43aeb25a0ef7 --- /dev/null +++ b/docs/reference/esql/functions/description/term.asciidoc @@ -0,0 +1,5 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +Performs a Term query on the specified field. Returns true if the provided term matches the row. diff --git a/docs/reference/esql/functions/examples/term.asciidoc b/docs/reference/esql/functions/examples/term.asciidoc new file mode 100644 index 0000000000000..b9d57f366294b --- /dev/null +++ b/docs/reference/esql/functions/examples/term.asciidoc @@ -0,0 +1,13 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Example* + +[source.merge.styled,esql] +---- +include::{esql-specs}/term-function.csv-spec[tag=term-with-field] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/term-function.csv-spec[tag=term-with-field-result] +|=== + diff --git a/docs/reference/esql/functions/kibana/definition/equals.json b/docs/reference/esql/functions/kibana/definition/equals.json index 885d949f4b20f..40f3d54ba597a 100644 --- a/docs/reference/esql/functions/kibana/definition/equals.json +++ b/docs/reference/esql/functions/kibana/definition/equals.json @@ -77,6 +77,42 @@ "variadic" : false, "returnType" : "boolean" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, { "params" : [ { diff --git a/docs/reference/esql/functions/kibana/definition/greater_than.json b/docs/reference/esql/functions/kibana/definition/greater_than.json index cf6e30a0a4547..ea2c0fb1212c7 100644 --- a/docs/reference/esql/functions/kibana/definition/greater_than.json +++ b/docs/reference/esql/functions/kibana/definition/greater_than.json @@ -23,6 +23,42 @@ "variadic" : false, "returnType" : "boolean" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, { "params" : [ { diff --git a/docs/reference/esql/functions/kibana/definition/greater_than_or_equal.json b/docs/reference/esql/functions/kibana/definition/greater_than_or_equal.json index 2535c68af6acf..7e1feb37e87b0 100644 --- a/docs/reference/esql/functions/kibana/definition/greater_than_or_equal.json +++ b/docs/reference/esql/functions/kibana/definition/greater_than_or_equal.json @@ -23,6 +23,42 @@ "variadic" : false, "returnType" : "boolean" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, { "params" : [ { diff --git a/docs/reference/esql/functions/kibana/definition/less_than.json b/docs/reference/esql/functions/kibana/definition/less_than.json index a73754d200d46..71aae4d759ecf 100644 --- a/docs/reference/esql/functions/kibana/definition/less_than.json +++ b/docs/reference/esql/functions/kibana/definition/less_than.json @@ -23,6 +23,42 @@ "variadic" : false, "returnType" : "boolean" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, { "params" : [ { diff --git a/docs/reference/esql/functions/kibana/definition/less_than_or_equal.json b/docs/reference/esql/functions/kibana/definition/less_than_or_equal.json index 7af477db32a34..f119b7ab2eb12 100644 --- a/docs/reference/esql/functions/kibana/definition/less_than_or_equal.json +++ b/docs/reference/esql/functions/kibana/definition/less_than_or_equal.json @@ -23,6 +23,42 @@ "variadic" : false, "returnType" : "boolean" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, { "params" : [ { diff --git a/docs/reference/esql/functions/kibana/definition/not_equals.json b/docs/reference/esql/functions/kibana/definition/not_equals.json index 24f31115cbc37..d35a5b43ec238 100644 --- a/docs/reference/esql/functions/kibana/definition/not_equals.json +++ b/docs/reference/esql/functions/kibana/definition/not_equals.json @@ -77,6 +77,42 @@ "variadic" : false, "returnType" : "boolean" }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "lhs", + "type" : "date_nanos", + "optional" : false, + "description" : "An expression." + }, + { + "name" : "rhs", + "type" : "date", + "optional" : false, + "description" : "An expression." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, { "params" : [ { diff --git a/docs/reference/esql/functions/kibana/definition/term.json b/docs/reference/esql/functions/kibana/definition/term.json new file mode 100644 index 0000000000000..d8bb61fd596a1 --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/term.json @@ -0,0 +1,85 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "eval", + "name" : "term", + "description" : "Performs a Term query on the specified field. Returns true if the provided term matches the row.", + "signatures" : [ + { + "params" : [ + { + "name" : "field", + "type" : "keyword", + "optional" : false, + "description" : "Field that the query will target." + }, + { + "name" : "query", + "type" : "keyword", + "optional" : false, + "description" : "Term you wish to find in the provided field." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "field", + "type" : "keyword", + "optional" : false, + "description" : "Field that the query will target." + }, + { + "name" : "query", + "type" : "text", + "optional" : false, + "description" : "Term you wish to find in the provided field." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "field", + "type" : "text", + "optional" : false, + "description" : "Field that the query will target." + }, + { + "name" : "query", + "type" : "keyword", + "optional" : false, + "description" : "Term you wish to find in the provided field." + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "field", + "type" : "text", + "optional" : false, + "description" : "Field that the query will target." + }, + { + "name" : "query", + "type" : "text", + "optional" : false, + "description" : "Term you wish to find in the provided field." + } + ], + "variadic" : false, + "returnType" : "boolean" + } + ], + "examples" : [ + "from books \n| where term(author, \"gabriel\") \n| keep book_no, title\n| limit 3;" + ], + "preview" : true, + "snapshot_only" : true +} diff --git a/docs/reference/esql/functions/kibana/docs/term.md b/docs/reference/esql/functions/kibana/docs/term.md new file mode 100644 index 0000000000000..83e61a949208d --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/term.md @@ -0,0 +1,13 @@ + + +### TERM +Performs a Term query on the specified field. Returns true if the provided term matches the row. + +``` +from books +| where term(author, "gabriel") +| keep book_no, title +| limit 3; +``` diff --git a/docs/reference/esql/functions/layout/term.asciidoc b/docs/reference/esql/functions/layout/term.asciidoc new file mode 100644 index 0000000000000..1fe94491bed04 --- /dev/null +++ b/docs/reference/esql/functions/layout/term.asciidoc @@ -0,0 +1,17 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-term]] +=== `TERM` + +preview::["Do not use on production environments. This functionality is in technical preview and may be changed or removed in a future release. Elastic will work to fix any issues, but features in technical preview are not subject to the support SLA of official GA features."] + +*Syntax* + +[.text-center] +image::esql/functions/signature/term.svg[Embedded,opts=inline] + +include::../parameters/term.asciidoc[] +include::../description/term.asciidoc[] +include::../types/term.asciidoc[] +include::../examples/term.asciidoc[] diff --git a/docs/reference/esql/functions/parameters/term.asciidoc b/docs/reference/esql/functions/parameters/term.asciidoc new file mode 100644 index 0000000000000..edba8625d04c5 --- /dev/null +++ b/docs/reference/esql/functions/parameters/term.asciidoc @@ -0,0 +1,9 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`field`:: +Field that the query will target. + +`query`:: +Term you wish to find in the provided field. diff --git a/docs/reference/esql/functions/signature/term.svg b/docs/reference/esql/functions/signature/term.svg new file mode 100644 index 0000000000000..955dd7fa215ab --- /dev/null +++ b/docs/reference/esql/functions/signature/term.svg @@ -0,0 +1 @@ +TERM(field,query) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/equals.asciidoc b/docs/reference/esql/functions/types/equals.asciidoc index 8d48b7ebf084a..1bb8bf2122b35 100644 --- a/docs/reference/esql/functions/types/equals.asciidoc +++ b/docs/reference/esql/functions/types/equals.asciidoc @@ -9,6 +9,8 @@ boolean | boolean | boolean cartesian_point | cartesian_point | boolean cartesian_shape | cartesian_shape | boolean date | date | boolean +date | date_nanos | boolean +date_nanos | date | boolean date_nanos | date_nanos | boolean double | double | boolean double | integer | boolean diff --git a/docs/reference/esql/functions/types/greater_than.asciidoc b/docs/reference/esql/functions/types/greater_than.asciidoc index 8000fd34c8507..39253ac445f42 100644 --- a/docs/reference/esql/functions/types/greater_than.asciidoc +++ b/docs/reference/esql/functions/types/greater_than.asciidoc @@ -6,6 +6,8 @@ |=== lhs | rhs | result date | date | boolean +date | date_nanos | boolean +date_nanos | date | boolean date_nanos | date_nanos | boolean double | double | boolean double | integer | boolean diff --git a/docs/reference/esql/functions/types/greater_than_or_equal.asciidoc b/docs/reference/esql/functions/types/greater_than_or_equal.asciidoc index 8000fd34c8507..39253ac445f42 100644 --- a/docs/reference/esql/functions/types/greater_than_or_equal.asciidoc +++ b/docs/reference/esql/functions/types/greater_than_or_equal.asciidoc @@ -6,6 +6,8 @@ |=== lhs | rhs | result date | date | boolean +date | date_nanos | boolean +date_nanos | date | boolean date_nanos | date_nanos | boolean double | double | boolean double | integer | boolean diff --git a/docs/reference/esql/functions/types/less_than.asciidoc b/docs/reference/esql/functions/types/less_than.asciidoc index 8000fd34c8507..39253ac445f42 100644 --- a/docs/reference/esql/functions/types/less_than.asciidoc +++ b/docs/reference/esql/functions/types/less_than.asciidoc @@ -6,6 +6,8 @@ |=== lhs | rhs | result date | date | boolean +date | date_nanos | boolean +date_nanos | date | boolean date_nanos | date_nanos | boolean double | double | boolean double | integer | boolean diff --git a/docs/reference/esql/functions/types/less_than_or_equal.asciidoc b/docs/reference/esql/functions/types/less_than_or_equal.asciidoc index 8000fd34c8507..39253ac445f42 100644 --- a/docs/reference/esql/functions/types/less_than_or_equal.asciidoc +++ b/docs/reference/esql/functions/types/less_than_or_equal.asciidoc @@ -6,6 +6,8 @@ |=== lhs | rhs | result date | date | boolean +date | date_nanos | boolean +date_nanos | date | boolean date_nanos | date_nanos | boolean double | double | boolean double | integer | boolean diff --git a/docs/reference/esql/functions/types/not_equals.asciidoc b/docs/reference/esql/functions/types/not_equals.asciidoc index 8d48b7ebf084a..1bb8bf2122b35 100644 --- a/docs/reference/esql/functions/types/not_equals.asciidoc +++ b/docs/reference/esql/functions/types/not_equals.asciidoc @@ -9,6 +9,8 @@ boolean | boolean | boolean cartesian_point | cartesian_point | boolean cartesian_shape | cartesian_shape | boolean date | date | boolean +date | date_nanos | boolean +date_nanos | date | boolean date_nanos | date_nanos | boolean double | double | boolean double | integer | boolean diff --git a/docs/reference/esql/functions/types/term.asciidoc b/docs/reference/esql/functions/types/term.asciidoc new file mode 100644 index 0000000000000..7523b29c62b1d --- /dev/null +++ b/docs/reference/esql/functions/types/term.asciidoc @@ -0,0 +1,12 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +field | query | result +keyword | keyword | boolean +keyword | text | boolean +text | keyword | boolean +text | text | boolean +|=== diff --git a/docs/reference/inference/inference-apis.asciidoc b/docs/reference/inference/inference-apis.asciidoc index 037d7abeb2a36..c7b779a994a05 100644 --- a/docs/reference/inference/inference-apis.asciidoc +++ b/docs/reference/inference/inference-apis.asciidoc @@ -35,6 +35,19 @@ Elastic –, then create an {infer} endpoint by the <>. Now use <> to perform <> on your data. +[discrete] +[[adaptive-allocations]] +=== Adaptive allocations + +Adaptive allocations allow inference services to dynamically adjust the number of model allocations based on the current load. + +When adaptive allocations are enabled: + +* The number of allocations scales up automatically when the load increases. +- Allocations scale down to a minimum of 0 when the load decreases, saving resources. + +For more information about adaptive allocations and resources, refer to the {ml-docs}/ml-nlp-auto-scale.html[trained model autoscaling] documentation. + //[discrete] //[[default-enpoints]] //=== Default {infer} endpoints diff --git a/docs/reference/inference/put-inference.asciidoc b/docs/reference/inference/put-inference.asciidoc index e7e25ec98b49d..4f82889f562d8 100644 --- a/docs/reference/inference/put-inference.asciidoc +++ b/docs/reference/inference/put-inference.asciidoc @@ -10,7 +10,6 @@ Creates an {infer} endpoint to perform an {infer} task. * For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. ==== - [discrete] [[put-inference-api-request]] ==== {api-request-title} @@ -47,6 +46,14 @@ Refer to the service list in the <> API. In the response, look for `"state": "fully_allocated"` and ensure the `"allocation_count"` matches the `"target_allocation_count"`. +* Avoid creating multiple endpoints for the same model unless required, as each endpoint consumes significant resources. +==== + + The following services are available through the {infer} API. You can find the available task types next to the service name. Click the links to review the configuration details of the services: @@ -67,4 +74,17 @@ Click the links to review the configuration details of the services: * <> (`text_embedding`) The {es} and ELSER services run on a {ml} node in your {es} cluster. The rest of -the services connect to external providers. \ No newline at end of file +the services connect to external providers. + +[discrete] +[[adaptive-allocations-put-inference]] +==== Adaptive allocations + +Adaptive allocations allow inference services to dynamically adjust the number of model allocations based on the current load. + +When adaptive allocations are enabled: + +- The number of allocations scales up automatically when the load increases. +- Allocations scale down to a minimum of 0 when the load decreases, saving resources. + +For more information about adaptive allocations and resources, refer to the {ml-docs}/ml-nlp-auto-scale.html[trained model autoscaling] documentation. \ No newline at end of file diff --git a/docs/reference/ingest/processors/inference.asciidoc b/docs/reference/ingest/processors/inference.asciidoc index 9c6f0592a1d91..e079b9d665290 100644 --- a/docs/reference/ingest/processors/inference.asciidoc +++ b/docs/reference/ingest/processors/inference.asciidoc @@ -735,3 +735,70 @@ You can also specify the target field as follows: In this case, {feat-imp} is exposed in the `my_field.foo.feature_importance` field. + + +[discrete] +[[inference-processor-examples]] +==== {infer-cap} processor examples + +The following example uses an <> in an {infer} processor named `query_helper_pipeline` to perform a chat completion task. +The processor generates an {es} query from natural language input using a prompt designed for a completion task type. +Refer to <> for the {infer} service you use and check the corresponding examples of setting up an endpoint with the chat completion task type. + + +[source,console] +-------------------------------------------------- +PUT _ingest/pipeline/query_helper_pipeline +{ + "processors": [ + { + "script": { + "source": "ctx.prompt = 'Please generate an elasticsearch search query on index `articles_index` for the following natural language query. Dates are in the field `@timestamp`, document types are in the field `type` (options are `news`, `publication`), categories in the field `category` and can be multiple (options are `medicine`, `pharmaceuticals`, `technology`), and document names are in the field `title` which should use a fuzzy match. Ignore fields which cannot be determined from the natural language query context: ' + ctx.content" <1> + } + }, + { + "inference": { + "model_id": "openai_chat_completions", <2> + "input_output": { + "input_field": "prompt", + "output_field": "query" + } + } + }, + { + "remove": { + "field": "prompt" + } + } + ] +} +-------------------------------------------------- +// TEST[skip: An inference endpoint is required.] +<1> The `prompt` field contains the prompt used for the completion task, created with <>. +`+ ctx.content` appends the natural language input to the prompt. +<2> The ID of the pre-configured {infer} endpoint, which utilizes the <> with the `completion` task type. + +The following API request will simulate running a document through the ingest pipeline created previously: + +[source,console] +-------------------------------------------------- +POST _ingest/pipeline/query_helper_pipeline/_simulate +{ + "docs": [ + { + "_source": { + "content": "artificial intelligence in medicine articles published in the last 12 months" <1> + } + } + ] +} +-------------------------------------------------- +// TEST[skip: An inference processor with an inference endpoint is required.] +<1> The natural language query used to generate an {es} query within the prompt created by the {infer} processor. + + +[discrete] +[[infer-proc-readings]] +==== Further readings + +* https://www.elastic.co/search-labs/blog/openwebcrawler-llms-semantic-text-resume-job-search[Which job is the best for you? Using LLMs and semantic_text to match resumes to jobs] \ No newline at end of file diff --git a/docs/reference/mapping/types/semantic-text.asciidoc b/docs/reference/mapping/types/semantic-text.asciidoc index f76a9352c2fe8..96dc402e10c60 100644 --- a/docs/reference/mapping/types/semantic-text.asciidoc +++ b/docs/reference/mapping/types/semantic-text.asciidoc @@ -12,13 +12,14 @@ Long passages are <> to smaller secti The `semantic_text` field type specifies an inference endpoint identifier that will be used to generate embeddings. You can create the inference endpoint by using the <>. -This field type and the <> type make it simpler to perform semantic search on your data. -If you don't specify an inference endpoint, the <> is used by default. +This field type and the <> type make it simpler to perform semantic search on your data. + +If you don’t specify an inference endpoint, the `inference_id` field defaults to `.elser-2-elasticsearch`, a preconfigured endpoint for the elasticsearch service. Using `semantic_text`, you won't need to specify how to generate embeddings for your data, or how to index it. The {infer} endpoint automatically determines the embedding generation, indexing, and query to use. -If you use the ELSER service, you can set up `semantic_text` with the following API request: +If you use the preconfigured `.elser-2-elasticsearch` endpoint, you can set up `semantic_text` with the following API request: [source,console] ------------------------------------------------------------ @@ -34,7 +35,7 @@ PUT my-index-000001 } ------------------------------------------------------------ -If you use a service other than ELSER, you must create an {infer} endpoint using the <> and reference it when setting up `semantic_text` as the following example demonstrates: +To use a custom {infer} endpoint instead of the default `.elser-2-elasticsearch`, you must <> and specify its `inference_id` when setting up the `semantic_text` field type. [source,console] ------------------------------------------------------------ @@ -53,8 +54,7 @@ PUT my-index-000002 // TEST[skip:Requires inference endpoint] <1> The `inference_id` of the {infer} endpoint to use to generate embeddings. - -The recommended way to use semantic_text is by having dedicated {infer} endpoints for ingestion and search. +The recommended way to use `semantic_text` is by having dedicated {infer} endpoints for ingestion and search. This ensures that search speed remains unaffected by ingestion workloads, and vice versa. After creating dedicated {infer} endpoints for both, you can reference them using the `inference_id` and `search_inference_id` parameters when setting up the index mapping for an index that uses the `semantic_text` field. @@ -82,10 +82,11 @@ PUT my-index-000003 `inference_id`:: (Required, string) -{infer-cap} endpoint that will be used to generate the embeddings for the field. +{infer-cap} endpoint that will be used to generate embeddings for the field. +By default, `.elser-2-elasticsearch` is used. This parameter cannot be updated. Use the <> to create the endpoint. -If `search_inference_id` is specified, the {infer} endpoint defined by `inference_id` will only be used at index time. +If `search_inference_id` is specified, the {infer} endpoint will only be used at index time. `search_inference_id`:: (Optional, string) @@ -112,50 +113,43 @@ Trying to <> that is used on a {infer-cap} endpoints have a limit on the amount of text they can process. To allow for large amounts of text to be used in semantic search, `semantic_text` automatically generates smaller passages if needed, called _chunks_. -Each chunk will include the text subpassage and the corresponding embedding generated from it. +Each chunk refers to a passage of the text and the corresponding embedding generated from it. When querying, the individual passages will be automatically searched for each document, and the most relevant passage will be used to compute a score. For more details on chunking and how to configure chunking settings, see <> in the Inference API documentation. +Refer to <> to learn more about +semantic search using `semantic_text` and the `semantic` query. [discrete] -[[semantic-text-structure]] -==== `semantic_text` structure +[[semantic-text-highlighting]] +==== Extracting Relevant Fragments from Semantic Text -Once a document is ingested, a `semantic_text` field will have the following structure: +You can extract the most relevant fragments from a semantic text field by using the <> in the <>. -[source,console-result] +[source,console] ------------------------------------------------------------ -"inference_field": { - "text": "these are not the droids you're looking for", <1> - "inference": { - "inference_id": "my-elser-endpoint", <2> - "model_settings": { <3> - "task_type": "sparse_embedding" +PUT test-index +{ + "query": { + "semantic": { + "field": "my_semantic_field" + } }, - "chunks": [ <4> - { - "text": "these are not the droids you're looking for", - "embeddings": { - (...) + "highlight": { + "fields": { + "my_semantic_field": { + "type": "semantic", + "number_of_fragments": 2, <1> + "order": "score" <2> + } } - } - ] - } + } } ------------------------------------------------------------ -// TEST[skip:TBD] -<1> The field will become an object structure to accommodate both the original -text and the inference results. -<2> The `inference_id` used to generate the embeddings. -<3> Model settings, including the task type and dimensions/similarity if -applicable. -<4> Inference results will be grouped in chunks, each with its corresponding -text and embeddings. - -Refer to <> to learn more about -semantic search using `semantic_text` and the `semantic` query. - +// TEST[skip:Requires inference endpoint] +<1> Specifies the maximum number of fragments to return. +<2> Sorts highlighted fragments by score when set to `score`. By default, fragments will be output in the order they appear in the field (order: none). [discrete] [[custom-indexing]] @@ -208,7 +202,7 @@ PUT test-index "properties": { "infer_field": { "type": "semantic_text", - "inference_id": "my-elser-endpoint" + "inference_id": ".elser-2-elasticsearch" }, "source_field": { "type": "text", diff --git a/docs/reference/mapping/types/sparse-vector.asciidoc b/docs/reference/mapping/types/sparse-vector.asciidoc index b24f65fcf97ca..22d4644ede490 100644 --- a/docs/reference/mapping/types/sparse-vector.asciidoc +++ b/docs/reference/mapping/types/sparse-vector.asciidoc @@ -26,6 +26,23 @@ PUT my-index See <> for a complete example on adding documents to a `sparse_vector` mapped field using ELSER. +[[sparse-vectors-params]] +==== Parameters for `sparse_vector` fields + +The following parameters are accepted by `sparse_vector` fields: + +[horizontal] + +<>:: + +Indicates whether the field value should be stored and retrievable independently of the <> field. +Accepted values: true or false (default). +The field's data is stored using term vectors, a disk-efficient structure compared to the original JSON input. +The input map can be retrieved during a search request via the <>. +To benefit from reduced disk usage, you must either: + * Exclude the field from <>. + * Use <>. + [[index-multi-value-sparse-vectors]] ==== Multi-value sparse vectors diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index d01047eac9815..4948db48664ed 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -18,7 +18,8 @@ end::adaptive-allocation-max-number[] tag::adaptive-allocation-min-number[] Specifies the minimum number of allocations to scale to. -If set, it must be greater than or equal to `1`. +If set, it must be greater than or equal to `0`. +If not defined, the deployment scales to `0`. end::adaptive-allocation-min-number[] tag::aggregations[] diff --git a/docs/reference/search/retriever.asciidoc b/docs/reference/search/retriever.asciidoc index b90b7e312c790..cb04d4fb6fbf1 100644 --- a/docs/reference/search/retriever.asciidoc +++ b/docs/reference/search/retriever.asciidoc @@ -765,11 +765,11 @@ clauses in a <>. [[retriever-restrictions]] ==== Restrictions on search parameters when specifying a retriever -When a retriever is specified as part of a search, the following elements are not allowed at the top-level. -Instead they are only allowed as elements of specific retrievers: +When a retriever is specified as part of a search, the following elements are not allowed at the top-level: * <> * <> * <> * <> * <> +* <> diff --git a/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc b/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc index ba9c81db21384..3448940b6fad7 100644 --- a/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc +++ b/docs/reference/search/search-your-data/semantic-search-semantic-text.asciidoc @@ -14,15 +14,15 @@ You don't need to define model related settings and parameters, or create {infer The recommended way to use <> in the {stack} is following the `semantic_text` workflow. When you need more control over indexing and query settings, you can still use the complete {infer} workflow (refer to <> to review the process). -This tutorial uses the <> for demonstration, but you can use any service and their supported models offered by the {infer-cap} API. +This tutorial uses the <> for demonstration, but you can use any service and their supported models offered by the {infer-cap} API. [discrete] [[semantic-text-requirements]] ==== Requirements -This tutorial uses the <> for demonstration, which is created automatically as needed. -To use the `semantic_text` field type with an {infer} service other than ELSER, you must create an inference endpoint using the <>. +This tutorial uses the <> for demonstration, which is created automatically as needed. +To use the `semantic_text` field type with an {infer} service other than `elasticsearch` service, you must create an inference endpoint using the <>. [discrete] @@ -48,7 +48,7 @@ PUT semantic-embeddings // TEST[skip:TBD] <1> The name of the field to contain the generated embeddings. <2> The field to contain the embeddings is a `semantic_text` field. -Since no `inference_id` is provided, the <> is used by default. +Since no `inference_id` is provided, the default endpoint `.elser-2-elasticsearch` for the <> is used. To use a different {infer} service, you must create an {infer} endpoint first using the <> and then specify it in the `semantic_text` field mapping using the `inference_id` parameter. [NOTE] diff --git a/docs/reference/search/search-your-data/semantic-text-hybrid-search b/docs/reference/search/search-your-data/semantic-text-hybrid-search index c56b283434df5..4b49a7c3155db 100644 --- a/docs/reference/search/search-your-data/semantic-text-hybrid-search +++ b/docs/reference/search/search-your-data/semantic-text-hybrid-search @@ -8,47 +8,12 @@ This tutorial demonstrates how to perform hybrid search, combining semantic sear In hybrid search, semantic search retrieves results based on the meaning of the text, while full-text search focuses on exact word matches. By combining both methods, hybrid search delivers more relevant results, particularly in cases where relying on a single approach may not be sufficient. -The recommended way to use hybrid search in the {stack} is following the `semantic_text` workflow. This tutorial uses the <> for demonstration, but you can use any service and its supported models offered by the {infer-cap} API. - -[discrete] -[[semantic-text-hybrid-infer-endpoint]] -==== Create the {infer} endpoint - -Create an inference endpoint by using the <>: - -[source,console] ------------------------------------------------------------- -PUT _inference/sparse_embedding/my-elser-endpoint <1> -{ - "service": "elser", <2> - "service_settings": { - "adaptive_allocations": { <3> - "enabled": true, - "min_number_of_allocations": 3, - "max_number_of_allocations": 10 - }, - "num_threads": 1 - } -} ------------------------------------------------------------- -// TEST[skip:TBD] -<1> The task type is `sparse_embedding` in the path as the `elser` service will -be used and ELSER creates sparse vectors. The `inference_id` is -`my-elser-endpoint`. -<2> The `elser` service is used in this example. -<3> This setting enables and configures adaptive allocations. -Adaptive allocations make it possible for ELSER to automatically scale up or down resources based on the current load on the process. - -[NOTE] -==== -You might see a 502 bad gateway error in the response when using the {kib} Console. -This error usually just reflects a timeout, while the model downloads in the background. -You can check the download progress in the {ml-app} UI. -==== +The recommended way to use hybrid search in the {stack} is following the `semantic_text` workflow. +This tutorial uses the <> for demonstration, but you can use any service and their supported models offered by the {infer-cap} API. [discrete] [[hybrid-search-create-index-mapping]] -==== Create an index mapping for hybrid search +==== Create an index mapping The destination index will contain both the embeddings for semantic search and the original text field for full-text search. This structure enables the combination of semantic search and full-text search. @@ -60,11 +25,10 @@ PUT semantic-embeddings "properties": { "semantic_text": { <1> "type": "semantic_text", - "inference_id": "my-elser-endpoint" <2> }, - "content": { <3> + "content": { <2> "type": "text", - "copy_to": "semantic_text" <4> + "copy_to": "semantic_text" <3> } } } @@ -72,9 +36,8 @@ PUT semantic-embeddings ------------------------------------------------------------ // TEST[skip:TBD] <1> The name of the field to contain the generated embeddings for semantic search. -<2> The identifier of the inference endpoint that generates the embeddings based on the input text. -<3> The name of the field to contain the original text for lexical search. -<4> The textual data stored in the `content` field will be copied to `semantic_text` and processed by the {infer} endpoint. +<2> The name of the field to contain the original text for lexical search. +<3> The textual data stored in the `content` field will be copied to `semantic_text` and processed by the {infer} endpoint. [NOTE] ==== diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index 37178fd9439d0..9189d2a27f3f3 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -4383,11 +4383,6 @@ - - - - - @@ -4408,9 +4403,9 @@ - - - + + + @@ -4433,11 +4428,6 @@ - - - - - @@ -4478,11 +4468,6 @@ - - - - - @@ -4493,11 +4478,11 @@ - - - - - + + + + + diff --git a/libs/entitlement/asm-provider/build.gradle b/libs/entitlement/asm-provider/build.gradle index 5f968629fe557..dcec0579a5bae 100644 --- a/libs/entitlement/asm-provider/build.gradle +++ b/libs/entitlement/asm-provider/build.gradle @@ -11,10 +11,10 @@ apply plugin: 'elasticsearch.build' dependencies { compileOnly project(':libs:entitlement') - implementation 'org.ow2.asm:asm:9.7' + implementation 'org.ow2.asm:asm:9.7.1' testImplementation project(":test:framework") testImplementation project(":libs:entitlement:bridge") - testImplementation 'org.ow2.asm:asm-util:9.7' + testImplementation 'org.ow2.asm:asm-util:9.7.1' } tasks.named('test').configure { diff --git a/libs/entitlement/entitlements-loading.svg b/libs/entitlement/entitlements-loading.svg new file mode 100644 index 0000000000000..4f0213b853bee --- /dev/null +++ b/libs/entitlement/entitlements-loading.svg @@ -0,0 +1,4 @@ + + + +
ES main
ES main
Boot Loader
Boot Loader
Platform Loader
Platform Loader
System Loader
System Loader
reflection
reflection
Agent Jar
Agent Jar
Server
Server
(Instrumented)
JDK classes
(Instrumented)...
agent main
(in unnamed module)
agent main...
entitlements ready
entitlements ready
reflection
reflection
Bridge
(patched into java.base)
Bridge...
Entitlements
Entitlements
Entitlements bootstrap
Entitlements bootstrap
  • Grant access to unnamed module
  • Set (static, protected) init arguments
  • Load agent
Grant access to unnamed modu...
(reflectively) call 
entitlements init
with Instrumentation
(reflectively) call...
Entitlements init
Entitlements init
  • Load plugin policies
  • Load server policy
  • Create entitlements manager
    • Policies
    • Method to lookup plugin by Module
  • Set entitlements manager in static (accessible by bridge)
  • Instrument jdk classes
  • run self test (force bridge to capture entitlements manager)
Load plugin policiesLoad server policyCreate e...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/libs/entitlement/src/main/java/module-info.java b/libs/entitlement/src/main/java/module-info.java index 54075ba60bbef..b8a125b98e641 100644 --- a/libs/entitlement/src/main/java/module-info.java +++ b/libs/entitlement/src/main/java/module-info.java @@ -17,6 +17,7 @@ requires static org.elasticsearch.entitlement.bridge; // At runtime, this will be in java.base exports org.elasticsearch.entitlement.runtime.api; + exports org.elasticsearch.entitlement.runtime.policy; exports org.elasticsearch.entitlement.instrumentation; exports org.elasticsearch.entitlement.bootstrap to org.elasticsearch.server; exports org.elasticsearch.entitlement.initialization to java.base; diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java index 0ffab5f93969f..fb694308466c6 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/initialization/EntitlementInitialization.java @@ -18,6 +18,8 @@ import org.elasticsearch.entitlement.instrumentation.MethodKey; import org.elasticsearch.entitlement.instrumentation.Transformer; import org.elasticsearch.entitlement.runtime.api.ElasticsearchEntitlementChecker; +import org.elasticsearch.entitlement.runtime.policy.CreateClassLoaderEntitlement; +import org.elasticsearch.entitlement.runtime.policy.ExitVMEntitlement; import org.elasticsearch.entitlement.runtime.policy.Policy; import org.elasticsearch.entitlement.runtime.policy.PolicyManager; import org.elasticsearch.entitlement.runtime.policy.PolicyParser; @@ -86,9 +88,11 @@ private static Class internalNameToClass(String internalName) { private static PolicyManager createPolicyManager() throws IOException { Map pluginPolicies = createPluginPolicies(EntitlementBootstrap.bootstrapArgs().pluginData()); - // TODO: What should the name be? // TODO(ES-10031): Decide what goes in the elasticsearch default policy and extend it - var serverPolicy = new Policy("server", List.of()); + var serverPolicy = new Policy( + "server", + List.of(new Scope("org.elasticsearch.server", List.of(new ExitVMEntitlement(), new CreateClassLoaderEntitlement()))) + ); return new PolicyManager(serverPolicy, pluginPolicies, EntitlementBootstrap.bootstrapArgs().pluginResolver()); } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java index 28a080470c043..aa63b630ed7cd 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/api/ElasticsearchEntitlementChecker.java @@ -10,7 +10,6 @@ package org.elasticsearch.entitlement.runtime.api; import org.elasticsearch.entitlement.bridge.EntitlementChecker; -import org.elasticsearch.entitlement.runtime.policy.FlagEntitlementType; import org.elasticsearch.entitlement.runtime.policy.PolicyManager; import java.net.URL; @@ -30,27 +29,27 @@ public ElasticsearchEntitlementChecker(PolicyManager policyManager) { @Override public void check$java_lang_System$exit(Class callerClass, int status) { - policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.SYSTEM_EXIT); + policyManager.checkExitVM(callerClass); } @Override public void check$java_net_URLClassLoader$(Class callerClass, URL[] urls) { - policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.CREATE_CLASSLOADER); + policyManager.checkCreateClassLoader(callerClass); } @Override public void check$java_net_URLClassLoader$(Class callerClass, URL[] urls, ClassLoader parent) { - policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.CREATE_CLASSLOADER); + policyManager.checkCreateClassLoader(callerClass); } @Override public void check$java_net_URLClassLoader$(Class callerClass, URL[] urls, ClassLoader parent, URLStreamHandlerFactory factory) { - policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.CREATE_CLASSLOADER); + policyManager.checkCreateClassLoader(callerClass); } @Override public void check$java_net_URLClassLoader$(Class callerClass, String name, URL[] urls, ClassLoader parent) { - policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.CREATE_CLASSLOADER); + policyManager.checkCreateClassLoader(callerClass); } @Override @@ -61,6 +60,6 @@ public ElasticsearchEntitlementChecker(PolicyManager policyManager) { ClassLoader parent, URLStreamHandlerFactory factory ) { - policyManager.checkFlagEntitlement(callerClass, FlagEntitlementType.CREATE_CLASSLOADER); + policyManager.checkCreateClassLoader(callerClass); } } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/CreateClassLoaderEntitlement.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/CreateClassLoaderEntitlement.java new file mode 100644 index 0000000000000..138515be9ffcb --- /dev/null +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/CreateClassLoaderEntitlement.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.runtime.policy; + +public class CreateClassLoaderEntitlement implements Entitlement { + @ExternalEntitlement + public CreateClassLoaderEntitlement() {} +} diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FlagEntitlementType.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/ExitVMEntitlement.java similarity index 79% rename from libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FlagEntitlementType.java rename to libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/ExitVMEntitlement.java index d40235ee12166..c4a8fc6833581 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FlagEntitlementType.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/ExitVMEntitlement.java @@ -9,7 +9,7 @@ package org.elasticsearch.entitlement.runtime.policy; -public enum FlagEntitlementType { - SYSTEM_EXIT, - CREATE_CLASSLOADER; -} +/** + * Internal policy type (not-parseable -- not available to plugins). + */ +public class ExitVMEntitlement implements Entitlement {} diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FileEntitlement.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FileEntitlement.java index 8df199591d3e4..d0837bc096183 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FileEntitlement.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/FileEntitlement.java @@ -20,6 +20,9 @@ public class FileEntitlement implements Entitlement { public static final int READ_ACTION = 0x1; public static final int WRITE_ACTION = 0x2; + public static final String READ = "read"; + public static final String WRITE = "write"; + private final String path; private final int actions; @@ -29,12 +32,12 @@ public FileEntitlement(String path, List actionsList) { int actionsInt = 0; for (String actionString : actionsList) { - if ("read".equals(actionString)) { + if (READ.equals(actionString)) { if ((actionsInt & READ_ACTION) == READ_ACTION) { throw new IllegalArgumentException("file action [read] specified multiple times"); } actionsInt |= READ_ACTION; - } else if ("write".equals(actionString)) { + } else if (WRITE.equals(actionString)) { if ((actionsInt & WRITE_ACTION) == WRITE_ACTION) { throw new IllegalArgumentException("file action [write] specified multiple times"); } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java index b3fb5b75a1d5a..a77c86d5ffd04 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyManager.java @@ -17,17 +17,45 @@ import java.lang.module.ModuleFinder; import java.lang.module.ModuleReference; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; public class PolicyManager { private static final Logger logger = LogManager.getLogger(ElasticsearchEntitlementChecker.class); + static class ModuleEntitlements { + public static final ModuleEntitlements NONE = new ModuleEntitlements(List.of()); + private final IdentityHashMap, List> entitlementsByType; + + ModuleEntitlements(List entitlements) { + this.entitlementsByType = entitlements.stream() + .collect(Collectors.toMap(Entitlement::getClass, e -> new ArrayList<>(List.of(e)), (a, b) -> { + a.addAll(b); + return a; + }, IdentityHashMap::new)); + } + + public boolean hasEntitlement(Class entitlementClass) { + return entitlementsByType.containsKey(entitlementClass); + } + + public Stream getEntitlements(Class entitlementClass) { + return entitlementsByType.get(entitlementClass).stream().map(entitlementClass::cast); + } + } + + final Map moduleEntitlementsMap = new HashMap<>(); + protected final Policy serverPolicy; protected final Map pluginPolicies; private final Function, String> pluginResolver; @@ -56,27 +84,110 @@ public PolicyManager(Policy defaultPolicy, Map pluginPolicies, F this.pluginResolver = pluginResolver; } - public void checkFlagEntitlement(Class callerClass, FlagEntitlementType type) { + private static List lookupEntitlementsForModule(Policy policy, String moduleName) { + for (int i = 0; i < policy.scopes.size(); ++i) { + var scope = policy.scopes.get(i); + if (scope.name.equals(moduleName)) { + return scope.entitlements; + } + } + return null; + } + + public void checkExitVM(Class callerClass) { + checkEntitlementPresent(callerClass, ExitVMEntitlement.class); + } + + public void checkCreateClassLoader(Class callerClass) { + checkEntitlementPresent(callerClass, CreateClassLoaderEntitlement.class); + } + + private void checkEntitlementPresent(Class callerClass, Class entitlementClass) { var requestingModule = requestingModule(callerClass); if (isTriviallyAllowed(requestingModule)) { return; } - // TODO: real policy check. For now, we only allow our hardcoded System.exit policy for server. - // TODO: this will be checked using policies - if (requestingModule.isNamed() - && requestingModule.getName().equals("org.elasticsearch.server") - && (type == FlagEntitlementType.SYSTEM_EXIT || type == FlagEntitlementType.CREATE_CLASSLOADER)) { - logger.debug("Allowed: caller [{}] in module [{}] has entitlement [{}]", callerClass, requestingModule.getName(), type); + ModuleEntitlements entitlements = getEntitlementsOrThrow(callerClass, requestingModule); + if (entitlements.hasEntitlement(entitlementClass)) { + logger.debug( + () -> Strings.format( + "Entitled: caller [%s], module [%s], type [%s]", + callerClass, + requestingModule.getName(), + entitlementClass.getSimpleName() + ) + ); return; } - - // TODO: plugins policy check using pluginResolver and pluginPolicies throw new NotEntitledException( - Strings.format("Missing entitlement [%s] for caller [%s] in module [%s]", type, callerClass, requestingModule.getName()) + Strings.format( + "Missing entitlement: caller [%s], module [%s], type [%s]", + callerClass, + requestingModule.getName(), + entitlementClass.getSimpleName() + ) ); } + ModuleEntitlements getEntitlementsOrThrow(Class callerClass, Module requestingModule) { + ModuleEntitlements cachedEntitlement = moduleEntitlementsMap.get(requestingModule); + if (cachedEntitlement != null) { + if (cachedEntitlement == ModuleEntitlements.NONE) { + throw new NotEntitledException(buildModuleNoPolicyMessage(callerClass, requestingModule) + "[CACHED]"); + } + return cachedEntitlement; + } + + if (isServerModule(requestingModule)) { + var scopeName = requestingModule.getName(); + return getModuleEntitlementsOrThrow(callerClass, requestingModule, serverPolicy, scopeName); + } + + // plugins + var pluginName = pluginResolver.apply(callerClass); + if (pluginName != null) { + var pluginPolicy = pluginPolicies.get(pluginName); + if (pluginPolicy != null) { + final String scopeName; + if (requestingModule.isNamed() == false) { + scopeName = ALL_UNNAMED; + } else { + scopeName = requestingModule.getName(); + } + return getModuleEntitlementsOrThrow(callerClass, requestingModule, pluginPolicy, scopeName); + } + } + + moduleEntitlementsMap.put(requestingModule, ModuleEntitlements.NONE); + throw new NotEntitledException(buildModuleNoPolicyMessage(callerClass, requestingModule)); + } + + private static String buildModuleNoPolicyMessage(Class callerClass, Module requestingModule) { + return Strings.format("Missing entitlement policy: caller [%s], module [%s]", callerClass, requestingModule.getName()); + } + + private ModuleEntitlements getModuleEntitlementsOrThrow(Class callerClass, Module module, Policy policy, String moduleName) { + var entitlements = lookupEntitlementsForModule(policy, moduleName); + if (entitlements == null) { + // Module without entitlements - remember we don't have any + moduleEntitlementsMap.put(module, ModuleEntitlements.NONE); + throw new NotEntitledException(buildModuleNoPolicyMessage(callerClass, module)); + } + // We have a policy for this module + var classEntitlements = createClassEntitlements(entitlements); + moduleEntitlementsMap.put(module, classEntitlements); + return classEntitlements; + } + + private static boolean isServerModule(Module requestingModule) { + return requestingModule.isNamed() && requestingModule.getLayer() == ModuleLayer.boot(); + } + + private ModuleEntitlements createClassEntitlements(List entitlements) { + return new ModuleEntitlements(entitlements); + } + private static Module requestingModule(Class callerClass) { if (callerClass != null) { Module callerModule = callerClass.getModule(); @@ -102,10 +213,10 @@ private static Module requestingModule(Class callerClass) { private static boolean isTriviallyAllowed(Module requestingModule) { if (requestingModule == null) { - logger.debug("Trivially allowed: entire call stack is in composed of classes in system modules"); + logger.debug("Entitlement trivially allowed: entire call stack is in composed of classes in system modules"); return true; } - logger.trace("Not trivially allowed"); + logger.trace("Entitlement not trivially allowed"); return false; } diff --git a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyParser.java b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyParser.java index ea6603af99925..0d1a7c14ece4b 100644 --- a/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyParser.java +++ b/libs/entitlement/src/main/java/org/elasticsearch/entitlement/runtime/policy/PolicyParser.java @@ -19,22 +19,43 @@ import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; - -import static org.elasticsearch.entitlement.runtime.policy.PolicyParserException.newPolicyParserException; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * A parser to parse policy files for entitlements. */ public class PolicyParser { - protected static final String entitlementPackageName = Entitlement.class.getPackage().getName(); + private static final Map> EXTERNAL_ENTITLEMENTS = Stream.of(FileEntitlement.class, CreateClassLoaderEntitlement.class) + .collect(Collectors.toUnmodifiableMap(PolicyParser::getEntitlementTypeName, Function.identity())); protected final XContentParser policyParser; protected final String policyName; + static String getEntitlementTypeName(Class entitlementClass) { + var entitlementClassName = entitlementClass.getSimpleName(); + + if (entitlementClassName.endsWith("Entitlement") == false) { + throw new IllegalArgumentException( + entitlementClassName + " is not a valid Entitlement class name. A valid class name must end with 'Entitlement'" + ); + } + + var strippedClassName = entitlementClassName.substring(0, entitlementClassName.indexOf("Entitlement")); + return Arrays.stream(strippedClassName.split("(?=\\p{Lu})")) + .filter(Predicate.not(String::isEmpty)) + .map(s -> s.toLowerCase(Locale.ROOT)) + .collect(Collectors.joining("_")); + } + public PolicyParser(InputStream inputStream, String policyName) throws IOException { this.policyParser = YamlXContent.yamlXContent.createParser(XContentParserConfiguration.EMPTY, Objects.requireNonNull(inputStream)); this.policyName = policyName; @@ -67,18 +88,23 @@ protected Scope parseScope(String scopeName) throws IOException { } List entitlements = new ArrayList<>(); while (policyParser.nextToken() != XContentParser.Token.END_ARRAY) { - if (policyParser.currentToken() != XContentParser.Token.START_OBJECT) { - throw newPolicyParserException(scopeName, "expected object "); - } - if (policyParser.nextToken() != XContentParser.Token.FIELD_NAME) { + if (policyParser.currentToken() == XContentParser.Token.VALUE_STRING) { + String entitlementType = policyParser.text(); + Entitlement entitlement = parseEntitlement(scopeName, entitlementType); + entitlements.add(entitlement); + } else if (policyParser.currentToken() == XContentParser.Token.START_OBJECT) { + if (policyParser.nextToken() != XContentParser.Token.FIELD_NAME) { + throw newPolicyParserException(scopeName, "expected object "); + } + String entitlementType = policyParser.currentName(); + Entitlement entitlement = parseEntitlement(scopeName, entitlementType); + entitlements.add(entitlement); + if (policyParser.nextToken() != XContentParser.Token.END_OBJECT) { + throw newPolicyParserException(scopeName, "expected closing object"); + } + } else { throw newPolicyParserException(scopeName, "expected object "); } - String entitlementType = policyParser.currentName(); - Entitlement entitlement = parseEntitlement(scopeName, entitlementType); - entitlements.add(entitlement); - if (policyParser.nextToken() != XContentParser.Token.END_OBJECT) { - throw newPolicyParserException(scopeName, "expected closing object"); - } } return new Scope(scopeName, entitlements); } catch (IOException ioe) { @@ -87,34 +113,29 @@ protected Scope parseScope(String scopeName) throws IOException { } protected Entitlement parseEntitlement(String scopeName, String entitlementType) throws IOException { - Class entitlementClass; - try { - entitlementClass = Class.forName( - entitlementPackageName - + "." - + Character.toUpperCase(entitlementType.charAt(0)) - + entitlementType.substring(1) - + "Entitlement" - ); - } catch (ClassNotFoundException cnfe) { - throw newPolicyParserException(scopeName, "unknown entitlement type [" + entitlementType + "]"); - } - if (Entitlement.class.isAssignableFrom(entitlementClass) == false) { + Class entitlementClass = EXTERNAL_ENTITLEMENTS.get(entitlementType); + + if (entitlementClass == null) { throw newPolicyParserException(scopeName, "unknown entitlement type [" + entitlementType + "]"); } + Constructor entitlementConstructor = entitlementClass.getConstructors()[0]; ExternalEntitlement entitlementMetadata = entitlementConstructor.getAnnotation(ExternalEntitlement.class); if (entitlementMetadata == null) { throw newPolicyParserException(scopeName, "unknown entitlement type [" + entitlementType + "]"); } - if (policyParser.nextToken() != XContentParser.Token.START_OBJECT) { - throw newPolicyParserException(scopeName, entitlementType, "expected entitlement parameters"); + Class[] parameterTypes = entitlementConstructor.getParameterTypes(); + String[] parametersNames = entitlementMetadata.parameterNames(); + + if (parameterTypes.length != 0 || parametersNames.length != 0) { + if (policyParser.nextToken() != XContentParser.Token.START_OBJECT) { + throw newPolicyParserException(scopeName, entitlementType, "expected entitlement parameters"); + } } + Map parsedValues = policyParser.map(); - Class[] parameterTypes = entitlementConstructor.getParameterTypes(); - String[] parametersNames = entitlementMetadata.parameterNames(); Object[] parameterValues = new Object[parameterTypes.length]; for (int parameterIndex = 0; parameterIndex < parameterTypes.length; ++parameterIndex) { String parameterName = parametersNames[parameterIndex]; diff --git a/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyManagerTests.java b/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyManagerTests.java new file mode 100644 index 0000000000000..45bdf2e457824 --- /dev/null +++ b/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyManagerTests.java @@ -0,0 +1,247 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.entitlement.runtime.policy; + +import org.elasticsearch.entitlement.runtime.api.NotEntitledException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.compiler.InMemoryJavaCompiler; +import org.elasticsearch.test.jar.JarUtils; + +import java.io.IOException; +import java.lang.module.Configuration; +import java.lang.module.ModuleFinder; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static java.util.Map.entry; +import static org.elasticsearch.entitlement.runtime.policy.PolicyManager.ALL_UNNAMED; +import static org.elasticsearch.test.LambdaMatchers.transformedMatch; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +@ESTestCase.WithoutSecurityManager +public class PolicyManagerTests extends ESTestCase { + + public void testGetEntitlementsThrowsOnMissingPluginUnnamedModule() { + var policyManager = new PolicyManager( + createEmptyTestServerPolicy(), + Map.of("plugin1", createPluginPolicy("plugin.module")), + c -> "plugin1" + ); + + // Any class from the current module (unnamed) will do + var callerClass = this.getClass(); + var requestingModule = callerClass.getModule(); + + var ex = assertThrows( + "No policy for the unnamed module", + NotEntitledException.class, + () -> policyManager.getEntitlementsOrThrow(callerClass, requestingModule) + ); + + assertEquals( + "Missing entitlement policy: caller [class org.elasticsearch.entitlement.runtime.policy.PolicyManagerTests], module [null]", + ex.getMessage() + ); + assertThat(policyManager.moduleEntitlementsMap, hasEntry(requestingModule, PolicyManager.ModuleEntitlements.NONE)); + } + + public void testGetEntitlementsThrowsOnMissingPolicyForPlugin() { + var policyManager = new PolicyManager(createEmptyTestServerPolicy(), Map.of(), c -> "plugin1"); + + // Any class from the current module (unnamed) will do + var callerClass = this.getClass(); + var requestingModule = callerClass.getModule(); + + var ex = assertThrows( + "No policy for this plugin", + NotEntitledException.class, + () -> policyManager.getEntitlementsOrThrow(callerClass, requestingModule) + ); + + assertEquals( + "Missing entitlement policy: caller [class org.elasticsearch.entitlement.runtime.policy.PolicyManagerTests], module [null]", + ex.getMessage() + ); + assertThat(policyManager.moduleEntitlementsMap, hasEntry(requestingModule, PolicyManager.ModuleEntitlements.NONE)); + } + + public void testGetEntitlementsFailureIsCached() { + var policyManager = new PolicyManager(createEmptyTestServerPolicy(), Map.of(), c -> "plugin1"); + + // Any class from the current module (unnamed) will do + var callerClass = this.getClass(); + var requestingModule = callerClass.getModule(); + + assertThrows(NotEntitledException.class, () -> policyManager.getEntitlementsOrThrow(callerClass, requestingModule)); + assertThat(policyManager.moduleEntitlementsMap, hasEntry(requestingModule, PolicyManager.ModuleEntitlements.NONE)); + + // A second time + var ex = assertThrows(NotEntitledException.class, () -> policyManager.getEntitlementsOrThrow(callerClass, requestingModule)); + + assertThat(ex.getMessage(), endsWith("[CACHED]")); + // Nothing new in the map + assertThat(policyManager.moduleEntitlementsMap, aMapWithSize(1)); + } + + public void testGetEntitlementsReturnsEntitlementsForPluginUnnamedModule() { + var policyManager = new PolicyManager( + createEmptyTestServerPolicy(), + Map.ofEntries(entry("plugin2", createPluginPolicy(ALL_UNNAMED))), + c -> "plugin2" + ); + + // Any class from the current module (unnamed) will do + var callerClass = this.getClass(); + var requestingModule = callerClass.getModule(); + + var entitlements = policyManager.getEntitlementsOrThrow(callerClass, requestingModule); + assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true)); + } + + public void testGetEntitlementsThrowsOnMissingPolicyForServer() throws ClassNotFoundException { + var policyManager = new PolicyManager(createTestServerPolicy("example"), Map.of(), c -> null); + + // Tests do not run modular, so we cannot use a server class. + // But we know that in production code the server module and its classes are in the boot layer. + // So we use a random module in the boot layer, and a random class from that module (not java.base -- it is + // loaded too early) to mimic a class that would be in the server module. + var mockServerClass = ModuleLayer.boot().findLoader("jdk.httpserver").loadClass("com.sun.net.httpserver.HttpServer"); + var requestingModule = mockServerClass.getModule(); + + var ex = assertThrows( + "No policy for this module in server", + NotEntitledException.class, + () -> policyManager.getEntitlementsOrThrow(mockServerClass, requestingModule) + ); + + assertEquals( + "Missing entitlement policy: caller [class com.sun.net.httpserver.HttpServer], module [jdk.httpserver]", + ex.getMessage() + ); + assertThat(policyManager.moduleEntitlementsMap, hasEntry(requestingModule, PolicyManager.ModuleEntitlements.NONE)); + } + + public void testGetEntitlementsReturnsEntitlementsForServerModule() throws ClassNotFoundException { + var policyManager = new PolicyManager(createTestServerPolicy("jdk.httpserver"), Map.of(), c -> null); + + // Tests do not run modular, so we cannot use a server class. + // But we know that in production code the server module and its classes are in the boot layer. + // So we use a random module in the boot layer, and a random class from that module (not java.base -- it is + // loaded too early) to mimic a class that would be in the server module. + var mockServerClass = ModuleLayer.boot().findLoader("jdk.httpserver").loadClass("com.sun.net.httpserver.HttpServer"); + var requestingModule = mockServerClass.getModule(); + + var entitlements = policyManager.getEntitlementsOrThrow(mockServerClass, requestingModule); + assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true)); + assertThat(entitlements.hasEntitlement(ExitVMEntitlement.class), is(true)); + } + + public void testGetEntitlementsReturnsEntitlementsForPluginModule() throws IOException, ClassNotFoundException { + final Path home = createTempDir(); + + Path jar = creteMockPluginJar(home); + + var policyManager = new PolicyManager( + createEmptyTestServerPolicy(), + Map.of("mock-plugin", createPluginPolicy("org.example.plugin")), + c -> "mock-plugin" + ); + + var layer = createLayerForJar(jar, "org.example.plugin"); + var mockPluginClass = layer.findLoader("org.example.plugin").loadClass("q.B"); + var requestingModule = mockPluginClass.getModule(); + + var entitlements = policyManager.getEntitlementsOrThrow(mockPluginClass, requestingModule); + assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true)); + assertThat( + entitlements.getEntitlements(FileEntitlement.class).toList(), + contains(transformedMatch(FileEntitlement::toString, containsString("/test/path"))) + ); + } + + public void testGetEntitlementsResultIsCached() { + var policyManager = new PolicyManager( + createEmptyTestServerPolicy(), + Map.ofEntries(entry("plugin2", createPluginPolicy(ALL_UNNAMED))), + c -> "plugin2" + ); + + // Any class from the current module (unnamed) will do + var callerClass = this.getClass(); + var requestingModule = callerClass.getModule(); + + var entitlements = policyManager.getEntitlementsOrThrow(callerClass, requestingModule); + assertThat(entitlements.hasEntitlement(CreateClassLoaderEntitlement.class), is(true)); + assertThat(policyManager.moduleEntitlementsMap, aMapWithSize(1)); + var cachedResult = policyManager.moduleEntitlementsMap.values().stream().findFirst().get(); + var entitlementsAgain = policyManager.getEntitlementsOrThrow(callerClass, requestingModule); + + // Nothing new in the map + assertThat(policyManager.moduleEntitlementsMap, aMapWithSize(1)); + assertThat(entitlementsAgain, sameInstance(cachedResult)); + } + + private static Policy createEmptyTestServerPolicy() { + return new Policy("server", List.of()); + } + + private static Policy createTestServerPolicy(String scopeName) { + return new Policy("server", List.of(new Scope(scopeName, List.of(new ExitVMEntitlement(), new CreateClassLoaderEntitlement())))); + } + + private static Policy createPluginPolicy(String... pluginModules) { + return new Policy( + "plugin", + Arrays.stream(pluginModules) + .map( + name -> new Scope( + name, + List.of(new FileEntitlement("/test/path", List.of(FileEntitlement.READ)), new CreateClassLoaderEntitlement()) + ) + ) + .toList() + ); + } + + private static Path creteMockPluginJar(Path home) throws IOException { + Path jar = home.resolve("mock-plugin.jar"); + + Map sources = Map.ofEntries( + entry("module-info", "module org.example.plugin { exports q; }"), + entry("q.B", "package q; public class B { }") + ); + + var classToBytes = InMemoryJavaCompiler.compile(sources); + JarUtils.createJarWithEntries( + jar, + Map.ofEntries(entry("module-info.class", classToBytes.get("module-info")), entry("q/B.class", classToBytes.get("q.B"))) + ); + return jar; + } + + private static ModuleLayer createLayerForJar(Path jar, String moduleName) { + Configuration cf = ModuleLayer.boot().configuration().resolve(ModuleFinder.of(jar), ModuleFinder.of(), Set.of(moduleName)); + var moduleController = ModuleLayer.defineModulesWithOneLoader( + cf, + List.of(ModuleLayer.boot()), + ClassLoader.getPlatformClassLoader() + ); + return moduleController.layer(); + } +} diff --git a/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyParserFailureTests.java b/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyParserFailureTests.java index de8280ea87fe5..7eb2b1fb476b3 100644 --- a/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyParserFailureTests.java +++ b/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyParserFailureTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.test.ESTestCase; import java.io.ByteArrayInputStream; -import java.io.IOException; import java.nio.charset.StandardCharsets; public class PolicyParserFailureTests extends ESTestCase { @@ -26,7 +25,7 @@ public void testParserSyntaxFailures() { assertEquals("[1:1] policy parsing error for [test-failure-policy.yaml]: expected object ", ppe.getMessage()); } - public void testEntitlementDoesNotExist() throws IOException { + public void testEntitlementDoesNotExist() { PolicyParserException ppe = expectThrows(PolicyParserException.class, () -> new PolicyParser(new ByteArrayInputStream(""" entitlement-module-name: - does_not_exist: {} @@ -38,7 +37,7 @@ public void testEntitlementDoesNotExist() throws IOException { ); } - public void testEntitlementMissingParameter() throws IOException { + public void testEntitlementMissingParameter() { PolicyParserException ppe = expectThrows(PolicyParserException.class, () -> new PolicyParser(new ByteArrayInputStream(""" entitlement-module-name: - file: {} @@ -61,7 +60,7 @@ public void testEntitlementMissingParameter() throws IOException { ); } - public void testEntitlementExtraneousParameter() throws IOException { + public void testEntitlementExtraneousParameter() { PolicyParserException ppe = expectThrows(PolicyParserException.class, () -> new PolicyParser(new ByteArrayInputStream(""" entitlement-module-name: - file: diff --git a/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyParserTests.java b/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyParserTests.java index 40016b2e3027e..a514cfe418895 100644 --- a/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyParserTests.java +++ b/libs/entitlement/src/test/java/org/elasticsearch/entitlement/runtime/policy/PolicyParserTests.java @@ -11,11 +11,31 @@ import org.elasticsearch.test.ESTestCase; +import java.io.ByteArrayInputStream; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.List; +import static org.elasticsearch.test.LambdaMatchers.transformedMatch; +import static org.hamcrest.Matchers.both; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + public class PolicyParserTests extends ESTestCase { + private static class TestWrongEntitlementName implements Entitlement {} + + public void testGetEntitlementTypeName() { + assertEquals("create_class_loader", PolicyParser.getEntitlementTypeName(CreateClassLoaderEntitlement.class)); + + var ex = expectThrows(IllegalArgumentException.class, () -> PolicyParser.getEntitlementTypeName(TestWrongEntitlementName.class)); + assertThat( + ex.getMessage(), + equalTo("TestWrongEntitlementName is not a valid Entitlement class name. A valid class name must end with 'Entitlement'") + ); + } + public void testPolicyBuilder() throws IOException { Policy parsedPolicy = new PolicyParser(PolicyParserTests.class.getResourceAsStream("test-policy.yaml"), "test-policy.yaml") .parsePolicy(); @@ -25,4 +45,23 @@ public void testPolicyBuilder() throws IOException { ); assertEquals(parsedPolicy, builtPolicy); } + + public void testParseCreateClassloader() throws IOException { + Policy parsedPolicy = new PolicyParser(new ByteArrayInputStream(""" + entitlement-module-name: + - create_class_loader + """.getBytes(StandardCharsets.UTF_8)), "test-policy.yaml").parsePolicy(); + Policy builtPolicy = new Policy( + "test-policy.yaml", + List.of(new Scope("entitlement-module-name", List.of(new CreateClassLoaderEntitlement()))) + ); + assertThat( + parsedPolicy.scopes, + contains( + both(transformedMatch((Scope scope) -> scope.name, equalTo("entitlement-module-name"))).and( + transformedMatch(scope -> scope.entitlements, contains(instanceOf(CreateClassLoaderEntitlement.class))) + ) + ) + ); + } } diff --git a/libs/entitlement/tools/securitymanager-scanner/build.gradle b/libs/entitlement/tools/securitymanager-scanner/build.gradle index 8d035c9e847c6..ebb671e5487ef 100644 --- a/libs/entitlement/tools/securitymanager-scanner/build.gradle +++ b/libs/entitlement/tools/securitymanager-scanner/build.gradle @@ -47,8 +47,8 @@ repositories { dependencies { compileOnly(project(':libs:core')) - implementation 'org.ow2.asm:asm:9.7' - implementation 'org.ow2.asm:asm-util:9.7' + implementation 'org.ow2.asm:asm:9.7.1' + implementation 'org.ow2.asm:asm-util:9.7.1' implementation(project(':libs:entitlement:tools:common')) } diff --git a/libs/plugin-scanner/build.gradle b/libs/plugin-scanner/build.gradle index d04af0624b3b1..44e6853140a5b 100644 --- a/libs/plugin-scanner/build.gradle +++ b/libs/plugin-scanner/build.gradle @@ -20,8 +20,8 @@ dependencies { api project(':libs:plugin-api') api project(":libs:x-content") - api 'org.ow2.asm:asm:9.7' - api 'org.ow2.asm:asm-tree:9.7' + api 'org.ow2.asm:asm:9.7.1' + api 'org.ow2.asm:asm-tree:9.7.1' testImplementation "junit:junit:${versions.junit}" testImplementation(project(":test:framework")) { diff --git a/libs/x-content/impl/src/main/java/org/elasticsearch/xcontent/provider/json/JsonXContentParser.java b/libs/x-content/impl/src/main/java/org/elasticsearch/xcontent/provider/json/JsonXContentParser.java index d42c56845d03f..38ef8bc2e4ef0 100644 --- a/libs/x-content/impl/src/main/java/org/elasticsearch/xcontent/provider/json/JsonXContentParser.java +++ b/libs/x-content/impl/src/main/java/org/elasticsearch/xcontent/provider/json/JsonXContentParser.java @@ -108,7 +108,11 @@ public String text() throws IOException { if (currentToken().isValue() == false) { throwOnNoText(); } - return parser.getText(); + try { + return parser.getText(); + } catch (JsonParseException e) { + throw newXContentParseException(e); + } } private void throwOnNoText() { diff --git a/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/SynonymTokenFilterFactory.java b/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/SynonymTokenFilterFactory.java index 9e31fdde4330b..9dc3478994f1f 100644 --- a/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/SynonymTokenFilterFactory.java +++ b/modules/analysis-common/src/main/java/org/elasticsearch/analysis/common/SynonymTokenFilterFactory.java @@ -13,7 +13,6 @@ import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.synonym.SynonymFilter; import org.apache.lucene.analysis.synonym.SynonymMap; -import org.elasticsearch.common.logging.DeprecationCategory; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.env.Environment; @@ -152,15 +151,6 @@ public static SynonymsSource fromSettings(Settings settings) { super(name, settings); this.settings = settings; - if (settings.get("ignore_case") != null) { - DEPRECATION_LOGGER.warn( - DeprecationCategory.ANALYSIS, - "synonym_ignore_case_option", - "The ignore_case option on the synonym_graph filter is deprecated. " - + "Instead, insert a lowercase filter in the filter chain before the synonym_graph filter." - ); - } - this.synonymsSource = SynonymsSource.fromSettings(settings); this.expand = settings.getAsBoolean("expand", true); this.format = settings.get("format", ""); diff --git a/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/ResolveClusterDataStreamIT.java b/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/ResolveClusterDataStreamIT.java index 4c85958498da0..aa6ecf35e06fa 100644 --- a/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/ResolveClusterDataStreamIT.java +++ b/modules/data-streams/src/internalClusterTest/java/org/elasticsearch/datastreams/ResolveClusterDataStreamIT.java @@ -78,7 +78,7 @@ public class ResolveClusterDataStreamIT extends AbstractMultiClustersTestCase { private static long LATEST_TIMESTAMP = 1691348820000L; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER_1, REMOTE_CLUSTER_2); } diff --git a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/DataStreamsStatsTransportAction.java b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/DataStreamsStatsTransportAction.java index 1b0b0aa6abebe..1d3b1b676282a 100644 --- a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/DataStreamsStatsTransportAction.java +++ b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/DataStreamsStatsTransportAction.java @@ -31,6 +31,7 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexService; import org.elasticsearch.index.engine.Engine; +import org.elasticsearch.index.engine.ReadOnlyEngine; import org.elasticsearch.index.shard.IndexShard; import org.elasticsearch.index.store.StoreStats; import org.elasticsearch.indices.IndicesService; @@ -130,7 +131,7 @@ protected void shardOperation( DataStream dataStream = indexAbstraction.getParentDataStream(); assert dataStream != null; long maxTimestamp = 0L; - try (Engine.Searcher searcher = indexShard.acquireSearcher("data_stream_stats")) { + try (Engine.Searcher searcher = indexShard.acquireSearcher(ReadOnlyEngine.FIELD_RANGE_SEARCH_SOURCE)) { IndexReader indexReader = searcher.getIndexReader(); byte[] maxPackedValue = PointValues.getMaxPackedValue(indexReader, DataStream.TIMESTAMP_FIELD_NAME); if (maxPackedValue != null) { diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpTaskState.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpTaskState.java index c4d0aef0183ed..c128af69009be 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpTaskState.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/EnterpriseGeoIpTaskState.java @@ -123,7 +123,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ENTERPRISE_GEOIP_DOWNLOADER; + return TransportVersions.V_8_16_0; } @Override diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java index 47ca79e3cb3b9..96525d427d3e8 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java @@ -44,7 +44,7 @@ public class GeoIpTaskState implements PersistentTaskState, VersionedNamedWriteable { private static boolean includeSha256(TransportVersion version) { - return version.isPatchFrom(TransportVersions.V_8_15_0) || version.onOrAfter(TransportVersions.ENTERPRISE_GEOIP_DOWNLOADER); + return version.onOrAfter(TransportVersions.V_8_15_0); } private static final ParseField DATABASES = new ParseField("databases"); diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpMetadata.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpMetadata.java index b6e73f3f33f7c..a50fe7dee9008 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpMetadata.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/IngestGeoIpMetadata.java @@ -69,7 +69,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ENTERPRISE_GEOIP_DOWNLOADER; + return TransportVersions.V_8_16_0; } public Map getDatabases() { @@ -138,7 +138,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ENTERPRISE_GEOIP_DOWNLOADER; + return TransportVersions.V_8_16_0; } } diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/DatabaseConfiguration.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/DatabaseConfiguration.java index a26364f9305e1..aa48c73cf1d73 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/DatabaseConfiguration.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/direct/DatabaseConfiguration.java @@ -138,7 +138,7 @@ public DatabaseConfiguration(StreamInput in) throws IOException { } private static Provider readProvider(StreamInput in) throws IOException { - if (in.getTransportVersion().onOrAfter(TransportVersions.INGEST_GEO_DATABASE_PROVIDERS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { return in.readNamedWriteable(Provider.class); } else { // prior to the above version, everything was always a maxmind, so this half of the if is logical @@ -154,7 +154,7 @@ public static DatabaseConfiguration parse(XContentParser parser, String id) { public void writeTo(StreamOutput out) throws IOException { out.writeString(id); out.writeString(name); - if (out.getTransportVersion().onOrAfter(TransportVersions.INGEST_GEO_DATABASE_PROVIDERS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeNamedWriteable(provider); } else { if (provider instanceof Maxmind maxmind) { diff --git a/modules/lang-painless/src/internalClusterTest/java/org/elasticsearch/painless/action/CrossClusterPainlessExecuteIT.java b/modules/lang-painless/src/internalClusterTest/java/org/elasticsearch/painless/action/CrossClusterPainlessExecuteIT.java index 4669ab25f5d8c..b21cabad9290c 100644 --- a/modules/lang-painless/src/internalClusterTest/java/org/elasticsearch/painless/action/CrossClusterPainlessExecuteIT.java +++ b/modules/lang-painless/src/internalClusterTest/java/org/elasticsearch/painless/action/CrossClusterPainlessExecuteIT.java @@ -54,7 +54,7 @@ public class CrossClusterPainlessExecuteIT extends AbstractMultiClustersTestCase private static final String KEYWORD_FIELD = "my_field"; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); } diff --git a/modules/reindex/src/internalClusterTest/java/org/elasticsearch/index/reindex/CrossClusterReindexIT.java b/modules/reindex/src/internalClusterTest/java/org/elasticsearch/index/reindex/CrossClusterReindexIT.java index 8b94337141243..4624393e9fb60 100644 --- a/modules/reindex/src/internalClusterTest/java/org/elasticsearch/index/reindex/CrossClusterReindexIT.java +++ b/modules/reindex/src/internalClusterTest/java/org/elasticsearch/index/reindex/CrossClusterReindexIT.java @@ -36,7 +36,7 @@ protected boolean reuseClusters() { } @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); } diff --git a/modules/repository-s3/build.gradle b/modules/repository-s3/build.gradle index 2cfb5d23db4ff..f0dc1ca714958 100644 --- a/modules/repository-s3/build.gradle +++ b/modules/repository-s3/build.gradle @@ -18,15 +18,11 @@ esplugin { classname 'org.elasticsearch.repositories.s3.S3RepositoryPlugin' } -versions << [ - 'aws': '1.12.270' -] - dependencies { - api "com.amazonaws:aws-java-sdk-s3:${versions.aws}" - api "com.amazonaws:aws-java-sdk-core:${versions.aws}" - api "com.amazonaws:aws-java-sdk-sts:${versions.aws}" - api "com.amazonaws:jmespath-java:${versions.aws}" + api "com.amazonaws:aws-java-sdk-s3:${versions.awsv1sdk}" + api "com.amazonaws:aws-java-sdk-core:${versions.awsv1sdk}" + api "com.amazonaws:aws-java-sdk-sts:${versions.awsv1sdk}" + api "com.amazonaws:jmespath-java:${versions.awsv1sdk}" api "org.apache.httpcomponents:httpclient:${versions.httpclient}" api "org.apache.httpcomponents:httpcore:${versions.httpcore}" api "commons-logging:commons-logging:${versions.commonslogging}" diff --git a/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/AbstractRepositoryS3RestTestCase.java b/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/AbstractRepositoryS3RestTestCase.java index 2199a64521759..67ada622efeea 100644 --- a/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/AbstractRepositoryS3RestTestCase.java +++ b/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/AbstractRepositoryS3RestTestCase.java @@ -19,6 +19,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.test.rest.ObjectPath; import java.io.Closeable; import java.io.IOException; @@ -27,7 +28,6 @@ import java.util.function.UnaryOperator; import java.util.stream.Collectors; -import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -152,10 +152,9 @@ private void testNonexistentBucket(Boolean readonly) throws Exception { final var responseException = expectThrows(ResponseException.class, () -> client().performRequest(registerRequest)); assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), responseException.getResponse().getStatusLine().getStatusCode()); - assertThat( - responseException.getMessage(), - allOf(containsString("repository_verification_exception"), containsString("is not accessible on master node")) - ); + final var responseObjectPath = ObjectPath.createFromResponse(responseException.getResponse()); + assertThat(responseObjectPath.evaluate("error.type"), equalTo("repository_verification_exception")); + assertThat(responseObjectPath.evaluate("error.reason"), containsString("is not accessible on master node")); } public void testNonexistentClient() throws Exception { @@ -181,15 +180,11 @@ private void testNonexistentClient(Boolean readonly) throws Exception { final var responseException = expectThrows(ResponseException.class, () -> client().performRequest(registerRequest)); assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), responseException.getResponse().getStatusLine().getStatusCode()); - assertThat( - responseException.getMessage(), - allOf( - containsString("repository_verification_exception"), - containsString("is not accessible on master node"), - containsString("illegal_argument_exception"), - containsString("Unknown s3 client name") - ) - ); + final var responseObjectPath = ObjectPath.createFromResponse(responseException.getResponse()); + assertThat(responseObjectPath.evaluate("error.type"), equalTo("repository_verification_exception")); + assertThat(responseObjectPath.evaluate("error.reason"), containsString("is not accessible on master node")); + assertThat(responseObjectPath.evaluate("error.caused_by.type"), equalTo("illegal_argument_exception")); + assertThat(responseObjectPath.evaluate("error.caused_by.reason"), containsString("Unknown s3 client name")); } public void testNonexistentSnapshot() throws Exception { @@ -212,7 +207,8 @@ private void testNonexistentSnapshot(Boolean readonly) throws Exception { final var getSnapshotRequest = new Request("GET", "/_snapshot/" + repositoryName + "/" + randomIdentifier()); final var getSnapshotException = expectThrows(ResponseException.class, () -> client().performRequest(getSnapshotRequest)); assertEquals(RestStatus.NOT_FOUND.getStatus(), getSnapshotException.getResponse().getStatusLine().getStatusCode()); - assertThat(getSnapshotException.getMessage(), containsString("snapshot_missing_exception")); + final var getResponseObjectPath = ObjectPath.createFromResponse(getSnapshotException.getResponse()); + assertThat(getResponseObjectPath.evaluate("error.type"), equalTo("snapshot_missing_exception")); final var restoreRequest = new Request("POST", "/_snapshot/" + repositoryName + "/" + randomIdentifier() + "/_restore"); if (randomBoolean()) { @@ -220,13 +216,15 @@ private void testNonexistentSnapshot(Boolean readonly) throws Exception { } final var restoreException = expectThrows(ResponseException.class, () -> client().performRequest(restoreRequest)); assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), restoreException.getResponse().getStatusLine().getStatusCode()); - assertThat(restoreException.getMessage(), containsString("snapshot_restore_exception")); + final var restoreResponseObjectPath = ObjectPath.createFromResponse(restoreException.getResponse()); + assertThat(restoreResponseObjectPath.evaluate("error.type"), equalTo("snapshot_restore_exception")); if (readonly != Boolean.TRUE) { final var deleteRequest = new Request("DELETE", "/_snapshot/" + repositoryName + "/" + randomIdentifier()); final var deleteException = expectThrows(ResponseException.class, () -> client().performRequest(deleteRequest)); assertEquals(RestStatus.NOT_FOUND.getStatus(), deleteException.getResponse().getStatusLine().getStatusCode()); - assertThat(deleteException.getMessage(), containsString("snapshot_missing_exception")); + final var deleteResponseObjectPath = ObjectPath.createFromResponse(deleteException.getResponse()); + assertThat(deleteResponseObjectPath.evaluate("error.type"), equalTo("snapshot_missing_exception")); } } } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java index 3fd5cc44a3403..1d39b993cef92 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java @@ -40,6 +40,7 @@ import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpVersion; +import org.apache.http.ConnectionClosedException; import org.apache.http.HttpHost; import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchException; @@ -48,6 +49,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.bulk.IncrementalBulkService; import org.elasticsearch.action.support.ActionTestUtils; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.client.Request; import org.elasticsearch.client.RestClient; @@ -100,6 +102,7 @@ import java.util.Collections; import java.util.List; import java.util.Set; +import java.util.concurrent.CancellationException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -110,6 +113,7 @@ import static com.carrotsearch.randomizedtesting.RandomizedTest.getRandom; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_SERVER_SHUTDOWN_GRACE_PERIOD; import static org.elasticsearch.rest.RestStatus.BAD_REQUEST; import static org.elasticsearch.rest.RestStatus.OK; import static org.elasticsearch.rest.RestStatus.UNAUTHORIZED; @@ -1039,8 +1043,16 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th } } - public void testRespondAfterClose() throws Exception { - final String url = "/thing"; + public void testRespondAfterServiceCloseWithClientCancel() throws Exception { + runRespondAfterServiceCloseTest(true); + } + + public void testRespondAfterServiceCloseWithServerCancel() throws Exception { + runRespondAfterServiceCloseTest(false); + } + + private void runRespondAfterServiceCloseTest(boolean clientCancel) throws Exception { + final String url = "/" + randomIdentifier(); final CountDownLatch responseReleasedLatch = new CountDownLatch(1); final SubscribableListener transportClosedFuture = new SubscribableListener<>(); final CountDownLatch handlingRequestLatch = new CountDownLatch(1); @@ -1066,7 +1078,9 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th try ( Netty4HttpServerTransport transport = new Netty4HttpServerTransport( - Settings.EMPTY, + clientCancel + ? Settings.EMPTY + : Settings.builder().put(SETTING_HTTP_SERVER_SHUTDOWN_GRACE_PERIOD.getKey(), TimeValue.timeValueMillis(1)).build(), networkService, threadPool, xContentRegistry(), @@ -1082,11 +1096,24 @@ public void dispatchBadRequest(final RestChannel channel, final ThreadContext th transport.start(); final var address = randomFrom(transport.boundAddress().boundAddresses()).address(); try (var client = RestClient.builder(new HttpHost(address.getAddress(), address.getPort())).build()) { - client.performRequestAsync(new Request("GET", url), ActionTestUtils.wrapAsRestResponseListener(ActionListener.noop())); + final var responseExceptionFuture = new PlainActionFuture(); + final var cancellable = client.performRequestAsync( + new Request("GET", url), + ActionTestUtils.wrapAsRestResponseListener(ActionTestUtils.assertNoSuccessListener(responseExceptionFuture::onResponse)) + ); safeAwait(handlingRequestLatch); + if (clientCancel) { + threadPool.generic().execute(cancellable::cancel); + } transport.close(); transportClosedFuture.onResponse(null); safeAwait(responseReleasedLatch); + final var responseException = safeGet(responseExceptionFuture); + if (clientCancel) { + assertThat(responseException, instanceOf(CancellationException.class)); + } else { + assertThat(responseException, instanceOf(ConnectionClosedException.class)); + } } } } diff --git a/muted-tests.yml b/muted-tests.yml index 17a7b26d1c091..dcaa415a67966 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -2,12 +2,6 @@ tests: - class: "org.elasticsearch.client.RestClientSingleHostIntegTests" issue: "https://github.com/elastic/elasticsearch/issues/102717" method: "testRequestResetAndAbort" -- class: org.elasticsearch.xpack.restart.FullClusterRestartIT - method: testSingleDoc {cluster=UPGRADED} - issue: https://github.com/elastic/elasticsearch/issues/111434 -- class: org.elasticsearch.xpack.restart.FullClusterRestartIT - method: testDataStreams {cluster=UPGRADED} - issue: https://github.com/elastic/elasticsearch/issues/111448 - class: org.elasticsearch.smoketest.WatcherYamlRestIT method: test {p0=watcher/usage/10_basic/Test watcher usage stats output} issue: https://github.com/elastic/elasticsearch/issues/112189 @@ -103,9 +97,6 @@ tests: - class: org.elasticsearch.search.StressSearchServiceReaperIT method: testStressReaper issue: https://github.com/elastic/elasticsearch/issues/115816 -- class: org.elasticsearch.search.SearchServiceTests - method: testParseSourceValidation - issue: https://github.com/elastic/elasticsearch/issues/115936 - class: org.elasticsearch.xpack.application.connector.ConnectorIndexServiceTests issue: https://github.com/elastic/elasticsearch/issues/116087 - class: org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT @@ -120,18 +111,12 @@ tests: - class: org.elasticsearch.action.search.SearchPhaseControllerTests method: testProgressListener issue: https://github.com/elastic/elasticsearch/issues/116149 -- class: org.elasticsearch.xpack.test.rest.XPackRestIT - method: test {p0=terms_enum/10_basic/Test security} - issue: https://github.com/elastic/elasticsearch/issues/116178 - class: org.elasticsearch.search.basic.SearchWithRandomDisconnectsIT method: testSearchWithRandomDisconnects issue: https://github.com/elastic/elasticsearch/issues/116175 - class: org.elasticsearch.xpack.deprecation.DeprecationHttpIT method: testDeprecatedSettingsReturnWarnings issue: https://github.com/elastic/elasticsearch/issues/108628 -- class: org.elasticsearch.action.search.SearchQueryThenFetchAsyncActionTests - method: testBottomFieldSort - issue: https://github.com/elastic/elasticsearch/issues/116249 - class: org.elasticsearch.xpack.shutdown.NodeShutdownIT method: testAllocationPreventedForRemoval issue: https://github.com/elastic/elasticsearch/issues/116363 @@ -141,9 +126,6 @@ tests: - class: org.elasticsearch.reservedstate.service.FileSettingsServiceTests method: testInvalidJSON issue: https://github.com/elastic/elasticsearch/issues/116521 -- class: org.elasticsearch.xpack.searchablesnapshots.SearchableSnapshotsCanMatchOnCoordinatorIntegTests - method: testSearchableSnapshotShardsAreSkippedBySearchRequestWithoutQueryingAnyNodeWhenTheyAreOutsideOfTheQueryRange - issue: https://github.com/elastic/elasticsearch/issues/116523 - class: org.elasticsearch.reservedstate.service.RepositoriesFileSettingsIT method: testSettingsApplied issue: https://github.com/elastic/elasticsearch/issues/116694 @@ -175,9 +157,6 @@ tests: - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=snapshot/10_basic/Create a source only snapshot and then restore it} issue: https://github.com/elastic/elasticsearch/issues/117295 -- class: org.elasticsearch.xpack.searchablesnapshots.RetrySearchIntegTests - method: testRetryPointInTime - issue: https://github.com/elastic/elasticsearch/issues/117116 - class: org.elasticsearch.xpack.inference.DefaultEndPointsIT method: testInferDeploysDefaultElser issue: https://github.com/elastic/elasticsearch/issues/114913 @@ -221,11 +200,6 @@ tests: issue: https://github.com/elastic/elasticsearch/issues/117815 - class: org.elasticsearch.xpack.ml.integration.DatafeedJobsRestIT issue: https://github.com/elastic/elasticsearch/issues/111319 -- class: org.elasticsearch.validation.DotPrefixClientYamlTestSuiteIT - issue: https://github.com/elastic/elasticsearch/issues/117893 -- class: org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilderTests - method: testToQuery - issue: https://github.com/elastic/elasticsearch/issues/117904 - class: org.elasticsearch.packaging.test.ArchiveGenerateInitialCredentialsTests method: test20NoAutoGenerationWhenAutoConfigurationDisabled issue: https://github.com/elastic/elasticsearch/issues/117891 @@ -235,15 +209,91 @@ tests: - class: org.elasticsearch.xpack.esql.plugin.ClusterRequestTests method: testFallbackIndicesOptions issue: https://github.com/elastic/elasticsearch/issues/117937 -- class: org.elasticsearch.xpack.esql.qa.single_node.RequestIndexFilteringIT - method: testFieldExistsFilter_KeepWildcard - issue: https://github.com/elastic/elasticsearch/issues/117935 -- class: org.elasticsearch.xpack.esql.qa.multi_node.RequestIndexFilteringIT - method: testFieldExistsFilter_KeepWildcard - issue: https://github.com/elastic/elasticsearch/issues/117935 - class: org.elasticsearch.xpack.ml.integration.RegressionIT method: testTwoJobsWithSameRandomizeSeedUseSameTrainingSet issue: https://github.com/elastic/elasticsearch/issues/117805 +- class: org.elasticsearch.packaging.test.ArchiveGenerateInitialCredentialsTests + method: test30NoAutogenerationWhenDaemonized + issue: https://github.com/elastic/elasticsearch/issues/117956 +- class: org.elasticsearch.packaging.test.CertGenCliTests + method: test40RunWithCert + issue: https://github.com/elastic/elasticsearch/issues/117955 +- class: org.elasticsearch.upgrades.QueryBuilderBWCIT + method: testQueryBuilderBWC {cluster=UPGRADED} + issue: https://github.com/elastic/elasticsearch/issues/116990 +- class: org.elasticsearch.xpack.restart.QueryBuilderBWCIT + method: testQueryBuilderBWC {p0=UPGRADED} + issue: https://github.com/elastic/elasticsearch/issues/116989 +- class: org.elasticsearch.index.reindex.ReindexNodeShutdownIT + method: testReindexWithShutdown + issue: https://github.com/elastic/elasticsearch/issues/118040 +- class: org.elasticsearch.packaging.test.ConfigurationTests + method: test20HostnameSubstitution + issue: https://github.com/elastic/elasticsearch/issues/118028 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test40AutoconfigurationNotTriggeredWhenNodeIsMeantToJoinExistingCluster + issue: https://github.com/elastic/elasticsearch/issues/118029 +- class: org.elasticsearch.packaging.test.ConfigurationTests + method: test30SymlinkedDataPath + issue: https://github.com/elastic/elasticsearch/issues/118111 +- class: org.elasticsearch.packaging.test.KeystoreManagementTests + method: test30KeystorePasswordFromFile + issue: https://github.com/elastic/elasticsearch/issues/118123 +- class: org.elasticsearch.packaging.test.KeystoreManagementTests + method: test31WrongKeystorePasswordFromFile + issue: https://github.com/elastic/elasticsearch/issues/118123 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test41AutoconfigurationNotTriggeredWhenNodeCannotContainData + issue: https://github.com/elastic/elasticsearch/issues/118110 +- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT + method: test {lookup-join.LookupMessageFromIndexKeepReordered SYNC} + issue: https://github.com/elastic/elasticsearch/issues/118150 +- class: org.elasticsearch.xpack.esql.qa.multi_node.EsqlSpecIT + method: test {lookup-join.LookupMessageFromIndexKeepReordered ASYNC} + issue: https://github.com/elastic/elasticsearch/issues/118151 +- class: org.elasticsearch.xpack.remotecluster.CrossClusterEsqlRCS2UnavailableRemotesIT + method: testEsqlRcs2UnavailableRemoteScenarios + issue: https://github.com/elastic/elasticsearch/issues/117419 +- class: org.elasticsearch.packaging.test.DebPreservationTests + method: test40RestartOnUpgrade + issue: https://github.com/elastic/elasticsearch/issues/118170 +- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT + method: testInferDeploysDefaultRerank + issue: https://github.com/elastic/elasticsearch/issues/118184 +- class: org.elasticsearch.xpack.esql.action.EsqlActionTaskIT + method: testCancelRequestWhenFailingFetchingPages + issue: https://github.com/elastic/elasticsearch/issues/118193 +- class: org.elasticsearch.packaging.test.MemoryLockingTests + method: test20MemoryLockingEnabled + issue: https://github.com/elastic/elasticsearch/issues/118195 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test42AutoconfigurationNotTriggeredWhenNodeCannotBecomeMaster + issue: https://github.com/elastic/elasticsearch/issues/118196 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test43AutoconfigurationNotTriggeredWhenTlsAlreadyConfigured + issue: https://github.com/elastic/elasticsearch/issues/118202 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test44AutoConfigurationNotTriggeredOnNotWriteableConfDir + issue: https://github.com/elastic/elasticsearch/issues/118208 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test51AutoConfigurationWithPasswordProtectedKeystore + issue: https://github.com/elastic/elasticsearch/issues/118212 +- class: org.elasticsearch.xpack.inference.InferenceCrudIT + method: testUnifiedCompletionInference + issue: https://github.com/elastic/elasticsearch/issues/118210 +- class: org.elasticsearch.ingest.common.IngestCommonClientYamlTestSuiteIT + issue: https://github.com/elastic/elasticsearch/issues/118215 +- class: org.elasticsearch.datastreams.DataStreamsClientYamlTestSuiteIT + method: test {p0=data_stream/120_data_streams_stats/Multiple data stream} + issue: https://github.com/elastic/elasticsearch/issues/118217 +- class: org.elasticsearch.xpack.security.operator.OperatorPrivilegesIT + method: testEveryActionIsEitherOperatorOnlyOrNonOperator + issue: https://github.com/elastic/elasticsearch/issues/118220 +- class: org.elasticsearch.validation.DotPrefixClientYamlTestSuiteIT + issue: https://github.com/elastic/elasticsearch/issues/118224 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test60StartAndStop + issue: https://github.com/elastic/elasticsearch/issues/118216 # Examples: # diff --git a/plugins/discovery-ec2/build.gradle b/plugins/discovery-ec2/build.gradle index 980e2467206d7..a4321a2d61f98 100644 --- a/plugins/discovery-ec2/build.gradle +++ b/plugins/discovery-ec2/build.gradle @@ -14,13 +14,9 @@ esplugin { classname 'org.elasticsearch.discovery.ec2.Ec2DiscoveryPlugin' } -versions << [ - 'aws': '1.12.270' -] - dependencies { - api "com.amazonaws:aws-java-sdk-ec2:${versions.aws}" - api "com.amazonaws:aws-java-sdk-core:${versions.aws}" + api "com.amazonaws:aws-java-sdk-ec2:${versions.awsv1sdk}" + api "com.amazonaws:aws-java-sdk-core:${versions.awsv1sdk}" api "org.apache.httpcomponents:httpclient:${versions.httpclient}" api "org.apache.httpcomponents:httpcore:${versions.httpcore}" api "commons-logging:commons-logging:${versions.commonslogging}" diff --git a/qa/logging-config/src/test/java/org/elasticsearch/common/logging/JsonLoggerTests.java b/qa/logging-config/src/test/java/org/elasticsearch/common/logging/JsonLoggerTests.java index 1066bf1360e41..ed6205c7a5208 100644 --- a/qa/logging-config/src/test/java/org/elasticsearch/common/logging/JsonLoggerTests.java +++ b/qa/logging-config/src/test/java/org/elasticsearch/common/logging/JsonLoggerTests.java @@ -125,14 +125,14 @@ public void testDeprecatedMessageWithoutXOpaqueId() throws IOException { jsonLogs, contains( allOf( - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "CRITICAL"), hasEntry("log.logger", "org.elasticsearch.deprecation.test"), hasEntry("elasticsearch.cluster.name", "elasticsearch"), hasEntry("elasticsearch.node.name", "sample-name"), hasEntry("message", "deprecated message1"), hasEntry("data_stream.type", "logs"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasKey("ecs.version"), hasEntry(DeprecatedMessage.KEY_FIELD_NAME, "a key"), @@ -168,8 +168,8 @@ public void testCompatibleLog() throws Exception { contains( allOf( hasEntry("log.level", "CRITICAL"), - hasEntry("event.dataset", "deprecation.elasticsearch"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasEntry("log.logger", "org.elasticsearch.deprecation.test"), @@ -186,8 +186,8 @@ public void testCompatibleLog() throws Exception { allOf( hasEntry("log.level", "CRITICAL"), // event.dataset and data_stream.dataset have to be the same across the data stream - hasEntry("event.dataset", "deprecation.elasticsearch"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasEntry("log.logger", "org.elasticsearch.deprecation.test"), @@ -240,8 +240,8 @@ public void testParseFieldEmittingDeprecatedLogs() throws Exception { // deprecation log for field deprecated_name allOf( hasEntry("log.level", "WARN"), - hasEntry("event.dataset", "deprecation.elasticsearch"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasEntry("log.logger", "org.elasticsearch.deprecation.xcontent.ParseField"), @@ -258,8 +258,8 @@ public void testParseFieldEmittingDeprecatedLogs() throws Exception { // deprecation log for field deprecated_name2 (note it is not being throttled) allOf( hasEntry("log.level", "WARN"), - hasEntry("event.dataset", "deprecation.elasticsearch"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasEntry("log.logger", "org.elasticsearch.deprecation.xcontent.ParseField"), @@ -276,8 +276,8 @@ public void testParseFieldEmittingDeprecatedLogs() throws Exception { // compatible log line allOf( hasEntry("log.level", "CRITICAL"), - hasEntry("event.dataset", "deprecation.elasticsearch"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasEntry("log.logger", "org.elasticsearch.deprecation.xcontent.ParseField"), @@ -327,14 +327,14 @@ public void testDeprecatedMessage() throws Exception { jsonLogs, contains( allOf( - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "WARN"), hasEntry("log.logger", "org.elasticsearch.deprecation.test"), hasEntry("elasticsearch.cluster.name", "elasticsearch"), hasEntry("elasticsearch.node.name", "sample-name"), hasEntry("message", "deprecated message1"), hasEntry("data_stream.type", "logs"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasKey("ecs.version"), hasEntry(DeprecatedMessage.KEY_FIELD_NAME, "someKey"), @@ -579,7 +579,7 @@ public void testDuplicateLogMessages() throws Exception { jsonLogs, contains( allOf( - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "CRITICAL"), hasEntry("log.logger", "org.elasticsearch.deprecation.test"), hasEntry("elasticsearch.cluster.name", "elasticsearch"), @@ -612,7 +612,7 @@ public void testDuplicateLogMessages() throws Exception { jsonLogs, contains( allOf( - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "CRITICAL"), hasEntry("log.logger", "org.elasticsearch.deprecation.test"), hasEntry("elasticsearch.cluster.name", "elasticsearch"), @@ -622,7 +622,7 @@ public void testDuplicateLogMessages() throws Exception { hasEntry("elasticsearch.event.category", "other") ), allOf( - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "CRITICAL"), hasEntry("log.logger", "org.elasticsearch.deprecation.test"), hasEntry("elasticsearch.cluster.name", "elasticsearch"), diff --git a/qa/logging-config/src/test/resources/org/elasticsearch/common/logging/json_layout/log4j2.properties b/qa/logging-config/src/test/resources/org/elasticsearch/common/logging/json_layout/log4j2.properties index 46baac4f1433c..b00caca66d03c 100644 --- a/qa/logging-config/src/test/resources/org/elasticsearch/common/logging/json_layout/log4j2.properties +++ b/qa/logging-config/src/test/resources/org/elasticsearch/common/logging/json_layout/log4j2.properties @@ -15,14 +15,13 @@ appender.deprecated.name = deprecated appender.deprecated.fileName = ${sys:es.logs.base_path}${sys:file.separator}${sys:es.logs.cluster_name}_deprecated.json # Intentionally follows a different pattern to above appender.deprecated.layout.type = ECSJsonLayout -appender.deprecated.layout.dataset = deprecation.elasticsearch +appender.deprecated.layout.dataset = elasticsearch.deprecation appender.deprecated.filter.rate_limit.type = RateLimitingFilter appender.deprecatedconsole.type = Console appender.deprecatedconsole.name = deprecatedconsole appender.deprecatedconsole.layout.type = ECSJsonLayout -# Intentionally follows a different pattern to above -appender.deprecatedconsole.layout.dataset = deprecation.elasticsearch +appender.deprecatedconsole.layout.dataset = elasticsearch.deprecation appender.deprecatedconsole.filter.rate_limit.type = RateLimitingFilter diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/BulkRestIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/BulkRestIT.java index 369d0824bdb28..3faa88339f0a3 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/BulkRestIT.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/BulkRestIT.java @@ -74,8 +74,7 @@ public void testBulkInvalidIndexNameString() throws IOException { ResponseException responseException = expectThrows(ResponseException.class, () -> getRestClient().performRequest(request)); assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(BAD_REQUEST.getStatus())); - assertThat(responseException.getMessage(), containsString("could not parse bulk request body")); - assertThat(responseException.getMessage(), containsString("json_parse_exception")); + assertThat(responseException.getMessage(), containsString("x_content_parse_exception")); assertThat(responseException.getMessage(), containsString("Invalid UTF-8")); } diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/migrate.reindex.json b/rest-api-spec/src/main/resources/rest-api-spec/api/migrate.reindex.json new file mode 100644 index 0000000000000..149a90bc198b0 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/migrate.reindex.json @@ -0,0 +1,29 @@ +{ + "migrate.reindex":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/data-stream-reindex.html", + "description":"This API reindexes all legacy backing indices for a data stream. It does this in a persistent task. The persistent task id is returned immediately, and the reindexing work is completed in that task" + }, + "stability":"experimental", + "visibility":"private", + "headers":{ + "accept": [ "application/json"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_migration/reindex", + "methods":[ + "POST" + ] + } + ] + }, + "body":{ + "description":"The body contains the fields `mode` and `source.index, where the only mode currently supported is `upgrade`, and the `source.index` must be a data stream name", + "required":true + } + } +} + diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/90_sparse_vector.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/90_sparse_vector.yml index 2505e6d7e353b..0b65a69bf500e 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/90_sparse_vector.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/90_sparse_vector.yml @@ -472,3 +472,120 @@ - match: _source.ml.tokens: {} + +--- +"stored sparse_vector": + + - requires: + cluster_features: [ "mapper.sparse_vector.store_support" ] + reason: "sparse_vector supports store parameter" + + - do: + indices.create: + index: test + body: + mappings: + properties: + ml.tokens: + type: sparse_vector + store: true + + - match: { acknowledged: true } + - do: + index: + index: test + id: "1" + body: + ml: + tokens: + running: 2 + good: 3 + run: 5 + race: 7 + for: 9 + + - match: { result: "created" } + + - do: + indices.refresh: { } + + - do: + search: + index: test + body: + fields: [ "ml.tokens" ] + + - length: { hits.hits.0.fields.ml\\.tokens: 1 } + - length: { hits.hits.0.fields.ml\\.tokens.0: 5 } + - match: { hits.hits.0.fields.ml\\.tokens.0.running: 2.0 } + - match: { hits.hits.0.fields.ml\\.tokens.0.good: 3.0 } + - match: { hits.hits.0.fields.ml\\.tokens.0.run: 5.0 } + - match: { hits.hits.0.fields.ml\\.tokens.0.race: 7.0 } + - match: { hits.hits.0.fields.ml\\.tokens.0.for: 9.0 } + +--- +"stored sparse_vector synthetic source": + + - requires: + cluster_features: [ "mapper.source.mode_from_index_setting", "mapper.sparse_vector.store_support" ] + reason: "sparse_vector supports store parameter" + + - do: + indices.create: + index: test + body: + settings: + index: + mapping.source.mode: synthetic + mappings: + properties: + ml.tokens: + type: sparse_vector + store: true + + - match: { acknowledged: true } + + - do: + index: + index: test + id: "1" + body: + ml: + tokens: + running: 2 + good: 3 + run: 5 + race: 7 + for: 9 + + - match: { result: "created" } + + - do: + indices.refresh: { } + + - do: + search: + index: test + body: + fields: [ "ml.tokens" ] + + - match: + hits.hits.0._source: { + ml: { + tokens: { + running: 2.0, + good: 3.0, + run: 5.0, + race: 7.0, + for: 9.0 + } + } + } + + - length: { hits.hits.0.fields.ml\\.tokens: 1 } + - length: { hits.hits.0.fields.ml\\.tokens.0: 5 } + - match: { hits.hits.0.fields.ml\\.tokens.0.running: 2.0 } + - match: { hits.hits.0.fields.ml\\.tokens.0.good: 3.0 } + - match: { hits.hits.0.fields.ml\\.tokens.0.run: 5.0 } + - match: { hits.hits.0.fields.ml\\.tokens.0.race: 7.0 } + - match: { hits.hits.0.fields.ml\\.tokens.0.for: 9.0 } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/remote/RemoteInfoIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/remote/RemoteInfoIT.java index 25678939cb375..9e578faaac70c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/remote/RemoteInfoIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/remote/RemoteInfoIT.java @@ -15,7 +15,6 @@ import org.elasticsearch.test.InternalTestCluster; import org.elasticsearch.test.NodeRoles; -import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -24,7 +23,7 @@ public class RemoteInfoIT extends AbstractMultiClustersTestCase { @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { if (randomBoolean()) { return List.of(); } else { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsRemoteIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsRemoteIT.java index 6cc9824245247..5f4315abff405 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsRemoteIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsRemoteIT.java @@ -23,7 +23,6 @@ import org.elasticsearch.test.InternalTestCluster; import org.junit.Assert; -import java.util.Collection; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -51,7 +50,7 @@ protected boolean reuseClusters() { } @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE1, REMOTE2); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/search/CCSPointInTimeIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/search/CCSPointInTimeIT.java index ed92e7704f4ba..7a75313d44189 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/search/CCSPointInTimeIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/search/CCSPointInTimeIT.java @@ -44,7 +44,7 @@ public class CCSPointInTimeIT extends AbstractMultiClustersTestCase { public static final String REMOTE_CLUSTER = "remote_cluster"; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/NoMasterNodeIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/NoMasterNodeIT.java index 13515d34ec65f..545b38f30ba94 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/NoMasterNodeIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/NoMasterNodeIT.java @@ -261,10 +261,11 @@ public void testNoMasterActionsWriteMasterBlock() throws Exception { GetResponse getResponse = clientToMasterlessNode.prepareGet("test1", "1").get(); assertExists(getResponse); - assertHitCount(clientToMasterlessNode.prepareSearch("test1").setAllowPartialSearchResults(true).setSize(0), 1L); - - logger.info("--> here 3"); - assertHitCount(clientToMasterlessNode.prepareSearch("test1").setAllowPartialSearchResults(true), 1L); + assertHitCount( + 1L, + clientToMasterlessNode.prepareSearch("test1").setAllowPartialSearchResults(true).setSize(0), + clientToMasterlessNode.prepareSearch("test1").setAllowPartialSearchResults(true) + ); assertResponse(clientToMasterlessNode.prepareSearch("test2").setAllowPartialSearchResults(true).setSize(0), countResponse -> { assertThat(countResponse.getTotalShards(), equalTo(3)); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/IndicesOptionsIntegrationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/IndicesOptionsIntegrationIT.java index f41277c5b80ca..545ed83bb79c8 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/IndicesOptionsIntegrationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/IndicesOptionsIntegrationIT.java @@ -398,8 +398,11 @@ public void testWildcardBehaviourSnapshotRestore() throws Exception { public void testAllMissingLenient() throws Exception { createIndex("test1"); prepareIndex("test1").setId("1").setSource("k", "v").setRefreshPolicy(IMMEDIATE).get(); - assertHitCount(prepareSearch("test2").setIndicesOptions(IndicesOptions.lenientExpandOpen()).setQuery(matchAllQuery()), 0L); - assertHitCount(prepareSearch("test2", "test3").setQuery(matchAllQuery()).setIndicesOptions(IndicesOptions.lenientExpandOpen()), 0L); + assertHitCount( + 0L, + prepareSearch("test2").setIndicesOptions(IndicesOptions.lenientExpandOpen()).setQuery(matchAllQuery()), + prepareSearch("test2", "test3").setQuery(matchAllQuery()).setIndicesOptions(IndicesOptions.lenientExpandOpen()) + ); // you should still be able to run empty searches without things blowing up assertHitCount(prepareSearch().setIndicesOptions(IndicesOptions.lenientExpandOpen()).setQuery(matchAllQuery()), 1L); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/cluster/ResolveClusterIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/cluster/ResolveClusterIT.java index 1a6674edc5147..4bdc5d63f4a2f 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/cluster/ResolveClusterIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/cluster/ResolveClusterIT.java @@ -28,7 +28,6 @@ import org.elasticsearch.transport.RemoteClusterAware; import java.io.IOException; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -54,7 +53,7 @@ public class ResolveClusterIT extends AbstractMultiClustersTestCase { private static long LATEST_TIMESTAMP = 1691348820000L; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER_1, REMOTE_CLUSTER_2); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexPrimaryRelocationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexPrimaryRelocationIT.java index 581145d949cf9..debcf5c06a7d6 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexPrimaryRelocationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexPrimaryRelocationIT.java @@ -98,11 +98,11 @@ public void run() { finished.set(true); indexingThread.join(); refresh("test"); - ElasticsearchAssertions.assertHitCount(prepareSearch("test").setTrackTotalHits(true), numAutoGenDocs.get()); ElasticsearchAssertions.assertHitCount( + numAutoGenDocs.get(), + prepareSearch("test").setTrackTotalHits(true), prepareSearch("test").setTrackTotalHits(true)// extra paranoia ;) - .setQuery(QueryBuilders.termQuery("auto", true)), - numAutoGenDocs.get() + .setQuery(QueryBuilders.termQuery("auto", true)) ); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/template/SimpleIndexTemplateIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/template/SimpleIndexTemplateIT.java index de9e3f28a2109..8496180e85d4e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/template/SimpleIndexTemplateIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/template/SimpleIndexTemplateIT.java @@ -500,9 +500,7 @@ public void testIndexTemplateWithAliases() throws Exception { refresh(); - assertHitCount(prepareSearch("test_index"), 5L); - assertHitCount(prepareSearch("simple_alias"), 5L); - assertHitCount(prepareSearch("templated_alias-test_index"), 5L); + assertHitCount(5L, prepareSearch("test_index"), prepareSearch("simple_alias"), prepareSearch("templated_alias-test_index")); assertResponse(prepareSearch("filtered_alias"), response -> { assertHitCount(response, 1L); @@ -584,8 +582,7 @@ public void testIndexTemplateWithAliasesSource() { prepareIndex("test_index").setId("2").setSource("field", "value2").get(); refresh(); - assertHitCount(prepareSearch("test_index"), 2L); - assertHitCount(prepareSearch("alias1"), 2L); + assertHitCount(2L, prepareSearch("test_index"), prepareSearch("alias1")); assertResponse(prepareSearch("alias2"), response -> { assertHitCount(response, 1L); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSCanMatchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSCanMatchIT.java index 3f354baace85a..ce898d9be15ca 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSCanMatchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSCanMatchIT.java @@ -55,7 +55,7 @@ public class CCSCanMatchIT extends AbstractMultiClustersTestCase { static final String REMOTE_CLUSTER = "cluster_a"; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of("cluster_a"); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSUsageTelemetryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSUsageTelemetryIT.java index c9d34dbf14015..9c1daccd2cc9e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSUsageTelemetryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CCSUsageTelemetryIT.java @@ -11,16 +11,19 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.stats.CCSTelemetrySnapshot; import org.elasticsearch.action.admin.cluster.stats.CCSUsageTelemetry.Result; +import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.search.ClosePointInTimeRequest; import org.elasticsearch.action.search.OpenPointInTimeRequest; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.TransportClosePointInTimeAction; import org.elasticsearch.action.search.TransportOpenPointInTimeAction; import org.elasticsearch.action.search.TransportSearchAction; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; @@ -78,7 +81,7 @@ protected boolean reuseClusters() { } @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE1, REMOTE2); } @@ -126,12 +129,9 @@ private CCSTelemetrySnapshot getTelemetryFromFailedSearch(SearchRequest searchRe // We want to send search to a specific node (we don't care which one) so that we could // collect the CCS telemetry from it later String nodeName = cluster(LOCAL_CLUSTER).getRandomNodeName(); - PlainActionFuture queryFuture = new PlainActionFuture<>(); - cluster(LOCAL_CLUSTER).client(nodeName).search(searchRequest, queryFuture); - assertBusy(() -> assertTrue(queryFuture.isDone())); // We expect failure, but we don't care too much which failure it is in this test - ExecutionException ee = expectThrows(ExecutionException.class, queryFuture::get); + ExecutionException ee = expectThrows(ExecutionException.class, cluster(LOCAL_CLUSTER).client(nodeName).search(searchRequest)::get); assertNotNull(ee.getCause()); return getTelemetrySnapshot(nodeName); @@ -637,56 +637,62 @@ private CCSTelemetrySnapshot getTelemetrySnapshot(String nodeName) { return usage.getCcsUsageHolder().getCCSTelemetrySnapshot(); } - private Map setupClusters() { + private Map setupClusters() throws ExecutionException, InterruptedException { String localIndex = "demo"; + String remoteIndex = "prod"; int numShardsLocal = randomIntBetween(2, 10); Settings localSettings = indexSettings(numShardsLocal, randomIntBetween(0, 1)).build(); - assertAcked( + final PlainActionFuture future = new PlainActionFuture<>(); + try (RefCountingListener refCountingListener = new RefCountingListener(future)) { client(LOCAL_CLUSTER).admin() .indices() .prepareCreate(localIndex) .setSettings(localSettings) .setMapping("@timestamp", "type=date", "f", "type=text") - ); - indexDocs(client(LOCAL_CLUSTER), localIndex); - - String remoteIndex = "prod"; - int numShardsRemote = randomIntBetween(2, 10); - for (String clusterAlias : remoteClusterAlias()) { - final InternalTestCluster remoteCluster = cluster(clusterAlias); - remoteCluster.ensureAtLeastNumDataNodes(randomIntBetween(2, 3)); - assertAcked( + .execute(refCountingListener.acquire(r -> { + assertAcked(r); + indexDocs(client(LOCAL_CLUSTER), localIndex, refCountingListener.acquire()); + })); + + int numShardsRemote = randomIntBetween(2, 10); + var remotes = remoteClusterAlias(); + runInParallel(remotes.size(), i -> { + final String clusterAlias = remotes.get(i); + final InternalTestCluster remoteCluster = cluster(clusterAlias); + remoteCluster.ensureAtLeastNumDataNodes(randomIntBetween(2, 3)); client(clusterAlias).admin() .indices() .prepareCreate(remoteIndex) .setSettings(indexSettings(numShardsRemote, randomIntBetween(0, 1))) .setMapping("@timestamp", "type=date", "f", "type=text") - ); - assertFalse( - client(clusterAlias).admin() - .cluster() - .prepareHealth(TEST_REQUEST_TIMEOUT, remoteIndex) - .setWaitForYellowStatus() - .setTimeout(TimeValue.timeValueSeconds(10)) - .get() - .isTimedOut() - ); - indexDocs(client(clusterAlias), remoteIndex); + .execute(refCountingListener.acquire(r -> { + assertAcked(r); + client(clusterAlias).admin() + .cluster() + .prepareHealth(TEST_REQUEST_TIMEOUT, remoteIndex) + .setWaitForYellowStatus() + .setTimeout(TimeValue.timeValueSeconds(10)) + .execute(refCountingListener.acquire(healthResponse -> { + assertFalse(healthResponse.isTimedOut()); + indexDocs(client(clusterAlias), remoteIndex, refCountingListener.acquire()); + })); + })); + }); } - + future.get(); Map clusterInfo = new HashMap<>(); clusterInfo.put("local.index", localIndex); clusterInfo.put("remote.index", remoteIndex); return clusterInfo; } - private int indexDocs(Client client, String index) { + private void indexDocs(Client client, String index, ActionListener listener) { int numDocs = between(5, 20); + final BulkRequestBuilder bulkRequest = client.prepareBulk(); for (int i = 0; i < numDocs; i++) { - client.prepareIndex(index).setSource("f", "v", "@timestamp", randomNonNegativeLong()).get(); + bulkRequest.add(client.prepareIndex(index).setSource("f", "v", "@timestamp", randomNonNegativeLong())); } - client.admin().indices().prepareRefresh(index).get(); - return numDocs; + bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).execute(listener.safeMap(r -> null)); } /** diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterIT.java index cb4d0681cdb23..57a9f8131ac2d 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterIT.java @@ -86,7 +86,7 @@ public class CrossClusterIT extends AbstractMultiClustersTestCase { @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of("cluster_a"); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java index 63eece88a53fc..823d3198bc7a2 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchIT.java @@ -60,7 +60,7 @@ public class CrossClusterSearchIT extends AbstractMultiClustersTestCase { private static long LATEST_TIMESTAMP = 1691348820000L; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchLeakIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchLeakIT.java index 8b493782d55b5..e8a3df353a01e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchLeakIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchLeakIT.java @@ -38,7 +38,7 @@ public class CrossClusterSearchLeakIT extends AbstractMultiClustersTestCase { @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of("cluster_a"); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java index 56b34f9b1dfec..f29cff98c6495 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java @@ -34,7 +34,7 @@ public class CCSFieldCapabilitiesIT extends AbstractMultiClustersTestCase { @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of("remote_cluster"); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/SimpleNestedIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/SimpleNestedIT.java index 8225386ed02d2..acfc55a740f1e 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/nested/SimpleNestedIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/nested/SimpleNestedIT.java @@ -53,6 +53,7 @@ import static org.hamcrest.Matchers.startsWith; public class SimpleNestedIT extends ESIntegTestCase { + public void testSimpleNested() throws Exception { assertAcked(prepareCreate("test").setMapping("nested1", "type=nested")); ensureGreen(); @@ -87,21 +88,20 @@ public void testSimpleNested() throws Exception { // check the numDocs assertDocumentCount("test", 3); - assertHitCount(prepareSearch("test").setQuery(termQuery("n_field1", "n_value1_1")), 0L); - - // search for something that matches the nested doc, and see that we don't find the nested doc - assertHitCount(prepareSearch("test"), 1L); - assertHitCount(prepareSearch("test").setQuery(termQuery("n_field1", "n_value1_1")), 0L); + assertHitCount( + 0L, + prepareSearch("test").setQuery(termQuery("n_field1", "n_value1_1")), + prepareSearch("test").setQuery(termQuery("n_field1", "n_value1_1")) + ); - // now, do a nested query - assertHitCountAndNoFailures( + assertHitCount( + 1L, + // search for something that matches the nested doc, and see that we don't find the nested doc + prepareSearch("test"), + // now, do a nested query prepareSearch("test").setQuery(nestedQuery("nested1", termQuery("nested1.n_field1", "n_value1_1"), ScoreMode.Avg)), - 1L - ); - assertHitCountAndNoFailures( prepareSearch("test").setQuery(nestedQuery("nested1", termQuery("nested1.n_field1", "n_value1_1"), ScoreMode.Avg)) - .setSearchType(SearchType.DFS_QUERY_THEN_FETCH), - 1L + .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) ); // add another doc, one that would match if it was not nested... diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/MinimalCompoundRetrieverIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/MinimalCompoundRetrieverIT.java index 97aa428822fae..8dc37bad675e8 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/MinimalCompoundRetrieverIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/MinimalCompoundRetrieverIT.java @@ -26,7 +26,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; -import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -43,7 +42,7 @@ public class MinimalCompoundRetrieverIT extends AbstractMultiClustersTestCase { private static final String REMOTE_CLUSTER = "cluster_a"; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/scroll/SearchScrollIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/scroll/SearchScrollIT.java index 7ac24b77a4b6d..a54e19b839ad3 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/scroll/SearchScrollIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/scroll/SearchScrollIT.java @@ -206,11 +206,17 @@ public void testScrollAndUpdateIndex() throws Exception { indicesAdmin().prepareRefresh().get(); - assertHitCount(prepareSearch().setSize(0).setQuery(matchAllQuery()), 500); - assertHitCount(prepareSearch().setSize(0).setQuery(termQuery("message", "test")), 500); - assertHitCount(prepareSearch().setSize(0).setQuery(termQuery("message", "test")), 500); - assertHitCount(prepareSearch().setSize(0).setQuery(termQuery("message", "update")), 0); - assertHitCount(prepareSearch().setSize(0).setQuery(termQuery("message", "update")), 0); + assertHitCount( + 500, + prepareSearch().setSize(0).setQuery(matchAllQuery()), + prepareSearch().setSize(0).setQuery(termQuery("message", "test")), + prepareSearch().setSize(0).setQuery(termQuery("message", "test")) + ); + assertHitCount( + 0, + prepareSearch().setSize(0).setQuery(termQuery("message", "update")), + prepareSearch().setSize(0).setQuery(termQuery("message", "update")) + ); SearchResponse searchResponse = prepareSearch().setQuery(queryStringQuery("user:kimchy")) .setSize(35) @@ -229,11 +235,17 @@ public void testScrollAndUpdateIndex() throws Exception { } while (searchResponse.getHits().getHits().length > 0); indicesAdmin().prepareRefresh().get(); - assertHitCount(prepareSearch().setSize(0).setQuery(matchAllQuery()), 500); - assertHitCount(prepareSearch().setSize(0).setQuery(termQuery("message", "test")), 0); - assertHitCount(prepareSearch().setSize(0).setQuery(termQuery("message", "test")), 0); - assertHitCount(prepareSearch().setSize(0).setQuery(termQuery("message", "update")), 500); - assertHitCount(prepareSearch().setSize(0).setQuery(termQuery("message", "update")), 500); + assertHitCount( + 500, + prepareSearch().setSize(0).setQuery(matchAllQuery()), + prepareSearch().setSize(0).setQuery(termQuery("message", "update")), + prepareSearch().setSize(0).setQuery(termQuery("message", "update")) + ); + assertHitCount( + 0, + prepareSearch().setSize(0).setQuery(termQuery("message", "test")), + prepareSearch().setSize(0).setQuery(termQuery("message", "test")) + ); } finally { clearScroll(searchResponse.getScrollId()); searchResponse.decRef(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java index e87c4790aa665..5a9be73d92268 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/simple/SimpleSearchIT.java @@ -147,16 +147,22 @@ public void testIpCidr() throws Exception { prepareIndex("test").setId("5").setSource("ip", "2001:db8::ff00:42:8329").get(); refresh(); - assertHitCount(prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "192.168.0.1"))), 1L); - assertHitCount(prepareSearch().setQuery(queryStringQuery("ip: 192.168.0.1")), 1L); - assertHitCount(prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "192.168.0.1/32"))), 1L); + assertHitCount(prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "192.168.1.5/32"))), 0L); + assertHitCount( + 1L, + prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "192.168.0.1"))), + prepareSearch().setQuery(queryStringQuery("ip: 192.168.0.1")), + prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "192.168.0.1/32"))), + prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "2001:db8::ff00:42:8329/128"))), + prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "2001:db8::/64"))) + ); assertHitCount(prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "192.168.0.0/24"))), 3L); - assertHitCount(prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "192.0.0.0/8"))), 4L); - assertHitCount(prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "0.0.0.0/0"))), 4L); - assertHitCount(prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "2001:db8::ff00:42:8329/128"))), 1L); - assertHitCount(prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "2001:db8::/64"))), 1L); + assertHitCount( + 4L, + prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "192.0.0.0/8"))), + prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "0.0.0.0/0"))) + ); assertHitCount(prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "::/0"))), 5L); - assertHitCount(prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "192.168.1.5/32"))), 0L); assertFailures( prepareSearch().setQuery(boolQuery().must(QueryBuilders.termQuery("ip", "0/0/0/0/0"))), @@ -170,8 +176,11 @@ public void testSimpleId() { prepareIndex("test").setId("XXX1").setSource("field", "value").setRefreshPolicy(IMMEDIATE).get(); // id is not indexed, but lets see that we automatically convert to - assertHitCount(prepareSearch().setQuery(QueryBuilders.termQuery("_id", "XXX1")), 1L); - assertHitCount(prepareSearch().setQuery(QueryBuilders.queryStringQuery("_id:XXX1")), 1L); + assertHitCount( + 1L, + prepareSearch().setQuery(QueryBuilders.termQuery("_id", "XXX1")), + prepareSearch().setQuery(QueryBuilders.queryStringQuery("_id:XXX1")) + ); } public void testSimpleDateRange() throws Exception { @@ -324,12 +333,12 @@ public void testLargeFromAndSizeSucceeds() throws Exception { createIndex("idx"); indexRandom(true, prepareIndex("idx").setSource("{}", XContentType.JSON)); - assertHitCount(prepareSearch("idx").setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) - 10), 1); - assertHitCount(prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)), 1); assertHitCount( + 1, + prepareSearch("idx").setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) - 10), + prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)), prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) / 2) - .setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) / 2 - 1), - 1 + .setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) / 2 - 1) ); } @@ -340,12 +349,12 @@ public void testTooLargeFromAndSizeOkBySetting() throws Exception { ).get(); indexRandom(true, prepareIndex("idx").setSource("{}", XContentType.JSON)); - assertHitCount(prepareSearch("idx").setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)), 1); - assertHitCount(prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) + 1), 1); assertHitCount( + 1, + prepareSearch("idx").setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)), + prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) + 1), prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)) - .setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)), - 1 + .setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)) ); } @@ -358,12 +367,12 @@ public void testTooLargeFromAndSizeOkByDynamicSetting() throws Exception { ); indexRandom(true, prepareIndex("idx").setSource("{}", XContentType.JSON)); - assertHitCount(prepareSearch("idx").setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)), 1); - assertHitCount(prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) + 1), 1); assertHitCount( + 1, + prepareSearch("idx").setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)), + prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) + 1), prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)) - .setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)), - 1 + .setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY)) ); } @@ -371,12 +380,12 @@ public void testTooLargeFromAndSizeBackwardsCompatibilityRecommendation() throws prepareCreate("idx").setSettings(Settings.builder().put(IndexSettings.MAX_RESULT_WINDOW_SETTING.getKey(), Integer.MAX_VALUE)).get(); indexRandom(true, prepareIndex("idx").setSource("{}", XContentType.JSON)); - assertHitCount(prepareSearch("idx").setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) * 10), 1); - assertHitCount(prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) * 10), 1); assertHitCount( + 1, + prepareSearch("idx").setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) * 10), + prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) * 10), prepareSearch("idx").setSize(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) * 10) - .setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) * 10), - 1 + .setFrom(IndexSettings.MAX_RESULT_WINDOW_SETTING.get(Settings.EMPTY) * 10) ); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java index 87665c3d784f1..bf7a315040caa 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/sort/FieldSortIT.java @@ -202,7 +202,6 @@ public void testIssue6614() throws InterruptedException { response -> { for (int j = 0; j < response.getHits().getHits().length; j++) { assertThat( - response.toString() + "\n vs. \n" + allDocsResponse.toString(), response.getHits().getHits()[j].getId(), equalTo(allDocsResponse.getHits().getHits()[j].getId()) ); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/RestoreSnapshotIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/RestoreSnapshotIT.java index fe83073eeb780..b490c7efd52cd 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/RestoreSnapshotIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/RestoreSnapshotIT.java @@ -678,9 +678,12 @@ public void testChangeSettingsOnRestore() throws Exception { indexRandom(true, builders); flushAndRefresh(); - assertHitCount(client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "foo")), numdocs); + assertHitCount( + numdocs, + client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "foo")), + client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "bar")) + ); assertHitCount(client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "Foo")), 0); - assertHitCount(client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "bar")), numdocs); createSnapshot("test-repo", "test-snap", Collections.singletonList("test-idx")); @@ -736,8 +739,11 @@ public void testChangeSettingsOnRestore() throws Exception { assertThat(getSettingsResponse.getSetting("test-idx", SETTING_NUMBER_OF_SHARDS), equalTo("" + numberOfShards)); assertThat(getSettingsResponse.getSetting("test-idx", "index.analysis.analyzer.my_analyzer.type"), equalTo("standard")); - assertHitCount(client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "Foo")), numdocs); - assertHitCount(client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "bar")), numdocs); + assertHitCount( + numdocs, + client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "Foo")), + client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "bar")) + ); logger.info("--> delete the index and recreate it while deleting all index settings"); cluster().wipeIndices("test-idx"); @@ -758,8 +764,11 @@ public void testChangeSettingsOnRestore() throws Exception { // Make sure that number of shards didn't change assertThat(getSettingsResponse.getSetting("test-idx", SETTING_NUMBER_OF_SHARDS), equalTo("" + numberOfShards)); - assertHitCount(client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "Foo")), numdocs); - assertHitCount(client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "bar")), numdocs); + assertHitCount( + numdocs, + client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "Foo")), + client.prepareSearch("test-idx").setSize(0).setQuery(matchQuery("field1", "bar")) + ); } public void testRestoreChangeIndexMode() { diff --git a/server/src/main/java/org/elasticsearch/ElasticsearchException.java b/server/src/main/java/org/elasticsearch/ElasticsearchException.java index 3c5c365654206..fcb5c20c28162 100644 --- a/server/src/main/java/org/elasticsearch/ElasticsearchException.java +++ b/server/src/main/java/org/elasticsearch/ElasticsearchException.java @@ -1947,13 +1947,13 @@ private enum ElasticsearchExceptionHandle { org.elasticsearch.ingest.IngestPipelineException.class, org.elasticsearch.ingest.IngestPipelineException::new, 182, - TransportVersions.INGEST_PIPELINE_EXCEPTION_ADDED + TransportVersions.V_8_16_0 ), INDEX_RESPONSE_WRAPPER_EXCEPTION( IndexDocFailureStoreStatus.ExceptionWithFailureStoreStatus.class, IndexDocFailureStoreStatus.ExceptionWithFailureStoreStatus::new, 183, - TransportVersions.FAILURE_STORE_STATUS_IN_INDEX_RESPONSE + TransportVersions.V_8_16_0 ); final Class exceptionClass; diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2e4842912dfae..40a209c5f0f14 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -54,11 +54,9 @@ static TransportVersion def(int id) { public static final TransportVersion ZERO = def(0); public static final TransportVersion V_7_0_0 = def(7_00_00_99); public static final TransportVersion V_7_0_1 = def(7_00_01_99); - public static final TransportVersion V_7_1_0 = def(7_01_00_99); public static final TransportVersion V_7_2_0 = def(7_02_00_99); public static final TransportVersion V_7_2_1 = def(7_02_01_99); public static final TransportVersion V_7_3_0 = def(7_03_00_99); - public static final TransportVersion V_7_3_2 = def(7_03_02_99); public static final TransportVersion V_7_4_0 = def(7_04_00_99); public static final TransportVersion V_7_5_0 = def(7_05_00_99); public static final TransportVersion V_7_6_0 = def(7_06_00_99); @@ -104,78 +102,7 @@ static TransportVersion def(int id) { public static final TransportVersion V_8_14_0 = def(8_636_00_1); public static final TransportVersion V_8_15_0 = def(8_702_00_2); public static final TransportVersion V_8_15_2 = def(8_702_00_3); - public static final TransportVersion QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_15 = def(8_702_00_4); - public static final TransportVersion ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS = def(8_703_00_0); - public static final TransportVersion INFERENCE_ADAPTIVE_ALLOCATIONS = def(8_704_00_0); - public static final TransportVersion INDEX_REQUEST_UPDATE_BY_SCRIPT_ORIGIN = def(8_705_00_0); - public static final TransportVersion ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED = def(8_706_00_0); - public static final TransportVersion ENRICH_CACHE_STATS_SIZE_ADDED = def(8_707_00_0); - public static final TransportVersion ENTERPRISE_GEOIP_DOWNLOADER = def(8_708_00_0); - public static final TransportVersion NODES_STATS_ENUM_SET = def(8_709_00_0); - public static final TransportVersion MASTER_NODE_METRICS = def(8_710_00_0); - public static final TransportVersion SEGMENT_LEVEL_FIELDS_STATS = def(8_711_00_0); - public static final TransportVersion ML_ADD_DETECTION_RULE_PARAMS = def(8_712_00_0); - public static final TransportVersion FIX_VECTOR_SIMILARITY_INNER_HITS = def(8_713_00_0); - public static final TransportVersion INDEX_REQUEST_UPDATE_BY_DOC_ORIGIN = def(8_714_00_0); - public static final TransportVersion ESQL_ATTRIBUTE_CACHED_SERIALIZATION = def(8_715_00_0); - public static final TransportVersion REGISTER_SLM_STATS = def(8_716_00_0); - public static final TransportVersion ESQL_NESTED_UNSUPPORTED = def(8_717_00_0); - public static final TransportVersion ESQL_SINGLE_VALUE_QUERY_SOURCE = def(8_718_00_0); - public static final TransportVersion ESQL_ORIGINAL_INDICES = def(8_719_00_0); - public static final TransportVersion ML_INFERENCE_EIS_INTEGRATION_ADDED = def(8_720_00_0); - public static final TransportVersion INGEST_PIPELINE_EXCEPTION_ADDED = def(8_721_00_0); - public static final TransportVersion ZDT_NANOS_SUPPORT_BROKEN = def(8_722_00_0); - public static final TransportVersion REMOVE_GLOBAL_RETENTION_FROM_TEMPLATES = def(8_723_00_0); - public static final TransportVersion RANDOM_RERANKER_RETRIEVER = def(8_724_00_0); - public static final TransportVersion ESQL_PROFILE_SLEEPS = def(8_725_00_0); - public static final TransportVersion ZDT_NANOS_SUPPORT = def(8_726_00_0); - public static final TransportVersion LTR_SERVERLESS_RELEASE = def(8_727_00_0); - public static final TransportVersion ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT = def(8_728_00_0); - public static final TransportVersion RANK_DOCS_RETRIEVER = def(8_729_00_0); - public static final TransportVersion ESQL_ES_FIELD_CACHED_SERIALIZATION = def(8_730_00_0); - public static final TransportVersion ADD_MANAGE_ROLES_PRIVILEGE = def(8_731_00_0); - public static final TransportVersion REPOSITORIES_TELEMETRY = def(8_732_00_0); - public static final TransportVersion ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED = def(8_733_00_0); - public static final TransportVersion FIELD_CAPS_RESPONSE_INDEX_MODE = def(8_734_00_0); - public static final TransportVersion GET_DATA_STREAMS_VERBOSE = def(8_735_00_0); - public static final TransportVersion ESQL_ADD_INDEX_MODE_CONCRETE_INDICES = def(8_736_00_0); - public static final TransportVersion UNASSIGNED_PRIMARY_COUNT_ON_CLUSTER_HEALTH = def(8_737_00_0); - public static final TransportVersion ESQL_AGGREGATE_EXEC_TRACKS_INTERMEDIATE_ATTRS = def(8_738_00_0); - public static final TransportVersion CCS_TELEMETRY_STATS = def(8_739_00_0); - public static final TransportVersion GLOBAL_RETENTION_TELEMETRY = def(8_740_00_0); - public static final TransportVersion ROUTING_TABLE_VERSION_REMOVED = def(8_741_00_0); - public static final TransportVersion ML_SCHEDULED_EVENT_TIME_SHIFT_CONFIGURATION = def(8_742_00_0); - public static final TransportVersion SIMULATE_COMPONENT_TEMPLATES_SUBSTITUTIONS = def(8_743_00_0); - public static final TransportVersion ML_INFERENCE_IBM_WATSONX_EMBEDDINGS_ADDED = def(8_744_00_0); - public static final TransportVersion BULK_INCREMENTAL_STATE = def(8_745_00_0); - public static final TransportVersion FAILURE_STORE_STATUS_IN_INDEX_RESPONSE = def(8_746_00_0); - public static final TransportVersion ESQL_AGGREGATION_OPERATOR_STATUS_FINISH_NANOS = def(8_747_00_0); - public static final TransportVersion ML_TELEMETRY_MEMORY_ADDED = def(8_748_00_0); - public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_TOTAL_SHARDS_PER_NODE = def(8_749_00_0); - public static final TransportVersion SEMANTIC_TEXT_SEARCH_INFERENCE_ID = def(8_750_00_0); - public static final TransportVersion ML_INFERENCE_CHUNKING_SETTINGS = def(8_751_00_0); - public static final TransportVersion SEMANTIC_QUERY_INNER_HITS = def(8_752_00_0); - public static final TransportVersion RETAIN_ILM_STEP_INFO = def(8_753_00_0); - public static final TransportVersion ADD_DATA_STREAM_OPTIONS = def(8_754_00_0); - public static final TransportVersion CCS_REMOTE_TELEMETRY_STATS = def(8_755_00_0); - public static final TransportVersion ESQL_CCS_EXECUTION_INFO = def(8_756_00_0); - public static final TransportVersion REGEX_AND_RANGE_INTERVAL_QUERIES = def(8_757_00_0); - public static final TransportVersion RRF_QUERY_REWRITE = def(8_758_00_0); - public static final TransportVersion SEARCH_FAILURE_STATS = def(8_759_00_0); - public static final TransportVersion INGEST_GEO_DATABASE_PROVIDERS = def(8_760_00_0); - public static final TransportVersion DATE_TIME_DOC_VALUES_LOCALES = def(8_761_00_0); - public static final TransportVersion FAST_REFRESH_RCO = def(8_762_00_0); - public static final TransportVersion TEXT_SIMILARITY_RERANKER_QUERY_REWRITE = def(8_763_00_0); - public static final TransportVersion SIMULATE_INDEX_TEMPLATES_SUBSTITUTIONS = def(8_764_00_0); - public static final TransportVersion RETRIEVERS_TELEMETRY_ADDED = def(8_765_00_0); - public static final TransportVersion ESQL_CACHED_STRING_SERIALIZATION = def(8_766_00_0); - public static final TransportVersion CHUNK_SENTENCE_OVERLAP_SETTING_ADDED = def(8_767_00_0); - public static final TransportVersion OPT_IN_ESQL_CCS_EXECUTION_INFO = def(8_768_00_0); - public static final TransportVersion QUERY_RULE_TEST_API = def(8_769_00_0); - public static final TransportVersion ESQL_PER_AGGREGATE_FILTER = def(8_770_00_0); - public static final TransportVersion ML_INFERENCE_ATTACH_TO_EXISTSING_DEPLOYMENT = def(8_771_00_0); - public static final TransportVersion CONVERT_FAILURE_STORE_OPTIONS_TO_SELECTOR_OPTIONS_INTERNALLY = def(8_772_00_0); - public static final TransportVersion INFERENCE_DONT_PERSIST_ON_READ_BACKPORT_8_16 = def(8_772_00_1); + public static final TransportVersion V_8_16_0 = def(8_772_00_1); public static final TransportVersion ADD_COMPATIBILITY_VERSIONS_TO_NODE_INFO_BACKPORT_8_16 = def(8_772_00_2); public static final TransportVersion SKIP_INNER_HITS_SEARCH_SOURCE_BACKPORT_8_16 = def(8_772_00_3); public static final TransportVersion QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_16 = def(8_772_00_4); diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java index e14f229f17acf..d929fb457d5d1 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportGetAllocationStatsAction.java @@ -118,7 +118,7 @@ public Request(TimeValue masterNodeTimeout, TaskId parentTaskId, EnumSet public Request(StreamInput in) throws IOException { super(in); - this.metrics = in.getTransportVersion().onOrAfter(TransportVersions.MASTER_NODE_METRICS) + this.metrics = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readEnumSet(Metric.class) : EnumSet.of(Metric.ALLOCATIONS, Metric.FS); } @@ -127,7 +127,7 @@ public Request(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { assert out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0); super.writeTo(out); - if (out.getTransportVersion().onOrAfter(TransportVersions.MASTER_NODE_METRICS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeEnumSet(metrics); } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsRequestParameters.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsRequestParameters.java index d34bc3ec0dc2f..c5e8f37ed3a96 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsRequestParameters.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsRequestParameters.java @@ -117,7 +117,7 @@ public static Metric get(String name) { } public static void writeSetTo(StreamOutput out, EnumSet metrics) throws IOException { - if (out.getTransportVersion().onOrAfter(TransportVersions.NODES_STATS_ENUM_SET)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeEnumSet(metrics); } else { out.writeCollection(metrics, (output, metric) -> output.writeString(metric.metricName)); @@ -125,7 +125,7 @@ public static void writeSetTo(StreamOutput out, EnumSet metrics) throws } public static EnumSet readSetFrom(StreamInput in) throws IOException { - if (in.getTransportVersion().onOrAfter(TransportVersions.NODES_STATS_ENUM_SET)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { return in.readEnumSet(Metric.class); } else { return in.readCollection((i) -> EnumSet.noneOf(Metric.class), (is, out) -> { diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/create/CreateSnapshotRequest.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/create/CreateSnapshotRequest.java index 9c9467db40de3..b6ced06623306 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/create/CreateSnapshotRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/snapshots/create/CreateSnapshotRequest.java @@ -118,7 +118,7 @@ public CreateSnapshotRequest(StreamInput in) throws IOException { waitForCompletion = in.readBoolean(); partial = in.readBoolean(); userMetadata = in.readGenericMap(); - uuid = in.getTransportVersion().onOrAfter(TransportVersions.REGISTER_SLM_STATS) ? in.readOptionalString() : null; + uuid = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalString() : null; } @Override @@ -136,7 +136,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(waitForCompletion); out.writeBoolean(partial); out.writeGenericMap(userMetadata); - if (out.getTransportVersion().onOrAfter(TransportVersions.REGISTER_SLM_STATS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalString(uuid); } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsNodeResponse.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsNodeResponse.java index f99baa855404c..abeb73e5d8c3e 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsNodeResponse.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/ClusterStatsNodeResponse.java @@ -44,14 +44,11 @@ public ClusterStatsNodeResponse(StreamInput in) throws IOException { } else { searchUsageStats = new SearchUsageStats(); } - if (in.getTransportVersion().onOrAfter(TransportVersions.REPOSITORIES_TELEMETRY)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { repositoryUsageStats = RepositoryUsageStats.readFrom(in); - } else { - repositoryUsageStats = RepositoryUsageStats.EMPTY; - } - if (in.getTransportVersion().onOrAfter(TransportVersions.CCS_TELEMETRY_STATS)) { ccsMetrics = new CCSTelemetrySnapshot(in); } else { + repositoryUsageStats = RepositoryUsageStats.EMPTY; ccsMetrics = new CCSTelemetrySnapshot(); } } @@ -118,12 +115,10 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_6_0)) { searchUsageStats.writeTo(out); } - if (out.getTransportVersion().onOrAfter(TransportVersions.REPOSITORIES_TELEMETRY)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { repositoryUsageStats.writeTo(out); - } // else just drop these stats, ok for bwc - if (out.getTransportVersion().onOrAfter(TransportVersions.CCS_TELEMETRY_STATS)) { ccsMetrics.writeTo(out); - } + } // else just drop these stats, ok for bwc } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/RemoteClusterStatsRequest.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/RemoteClusterStatsRequest.java index 47843a91351ee..6c3c5cbb50ece 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/RemoteClusterStatsRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/RemoteClusterStatsRequest.java @@ -36,9 +36,9 @@ public ActionRequestValidationException validate() { @Override public void writeTo(StreamOutput out) throws IOException { - assert out.getTransportVersion().onOrAfter(TransportVersions.CCS_REMOTE_TELEMETRY_STATS) + assert out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) : "RemoteClusterStatsRequest is not supported by the remote cluster"; - if (out.getTransportVersion().before(TransportVersions.CCS_REMOTE_TELEMETRY_STATS)) { + if (out.getTransportVersion().before(TransportVersions.V_8_16_0)) { throw new UnsupportedOperationException("RemoteClusterStatsRequest is not supported by the remote cluster"); } super.writeTo(out); diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStats.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStats.java index 0f6c56fd21bd7..a6e80b5efd08c 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStats.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStats.java @@ -22,8 +22,8 @@ import java.util.Map; import java.util.Objects; -import static org.elasticsearch.TransportVersions.RETRIEVERS_TELEMETRY_ADDED; import static org.elasticsearch.TransportVersions.V_8_12_0; +import static org.elasticsearch.TransportVersions.V_8_16_0; /** * Holds a snapshot of the search usage statistics. @@ -71,7 +71,7 @@ public SearchUsageStats(StreamInput in) throws IOException { this.sections = in.readMap(StreamInput::readLong); this.totalSearchCount = in.readVLong(); this.rescorers = in.getTransportVersion().onOrAfter(V_8_12_0) ? in.readMap(StreamInput::readLong) : Map.of(); - this.retrievers = in.getTransportVersion().onOrAfter(RETRIEVERS_TELEMETRY_ADDED) ? in.readMap(StreamInput::readLong) : Map.of(); + this.retrievers = in.getTransportVersion().onOrAfter(V_8_16_0) ? in.readMap(StreamInput::readLong) : Map.of(); } @Override @@ -83,7 +83,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(V_8_12_0)) { out.writeMap(rescorers, StreamOutput::writeLong); } - if (out.getTransportVersion().onOrAfter(RETRIEVERS_TELEMETRY_ADDED)) { + if (out.getTransportVersion().onOrAfter(V_8_16_0)) { out.writeMap(retrievers, StreamOutput::writeLong); } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java index 97585ea9a1024..2c20daa5d7afb 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/stats/TransportClusterStatsAction.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.lucene.store.AlreadyClosedException; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; @@ -72,8 +73,6 @@ import java.util.function.BooleanSupplier; import java.util.stream.Collectors; -import static org.elasticsearch.TransportVersions.CCS_REMOTE_TELEMETRY_STATS; - /** * Transport action implementing _cluster/stats API. */ @@ -450,7 +449,7 @@ protected void sendItemRequest(String clusterAlias, ActionListener { - if (connection.getTransportVersion().before(CCS_REMOTE_TELEMETRY_STATS)) { + if (connection.getTransportVersion().before(TransportVersions.V_8_16_0)) { responseListener.onResponse(null); } else { remoteClusterClient.execute(connection, TransportRemoteClusterStatsAction.REMOTE_TYPE, remoteRequest, responseListener); diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/template/get/GetComponentTemplateAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/template/get/GetComponentTemplateAction.java index c6d990e5a1d62..f729455edcc24 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/template/get/GetComponentTemplateAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/template/get/GetComponentTemplateAction.java @@ -131,8 +131,7 @@ public Response(StreamInput in) throws IOException { } else { rolloverConfiguration = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0) - && in.getTransportVersion().before(TransportVersions.REMOVE_GLOBAL_RETENTION_FROM_TEMPLATES)) { + if (in.getTransportVersion().between(TransportVersions.V_8_14_0, TransportVersions.V_8_16_0)) { in.readOptionalWriteable(DataStreamGlobalRetention::read); } } @@ -190,8 +189,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_9_X)) { out.writeOptionalWriteable(rolloverConfiguration); } - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0) - && out.getTransportVersion().before(TransportVersions.REMOVE_GLOBAL_RETENTION_FROM_TEMPLATES)) { + if (out.getTransportVersion().between(TransportVersions.V_8_14_0, TransportVersions.V_8_16_0)) { out.writeOptionalWriteable(null); } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/template/get/GetComposableIndexTemplateAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/template/get/GetComposableIndexTemplateAction.java index a47f89030cc60..67f87476ea6a5 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/template/get/GetComposableIndexTemplateAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/template/get/GetComposableIndexTemplateAction.java @@ -132,8 +132,7 @@ public Response(StreamInput in) throws IOException { } else { rolloverConfiguration = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0) - && in.getTransportVersion().before(TransportVersions.REMOVE_GLOBAL_RETENTION_FROM_TEMPLATES)) { + if (in.getTransportVersion().between(TransportVersions.V_8_14_0, TransportVersions.V_8_16_0)) { in.readOptionalWriteable(DataStreamGlobalRetention::read); } } @@ -191,8 +190,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_9_X)) { out.writeOptionalWriteable(rolloverConfiguration); } - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0) - && out.getTransportVersion().before(TransportVersions.REMOVE_GLOBAL_RETENTION_FROM_TEMPLATES)) { + if (out.getTransportVersion().between(TransportVersions.V_8_14_0, TransportVersions.V_8_16_0)) { out.writeOptionalWriteable(null); } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/template/post/SimulateIndexTemplateResponse.java b/server/src/main/java/org/elasticsearch/action/admin/indices/template/post/SimulateIndexTemplateResponse.java index 064c24cf4afa3..80e6fbfe051a4 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/template/post/SimulateIndexTemplateResponse.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/template/post/SimulateIndexTemplateResponse.java @@ -82,8 +82,7 @@ public SimulateIndexTemplateResponse(StreamInput in) throws IOException { rolloverConfiguration = in.getTransportVersion().onOrAfter(TransportVersions.V_8_9_X) ? in.readOptionalWriteable(RolloverConfiguration::new) : null; - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0) - && in.getTransportVersion().before(TransportVersions.REMOVE_GLOBAL_RETENTION_FROM_TEMPLATES)) { + if (in.getTransportVersion().between(TransportVersions.V_8_14_0, TransportVersions.V_8_16_0)) { in.readOptionalWriteable(DataStreamGlobalRetention::read); } } @@ -104,8 +103,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_9_X)) { out.writeOptionalWriteable(rolloverConfiguration); } - if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_14_0) - && out.getTransportVersion().before(TransportVersions.REMOVE_GLOBAL_RETENTION_FROM_TEMPLATES)) { + if (out.getTransportVersion().between(TransportVersions.V_8_14_0, TransportVersions.V_8_16_0)) { out.writeOptionalWriteable(null); } } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkItemResponse.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkItemResponse.java index d5931c85bb2e1..1ff970de7525e 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkItemResponse.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkItemResponse.java @@ -200,7 +200,7 @@ public Failure(StreamInput in) throws IOException { seqNo = in.readZLong(); term = in.readVLong(); aborted = in.readBoolean(); - if (in.getTransportVersion().onOrAfter(TransportVersions.FAILURE_STORE_STATUS_IN_INDEX_RESPONSE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { failureStoreStatus = IndexDocFailureStoreStatus.read(in); } else { failureStoreStatus = IndexDocFailureStoreStatus.NOT_APPLICABLE_OR_UNKNOWN; @@ -218,7 +218,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeZLong(seqNo); out.writeVLong(term); out.writeBoolean(aborted); - if (out.getTransportVersion().onOrAfter(TransportVersions.FAILURE_STORE_STATUS_IN_INDEX_RESPONSE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { failureStoreStatus.writeTo(out); } } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java index f62b2f48fa2fd..91caebc420ffb 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequest.java @@ -98,7 +98,7 @@ public BulkRequest(StreamInput in) throws IOException { for (DocWriteRequest request : requests) { indices.add(Objects.requireNonNull(request.index(), "request index must not be null")); } - if (in.getTransportVersion().onOrAfter(TransportVersions.BULK_INCREMENTAL_STATE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { incrementalState = new BulkRequest.IncrementalState(in); } else { incrementalState = BulkRequest.IncrementalState.EMPTY; @@ -454,7 +454,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(requests, DocWriteRequest::writeDocumentRequest); refreshPolicy.writeTo(out); out.writeTimeValue(timeout); - if (out.getTransportVersion().onOrAfter(TransportVersions.BULK_INCREMENTAL_STATE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { incrementalState.writeTo(out); } } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkResponse.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkResponse.java index ec7a08007de93..12d3aa67ca9bb 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkResponse.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkResponse.java @@ -46,7 +46,7 @@ public BulkResponse(StreamInput in) throws IOException { responses = in.readArray(BulkItemResponse::new, BulkItemResponse[]::new); tookInMillis = in.readVLong(); ingestTookInMillis = in.readZLong(); - if (in.getTransportVersion().onOrAfter(TransportVersions.BULK_INCREMENTAL_STATE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { incrementalState = new BulkRequest.IncrementalState(in); } else { incrementalState = BulkRequest.IncrementalState.EMPTY; @@ -151,7 +151,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeArray(responses); out.writeVLong(tookInMillis); out.writeZLong(ingestTookInMillis); - if (out.getTransportVersion().onOrAfter(TransportVersions.BULK_INCREMENTAL_STATE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { incrementalState.writeTo(out); } } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/IndexDocFailureStoreStatus.java b/server/src/main/java/org/elasticsearch/action/bulk/IndexDocFailureStoreStatus.java index cb83d693a415b..7367dfa1d53fd 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/IndexDocFailureStoreStatus.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/IndexDocFailureStoreStatus.java @@ -124,7 +124,7 @@ public ExceptionWithFailureStoreStatus(BulkItemResponse.Failure failure) { public ExceptionWithFailureStoreStatus(StreamInput in) throws IOException { super(in); - if (in.getTransportVersion().onOrAfter(TransportVersions.FAILURE_STORE_STATUS_IN_INDEX_RESPONSE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { failureStoreStatus = IndexDocFailureStoreStatus.fromId(in.readByte()); } else { failureStoreStatus = NOT_APPLICABLE_OR_UNKNOWN; @@ -134,7 +134,7 @@ public ExceptionWithFailureStoreStatus(StreamInput in) throws IOException { @Override protected void writeTo(StreamOutput out, Writer nestedExceptionsWriter) throws IOException { super.writeTo(out, nestedExceptionsWriter); - if (out.getTransportVersion().onOrAfter(TransportVersions.FAILURE_STORE_STATUS_IN_INDEX_RESPONSE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeByte(failureStoreStatus.getId()); } } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/SimulateBulkRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/SimulateBulkRequest.java index cc7fd431d8097..290d342e9dc12 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/SimulateBulkRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/SimulateBulkRequest.java @@ -135,14 +135,11 @@ public SimulateBulkRequest( public SimulateBulkRequest(StreamInput in) throws IOException { super(in); this.pipelineSubstitutions = (Map>) in.readGenericValue(); - if (in.getTransportVersion().onOrAfter(TransportVersions.SIMULATE_COMPONENT_TEMPLATES_SUBSTITUTIONS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.componentTemplateSubstitutions = (Map>) in.readGenericValue(); - } else { - componentTemplateSubstitutions = Map.of(); - } - if (in.getTransportVersion().onOrAfter(TransportVersions.SIMULATE_INDEX_TEMPLATES_SUBSTITUTIONS)) { this.indexTemplateSubstitutions = (Map>) in.readGenericValue(); } else { + componentTemplateSubstitutions = Map.of(); indexTemplateSubstitutions = Map.of(); } if (in.getTransportVersion().onOrAfter(TransportVersions.SIMULATE_MAPPING_ADDITION)) { @@ -156,10 +153,8 @@ public SimulateBulkRequest(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeGenericValue(pipelineSubstitutions); - if (out.getTransportVersion().onOrAfter(TransportVersions.SIMULATE_COMPONENT_TEMPLATES_SUBSTITUTIONS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeGenericValue(componentTemplateSubstitutions); - } - if (out.getTransportVersion().onOrAfter(TransportVersions.SIMULATE_INDEX_TEMPLATES_SUBSTITUTIONS)) { out.writeGenericValue(indexTemplateSubstitutions); } if (out.getTransportVersion().onOrAfter(TransportVersions.SIMULATE_MAPPING_ADDITION)) { diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java b/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java index c1cf0fa7aab42..93c40ad18cc8a 100644 --- a/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java +++ b/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java @@ -112,7 +112,7 @@ public Request(StreamInput in) throws IOException { } else { this.includeDefaults = false; } - if (in.getTransportVersion().onOrAfter(TransportVersions.GET_DATA_STREAMS_VERBOSE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.verbose = in.readBoolean(); } else { this.verbose = false; @@ -127,7 +127,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_9_X)) { out.writeBoolean(includeDefaults); } - if (out.getTransportVersion().onOrAfter(TransportVersions.GET_DATA_STREAMS_VERBOSE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeBoolean(verbose); } } @@ -275,7 +275,7 @@ public DataStreamInfo( in.getTransportVersion().onOrAfter(TransportVersions.V_8_3_0) ? in.readOptionalWriteable(TimeSeries::new) : null, in.getTransportVersion().onOrAfter(V_8_11_X) ? in.readMap(Index::new, IndexProperties::new) : Map.of(), in.getTransportVersion().onOrAfter(V_8_11_X) ? in.readBoolean() : true, - in.getTransportVersion().onOrAfter(TransportVersions.GET_DATA_STREAMS_VERBOSE) ? in.readOptionalVLong() : null + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalVLong() : null ); } @@ -328,7 +328,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeMap(indexSettingsValues); out.writeBoolean(templatePreferIlmValue); } - if (out.getTransportVersion().onOrAfter(TransportVersions.GET_DATA_STREAMS_VERBOSE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalVLong(maximumTimestamp); } } diff --git a/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesIndexResponse.java b/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesIndexResponse.java index d16100a64713e..6f510ad26f5ec 100644 --- a/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesIndexResponse.java +++ b/server/src/main/java/org/elasticsearch/action/fieldcaps/FieldCapabilitiesIndexResponse.java @@ -62,7 +62,7 @@ public FieldCapabilitiesIndexResponse( } else { this.indexMappingHash = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.FIELD_CAPS_RESPONSE_INDEX_MODE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.indexMode = IndexMode.readFrom(in); } else { this.indexMode = IndexMode.STANDARD; @@ -77,7 +77,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(MAPPING_HASH_VERSION)) { out.writeOptionalString(indexMappingHash); } - if (out.getTransportVersion().onOrAfter(TransportVersions.FIELD_CAPS_RESPONSE_INDEX_MODE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { IndexMode.writeTo(indexMode, out); } } @@ -105,7 +105,7 @@ static List readList(StreamInput input) throws I private static void collectCompressedResponses(StreamInput input, int groups, ArrayList responses) throws IOException { final CompressedGroup[] compressedGroups = new CompressedGroup[groups]; - final boolean readIndexMode = input.getTransportVersion().onOrAfter(TransportVersions.FIELD_CAPS_RESPONSE_INDEX_MODE); + final boolean readIndexMode = input.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0); for (int i = 0; i < groups; i++) { final String[] indices = input.readStringArray(); final IndexMode indexMode = readIndexMode ? IndexMode.readFrom(input) : IndexMode.STANDARD; @@ -179,7 +179,7 @@ private static void writeCompressedResponses(StreamOutput output, Map { o.writeCollection(fieldCapabilitiesIndexResponses, (oo, r) -> oo.writeString(r.indexName)); var first = fieldCapabilitiesIndexResponses.get(0); - if (output.getTransportVersion().onOrAfter(TransportVersions.FIELD_CAPS_RESPONSE_INDEX_MODE)) { + if (output.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { IndexMode.writeTo(first.indexMode, o); } o.writeString(first.indexMappingHash); diff --git a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java index c0811e7424b0d..5254c6fd06db7 100644 --- a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java +++ b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java @@ -205,10 +205,8 @@ public IndexRequest(@Nullable ShardId shardId, StreamInput in) throws IOExceptio if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { in.readZLong(); // obsolete normalisedBytesParsed } - if (in.getTransportVersion().onOrAfter(TransportVersions.INDEX_REQUEST_UPDATE_BY_SCRIPT_ORIGIN)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { in.readBoolean(); // obsolete originatesFromUpdateByScript - } - if (in.getTransportVersion().onOrAfter(TransportVersions.INDEX_REQUEST_UPDATE_BY_DOC_ORIGIN)) { in.readBoolean(); // obsolete originatesFromUpdateByDoc } } @@ -789,10 +787,8 @@ private void writeBody(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { out.writeZLong(-1); // obsolete normalisedBytesParsed } - if (out.getTransportVersion().onOrAfter(TransportVersions.INDEX_REQUEST_UPDATE_BY_SCRIPT_ORIGIN)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeBoolean(false); // obsolete originatesFromUpdateByScript - } - if (out.getTransportVersion().onOrAfter(TransportVersions.INDEX_REQUEST_UPDATE_BY_DOC_ORIGIN)) { out.writeBoolean(false); // obsolete originatesFromUpdateByDoc } } diff --git a/server/src/main/java/org/elasticsearch/action/index/IndexResponse.java b/server/src/main/java/org/elasticsearch/action/index/IndexResponse.java index 8d1bdf227e24d..7c45de8905174 100644 --- a/server/src/main/java/org/elasticsearch/action/index/IndexResponse.java +++ b/server/src/main/java/org/elasticsearch/action/index/IndexResponse.java @@ -46,7 +46,7 @@ public IndexResponse(ShardId shardId, StreamInput in) throws IOException { } else { executedPipelines = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.FAILURE_STORE_STATUS_IN_INDEX_RESPONSE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { failureStoreStatus = IndexDocFailureStoreStatus.read(in); } else { failureStoreStatus = IndexDocFailureStoreStatus.NOT_APPLICABLE_OR_UNKNOWN; @@ -60,7 +60,7 @@ public IndexResponse(StreamInput in) throws IOException { } else { executedPipelines = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.FAILURE_STORE_STATUS_IN_INDEX_RESPONSE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { failureStoreStatus = IndexDocFailureStoreStatus.read(in); } else { failureStoreStatus = IndexDocFailureStoreStatus.NOT_APPLICABLE_OR_UNKNOWN; @@ -126,7 +126,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { out.writeOptionalCollection(executedPipelines, StreamOutput::writeString); } - if (out.getTransportVersion().onOrAfter(TransportVersions.FAILURE_STORE_STATUS_IN_INDEX_RESPONSE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { failureStoreStatus.writeTo(out); } } @@ -137,7 +137,7 @@ public void writeThin(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { out.writeOptionalCollection(executedPipelines, StreamOutput::writeString); } - if (out.getTransportVersion().onOrAfter(TransportVersions.FAILURE_STORE_STATUS_IN_INDEX_RESPONSE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { failureStoreStatus.writeTo(out); } } diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 09fb70fb06ba4..800193e258dba 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -739,7 +739,7 @@ void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connecti * @see #onShardFailure(int, SearchShardTarget, Exception) * @see #onShardResult(SearchPhaseResult, SearchShardIterator) */ - final void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() + private void onPhaseDone() { // as a tribute to @kimchy aka. finishHim() executeNextPhase(this, this::getNextPhase); } @@ -762,13 +762,6 @@ public final void execute(Runnable command) { executor.execute(command); } - /** - * Notifies the top-level listener of the provided exception - */ - public void onFailure(Exception e) { - listener.onFailure(e); - } - /** * Builds an request for the initial search phase. * diff --git a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java index 8feed2aea00b0..e8d94c32bdcc7 100644 --- a/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/ExpandSearchPhase.java @@ -102,7 +102,7 @@ private void doRun() { for (InnerHitBuilder innerHitBuilder : innerHitBuilders) { MultiSearchResponse.Item item = it.next(); if (item.isFailure()) { - context.onPhaseFailure(this, "failed to expand hits", item.getFailure()); + phaseFailure(item.getFailure()); return; } SearchHits innerHits = item.getResponse().getHits(); @@ -119,7 +119,11 @@ private void doRun() { } } onPhaseDone(); - }, context::onFailure)); + }, this::phaseFailure)); + } + + private void phaseFailure(Exception ex) { + context.onPhaseFailure(this, "failed to expand hits", ex); } private static SearchSourceBuilder buildExpandSearchSourceBuilder(InnerHitBuilder options, CollapseBuilder innerCollapseBuilder) { diff --git a/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeRequest.java b/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeRequest.java index 969ba2ad983ce..d68e2ce1b02b7 100644 --- a/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeRequest.java @@ -63,7 +63,7 @@ public OpenPointInTimeRequest(StreamInput in) throws IOException { if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { this.indexFilter = in.readOptionalNamedWriteable(QueryBuilder.class); } - if (in.getTransportVersion().onOrAfter(TransportVersions.ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.allowPartialSearchResults = in.readBoolean(); } } @@ -82,7 +82,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { out.writeOptionalWriteable(indexFilter); } - if (out.getTransportVersion().onOrAfter(TransportVersions.ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeBoolean(allowPartialSearchResults); } else if (allowPartialSearchResults) { throw new IOException("[allow_partial_search_results] is not supported on nodes with version " + out.getTransportVersion()); diff --git a/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java b/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java index 3c830c8ed9dc1..b3ffc564d848c 100644 --- a/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java +++ b/server/src/main/java/org/elasticsearch/action/search/OpenPointInTimeResponse.java @@ -47,7 +47,7 @@ public OpenPointInTimeResponse( @Override public void writeTo(StreamOutput out) throws IOException { out.writeBytesReference(pointInTimeId); - if (out.getTransportVersion().onOrAfter(TransportVersions.ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVInt(totalShards); out.writeVInt(successfulShards); out.writeVInt(failedShards); diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java b/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java index ca810bb88653f..c2f1510341fb0 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchContextId.java @@ -63,14 +63,14 @@ public static BytesReference encode( TransportVersion version, ShardSearchFailure[] shardFailures ) { - assert shardFailures.length == 0 || version.onOrAfter(TransportVersions.ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT) + assert shardFailures.length == 0 || version.onOrAfter(TransportVersions.V_8_16_0) : "[allow_partial_search_results] cannot be enabled on a cluster that has not been fully upgraded to version [" - + TransportVersions.ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT + + TransportVersions.V_8_16_0.toReleaseVersion() + "] or higher."; try (var out = new BytesStreamOutput()) { out.setTransportVersion(version); TransportVersion.writeVersion(version, out); - boolean allowNullContextId = out.getTransportVersion().onOrAfter(TransportVersions.ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT); + boolean allowNullContextId = out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0); int shardSize = searchPhaseResults.size() + (allowNullContextId ? shardFailures.length : 0); out.writeVInt(shardSize); for (var searchResult : searchPhaseResults) { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchContextIdForNode.java b/server/src/main/java/org/elasticsearch/action/search/SearchContextIdForNode.java index 7509a7b0fed04..f91a9d09f4bb4 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchContextIdForNode.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchContextIdForNode.java @@ -37,7 +37,7 @@ public final class SearchContextIdForNode implements Writeable { } SearchContextIdForNode(StreamInput in) throws IOException { - boolean allowNull = in.getTransportVersion().onOrAfter(TransportVersions.ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT); + boolean allowNull = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0); this.node = allowNull ? in.readOptionalString() : in.readString(); this.clusterAlias = in.readOptionalString(); this.searchContextId = allowNull ? in.readOptionalWriteable(ShardSearchContextId::new) : new ShardSearchContextId(in); @@ -45,7 +45,7 @@ public final class SearchContextIdForNode implements Writeable { @Override public void writeTo(StreamOutput out) throws IOException { - boolean allowNull = out.getTransportVersion().onOrAfter(TransportVersions.ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT); + boolean allowNull = out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0); if (allowNull) { out.writeOptionalString(node); } else { @@ -53,7 +53,7 @@ public void writeTo(StreamOutput out) throws IOException { // We should never set a null node if the cluster is not fully upgraded to a version that can handle it. throw new IOException( "Cannot write null node value to a node in version " - + out.getTransportVersion() + + out.getTransportVersion().toReleaseVersion() + ". The target node must be specified to retrieve the ShardSearchContextId." ); } @@ -67,7 +67,7 @@ public void writeTo(StreamOutput out) throws IOException { // We should never set a null search context id if the cluster is not fully upgraded to a version that can handle it. throw new IOException( "Cannot write null search context ID to a node in version " - + out.getTransportVersion() + + out.getTransportVersion().toReleaseVersion() + ". A valid search context ID is required to identify the shard's search context in this version." ); } diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java index 9e60eedbad6a2..36ca0fba94372 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java @@ -104,8 +104,7 @@ public TransportOpenPointInTimeAction( protected void doExecute(Task task, OpenPointInTimeRequest request, ActionListener listener) { final ClusterState clusterState = clusterService.state(); // Check if all the nodes in this cluster know about the service - if (request.allowPartialSearchResults() - && clusterState.getMinTransportVersion().before(TransportVersions.ALLOW_PARTIAL_SEARCH_RESULTS_IN_PIT)) { + if (request.allowPartialSearchResults() && clusterState.getMinTransportVersion().before(TransportVersions.V_8_16_0)) { listener.onFailure( new ElasticsearchStatusException( format( diff --git a/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java b/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java index 85889d8398cb1..ebbd47336e3da 100644 --- a/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java +++ b/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java @@ -982,12 +982,11 @@ public void writeIndicesOptions(StreamOutput out) throws IOException { states.add(WildcardStates.HIDDEN); } out.writeEnumSet(states); - if (out.getTransportVersion() - .between(TransportVersions.V_8_14_0, TransportVersions.CONVERT_FAILURE_STORE_OPTIONS_TO_SELECTOR_OPTIONS_INTERNALLY)) { + if (out.getTransportVersion().between(TransportVersions.V_8_14_0, TransportVersions.V_8_16_0)) { out.writeBoolean(includeRegularIndices()); out.writeBoolean(includeFailureIndices()); } - if (out.getTransportVersion().onOrAfter(TransportVersions.CONVERT_FAILURE_STORE_OPTIONS_TO_SELECTOR_OPTIONS_INTERNALLY)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { selectorOptions.writeTo(out); } } @@ -1010,8 +1009,7 @@ public static IndicesOptions readIndicesOptions(StreamInput in) throws IOExcepti .ignoreThrottled(options.contains(Option.IGNORE_THROTTLED)) .build(); SelectorOptions selectorOptions = SelectorOptions.DEFAULT; - if (in.getTransportVersion() - .between(TransportVersions.V_8_14_0, TransportVersions.CONVERT_FAILURE_STORE_OPTIONS_TO_SELECTOR_OPTIONS_INTERNALLY)) { + if (in.getTransportVersion().between(TransportVersions.V_8_14_0, TransportVersions.V_8_16_0)) { // Reading from an older node, which will be sending two booleans that we must read out and ignore. var includeData = in.readBoolean(); var includeFailures = in.readBoolean(); @@ -1023,7 +1021,7 @@ public static IndicesOptions readIndicesOptions(StreamInput in) throws IOExcepti selectorOptions = SelectorOptions.FAILURES; } } - if (in.getTransportVersion().onOrAfter(TransportVersions.CONVERT_FAILURE_STORE_OPTIONS_TO_SELECTOR_OPTIONS_INTERNALLY)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { selectorOptions = SelectorOptions.read(in); } return new IndicesOptions( diff --git a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java index c06ea9305aef8..27cbb39c05d38 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java @@ -42,9 +42,7 @@ import org.elasticsearch.nativeaccess.NativeAccess; import org.elasticsearch.node.Node; import org.elasticsearch.node.NodeValidationException; -import org.elasticsearch.plugins.PluginBundle; import org.elasticsearch.plugins.PluginsLoader; -import org.elasticsearch.plugins.PluginsUtils; import java.io.IOException; import java.io.InputStream; @@ -54,10 +52,8 @@ import java.nio.file.Path; import java.security.Permission; import java.security.Security; -import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -208,21 +204,17 @@ private static void initPhase2(Bootstrap bootstrap) throws IOException { // load the plugin Java modules and layers now for use in entitlements var pluginsLoader = PluginsLoader.createPluginsLoader(nodeEnv.modulesFile(), nodeEnv.pluginsFile()); bootstrap.setPluginsLoader(pluginsLoader); + var pluginsResolver = PluginsResolver.create(pluginsLoader); if (Boolean.parseBoolean(System.getProperty("es.entitlements.enabled"))) { LogManager.getLogger(Elasticsearch.class).info("Bootstrapping Entitlements"); - List> pluginData = new ArrayList<>(); - Set moduleBundles = PluginsUtils.getModuleBundles(nodeEnv.modulesFile()); - for (PluginBundle moduleBundle : moduleBundles) { - pluginData.add(Tuple.tuple(moduleBundle.getDir(), moduleBundle.pluginDescriptor().isModular())); - } - Set pluginBundles = PluginsUtils.getPluginBundles(nodeEnv.pluginsFile()); - for (PluginBundle pluginBundle : pluginBundles) { - pluginData.add(Tuple.tuple(pluginBundle.getDir(), pluginBundle.pluginDescriptor().isModular())); - } - // TODO: add a functor to map module to plugin name - EntitlementBootstrap.bootstrap(pluginData, callerClass -> null); + List> pluginData = pluginsLoader.allBundles() + .stream() + .map(bundle -> Tuple.tuple(bundle.getDir(), bundle.pluginDescriptor().isModular())) + .toList(); + + EntitlementBootstrap.bootstrap(pluginData, pluginsResolver::resolveClassToPluginName); } else { // install SM after natives, shutdown hooks, etc. LogManager.getLogger(Elasticsearch.class).info("Bootstrapping java SecurityManager"); diff --git a/server/src/main/java/org/elasticsearch/bootstrap/PluginsResolver.java b/server/src/main/java/org/elasticsearch/bootstrap/PluginsResolver.java new file mode 100644 index 0000000000000..256e91cbee16d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/bootstrap/PluginsResolver.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.bootstrap; + +import org.elasticsearch.plugins.PluginsLoader; + +import java.util.HashMap; +import java.util.Map; + +class PluginsResolver { + private final Map pluginNameByModule; + + private PluginsResolver(Map pluginNameByModule) { + this.pluginNameByModule = pluginNameByModule; + } + + public static PluginsResolver create(PluginsLoader pluginsLoader) { + Map pluginNameByModule = new HashMap<>(); + + pluginsLoader.pluginLayers().forEach(pluginLayer -> { + var pluginName = pluginLayer.pluginBundle().pluginDescriptor().getName(); + if (pluginLayer.pluginModuleLayer() != null && pluginLayer.pluginModuleLayer() != ModuleLayer.boot()) { + // This plugin is a Java Module + for (var module : pluginLayer.pluginModuleLayer().modules()) { + pluginNameByModule.put(module, pluginName); + } + } else { + // This plugin is not modularized + pluginNameByModule.put(pluginLayer.pluginClassLoader().getUnnamedModule(), pluginName); + } + }); + + return new PluginsResolver(pluginNameByModule); + } + + public String resolveClassToPluginName(Class clazz) { + var module = clazz.getModule(); + return pluginNameByModule.get(module); + } +} diff --git a/server/src/main/java/org/elasticsearch/cluster/health/ClusterIndexHealth.java b/server/src/main/java/org/elasticsearch/cluster/health/ClusterIndexHealth.java index b6c1defe91a75..9cf567c219660 100644 --- a/server/src/main/java/org/elasticsearch/cluster/health/ClusterIndexHealth.java +++ b/server/src/main/java/org/elasticsearch/cluster/health/ClusterIndexHealth.java @@ -111,7 +111,7 @@ public ClusterIndexHealth(final StreamInput in) throws IOException { unassignedShards = in.readVInt(); status = ClusterHealthStatus.readFrom(in); shards = in.readMapValues(ClusterShardHealth::new, ClusterShardHealth::getShardId); - if (in.getTransportVersion().onOrAfter(TransportVersions.UNASSIGNED_PRIMARY_COUNT_ON_CLUSTER_HEALTH)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { unassignedPrimaryShards = in.readVInt(); } else { unassignedPrimaryShards = 0; @@ -203,7 +203,7 @@ public void writeTo(final StreamOutput out) throws IOException { out.writeVInt(unassignedShards); out.writeByte(status.value()); out.writeMapValues(shards); - if (out.getTransportVersion().onOrAfter(TransportVersions.UNASSIGNED_PRIMARY_COUNT_ON_CLUSTER_HEALTH)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVInt(unassignedPrimaryShards); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/health/ClusterShardHealth.java b/server/src/main/java/org/elasticsearch/cluster/health/ClusterShardHealth.java index 63863542564cd..f512acb6e04d0 100644 --- a/server/src/main/java/org/elasticsearch/cluster/health/ClusterShardHealth.java +++ b/server/src/main/java/org/elasticsearch/cluster/health/ClusterShardHealth.java @@ -96,7 +96,7 @@ public ClusterShardHealth(final StreamInput in) throws IOException { initializingShards = in.readVInt(); unassignedShards = in.readVInt(); primaryActive = in.readBoolean(); - if (in.getTransportVersion().onOrAfter(TransportVersions.UNASSIGNED_PRIMARY_COUNT_ON_CLUSTER_HEALTH)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { unassignedPrimaryShards = in.readVInt(); } else { unassignedPrimaryShards = 0; @@ -167,7 +167,7 @@ public void writeTo(final StreamOutput out) throws IOException { out.writeVInt(initializingShards); out.writeVInt(unassignedShards); out.writeBoolean(primaryActive); - if (out.getTransportVersion().onOrAfter(TransportVersions.UNASSIGNED_PRIMARY_COUNT_ON_CLUSTER_HEALTH)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVInt(unassignedPrimaryShards); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/health/ClusterStateHealth.java b/server/src/main/java/org/elasticsearch/cluster/health/ClusterStateHealth.java index 579429b5d51dd..31f275e29c368 100644 --- a/server/src/main/java/org/elasticsearch/cluster/health/ClusterStateHealth.java +++ b/server/src/main/java/org/elasticsearch/cluster/health/ClusterStateHealth.java @@ -120,7 +120,7 @@ public ClusterStateHealth(final StreamInput in) throws IOException { status = ClusterHealthStatus.readFrom(in); indices = in.readMapValues(ClusterIndexHealth::new, ClusterIndexHealth::getIndex); activeShardsPercent = in.readDouble(); - if (in.getTransportVersion().onOrAfter(TransportVersions.UNASSIGNED_PRIMARY_COUNT_ON_CLUSTER_HEALTH)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { unassignedPrimaryShards = in.readVInt(); } else { unassignedPrimaryShards = 0; @@ -212,7 +212,7 @@ public void writeTo(final StreamOutput out) throws IOException { out.writeByte(status.value()); out.writeMapValues(indices); out.writeDouble(activeShardsPercent); - if (out.getTransportVersion().onOrAfter(TransportVersions.UNASSIGNED_PRIMARY_COUNT_ON_CLUSTER_HEALTH)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVInt(unassignedPrimaryShards); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java index 4dcc7c73c280e..979434950cf7a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java @@ -71,6 +71,7 @@ public final class DataStream implements SimpleDiffable, ToXContentO public static final FeatureFlag FAILURE_STORE_FEATURE_FLAG = new FeatureFlag("failure_store"); public static final TransportVersion ADDED_FAILURE_STORE_TRANSPORT_VERSION = TransportVersions.V_8_12_0; public static final TransportVersion ADDED_AUTO_SHARDING_EVENT_VERSION = TransportVersions.V_8_14_0; + public static final TransportVersion ADD_DATA_STREAM_OPTIONS_VERSION = TransportVersions.V_8_16_0; public static boolean isFailureStoreFeatureFlagEnabled() { return FAILURE_STORE_FEATURE_FLAG.isEnabled(); @@ -200,9 +201,7 @@ public static DataStream read(StreamInput in) throws IOException { : null; // This boolean flag has been moved in data stream options var failureStoreEnabled = in.getTransportVersion() - .between(DataStream.ADDED_FAILURE_STORE_TRANSPORT_VERSION, TransportVersions.ADD_DATA_STREAM_OPTIONS) - ? in.readBoolean() - : false; + .between(DataStream.ADDED_FAILURE_STORE_TRANSPORT_VERSION, TransportVersions.V_8_16_0) ? in.readBoolean() : false; var failureIndices = in.getTransportVersion().onOrAfter(DataStream.ADDED_FAILURE_STORE_TRANSPORT_VERSION) ? readIndices(in) : List.of(); @@ -216,7 +215,7 @@ public static DataStream read(StreamInput in) throws IOException { .setAutoShardingEvent(in.readOptionalWriteable(DataStreamAutoShardingEvent::new)); } DataStreamOptions dataStreamOptions; - if (in.getTransportVersion().onOrAfter(TransportVersions.ADD_DATA_STREAM_OPTIONS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { dataStreamOptions = in.readOptionalWriteable(DataStreamOptions::read); } else { // We cannot distinguish if failure store was explicitly disabled or not. Given that failure store @@ -1077,7 +1076,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalWriteable(lifecycle); } if (out.getTransportVersion() - .between(DataStream.ADDED_FAILURE_STORE_TRANSPORT_VERSION, TransportVersions.ADD_DATA_STREAM_OPTIONS)) { + .between(DataStream.ADDED_FAILURE_STORE_TRANSPORT_VERSION, DataStream.ADD_DATA_STREAM_OPTIONS_VERSION)) { out.writeBoolean(isFailureStoreEnabled()); } if (out.getTransportVersion().onOrAfter(DataStream.ADDED_FAILURE_STORE_TRANSPORT_VERSION)) { @@ -1093,7 +1092,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeBoolean(failureIndices.rolloverOnWrite); out.writeOptionalWriteable(failureIndices.autoShardingEvent); } - if (out.getTransportVersion().onOrAfter(TransportVersions.ADD_DATA_STREAM_OPTIONS)) { + if (out.getTransportVersion().onOrAfter(DataStream.ADD_DATA_STREAM_OPTIONS_VERSION)) { out.writeOptionalWriteable(dataStreamOptions.isEmpty() ? null : dataStreamOptions); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 271c60e829a87..8917d5a9cbbb5 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -9,6 +9,7 @@ package org.elasticsearch.cluster.metadata; +import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.Diff; import org.elasticsearch.cluster.SimpleDiffable; import org.elasticsearch.common.io.stream.StreamInput; @@ -23,8 +24,6 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_SEARCH_INFERENCE_ID; - /** * Contains inference field data for fields. * As inference is done in the coordinator node to avoid re-doing it at shard / replica level, the coordinator needs to check for the need @@ -56,7 +55,7 @@ public InferenceFieldMetadata(String name, String inferenceId, String searchInfe public InferenceFieldMetadata(StreamInput input) throws IOException { this.name = input.readString(); this.inferenceId = input.readString(); - if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_SEARCH_INFERENCE_ID)) { + if (input.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.searchInferenceId = input.readString(); } else { this.searchInferenceId = this.inferenceId; @@ -68,7 +67,7 @@ public InferenceFieldMetadata(StreamInput input) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeString(name); out.writeString(inferenceId); - if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_SEARCH_INFERENCE_ID)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeString(searchInferenceId); } out.writeStringArray(sourceFields); diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java b/server/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java index 790b8e4ab75fa..60cf6b10417fa 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/RoutingTable.java @@ -317,7 +317,7 @@ public static Diff readDiffFrom(StreamInput in) throws IOException public static RoutingTable readFrom(StreamInput in) throws IOException { Builder builder = new Builder(); - if (in.getTransportVersion().before(TransportVersions.ROUTING_TABLE_VERSION_REMOVED)) { + if (in.getTransportVersion().before(TransportVersions.V_8_16_0)) { in.readLong(); // previously 'version', unused in all applicable versions so any number will do } int size = in.readVInt(); @@ -331,7 +331,7 @@ public static RoutingTable readFrom(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { - if (out.getTransportVersion().before(TransportVersions.ROUTING_TABLE_VERSION_REMOVED)) { + if (out.getTransportVersion().before(TransportVersions.V_8_16_0)) { out.writeLong(0); // previously 'version', unused in all applicable versions so any number will do } out.writeCollection(indicesRouting.values()); @@ -349,7 +349,7 @@ private static class RoutingTableDiff implements Diff { new DiffableUtils.DiffableValueReader<>(IndexRoutingTable::readFrom, IndexRoutingTable::readDiffFrom); RoutingTableDiff(StreamInput in) throws IOException { - if (in.getTransportVersion().before(TransportVersions.ROUTING_TABLE_VERSION_REMOVED)) { + if (in.getTransportVersion().before(TransportVersions.V_8_16_0)) { in.readLong(); // previously 'version', unused in all applicable versions so any number will do } indicesRouting = DiffableUtils.readImmutableOpenMapDiff(in, DiffableUtils.getStringKeySerializer(), DIFF_VALUE_READER); @@ -366,7 +366,7 @@ public RoutingTable apply(RoutingTable part) { @Override public void writeTo(StreamOutput out) throws IOException { - if (out.getTransportVersion().before(TransportVersions.ROUTING_TABLE_VERSION_REMOVED)) { + if (out.getTransportVersion().before(TransportVersions.V_8_16_0)) { out.writeLong(0); // previously 'version', unused in all applicable versions so any number will do } indicesRouting.writeTo(out); diff --git a/server/src/main/java/org/elasticsearch/common/hash/Murmur3Hasher.java b/server/src/main/java/org/elasticsearch/common/hash/Murmur3Hasher.java index 817587771d795..aec28484138fb 100644 --- a/server/src/main/java/org/elasticsearch/common/hash/Murmur3Hasher.java +++ b/server/src/main/java/org/elasticsearch/common/hash/Murmur3Hasher.java @@ -40,7 +40,12 @@ public void update(byte[] inputBytes) { update(inputBytes, 0, inputBytes.length); } - private void update(byte[] inputBytes, int offset, int length) { + /** + * Similar to {@link #update(byte[])}, but processes a specific portion of the input bytes + * starting from the given {@code offset} for the specified {@code length}. + * @see #update(byte[]) + */ + public void update(byte[] inputBytes, int offset, int length) { if (remainderLength + length >= remainder.length) { if (remainderLength > 0) { // fill rest of remainder from inputBytes and hash remainder diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java index 644cc6bb69927..e07861ba05433 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java @@ -908,11 +908,8 @@ public final Instant readOptionalInstant() throws IOException { private ZonedDateTime readZonedDateTime() throws IOException { final String timeZoneId = readString(); final Instant instant; - if (getTransportVersion().onOrAfter(TransportVersions.ZDT_NANOS_SUPPORT_BROKEN)) { - // epoch seconds can be negative, but it was incorrectly first written as vlong - boolean zlong = getTransportVersion().onOrAfter(TransportVersions.ZDT_NANOS_SUPPORT); - long seconds = zlong ? readZLong() : readVLong(); - instant = Instant.ofEpochSecond(seconds, readInt()); + if (getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { + instant = Instant.ofEpochSecond(readZLong(), readInt()); } else { instant = Instant.ofEpochMilli(readLong()); } diff --git a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java index d724e5ea25ca6..6738af32f04d6 100644 --- a/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java +++ b/server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java @@ -768,13 +768,8 @@ public final void writeOptionalInstant(@Nullable Instant instant) throws IOExcep final ZonedDateTime zonedDateTime = (ZonedDateTime) v; o.writeString(zonedDateTime.getZone().getId()); Instant instant = zonedDateTime.toInstant(); - if (o.getTransportVersion().onOrAfter(TransportVersions.ZDT_NANOS_SUPPORT_BROKEN)) { - // epoch seconds can be negative, but it was incorrectly first written as vlong - if (o.getTransportVersion().onOrAfter(TransportVersions.ZDT_NANOS_SUPPORT)) { - o.writeZLong(instant.getEpochSecond()); - } else { - o.writeVLong(instant.getEpochSecond()); - } + if (o.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { + o.writeZLong(instant.getEpochSecond()); o.writeInt(instant.getNano()); } else { o.writeLong(instant.toEpochMilli()); diff --git a/server/src/main/java/org/elasticsearch/common/logging/DeprecatedMessage.java b/server/src/main/java/org/elasticsearch/common/logging/DeprecatedMessage.java index 0bcde14fcf19a..ca89313e59de2 100644 --- a/server/src/main/java/org/elasticsearch/common/logging/DeprecatedMessage.java +++ b/server/src/main/java/org/elasticsearch/common/logging/DeprecatedMessage.java @@ -57,7 +57,7 @@ private static ESLogMessage getEsLogMessage( String messagePattern, Object[] args ) { - ESLogMessage esLogMessage = new ESLogMessage(messagePattern, args).field("data_stream.dataset", "deprecation.elasticsearch") + ESLogMessage esLogMessage = new ESLogMessage(messagePattern, args).field("data_stream.dataset", "elasticsearch.deprecation") .field("data_stream.type", "logs") .field("data_stream.namespace", "default") .field(KEY_FIELD_NAME, key) diff --git a/server/src/main/java/org/elasticsearch/common/time/DateUtils.java b/server/src/main/java/org/elasticsearch/common/time/DateUtils.java index 9f642734ba832..72306b6ed675e 100644 --- a/server/src/main/java/org/elasticsearch/common/time/DateUtils.java +++ b/server/src/main/java/org/elasticsearch/common/time/DateUtils.java @@ -293,6 +293,37 @@ public static long toMilliSeconds(long nanoSecondsSinceEpoch) { return nanoSecondsSinceEpoch / 1_000_000; } + /** + * Compare an epoch nanosecond date (such as returned by {@link DateUtils#toLong} + * to an epoch millisecond date (such as returned by {@link Instant#toEpochMilli()}}. + *

+ * NB: This function does not implement {@link java.util.Comparator} in + * order to avoid performance costs of autoboxing the input longs. + * + * @param nanos Epoch date represented as a long number of nanoseconds. + * Note that Elasticsearch does not support nanosecond dates + * before Epoch, so this number should never be negative. + * @param millis Epoch date represented as a long number of milliseconds. + * This parameter does not have to be constrained to the + * range of long nanosecond dates. + * @return -1 if the nanosecond date is before the millisecond date, + * 0 if the two dates represent the same instant, + * 1 if the nanosecond date is after the millisecond date + */ + public static int compareNanosToMillis(long nanos, long millis) { + assert nanos >= 0; + if (millis < 0) { + return 1; + } + if (millis > MAX_NANOSECOND_IN_MILLIS) { + return -1; + } + // This can't overflow, because we know millis is between 0 and MAX_NANOSECOND_IN_MILLIS, + // and MAX_NANOSECOND_IN_MILLIS * 1_000_000 doesn't overflow. + long diff = nanos - (millis * 1_000_000); + return diff == 0 ? 0 : diff < 0 ? -1 : 1; + } + /** * Rounds the given utc milliseconds sicne the epoch down to the next unit millis * diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java index 2e78cc6f516b1..6a5aa2943de92 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.xcontent.ToXContent; +import java.util.Collections; import java.util.Iterator; public enum ChunkedToXContentHelper { @@ -53,6 +54,14 @@ public static Iterator field(String name, String value) { return Iterators.single(((builder, params) -> builder.field(name, value))); } + public static Iterator optionalField(String name, String value) { + if (value == null) { + return Collections.emptyIterator(); + } else { + return field(name, value); + } + } + /** * Creates an Iterator of a single ToXContent object that serializes the given object as a single chunk. Just wraps {@link * Iterators#single}, but still useful because it avoids any type ambiguity. diff --git a/server/src/main/java/org/elasticsearch/index/engine/CommitStats.java b/server/src/main/java/org/elasticsearch/index/engine/CommitStats.java index a871524b45e9e..520174a4b3638 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/CommitStats.java +++ b/server/src/main/java/org/elasticsearch/index/engine/CommitStats.java @@ -46,7 +46,7 @@ public CommitStats(SegmentInfos segmentInfos) { generation = in.readLong(); id = in.readOptionalString(); numDocs = in.readInt(); - numLeaves = in.getTransportVersion().onOrAfter(TransportVersions.SEGMENT_LEVEL_FIELDS_STATS) ? in.readVInt() : 0; + numLeaves = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readVInt() : 0; } @Override @@ -100,7 +100,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeLong(generation); out.writeOptionalString(id); out.writeInt(numDocs); - if (out.getTransportVersion().onOrAfter(TransportVersions.SEGMENT_LEVEL_FIELDS_STATS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVInt(numLeaves); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java index e00e7b2320000..9ddb6f0d496a0 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParser.java @@ -946,7 +946,9 @@ public Query termQuery(Object value, SearchExecutionContext context) { protected void parseCreateField(DocumentParserContext context) { // Run-time fields are mapped to this mapper, so it needs to handle storing values for use in synthetic source. // #parseValue calls this method once the run-time field is created. - if (context.dynamic() == ObjectMapper.Dynamic.RUNTIME && context.canAddIgnoredField()) { + var fieldType = context.mappingLookup().getFieldType(path); + boolean isRuntimeField = fieldType instanceof AbstractScriptFieldType; + if ((context.dynamic() == ObjectMapper.Dynamic.RUNTIME || isRuntimeField) && context.canAddIgnoredField()) { try { context.addIgnoredField( IgnoredSourceFieldMapper.NameValue.fromContext(context, path, context.encodeFlattenedToken()) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java index ffb38d229078e..276d3e151361c 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MapperFeatures.java @@ -62,6 +62,7 @@ public Set getFeatures() { ); public static final NodeFeature META_FETCH_FIELDS_ERROR_CODE_CHANGED = new NodeFeature("meta_fetch_fields_error_code_changed"); + public static final NodeFeature SPARSE_VECTOR_STORE_SUPPORT = new NodeFeature("mapper.sparse_vector.store_support"); @Override public Set getTestFeatures() { @@ -75,7 +76,8 @@ public Set getTestFeatures() { MapperService.LOGSDB_DEFAULT_IGNORE_DYNAMIC_BEYOND_LIMIT, DocumentParser.FIX_PARSING_SUBOBJECTS_FALSE_DYNAMIC_FALSE, CONSTANT_KEYWORD_SYNTHETIC_SOURCE_WRITE_FIX, - META_FETCH_FIELDS_ERROR_CODE_CHANGED + META_FETCH_FIELDS_ERROR_CODE_CHANGED, + SPARSE_VECTOR_STORE_SUPPORT ); } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/NodeMappingStats.java b/server/src/main/java/org/elasticsearch/index/mapper/NodeMappingStats.java index 56210a292995c..10b0856540399 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/NodeMappingStats.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/NodeMappingStats.java @@ -52,7 +52,7 @@ public NodeMappingStats() { public NodeMappingStats(StreamInput in) throws IOException { totalCount = in.readVLong(); totalEstimatedOverhead = in.readVLong(); - if (in.getTransportVersion().onOrAfter(TransportVersions.SEGMENT_LEVEL_FIELDS_STATS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { totalSegments = in.readVLong(); totalSegmentFields = in.readVLong(); } @@ -93,7 +93,7 @@ public long getTotalSegmentFields() { public void writeTo(StreamOutput out) throws IOException { out.writeVLong(totalCount); out.writeVLong(totalEstimatedOverhead); - if (out.getTransportVersion().onOrAfter(TransportVersions.SEGMENT_LEVEL_FIELDS_STATS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVLong(totalSegments); out.writeVLong(totalSegmentFields); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java index d0a8dfae4f242..552e66336005d 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapper.java @@ -11,6 +11,12 @@ import org.apache.lucene.document.FeatureField; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.TermVectors; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; @@ -25,14 +31,22 @@ import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.mapper.SourceValueFetcher; import org.elasticsearch.index.mapper.TextSearchInfo; import org.elasticsearch.index.mapper.ValueFetcher; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.fetch.StoredFieldsSpec; +import org.elasticsearch.search.lookup.Source; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser.Token; import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; +import java.util.stream.Stream; import static org.elasticsearch.index.query.AbstractQueryBuilder.DEFAULT_BOOST; @@ -52,8 +66,12 @@ public class SparseVectorFieldMapper extends FieldMapper { static final IndexVersion NEW_SPARSE_VECTOR_INDEX_VERSION = IndexVersions.NEW_SPARSE_VECTOR; static final IndexVersion SPARSE_VECTOR_IN_FIELD_NAMES_INDEX_VERSION = IndexVersions.SPARSE_VECTOR_IN_FIELD_NAMES_SUPPORT; - public static class Builder extends FieldMapper.Builder { + private static SparseVectorFieldMapper toType(FieldMapper in) { + return (SparseVectorFieldMapper) in; + } + public static class Builder extends FieldMapper.Builder { + private final Parameter stored = Parameter.storeParam(m -> toType(m).fieldType().isStored(), false); private final Parameter> meta = Parameter.metaParam(); public Builder(String name) { @@ -62,14 +80,14 @@ public Builder(String name) { @Override protected Parameter[] getParameters() { - return new Parameter[] { meta }; + return new Parameter[] { stored, meta }; } @Override public SparseVectorFieldMapper build(MapperBuilderContext context) { return new SparseVectorFieldMapper( leafName(), - new SparseVectorFieldType(context.buildFullName(leafName()), meta.getValue()), + new SparseVectorFieldType(context.buildFullName(leafName()), stored.getValue(), meta.getValue()), builderParams(this, context) ); } @@ -87,8 +105,8 @@ public SparseVectorFieldMapper build(MapperBuilderContext context) { public static final class SparseVectorFieldType extends MappedFieldType { - public SparseVectorFieldType(String name, Map meta) { - super(name, true, false, false, TextSearchInfo.SIMPLE_MATCH_ONLY, meta); + public SparseVectorFieldType(String name, boolean isStored, Map meta) { + super(name, true, isStored, false, TextSearchInfo.SIMPLE_MATCH_ONLY, meta); } @Override @@ -103,6 +121,9 @@ public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext @Override public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + if (isStored()) { + return new SparseVectorValueFetcher(name()); + } return SourceValueFetcher.identity(name(), context, format); } @@ -135,6 +156,14 @@ private SparseVectorFieldMapper(String simpleName, MappedFieldType mappedFieldTy super(simpleName, mappedFieldType, builderParams); } + @Override + protected SyntheticSourceSupport syntheticSourceSupport() { + if (fieldType().isStored()) { + return new SyntheticSourceSupport.Native(new SparseVectorSyntheticFieldLoader(fullPath(), leafName())); + } + return super.syntheticSourceSupport(); + } + @Override public Map indexAnalyzers() { return Map.of(mappedFieldType.name(), Lucene.KEYWORD_ANALYZER); @@ -189,9 +218,9 @@ public void parse(DocumentParserContext context) throws IOException { // based on recommendations from this paper: https://arxiv.org/pdf/2305.18494.pdf IndexableField currentField = context.doc().getByKey(key); if (currentField == null) { - context.doc().addWithKey(key, new FeatureField(fullPath(), feature, value)); - } else if (currentField instanceof FeatureField && ((FeatureField) currentField).getFeatureValue() < value) { - ((FeatureField) currentField).setFeatureValue(value); + context.doc().addWithKey(key, new XFeatureField(fullPath(), feature, value, fieldType().isStored())); + } else if (currentField instanceof XFeatureField && ((XFeatureField) currentField).getFeatureValue() < value) { + ((XFeatureField) currentField).setFeatureValue(value); } } else { throw new IllegalArgumentException( @@ -219,4 +248,114 @@ protected String contentType() { return CONTENT_TYPE; } + private static class SparseVectorValueFetcher implements ValueFetcher { + private final String fieldName; + private TermVectors termVectors; + + private SparseVectorValueFetcher(String fieldName) { + this.fieldName = fieldName; + } + + @Override + public void setNextReader(LeafReaderContext context) { + try { + termVectors = context.reader().termVectors(); + } catch (IOException exc) { + throw new UncheckedIOException(exc); + } + } + + @Override + public List fetchValues(Source source, int doc, List ignoredValues) throws IOException { + if (termVectors == null) { + return List.of(); + } + var terms = termVectors.get(doc, fieldName); + if (terms == null) { + return List.of(); + } + + var termsEnum = terms.iterator(); + PostingsEnum postingsScratch = null; + Map result = new LinkedHashMap<>(); + while (termsEnum.next() != null) { + postingsScratch = termsEnum.postings(postingsScratch); + postingsScratch.nextDoc(); + result.put(termsEnum.term().utf8ToString(), XFeatureField.decodeFeatureValue(postingsScratch.freq())); + assert postingsScratch.nextDoc() == DocIdSetIterator.NO_MORE_DOCS; + } + return List.of(result); + } + + @Override + public StoredFieldsSpec storedFieldsSpec() { + return StoredFieldsSpec.NO_REQUIREMENTS; + } + } + + private static class SparseVectorSyntheticFieldLoader implements SourceLoader.SyntheticFieldLoader { + private final String fullPath; + private final String leafName; + + private TermsEnum termsDocEnum; + + private SparseVectorSyntheticFieldLoader(String fullPath, String leafName) { + this.fullPath = fullPath; + this.leafName = leafName; + } + + @Override + public Stream> storedFieldLoaders() { + return Stream.of(); + } + + @Override + public DocValuesLoader docValuesLoader(LeafReader leafReader, int[] docIdsInLeaf) throws IOException { + var fieldInfos = leafReader.getFieldInfos().fieldInfo(fullPath); + if (fieldInfos == null || fieldInfos.hasTermVectors() == false) { + return null; + } + return docId -> { + var terms = leafReader.termVectors().get(docId, fullPath); + if (terms == null) { + return false; + } + termsDocEnum = terms.iterator(); + if (termsDocEnum.next() == null) { + termsDocEnum = null; + return false; + } + return true; + }; + } + + @Override + public boolean hasValue() { + return termsDocEnum != null; + } + + @Override + public void write(XContentBuilder b) throws IOException { + assert termsDocEnum != null; + PostingsEnum reuse = null; + b.startObject(leafName); + do { + reuse = termsDocEnum.postings(reuse); + reuse.nextDoc(); + b.field(termsDocEnum.term().utf8ToString(), XFeatureField.decodeFeatureValue(reuse.freq())); + } while (termsDocEnum.next() != null); + b.endObject(); + } + + @Override + public String fieldName() { + return leafName; + } + + @Override + public void reset() { + termsDocEnum = null; + } + } + } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/XFeatureField.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/XFeatureField.java new file mode 100644 index 0000000000000..5f4afb4a86acc --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/XFeatureField.java @@ -0,0 +1,177 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.elasticsearch.index.mapper.vectors; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.IndexOptions; + +/** + * This class is forked from the Lucene {@link FeatureField} implementation to enable support for storing term vectors. + * It should be removed once apache/lucene#14034 becomes available. + */ +public final class XFeatureField extends Field { + private static final FieldType FIELD_TYPE = new FieldType(); + private static final FieldType FIELD_TYPE_STORE_TERM_VECTORS = new FieldType(); + + static { + FIELD_TYPE.setTokenized(false); + FIELD_TYPE.setOmitNorms(true); + FIELD_TYPE.setIndexOptions(IndexOptions.DOCS_AND_FREQS); + + FIELD_TYPE_STORE_TERM_VECTORS.setTokenized(false); + FIELD_TYPE_STORE_TERM_VECTORS.setOmitNorms(true); + FIELD_TYPE_STORE_TERM_VECTORS.setIndexOptions(IndexOptions.DOCS_AND_FREQS); + FIELD_TYPE_STORE_TERM_VECTORS.setStoreTermVectors(true); + } + + private float featureValue; + + /** + * Create a feature. + * + * @param fieldName The name of the field to store the information into. All features may be + * stored in the same field. + * @param featureName The name of the feature, eg. 'pagerank`. It will be indexed as a term. + * @param featureValue The value of the feature, must be a positive, finite, normal float. + */ + public XFeatureField(String fieldName, String featureName, float featureValue) { + this(fieldName, featureName, featureValue, false); + } + + /** + * Create a feature. + * + * @param fieldName The name of the field to store the information into. All features may be + * stored in the same field. + * @param featureName The name of the feature, eg. 'pagerank`. It will be indexed as a term. + * @param featureValue The value of the feature, must be a positive, finite, normal float. + */ + public XFeatureField(String fieldName, String featureName, float featureValue, boolean storeTermVectors) { + super(fieldName, featureName, storeTermVectors ? FIELD_TYPE_STORE_TERM_VECTORS : FIELD_TYPE); + setFeatureValue(featureValue); + } + + /** + * Update the feature value of this field. + */ + public void setFeatureValue(float featureValue) { + if (Float.isFinite(featureValue) == false) { + throw new IllegalArgumentException( + "featureValue must be finite, got: " + featureValue + " for feature " + fieldsData + " on field " + name + ); + } + if (featureValue < Float.MIN_NORMAL) { + throw new IllegalArgumentException( + "featureValue must be a positive normal float, got: " + + featureValue + + " for feature " + + fieldsData + + " on field " + + name + + " which is less than the minimum positive normal float: " + + Float.MIN_NORMAL + ); + } + this.featureValue = featureValue; + } + + @Override + public TokenStream tokenStream(Analyzer analyzer, TokenStream reuse) { + FeatureTokenStream stream; + if (reuse instanceof FeatureTokenStream) { + stream = (FeatureTokenStream) reuse; + } else { + stream = new FeatureTokenStream(); + } + + int freqBits = Float.floatToIntBits(featureValue); + stream.setValues((String) fieldsData, freqBits >>> 15); + return stream; + } + + /** + * This is useful if you have multiple features sharing a name and you want to take action to + * deduplicate them. + * + * @return the feature value of this field. + */ + public float getFeatureValue() { + return featureValue; + } + + private static final class FeatureTokenStream extends TokenStream { + private final CharTermAttribute termAttribute = addAttribute(CharTermAttribute.class); + private final TermFrequencyAttribute freqAttribute = addAttribute(TermFrequencyAttribute.class); + private boolean used = true; + private String value = null; + private int freq = 0; + + private FeatureTokenStream() {} + + /** + * Sets the values + */ + void setValues(String value, int freq) { + this.value = value; + this.freq = freq; + } + + @Override + public boolean incrementToken() { + if (used) { + return false; + } + clearAttributes(); + termAttribute.append(value); + freqAttribute.setTermFrequency(freq); + used = true; + return true; + } + + @Override + public void reset() { + used = false; + } + + @Override + public void close() { + value = null; + } + } + + static final int MAX_FREQ = Float.floatToIntBits(Float.MAX_VALUE) >>> 15; + + static float decodeFeatureValue(float freq) { + if (freq > MAX_FREQ) { + // This is never used in practice but callers of the SimScorer API might + // occasionally call it on eg. Float.MAX_VALUE to compute the max score + // so we need to be consistent. + return Float.MAX_VALUE; + } + int tf = (int) freq; // lossless + int featureBits = tf << 15; + return Float.intBitsToFloat(featureBits); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/query/IntervalsSourceProvider.java b/server/src/main/java/org/elasticsearch/index/query/IntervalsSourceProvider.java index 647e45d1beda1..6ae0c4872cfa5 100644 --- a/server/src/main/java/org/elasticsearch/index/query/IntervalsSourceProvider.java +++ b/server/src/main/java/org/elasticsearch/index/query/IntervalsSourceProvider.java @@ -825,7 +825,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.REGEX_AND_RANGE_INTERVAL_QUERIES; + return TransportVersions.V_8_16_0; } @Override @@ -1129,7 +1129,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.REGEX_AND_RANGE_INTERVAL_QUERIES; + return TransportVersions.V_8_16_0; } @Override diff --git a/server/src/main/java/org/elasticsearch/index/query/NestedQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/NestedQueryBuilder.java index 83bca7d27aeeb..503b2adf756f5 100644 --- a/server/src/main/java/org/elasticsearch/index/query/NestedQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/NestedQueryBuilder.java @@ -321,8 +321,7 @@ public static Query toQuery( // ToParentBlockJoinQuery requires that the inner query only matches documents // in its child space - NestedHelper nestedHelper = new NestedHelper(context.nestedLookup(), context::isFieldMapped); - if (nestedHelper.mightMatchNonNestedDocs(innerQuery, path)) { + if (NestedHelper.mightMatchNonNestedDocs(innerQuery, path, context)) { innerQuery = Queries.filtered(innerQuery, mapper.nestedTypeFilter()); } diff --git a/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java b/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java index 33077697a2ce6..889fa40b79aa1 100644 --- a/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/index/query/RankDocsQueryBuilder.java @@ -25,8 +25,6 @@ import java.util.Map; import java.util.Objects; -import static org.elasticsearch.TransportVersions.RRF_QUERY_REWRITE; - public class RankDocsQueryBuilder extends AbstractQueryBuilder { public static final String NAME = "rank_docs_query"; @@ -44,7 +42,7 @@ public RankDocsQueryBuilder(RankDoc[] rankDocs, QueryBuilder[] queryBuilders, bo public RankDocsQueryBuilder(StreamInput in) throws IOException { super(in); this.rankDocs = in.readArray(c -> c.readNamedWriteable(RankDoc.class), RankDoc[]::new); - if (in.getTransportVersion().onOrAfter(RRF_QUERY_REWRITE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.queryBuilders = in.readOptionalArray(c -> c.readNamedWriteable(QueryBuilder.class), QueryBuilder[]::new); this.onlyRankDocs = in.readBoolean(); } else { @@ -85,7 +83,7 @@ public RankDoc[] rankDocs() { @Override protected void doWriteTo(StreamOutput out) throws IOException { out.writeArray(StreamOutput::writeNamedWriteable, rankDocs); - if (out.getTransportVersion().onOrAfter(RRF_QUERY_REWRITE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalArray(StreamOutput::writeNamedWriteable, queryBuilders); out.writeBoolean(onlyRankDocs); } @@ -145,6 +143,6 @@ protected int doHashCode() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.RANK_DOCS_RETRIEVER; + return TransportVersions.V_8_16_0; } } diff --git a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java index b07112440d3c2..d5e48a6a54daa 100644 --- a/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java +++ b/server/src/main/java/org/elasticsearch/index/query/SearchExecutionContext.java @@ -493,14 +493,18 @@ public boolean containsBrokenAnalysis(String field) { */ public SearchLookup lookup() { if (this.lookup == null) { - SourceProvider sourceProvider = isSourceSynthetic() - ? SourceProvider.fromSyntheticSource(mappingLookup.getMapping(), mapperMetrics.sourceFieldMetrics()) - : SourceProvider.fromStoredFields(); + var sourceProvider = createSourceProvider(); setLookupProviders(sourceProvider, LeafFieldLookupProvider.fromStoredFields()); } return this.lookup; } + public SourceProvider createSourceProvider() { + return isSourceSynthetic() + ? SourceProvider.fromSyntheticSource(mappingLookup.getMapping(), mapperMetrics.sourceFieldMetrics()) + : SourceProvider.fromStoredFields(); + } + /** * Replace the standard source provider and field lookup provider on the SearchLookup * diff --git a/server/src/main/java/org/elasticsearch/index/search/NestedHelper.java b/server/src/main/java/org/elasticsearch/index/search/NestedHelper.java index 96e8ac35c8e32..a04f930e052b9 100644 --- a/server/src/main/java/org/elasticsearch/index/search/NestedHelper.java +++ b/server/src/main/java/org/elasticsearch/index/search/NestedHelper.java @@ -21,29 +21,21 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TermInSetQuery; import org.apache.lucene.search.TermQuery; -import org.elasticsearch.index.mapper.NestedLookup; import org.elasticsearch.index.mapper.NestedObjectMapper; - -import java.util.function.Predicate; +import org.elasticsearch.index.query.SearchExecutionContext; /** Utility class to filter parent and children clauses when building nested * queries. */ public final class NestedHelper { - private final NestedLookup nestedLookup; - private final Predicate isMappedFieldPredicate; - - public NestedHelper(NestedLookup nestedLookup, Predicate isMappedFieldPredicate) { - this.nestedLookup = nestedLookup; - this.isMappedFieldPredicate = isMappedFieldPredicate; - } + private NestedHelper() {} /** Returns true if the given query might match nested documents. */ - public boolean mightMatchNestedDocs(Query query) { + public static boolean mightMatchNestedDocs(Query query, SearchExecutionContext searchExecutionContext) { if (query instanceof ConstantScoreQuery) { - return mightMatchNestedDocs(((ConstantScoreQuery) query).getQuery()); + return mightMatchNestedDocs(((ConstantScoreQuery) query).getQuery(), searchExecutionContext); } else if (query instanceof BoostQuery) { - return mightMatchNestedDocs(((BoostQuery) query).getQuery()); + return mightMatchNestedDocs(((BoostQuery) query).getQuery(), searchExecutionContext); } else if (query instanceof MatchAllDocsQuery) { return true; } else if (query instanceof MatchNoDocsQuery) { @@ -51,17 +43,17 @@ public boolean mightMatchNestedDocs(Query query) { } else if (query instanceof TermQuery) { // We only handle term(s) queries and range queries, which should already // cover a high majority of use-cases - return mightMatchNestedDocs(((TermQuery) query).getTerm().field()); + return mightMatchNestedDocs(((TermQuery) query).getTerm().field(), searchExecutionContext); } else if (query instanceof TermInSetQuery tis) { if (tis.getTermsCount() > 0) { - return mightMatchNestedDocs(tis.getField()); + return mightMatchNestedDocs(tis.getField(), searchExecutionContext); } else { return false; } } else if (query instanceof PointRangeQuery) { - return mightMatchNestedDocs(((PointRangeQuery) query).getField()); + return mightMatchNestedDocs(((PointRangeQuery) query).getField(), searchExecutionContext); } else if (query instanceof IndexOrDocValuesQuery) { - return mightMatchNestedDocs(((IndexOrDocValuesQuery) query).getIndexQuery()); + return mightMatchNestedDocs(((IndexOrDocValuesQuery) query).getIndexQuery(), searchExecutionContext); } else if (query instanceof final BooleanQuery bq) { final boolean hasRequiredClauses = bq.clauses().stream().anyMatch(BooleanClause::isRequired); if (hasRequiredClauses) { @@ -69,13 +61,13 @@ public boolean mightMatchNestedDocs(Query query) { .stream() .filter(BooleanClause::isRequired) .map(BooleanClause::query) - .allMatch(this::mightMatchNestedDocs); + .allMatch(f -> mightMatchNestedDocs(f, searchExecutionContext)); } else { return bq.clauses() .stream() .filter(c -> c.occur() == Occur.SHOULD) .map(BooleanClause::query) - .anyMatch(this::mightMatchNestedDocs); + .anyMatch(f -> mightMatchNestedDocs(f, searchExecutionContext)); } } else if (query instanceof ESToParentBlockJoinQuery) { return ((ESToParentBlockJoinQuery) query).getPath() != null; @@ -85,7 +77,7 @@ public boolean mightMatchNestedDocs(Query query) { } /** Returns true if a query on the given field might match nested documents. */ - boolean mightMatchNestedDocs(String field) { + private static boolean mightMatchNestedDocs(String field, SearchExecutionContext searchExecutionContext) { if (field.startsWith("_")) { // meta field. Every meta field behaves differently, eg. nested // documents have the same _uid as their parent, put their path in @@ -94,36 +86,36 @@ boolean mightMatchNestedDocs(String field) { // we might add a nested filter when it is nor required. return true; } - if (isMappedFieldPredicate.test(field) == false) { + if (searchExecutionContext.isFieldMapped(field) == false) { // field does not exist return false; } - return nestedLookup.getNestedParent(field) != null; + return searchExecutionContext.nestedLookup().getNestedParent(field) != null; } /** Returns true if the given query might match parent documents or documents * that are nested under a different path. */ - public boolean mightMatchNonNestedDocs(Query query, String nestedPath) { + public static boolean mightMatchNonNestedDocs(Query query, String nestedPath, SearchExecutionContext searchExecutionContext) { if (query instanceof ConstantScoreQuery) { - return mightMatchNonNestedDocs(((ConstantScoreQuery) query).getQuery(), nestedPath); + return mightMatchNonNestedDocs(((ConstantScoreQuery) query).getQuery(), nestedPath, searchExecutionContext); } else if (query instanceof BoostQuery) { - return mightMatchNonNestedDocs(((BoostQuery) query).getQuery(), nestedPath); + return mightMatchNonNestedDocs(((BoostQuery) query).getQuery(), nestedPath, searchExecutionContext); } else if (query instanceof MatchAllDocsQuery) { return true; } else if (query instanceof MatchNoDocsQuery) { return false; } else if (query instanceof TermQuery) { - return mightMatchNonNestedDocs(((TermQuery) query).getTerm().field(), nestedPath); + return mightMatchNonNestedDocs(searchExecutionContext, ((TermQuery) query).getTerm().field(), nestedPath); } else if (query instanceof TermInSetQuery tis) { if (tis.getTermsCount() > 0) { - return mightMatchNonNestedDocs(tis.getField(), nestedPath); + return mightMatchNonNestedDocs(searchExecutionContext, tis.getField(), nestedPath); } else { return false; } } else if (query instanceof PointRangeQuery) { - return mightMatchNonNestedDocs(((PointRangeQuery) query).getField(), nestedPath); + return mightMatchNonNestedDocs(searchExecutionContext, ((PointRangeQuery) query).getField(), nestedPath); } else if (query instanceof IndexOrDocValuesQuery) { - return mightMatchNonNestedDocs(((IndexOrDocValuesQuery) query).getIndexQuery(), nestedPath); + return mightMatchNonNestedDocs(((IndexOrDocValuesQuery) query).getIndexQuery(), nestedPath, searchExecutionContext); } else if (query instanceof final BooleanQuery bq) { final boolean hasRequiredClauses = bq.clauses().stream().anyMatch(BooleanClause::isRequired); if (hasRequiredClauses) { @@ -131,13 +123,13 @@ public boolean mightMatchNonNestedDocs(Query query, String nestedPath) { .stream() .filter(BooleanClause::isRequired) .map(BooleanClause::query) - .allMatch(q -> mightMatchNonNestedDocs(q, nestedPath)); + .allMatch(q -> mightMatchNonNestedDocs(q, nestedPath, searchExecutionContext)); } else { return bq.clauses() .stream() .filter(c -> c.occur() == Occur.SHOULD) .map(BooleanClause::query) - .anyMatch(q -> mightMatchNonNestedDocs(q, nestedPath)); + .anyMatch(q -> mightMatchNonNestedDocs(q, nestedPath, searchExecutionContext)); } } else { return true; @@ -146,7 +138,7 @@ public boolean mightMatchNonNestedDocs(Query query, String nestedPath) { /** Returns true if a query on the given field might match parent documents * or documents that are nested under a different path. */ - boolean mightMatchNonNestedDocs(String field, String nestedPath) { + private static boolean mightMatchNonNestedDocs(SearchExecutionContext searchExecutionContext, String field, String nestedPath) { if (field.startsWith("_")) { // meta field. Every meta field behaves differently, eg. nested // documents have the same _uid as their parent, put their path in @@ -155,9 +147,10 @@ boolean mightMatchNonNestedDocs(String field, String nestedPath) { // we might add a nested filter when it is nor required. return true; } - if (isMappedFieldPredicate.test(field) == false) { + if (searchExecutionContext.isFieldMapped(field) == false) { return false; } + var nestedLookup = searchExecutionContext.nestedLookup(); String nestedParent = nestedLookup.getNestedParent(field); if (nestedParent == null || nestedParent.startsWith(nestedPath) == false) { // the field is not a sub field of the nested path diff --git a/server/src/main/java/org/elasticsearch/index/search/stats/SearchStats.java b/server/src/main/java/org/elasticsearch/index/search/stats/SearchStats.java index ff514091979c3..8b19d72ccc09d 100644 --- a/server/src/main/java/org/elasticsearch/index/search/stats/SearchStats.java +++ b/server/src/main/java/org/elasticsearch/index/search/stats/SearchStats.java @@ -105,7 +105,7 @@ private Stats(StreamInput in) throws IOException { suggestTimeInMillis = in.readVLong(); suggestCurrent = in.readVLong(); - if (in.getTransportVersion().onOrAfter(TransportVersions.SEARCH_FAILURE_STATS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { queryFailure = in.readVLong(); fetchFailure = in.readVLong(); } @@ -129,7 +129,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVLong(suggestTimeInMillis); out.writeVLong(suggestCurrent); - if (out.getTransportVersion().onOrAfter(TransportVersions.SEARCH_FAILURE_STATS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVLong(queryFailure); out.writeVLong(fetchFailure); } diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index 993079a3106d7..f84ac22cd78e4 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -345,8 +345,9 @@ public IndexShard( this.mapperService = mapperService; this.indexCache = indexCache; this.internalIndexingStats = new InternalIndexingStats(); + var indexingFailuresDebugListener = new IndexingFailuresDebugListener(this); this.indexingOperationListeners = new IndexingOperationListener.CompositeListener( - CollectionUtils.appendToCopyNoNullElements(listeners, internalIndexingStats), + CollectionUtils.appendToCopyNoNullElements(listeners, internalIndexingStats, indexingFailuresDebugListener), logger ); this.bulkOperationListener = new ShardBulkStats(); diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexingFailuresDebugListener.java b/server/src/main/java/org/elasticsearch/index/shard/IndexingFailuresDebugListener.java new file mode 100644 index 0000000000000..13c0d917d492d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexingFailuresDebugListener.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.shard; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.index.engine.Engine; + +import static org.elasticsearch.core.Strings.format; + +public class IndexingFailuresDebugListener implements IndexingOperationListener { + + private static final Logger LOGGER = LogManager.getLogger(IndexingFailuresDebugListener.class); + + private final IndexShard indexShard; + + public IndexingFailuresDebugListener(IndexShard indexShard) { + this.indexShard = indexShard; + } + + @Override + public void postIndex(ShardId shardId, Engine.Index index, Engine.IndexResult result) { + if (LOGGER.isDebugEnabled()) { + if (result.getResultType() == Engine.Result.Type.FAILURE) { + postIndex(shardId, index, result.getFailure()); + } + } + } + + @Override + public void postIndex(ShardId shardId, Engine.Index index, Exception ex) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + () -> format( + "index-fail [%s] seq# [%s] allocation-id [%s] primaryTerm [%s] operationPrimaryTerm [%s] origin [%s]", + index.id(), + index.seqNo(), + indexShard.routingEntry().allocationId(), + index.primaryTerm(), + indexShard.getOperationPrimaryTerm(), + index.origin() + ), + ex + ); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java b/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java index 0f1b40f80c36c..9db316d9683ed 100644 --- a/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java +++ b/server/src/main/java/org/elasticsearch/index/similarity/SimilarityService.java @@ -20,8 +20,6 @@ import org.apache.lucene.search.similarities.Similarity.SimScorer; import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.TriFunction; -import org.elasticsearch.common.logging.DeprecationCategory; -import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.Nullable; @@ -40,7 +38,6 @@ import java.util.function.Supplier; public final class SimilarityService { - private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(SimilarityService.class); public static final String DEFAULT_SIMILARITY = "BM25"; private static final Map>> DEFAULTS; public static final Map> BUILT_IN; @@ -115,13 +112,6 @@ public SimilarityService( defaultSimilarity = (providers.get("default") != null) ? providers.get("default").get() : providers.get(SimilarityService.DEFAULT_SIMILARITY).get(); - if (providers.get("base") != null) { - deprecationLogger.warn( - DeprecationCategory.QUERIES, - "base_similarity_ignored", - "The [base] similarity is ignored since query normalization and coords have been removed" - ); - } } /** diff --git a/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java b/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java index 9c666bd4a35f5..ee38273f13daf 100644 --- a/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/EmptySecretSettings.java @@ -44,7 +44,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_EIS_INTEGRATION_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 4497254aad1f0..c2d690d8160ac 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -112,6 +112,23 @@ void infer( ); /** + * Perform completion inference on the model using the unified schema. + * + * @param model The model + * @param request Parameters for the request + * @param timeout The timeout for the request + * @param listener Inference result listener + */ + void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ); + + /** + * Chunk long text. + * * @param model The model * @param query Inference query, mainly for re-ranking * @param input Inference input diff --git a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java index ebf32f0411555..53ce0bab63612 100644 --- a/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java +++ b/server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java @@ -121,7 +121,7 @@ public ModelConfigurations(StreamInput in) throws IOException { this.service = in.readString(); this.serviceSettings = in.readNamedWriteable(ServiceSettings.class); this.taskSettings = in.readNamedWriteable(TaskSettings.class); - this.chunkingSettings = in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS) + this.chunkingSettings = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalNamedWriteable(ChunkingSettings.class) : null; } @@ -133,7 +133,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(service); out.writeNamedWriteable(serviceSettings); out.writeNamedWriteable(taskSettings); - if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalNamedWriteable(chunkingSettings); } } diff --git a/server/src/main/java/org/elasticsearch/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java index b0e5bababbbc0..fcb8ea7213795 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -38,6 +38,10 @@ public static TaskType fromString(String name) { } public static TaskType fromStringOrStatusException(String name) { + if (name == null) { + throw new ElasticsearchStatusException("Task type must not be null", RestStatus.BAD_REQUEST); + } + try { TaskType taskType = TaskType.fromString(name); return Objects.requireNonNull(taskType); diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java new file mode 100644 index 0000000000000..e596be626b518 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -0,0 +1,425 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public record UnifiedCompletionRequest( + List messages, + @Nullable String model, + @Nullable Long maxCompletionTokens, + @Nullable List stop, + @Nullable Float temperature, + @Nullable ToolChoice toolChoice, + @Nullable List tools, + @Nullable Float topP +) implements Writeable { + + public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {} + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + UnifiedCompletionRequest.class.getSimpleName(), + args -> new UnifiedCompletionRequest( + (List) args[0], + (String) args[1], + (Long) args[2], + (List) args[3], + (Float) args[4], + (ToolChoice) args[5], + (List) args[6], + (Float) args[7] + ) + ); + + static { + PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages")); + PARSER.declareString(optionalConstructorArg(), new ParseField("model")); + PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens")); + PARSER.declareStringArray(optionalConstructorArg(), new ParseField("stop")); + PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature")); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> parseToolChoice(p), + new ParseField("tool_choice"), + ObjectParser.ValueType.OBJECT_OR_STRING + ); + PARSER.declareObjectArray(optionalConstructorArg(), Tool.PARSER::apply, new ParseField("tools")); + PARSER.declareFloat(optionalConstructorArg(), new ParseField("top_p")); + } + + public static List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(Content.class, ContentObjects.NAME, ContentObjects::new), + new NamedWriteableRegistry.Entry(Content.class, ContentString.NAME, ContentString::new), + new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceObject.NAME, ToolChoiceObject::new), + new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new) + ); + } + + public static UnifiedCompletionRequest of(List messages) { + return new UnifiedCompletionRequest(messages, null, null, null, null, null, null, null); + } + + public UnifiedCompletionRequest(StreamInput in) throws IOException { + this( + in.readCollectionAsImmutableList(Message::new), + in.readOptionalString(), + in.readOptionalVLong(), + in.readOptionalStringCollectionAsList(), + in.readOptionalFloat(), + in.readOptionalNamedWriteable(ToolChoice.class), + in.readOptionalCollectionAsList(Tool::new), + in.readOptionalFloat() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(messages); + out.writeOptionalString(model); + out.writeOptionalVLong(maxCompletionTokens); + out.writeOptionalStringCollection(stop); + out.writeOptionalFloat(temperature); + out.writeOptionalNamedWriteable(toolChoice); + out.writeOptionalCollection(tools); + out.writeOptionalFloat(topP); + } + + public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List toolCalls) + implements + Writeable { + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Message.class.getSimpleName(), + args -> new Message((Content) args[0], (String) args[1], (String) args[2], (String) args[3], (List) args[4]) + ); + + static { + PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY); + PARSER.declareString(constructorArg(), new ParseField("role")); + PARSER.declareString(optionalConstructorArg(), new ParseField("name")); + PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id")); + PARSER.declareObjectArray(optionalConstructorArg(), ToolCall.PARSER::apply, new ParseField("tool_calls")); + } + + private static Content parseContent(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.START_ARRAY) { + var parsedContentObjects = XContentParserUtils.parseList(parser, (p) -> ContentObject.PARSER.apply(p, null)); + return new ContentObjects(parsedContentObjects); + } else if (token == XContentParser.Token.VALUE_STRING) { + return ContentString.of(parser); + } + + throw new XContentParseException("Expected an array start token or a value string token but found token [" + token + "]"); + } + + public Message(StreamInput in) throws IOException { + this( + in.readNamedWriteable(Content.class), + in.readString(), + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalCollectionAsList(ToolCall::new) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(content); + out.writeString(role); + out.writeOptionalString(name); + out.writeOptionalString(toolCallId); + out.writeOptionalCollection(toolCalls); + } + } + + public record ContentObjects(List contentObjects) implements Content, NamedWriteable { + + public static final String NAME = "content_objects"; + + public ContentObjects(StreamInput in) throws IOException { + this(in.readCollectionAsImmutableList(ContentObject::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(contentObjects); + } + + @Override + public String getWriteableName() { + return NAME; + } + } + + public record ContentObject(String text, String type) implements Writeable { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ContentObject.class.getSimpleName(), + args -> new ContentObject((String) args[0], (String) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("text")); + PARSER.declareString(constructorArg(), new ParseField("type")); + } + + public ContentObject(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(text); + out.writeString(type); + } + + public String toString() { + return text + ":" + type; + } + + } + + public record ContentString(String content) implements Content, NamedWriteable { + public static final String NAME = "content_string"; + + public static ContentString of(XContentParser parser) throws IOException { + var content = parser.text(); + return new ContentString(content); + } + + public ContentString(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(content); + } + + @Override + public String getWriteableName() { + return NAME; + } + + public String toString() { + return content; + } + } + + public record ToolCall(String id, FunctionField function, String type) implements Writeable { + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ToolCall.class.getSimpleName(), + args -> new ToolCall((String) args[0], (FunctionField) args[1], (String) args[2]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("id")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + PARSER.declareString(constructorArg(), new ParseField("type")); + } + + public ToolCall(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + function.writeTo(out); + out.writeString(type); + } + + public record FunctionField(String arguments, String name) implements Writeable { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tool_call_function_field", + args -> new FunctionField((String) args[0], (String) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("arguments")); + PARSER.declareString(constructorArg(), new ParseField("name")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(arguments); + out.writeString(name); + } + } + } + + private static ToolChoice parseToolChoice(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.START_OBJECT) { + return ToolChoiceObject.PARSER.apply(parser, null); + } else if (token == XContentParser.Token.VALUE_STRING) { + return ToolChoiceString.of(parser); + } + + throw new XContentParseException("Unsupported token [" + token + "]"); + } + + public sealed interface ToolChoice extends NamedWriteable permits ToolChoiceObject, ToolChoiceString {} + + public record ToolChoiceObject(String type, FunctionField function) implements ToolChoice, NamedWriteable { + + public static final String NAME = "tool_choice_object"; + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ToolChoiceObject.class.getSimpleName(), + args -> new ToolChoiceObject((String) args[0], (FunctionField) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("type")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + } + + public ToolChoiceObject(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + function.writeTo(out); + } + + @Override + public String getWriteableName() { + return NAME; + } + + public record FunctionField(String name) implements Writeable { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tool_choice_function_field", + args -> new FunctionField((String) args[0]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("name")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + } + } + } + + public record ToolChoiceString(String value) implements ToolChoice, NamedWriteable { + public static final String NAME = "tool_choice_string"; + + public static ToolChoiceString of(XContentParser parser) throws IOException { + var content = parser.text(); + return new ToolChoiceString(content); + } + + public ToolChoiceString(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + + @Override + public String getWriteableName() { + return NAME; + } + } + + public record Tool(String type, FunctionField function) implements Writeable { + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Tool.class.getSimpleName(), + args -> new Tool((String) args[0], (FunctionField) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("type")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + } + + public Tool(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + function.writeTo(out); + } + + public record FunctionField( + @Nullable String description, + String name, + @Nullable Map parameters, + @Nullable Boolean strict + ) implements Writeable { + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tool_function_field", + args -> new FunctionField((String) args[0], (String) args[1], (Map) args[2], (Boolean) args[3]) + ); + + static { + PARSER.declareString(optionalConstructorArg(), new ParseField("description")); + PARSER.declareString(constructorArg(), new ParseField("name")); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), new ParseField("parameters")); + PARSER.declareBoolean(optionalConstructorArg(), new ParseField("strict")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readOptionalString(), in.readString(), in.readGenericMap(), in.readOptionalBoolean()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(description); + out.writeString(name); + out.writeGenericMap(parameters); + out.writeOptionalBoolean(strict); + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/ingest/EnterpriseGeoIpTask.java b/server/src/main/java/org/elasticsearch/ingest/EnterpriseGeoIpTask.java index e696c38b9f017..ff6a687da9b4d 100644 --- a/server/src/main/java/org/elasticsearch/ingest/EnterpriseGeoIpTask.java +++ b/server/src/main/java/org/elasticsearch/ingest/EnterpriseGeoIpTask.java @@ -64,7 +64,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ENTERPRISE_GEOIP_DOWNLOADER; + return TransportVersions.V_8_16_0; } @Override diff --git a/server/src/main/java/org/elasticsearch/monitor/os/OsService.java b/server/src/main/java/org/elasticsearch/monitor/os/OsService.java index 7609cc14c6b3b..ceed2b0e41fc1 100644 --- a/server/src/main/java/org/elasticsearch/monitor/os/OsService.java +++ b/server/src/main/java/org/elasticsearch/monitor/os/OsService.java @@ -25,7 +25,6 @@ public class OsService implements ReportingService { private static final Logger logger = LogManager.getLogger(OsService.class); - private final OsProbe probe; private final OsInfo info; private final SingleObjectCache osStatsCache; @@ -37,10 +36,9 @@ public class OsService implements ReportingService { ); public OsService(Settings settings) throws IOException { - this.probe = OsProbe.getInstance(); TimeValue refreshInterval = REFRESH_INTERVAL_SETTING.get(settings); - this.info = probe.osInfo(refreshInterval.millis(), EsExecutors.nodeProcessors(settings)); - this.osStatsCache = new OsStatsCache(refreshInterval, probe.osStats()); + this.info = OsProbe.getInstance().osInfo(refreshInterval.millis(), EsExecutors.nodeProcessors(settings)); + this.osStatsCache = new OsStatsCache(refreshInterval); logger.debug("using refresh_interval [{}]", refreshInterval); } @@ -53,14 +51,28 @@ public OsStats stats() { return osStatsCache.getOrRefresh(); } - private class OsStatsCache extends SingleObjectCache { - OsStatsCache(TimeValue interval, OsStats initValue) { - super(interval, initValue); + private static class OsStatsCache extends SingleObjectCache { + + private static final OsStats MISSING = new OsStats( + 0L, + new OsStats.Cpu((short) 0, new double[0]), + new OsStats.Mem(0, 0, 0), + new OsStats.Swap(0, 0), + null + ); + + OsStatsCache(TimeValue interval) { + super(interval, MISSING); } @Override protected OsStats refresh() { - return probe.osStats(); + return OsProbe.getInstance().osStats(); + } + + @Override + protected boolean needsRefresh() { + return getNoRefresh() == MISSING || super.needsRefresh(); } } } diff --git a/server/src/main/java/org/elasticsearch/plugins/PluginsLoader.java b/server/src/main/java/org/elasticsearch/plugins/PluginsLoader.java index aa21e5c64d903..aadda93f977b6 100644 --- a/server/src/main/java/org/elasticsearch/plugins/PluginsLoader.java +++ b/server/src/main/java/org/elasticsearch/plugins/PluginsLoader.java @@ -50,7 +50,6 @@ * to have all the plugin information they need prior to starting. */ public class PluginsLoader { - /** * Contains information about the {@link ClassLoader} required to load a plugin */ @@ -64,18 +63,26 @@ public interface PluginLayer { * @return The {@link ClassLoader} used to instantiate the main class for the plugin */ ClassLoader pluginClassLoader(); + + /** + * @return The {@link ModuleLayer} for the plugin modules + */ + ModuleLayer pluginModuleLayer(); } /** * Contains information about the {@link ClassLoader}s and {@link ModuleLayer} required for loading a plugin - * @param pluginBundle Information about the bundle of jars used in this plugin + * + * @param pluginBundle Information about the bundle of jars used in this plugin * @param pluginClassLoader The {@link ClassLoader} used to instantiate the main class for the plugin - * @param spiClassLoader The exported {@link ClassLoader} visible to other Java modules - * @param spiModuleLayer The exported {@link ModuleLayer} visible to other Java modules + * @param pluginModuleLayer The {@link ModuleLayer} containing the Java modules of the plugin + * @param spiClassLoader The exported {@link ClassLoader} visible to other Java modules + * @param spiModuleLayer The exported {@link ModuleLayer} visible to other Java modules */ private record LoadedPluginLayer( PluginBundle pluginBundle, ClassLoader pluginClassLoader, + ModuleLayer pluginModuleLayer, ClassLoader spiClassLoader, ModuleLayer spiModuleLayer ) implements PluginLayer { @@ -103,6 +110,10 @@ public record LayerAndLoader(ModuleLayer layer, ClassLoader loader) { public static LayerAndLoader ofLoader(ClassLoader loader) { return new LayerAndLoader(ModuleLayer.boot(), loader); } + + public static LayerAndLoader ofUberModuleLoader(UberModuleClassLoader loader) { + return new LayerAndLoader(loader.getLayer(), loader); + } } private static final Logger logger = LogManager.getLogger(PluginsLoader.class); @@ -111,6 +122,7 @@ public static LayerAndLoader ofLoader(ClassLoader loader) { private final List moduleDescriptors; private final List pluginDescriptors; private final Map loadedPluginLayers; + private final Set allBundles; /** * Constructs a new PluginsLoader @@ -185,17 +197,19 @@ public static PluginsLoader createPluginsLoader(Path modulesDirectory, Path plug } } - return new PluginsLoader(moduleDescriptors, pluginDescriptors, loadedPluginLayers); + return new PluginsLoader(moduleDescriptors, pluginDescriptors, loadedPluginLayers, Set.copyOf(seenBundles)); } PluginsLoader( List moduleDescriptors, List pluginDescriptors, - Map loadedPluginLayers + Map loadedPluginLayers, + Set allBundles ) { this.moduleDescriptors = moduleDescriptors; this.pluginDescriptors = pluginDescriptors; this.loadedPluginLayers = loadedPluginLayers; + this.allBundles = allBundles; } public List moduleDescriptors() { @@ -210,6 +224,10 @@ public Stream pluginLayers() { return loadedPluginLayers.values().stream().map(Function.identity()); } + public Set allBundles() { + return allBundles; + } + private static void loadPluginLayer( PluginBundle bundle, Map loaded, @@ -239,7 +257,7 @@ private static void loadPluginLayer( } final ClassLoader pluginParentLoader = spiLayerAndLoader == null ? parentLoader : spiLayerAndLoader.loader(); - final LayerAndLoader pluginLayerAndLoader = createPlugin( + final LayerAndLoader pluginLayerAndLoader = createPluginLayerAndLoader( bundle, pluginParentLoader, extendedPlugins, @@ -253,7 +271,16 @@ private static void loadPluginLayer( spiLayerAndLoader = pluginLayerAndLoader; } - loaded.put(name, new LoadedPluginLayer(bundle, pluginClassLoader, spiLayerAndLoader.loader, spiLayerAndLoader.layer)); + loaded.put( + name, + new LoadedPluginLayer( + bundle, + pluginClassLoader, + pluginLayerAndLoader.layer(), + spiLayerAndLoader.loader, + spiLayerAndLoader.layer + ) + ); } static LayerAndLoader createSPI( @@ -277,7 +304,7 @@ static LayerAndLoader createSPI( } } - static LayerAndLoader createPlugin( + private static LayerAndLoader createPluginLayerAndLoader( PluginBundle bundle, ClassLoader pluginParentLoader, List extendedPlugins, @@ -294,7 +321,7 @@ static LayerAndLoader createPlugin( return createPluginModuleLayer(bundle, pluginParentLoader, parentLayers, qualifiedExports); } else if (plugin.isStable()) { logger.debug(() -> "Loading bundle: " + plugin.getName() + ", non-modular as synthetic module"); - return LayerAndLoader.ofLoader( + return LayerAndLoader.ofUberModuleLoader( UberModuleClassLoader.getInstance( pluginParentLoader, ModuleLayer.boot(), diff --git a/server/src/main/java/org/elasticsearch/rest/action/document/RestGetSourceAction.java b/server/src/main/java/org/elasticsearch/rest/action/document/RestGetSourceAction.java index a09fcbd0c5273..7e4d23db70288 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/document/RestGetSourceAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/document/RestGetSourceAction.java @@ -15,7 +15,6 @@ import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestChannel; @@ -40,9 +39,6 @@ */ @ServerlessScope(Scope.PUBLIC) public class RestGetSourceAction extends BaseRestHandler { - private final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(RestGetSourceAction.class); - static final String TYPES_DEPRECATION_MESSAGE = "[types removal] Specifying types in get_source and exist_source " - + "requests is deprecated."; @Override public List routes() { diff --git a/server/src/main/java/org/elasticsearch/rest/action/document/RestMultiTermVectorsAction.java b/server/src/main/java/org/elasticsearch/rest/action/document/RestMultiTermVectorsAction.java index 65aa1869a41e4..9d39bf7f343c6 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/document/RestMultiTermVectorsAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/document/RestMultiTermVectorsAction.java @@ -13,7 +13,6 @@ import org.elasticsearch.action.termvectors.TermVectorsRequest; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; @@ -28,8 +27,6 @@ @ServerlessScope(Scope.PUBLIC) public class RestMultiTermVectorsAction extends BaseRestHandler { - private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(RestMultiTermVectorsAction.class); - static final String TYPES_DEPRECATION_MESSAGE = "[types removal] Specifying types in multi term vector requests is deprecated."; @Override public List routes() { diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/RestCountAction.java b/server/src/main/java/org/elasticsearch/rest/action/search/RestCountAction.java index c1a55874bfc58..b0e08b376f9d0 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/RestCountAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/RestCountAction.java @@ -14,7 +14,6 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; @@ -36,8 +35,6 @@ @ServerlessScope(Scope.PUBLIC) public class RestCountAction extends BaseRestHandler { - private final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(RestCountAction.class); - static final String TYPES_DEPRECATION_MESSAGE = "[types removal] Specifying types in count requests is deprecated."; @Override public List routes() { diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java index ff062084a3cbb..a9c2ff7576b05 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/RestSearchAction.java @@ -16,7 +16,6 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.client.internal.node.NodeClient; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.Nullable; import org.elasticsearch.features.NodeFeature; @@ -56,8 +55,6 @@ @ServerlessScope(Scope.PUBLIC) public class RestSearchAction extends BaseRestHandler { - private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(RestSearchAction.class); - public static final String TYPES_DEPRECATION_MESSAGE = "[types removal] Specifying types in search requests is deprecated."; /** * Indicates whether hits.total should be rendered as an integer or an object diff --git a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java index 8ac35f7c40caa..b87d097413b67 100644 --- a/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/elasticsearch/search/DefaultSearchContext.java @@ -444,10 +444,9 @@ public void preProcess() { public Query buildFilteredQuery(Query query) { List filters = new ArrayList<>(); NestedLookup nestedLookup = searchExecutionContext.nestedLookup(); - NestedHelper nestedHelper = new NestedHelper(nestedLookup, searchExecutionContext::isFieldMapped); if (nestedLookup != NestedLookup.EMPTY - && nestedHelper.mightMatchNestedDocs(query) - && (aliasFilter == null || nestedHelper.mightMatchNestedDocs(aliasFilter))) { + && NestedHelper.mightMatchNestedDocs(query, searchExecutionContext) + && (aliasFilter == null || NestedHelper.mightMatchNestedDocs(aliasFilter, searchExecutionContext))) { filters.add(Queries.newNonNestedFilter(searchExecutionContext.indexVersionCreated())); } diff --git a/server/src/main/java/org/elasticsearch/search/DocValueFormat.java b/server/src/main/java/org/elasticsearch/search/DocValueFormat.java index a1e8eb25f4780..f8d161ef1f5e5 100644 --- a/server/src/main/java/org/elasticsearch/search/DocValueFormat.java +++ b/server/src/main/java/org/elasticsearch/search/DocValueFormat.java @@ -263,7 +263,7 @@ private DateTime(DateFormatter formatter, ZoneId timeZone, DateFieldMapper.Resol private DateTime(StreamInput in) throws IOException { String formatterPattern = in.readString(); - Locale locale = in.getTransportVersion().onOrAfter(TransportVersions.DATE_TIME_DOC_VALUES_LOCALES) + Locale locale = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? LocaleUtils.parse(in.readString()) : DateFieldMapper.DEFAULT_LOCALE; String zoneId = in.readString(); @@ -297,7 +297,7 @@ public static DateTime readFrom(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(formatter.pattern()); - if (out.getTransportVersion().onOrAfter(TransportVersions.DATE_TIME_DOC_VALUES_LOCALES)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeString(formatter.locale().toString()); } out.writeString(timeZone.getId()); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/BucketOrder.java b/server/src/main/java/org/elasticsearch/search/aggregations/BucketOrder.java index 2d360705f75b6..c412ecb5d6361 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/BucketOrder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/BucketOrder.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation.Bucket; +import org.elasticsearch.search.aggregations.bucket.terms.BucketAndOrd; import org.elasticsearch.search.aggregations.support.AggregationPath; import org.elasticsearch.xcontent.ToXContentObject; @@ -20,13 +21,12 @@ import java.util.Comparator; import java.util.List; import java.util.function.BiFunction; -import java.util.function.ToLongFunction; /** * {@link Bucket} ordering strategy. Buckets can be order either as * "complete" buckets using {@link #comparator()} or against a combination * of the buckets internals with its ordinal with - * {@link #partiallyBuiltBucketComparator(ToLongFunction, Aggregator)}. + * {@link #partiallyBuiltBucketComparator(Aggregator)}. */ public abstract class BucketOrder implements ToXContentObject, Writeable { /** @@ -102,7 +102,7 @@ public final void validate(Aggregator aggregator) throws AggregationExecutionExc * to validate this order because doing so checks all of the appropriate * paths. */ - partiallyBuiltBucketComparator(null, aggregator); + partiallyBuiltBucketComparator(aggregator); } /** @@ -121,7 +121,7 @@ public final void validate(Aggregator aggregator) throws AggregationExecutionExc * with it all the time. *

*/ - public abstract Comparator partiallyBuiltBucketComparator(ToLongFunction ordinalReader, Aggregator aggregator); + public abstract Comparator> partiallyBuiltBucketComparator(Aggregator aggregator); /** * Build a comparator for fully built buckets. diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/InternalOrder.java b/server/src/main/java/org/elasticsearch/search/aggregations/InternalOrder.java index b2ca4a10dc4b3..3593eb5adf7e4 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/InternalOrder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/InternalOrder.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.search.aggregations.Aggregator.BucketComparator; import org.elasticsearch.search.aggregations.bucket.MultiBucketsAggregation.Bucket; +import org.elasticsearch.search.aggregations.bucket.terms.BucketAndOrd; import org.elasticsearch.search.aggregations.support.AggregationPath; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.sort.SortValue; @@ -30,7 +31,6 @@ import java.util.List; import java.util.Objects; import java.util.function.BiFunction; -import java.util.function.ToLongFunction; /** * Implementations for {@link Bucket} ordering strategies. @@ -63,10 +63,10 @@ public AggregationPath path() { } @Override - public Comparator partiallyBuiltBucketComparator(ToLongFunction ordinalReader, Aggregator aggregator) { + public Comparator> partiallyBuiltBucketComparator(Aggregator aggregator) { try { BucketComparator bucketComparator = path.bucketComparator(aggregator, order); - return (lhs, rhs) -> bucketComparator.compare(ordinalReader.applyAsLong(lhs), ordinalReader.applyAsLong(rhs)); + return (lhs, rhs) -> bucketComparator.compare(lhs.ord, rhs.ord); } catch (IllegalArgumentException e) { throw new AggregationExecutionException.InvalidPath("Invalid aggregation order path [" + path + "]. " + e.getMessage(), e); } @@ -188,12 +188,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - public Comparator partiallyBuiltBucketComparator(ToLongFunction ordinalReader, Aggregator aggregator) { - List> comparators = orderElements.stream() - .map(oe -> oe.partiallyBuiltBucketComparator(ordinalReader, aggregator)) - .toList(); + public Comparator> partiallyBuiltBucketComparator(Aggregator aggregator) { + List>> comparators = new ArrayList<>(orderElements.size()); + for (BucketOrder order : orderElements) { + comparators.add(order.partiallyBuiltBucketComparator(aggregator)); + } return (lhs, rhs) -> { - for (Comparator c : comparators) { + for (Comparator> c : comparators) { int result = c.compare(lhs, rhs); if (result != 0) { return result; @@ -299,9 +300,9 @@ byte id() { } @Override - public Comparator partiallyBuiltBucketComparator(ToLongFunction ordinalReader, Aggregator aggregator) { + public Comparator> partiallyBuiltBucketComparator(Aggregator aggregator) { Comparator comparator = comparator(); - return comparator::compare; + return (lhs, rhs) -> comparator.compare(lhs.bucket, rhs.bucket); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/countedterms/CountedTermsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/countedterms/CountedTermsAggregator.java index 344b90b06c4f6..571ce3a9a4519 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/countedterms/CountedTermsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/countedterms/CountedTermsAggregator.java @@ -13,6 +13,7 @@ import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.util.IntArray; import org.elasticsearch.common.util.LongArray; import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.core.Releasables; @@ -26,6 +27,7 @@ import org.elasticsearch.search.aggregations.InternalOrder; import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; +import org.elasticsearch.search.aggregations.bucket.terms.BucketAndOrd; import org.elasticsearch.search.aggregations.bucket.terms.BucketPriorityQueue; import org.elasticsearch.search.aggregations.bucket.terms.BytesKeyedBucketOrds; import org.elasticsearch.search.aggregations.bucket.terms.InternalTerms; @@ -38,7 +40,6 @@ import java.util.Arrays; import java.util.Map; import java.util.function.BiConsumer; -import java.util.function.Supplier; import static java.util.Collections.emptyList; import static org.elasticsearch.search.aggregations.InternalOrder.isKeyOrder; @@ -115,51 +116,57 @@ public InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throw LongArray otherDocCounts = bigArrays().newLongArray(owningBucketOrds.size()); ObjectArray topBucketsPerOrd = bigArrays().newObjectArray(owningBucketOrds.size()) ) { - for (long ordIdx = 0; ordIdx < topBucketsPerOrd.size(); ordIdx++) { - int size = (int) Math.min(bucketOrds.size(), bucketCountThresholds.getShardSize()); - - // as users can't control sort order, in practice we'll always sort by doc count descending - try ( - BucketPriorityQueue ordered = new BucketPriorityQueue<>( - size, - bigArrays(), - partiallyBuiltBucketComparator - ) - ) { - StringTerms.Bucket spare = null; - BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrds.get(ordIdx)); - Supplier emptyBucketBuilder = () -> new StringTerms.Bucket( - new BytesRef(), - 0, - null, - false, - 0, - format - ); - while (ordsEnum.next()) { - long docCount = bucketDocCount(ordsEnum.ord()); - otherDocCounts.increment(ordIdx, docCount); - if (spare == null) { - checkRealMemoryCBForInternalBucket(); - spare = emptyBucketBuilder.get(); + try (IntArray bucketsToCollect = bigArrays().newIntArray(owningBucketOrds.size())) { + // find how many buckets we are going to collect + long ordsToCollect = 0; + for (long ordIdx = 0; ordIdx < owningBucketOrds.size(); ordIdx++) { + int size = (int) Math.min(bucketOrds.bucketsInOrd(owningBucketOrds.get(ordIdx)), bucketCountThresholds.getShardSize()); + bucketsToCollect.set(ordIdx, size); + ordsToCollect += size; + } + try (LongArray ordsArray = bigArrays().newLongArray(ordsToCollect)) { + long ordsCollected = 0; + for (long ordIdx = 0; ordIdx < owningBucketOrds.size(); ordIdx++) { + // as users can't control sort order, in practice we'll always sort by doc count descending + try ( + BucketPriorityQueue ordered = new BucketPriorityQueue<>( + bucketsToCollect.get(ordIdx), + bigArrays(), + order.partiallyBuiltBucketComparator(this) + ) + ) { + BucketAndOrd spare = null; + BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrds.get(ordIdx)); + while (ordsEnum.next()) { + long docCount = bucketDocCount(ordsEnum.ord()); + otherDocCounts.increment(ordIdx, docCount); + if (spare == null) { + checkRealMemoryCBForInternalBucket(); + spare = new BucketAndOrd<>(new StringTerms.Bucket(new BytesRef(), 0, null, false, 0, format)); + } + ordsEnum.readValue(spare.bucket.getTermBytes()); + spare.bucket.setDocCount(docCount); + spare.ord = ordsEnum.ord(); + spare = ordered.insertWithOverflow(spare); + } + final int orderedSize = (int) ordered.size(); + final StringTerms.Bucket[] buckets = new StringTerms.Bucket[orderedSize]; + for (int i = orderedSize - 1; i >= 0; --i) { + BucketAndOrd bucketAndOrd = ordered.pop(); + buckets[i] = bucketAndOrd.bucket; + ordsArray.set(ordsCollected + i, bucketAndOrd.ord); + otherDocCounts.increment(ordIdx, -bucketAndOrd.bucket.getDocCount()); + bucketAndOrd.bucket.setTermBytes(BytesRef.deepCopyOf(bucketAndOrd.bucket.getTermBytes())); + } + topBucketsPerOrd.set(ordIdx, buckets); + ordsCollected += orderedSize; } - ordsEnum.readValue(spare.getTermBytes()); - spare.setDocCount(docCount); - spare.setBucketOrd(ordsEnum.ord()); - spare = ordered.insertWithOverflow(spare); - } - - topBucketsPerOrd.set(ordIdx, new StringTerms.Bucket[(int) ordered.size()]); - for (int i = (int) ordered.size() - 1; i >= 0; --i) { - topBucketsPerOrd.get(ordIdx)[i] = ordered.pop(); - otherDocCounts.increment(ordIdx, -topBucketsPerOrd.get(ordIdx)[i].getDocCount()); - topBucketsPerOrd.get(ordIdx)[i].setTermBytes(BytesRef.deepCopyOf(topBucketsPerOrd.get(ordIdx)[i].getTermBytes())); } + assert ordsCollected == ordsArray.size(); + buildSubAggsForAllBuckets(topBucketsPerOrd, ordsArray, InternalTerms.Bucket::setAggregations); } } - buildSubAggsForAllBuckets(topBucketsPerOrd, InternalTerms.Bucket::getBucketOrd, InternalTerms.Bucket::setAggregations); - return buildAggregations(Math.toIntExact(owningBucketOrds.size()), ordIdx -> { final BucketOrder reduceOrder; if (isKeyOrder(order) == false) { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/BucketPriorityQueue.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/BucketPriorityQueue.java index 7f8e5c8c885fa..9550003a5bd1e 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/BucketPriorityQueue.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/BucketPriorityQueue.java @@ -13,17 +13,17 @@ import java.util.Comparator; -public class BucketPriorityQueue extends ObjectArrayPriorityQueue { +public class BucketPriorityQueue extends ObjectArrayPriorityQueue> { - private final Comparator comparator; + private final Comparator> comparator; - public BucketPriorityQueue(int size, BigArrays bigArrays, Comparator comparator) { + public BucketPriorityQueue(int size, BigArrays bigArrays, Comparator> comparator) { super(size, bigArrays); this.comparator = comparator; } @Override - protected boolean lessThan(B a, B b) { + protected boolean lessThan(BucketAndOrd a, BucketAndOrd b) { return comparator.compare(a, b) > 0; // reverse, since we reverse again when adding to a list } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/BucketSignificancePriorityQueue.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/BucketSignificancePriorityQueue.java index fe751c9e79189..4736f52d93622 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/BucketSignificancePriorityQueue.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/BucketSignificancePriorityQueue.java @@ -12,14 +12,14 @@ import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.ObjectArrayPriorityQueue; -public class BucketSignificancePriorityQueue extends ObjectArrayPriorityQueue { +public class BucketSignificancePriorityQueue extends ObjectArrayPriorityQueue> { public BucketSignificancePriorityQueue(int size, BigArrays bigArrays) { super(size, bigArrays); } @Override - protected boolean lessThan(SignificantTerms.Bucket o1, SignificantTerms.Bucket o2) { - return o1.getSignificanceScore() < o2.getSignificanceScore(); + protected boolean lessThan(BucketAndOrd o1, BucketAndOrd o2) { + return o1.bucket.getSignificanceScore() < o2.bucket.getSignificanceScore(); } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java index 4cf710232c7a0..439b61cc43ddf 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/GlobalOrdinalsStringTermsAggregator.java @@ -20,12 +20,11 @@ import org.apache.lucene.util.PriorityQueue; import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.IntArray; import org.elasticsearch.common.util.LongArray; import org.elasticsearch.common.util.LongHash; import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.common.util.ObjectArrayPriorityQueue; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.DocValueFormat; @@ -102,14 +101,14 @@ public GlobalOrdinalsStringTermsAggregator( this.valueCount = valuesSupplier.get().getValueCount(); this.acceptedGlobalOrdinals = acceptedOrds; if (remapGlobalOrds) { - this.collectionStrategy = new RemapGlobalOrds(cardinality, excludeDeletedDocs); + this.collectionStrategy = new RemapGlobalOrds<>(this.resultStrategy, cardinality, excludeDeletedDocs); } else { this.collectionStrategy = cardinality.map(estimate -> { if (estimate > 1) { // This is a 500 class error, because we should never be able to reach it. throw new AggregationExecutionException("Dense ords don't know how to collect from many buckets"); } - return new DenseGlobalOrds(excludeDeletedDocs); + return new DenseGlobalOrds<>(this.resultStrategy, excludeDeletedDocs); }); } } @@ -193,7 +192,13 @@ public void collect(int doc, long owningBucketOrd) throws IOException { @Override public InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throws IOException { - return resultStrategy.buildAggregations(owningBucketOrds); + if (valueCount == 0) { // no context in this reader + return GlobalOrdinalsStringTermsAggregator.this.buildAggregations( + Math.toIntExact(owningBucketOrds.size()), + ordIdx -> resultStrategy.buildNoValuesResult(owningBucketOrds.get(ordIdx)) + ); + } + return collectionStrategy.buildAggregations(owningBucketOrds); } @Override @@ -401,8 +406,8 @@ private void mapSegmentCountsToGlobalCounts(LongUnaryOperator mapping) throws IO * The {@link GlobalOrdinalsStringTermsAggregator} uses one of these * to collect the global ordinals by calling * {@link CollectionStrategy#collectGlobalOrd} for each global ordinal - * that it hits and then calling {@link CollectionStrategy#forEach} - * once to iterate on the results. + * that it hits and then calling {@link CollectionStrategy#buildAggregations} + * to generate the results. */ abstract static class CollectionStrategy implements Releasable { /** @@ -438,15 +443,9 @@ abstract static class CollectionStrategy implements Releasable { abstract long globalOrdToBucketOrd(long owningBucketOrd, long globalOrd); /** - * Iterate all of the buckets. Implementations take into account - * the {@link BucketCountThresholds}. In particular, - * if the {@link BucketCountThresholds#getMinDocCount()} is 0 then - * they'll make sure to iterate a bucket even if it was never - * {{@link #collectGlobalOrd collected}. - * If {@link BucketCountThresholds#getMinDocCount()} is not 0 then - * they'll skip all global ords that weren't collected. + * Create the aggregation result */ - abstract void forEach(long owningBucketOrd, BucketInfoConsumer consumer) throws IOException; + abstract InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throws IOException; } interface BucketInfoConsumer { @@ -457,12 +456,17 @@ interface BucketInfoConsumer { * {@linkplain CollectionStrategy} that just uses the global ordinal as the * bucket ordinal. */ - class DenseGlobalOrds extends CollectionStrategy { + class DenseGlobalOrds< + R extends InternalAggregation, + B extends InternalMultiBucketAggregation.InternalBucket, + TB extends InternalMultiBucketAggregation.InternalBucket> extends CollectionStrategy { private final boolean excludeDeletedDocs; + private final ResultStrategy collectionStrategy; - DenseGlobalOrds(boolean excludeDeletedDocs) { + DenseGlobalOrds(ResultStrategy collectionStrategy, boolean excludeDeletedDocs) { this.excludeDeletedDocs = excludeDeletedDocs; + this.collectionStrategy = collectionStrategy; } @Override @@ -492,9 +496,7 @@ long globalOrdToBucketOrd(long owningBucketOrd, long globalOrd) { return globalOrd; } - @Override - void forEach(long owningBucketOrd, BucketInfoConsumer consumer) throws IOException { - assert owningBucketOrd == 0; + private void collect(BucketInfoConsumer consumer) throws IOException { if (excludeDeletedDocs) { forEachExcludeDeletedDocs(consumer); } else { @@ -518,7 +520,7 @@ private void forEachAllowDeletedDocs(BucketInfoConsumer consumer) throws IOExcep * Excludes deleted docs in the results by cross-checking with liveDocs. */ private void forEachExcludeDeletedDocs(BucketInfoConsumer consumer) throws IOException { - try (LongHash accepted = new LongHash(20, new BigArrays(null, null, ""))) { + try (LongHash accepted = new LongHash(20, bigArrays())) { for (LeafReaderContext ctx : searcher().getTopReaderContext().leaves()) { LeafReader reader = ctx.reader(); Bits liveDocs = reader.getLiveDocs(); @@ -550,6 +552,62 @@ private void forEachExcludeDeletedDocs(BucketInfoConsumer consumer) throws IOExc @Override public void close() {} + + @Override + InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throws IOException { + assert owningBucketOrds.size() == 1 && owningBucketOrds.get(0) == 0; + try ( + LongArray otherDocCount = bigArrays().newLongArray(1, true); + ObjectArray topBucketsPreOrd = collectionStrategy.buildTopBucketsPerOrd(1) + ) { + GlobalOrdLookupFunction lookupGlobalOrd = valuesSupplier.get()::lookupOrd; + final int size = (int) Math.min(valueCount, bucketCountThresholds.getShardSize()); + try (ObjectArrayPriorityQueue> ordered = collectionStrategy.buildPriorityQueue(size)) { + BucketUpdater updater = collectionStrategy.bucketUpdater(0, lookupGlobalOrd); + collect(new BucketInfoConsumer() { + BucketAndOrd spare = null; + + @Override + public void accept(long globalOrd, long bucketOrd, long docCount) throws IOException { + otherDocCount.increment(0, docCount); + if (docCount >= bucketCountThresholds.getShardMinDocCount()) { + if (spare == null) { + checkRealMemoryCBForInternalBucket(); + spare = new BucketAndOrd<>(collectionStrategy.buildEmptyTemporaryBucket()); + } + spare.ord = bucketOrd; + updater.updateBucket(spare.bucket, globalOrd, docCount); + spare = ordered.insertWithOverflow(spare); + } + } + }); + + // Get the top buckets + int orderedSize = (int) ordered.size(); + try (LongArray ordsArray = bigArrays().newLongArray(orderedSize)) { + B[] buckets = collectionStrategy.buildBuckets(orderedSize); + for (int i = orderedSize - 1; i >= 0; --i) { + checkRealMemoryCBForInternalBucket(); + BucketAndOrd bucketAndOrd = ordered.pop(); + B bucket = collectionStrategy.convertTempBucketToRealBucket(bucketAndOrd.bucket, lookupGlobalOrd); + ordsArray.set(i, bucketAndOrd.ord); + buckets[i] = bucket; + otherDocCount.increment(0, -bucket.getDocCount()); + } + topBucketsPreOrd.set(0, buckets); + collectionStrategy.buildSubAggs(topBucketsPreOrd, ordsArray); + } + } + return GlobalOrdinalsStringTermsAggregator.this.buildAggregations( + Math.toIntExact(owningBucketOrds.size()), + ordIdx -> collectionStrategy.buildResult( + owningBucketOrds.get(ordIdx), + otherDocCount.get(ordIdx), + topBucketsPreOrd.get(ordIdx) + ) + ); + } + } } /** @@ -558,13 +616,22 @@ public void close() {} * {@link DenseGlobalOrds} when collecting every ordinal, but significantly * less when collecting only a few. */ - private class RemapGlobalOrds extends CollectionStrategy { + private class RemapGlobalOrds< + R extends InternalAggregation, + B extends InternalMultiBucketAggregation.InternalBucket, + TB extends InternalMultiBucketAggregation.InternalBucket> extends CollectionStrategy { private final LongKeyedBucketOrds bucketOrds; private final boolean excludeDeletedDocs; + private final ResultStrategy collectionStrategy; - private RemapGlobalOrds(CardinalityUpperBound cardinality, boolean excludeDeletedDocs) { + private RemapGlobalOrds( + ResultStrategy collectionStrategy, + CardinalityUpperBound cardinality, + boolean excludeDeletedDocs + ) { bucketOrds = LongKeyedBucketOrds.buildForValueRange(bigArrays(), cardinality, 0, valueCount - 1); this.excludeDeletedDocs = excludeDeletedDocs; + this.collectionStrategy = collectionStrategy; } @Override @@ -596,30 +663,14 @@ long globalOrdToBucketOrd(long owningBucketOrd, long globalOrd) { return bucketOrds.find(owningBucketOrd, globalOrd); } - @Override - void forEach(long owningBucketOrd, BucketInfoConsumer consumer) throws IOException { + private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOException { if (excludeDeletedDocs) { - forEachExcludeDeletedDocs(owningBucketOrd, consumer); - } else { - forEachAllowDeletedDocs(owningBucketOrd, consumer); - } - } - - void forEachAllowDeletedDocs(long owningBucketOrd, BucketInfoConsumer consumer) throws IOException { - if (bucketCountThresholds.getMinDocCount() == 0) { + forEachExcludeDeletedDocs(owningBucketOrd); + } else if (bucketCountThresholds.getMinDocCount() == 0) { for (long globalOrd = 0; globalOrd < valueCount; globalOrd++) { - if (false == acceptedGlobalOrdinals.test(globalOrd)) { - continue; - } - addBucketForMinDocCountZero(owningBucketOrd, globalOrd, consumer, null); - } - } else { - LongKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrd); - while (ordsEnum.next()) { - if (false == acceptedGlobalOrdinals.test(ordsEnum.value())) { - continue; + if (acceptedGlobalOrdinals.test(globalOrd)) { + bucketOrds.add(owningBucketOrd, globalOrd); } - consumer.accept(ordsEnum.value(), ordsEnum.ord(), bucketDocCount(ordsEnum.ord())); } } } @@ -627,9 +678,9 @@ void forEachAllowDeletedDocs(long owningBucketOrd, BucketInfoConsumer consumer) /** * Excludes deleted docs in the results by cross-checking with liveDocs. */ - void forEachExcludeDeletedDocs(long owningBucketOrd, BucketInfoConsumer consumer) throws IOException { + private void forEachExcludeDeletedDocs(long owningBucketOrd) throws IOException { assert bucketCountThresholds.getMinDocCount() == 0; - try (LongHash accepted = new LongHash(20, new BigArrays(null, null, ""))) { + try (LongHash accepted = new LongHash(20, bigArrays())) { for (LeafReaderContext ctx : searcher().getTopReaderContext().leaves()) { LeafReader reader = ctx.reader(); Bits liveDocs = reader.getLiveDocs(); @@ -646,7 +697,8 @@ void forEachExcludeDeletedDocs(long owningBucketOrd, BucketInfoConsumer consumer if (false == acceptedGlobalOrdinals.test(globalOrd)) { continue; } - addBucketForMinDocCountZero(owningBucketOrd, globalOrd, consumer, accepted); + bucketOrds.add(owningBucketOrd, globalOrd); + accepted.add(globalOrd); } } } @@ -655,110 +707,93 @@ void forEachExcludeDeletedDocs(long owningBucketOrd, BucketInfoConsumer consumer } } - private void addBucketForMinDocCountZero( - long owningBucketOrd, - long globalOrd, - BucketInfoConsumer consumer, - @Nullable LongHash accepted - ) throws IOException { - /* - * Use `add` instead of `find` here to assign an ordinal - * even if the global ord wasn't found so we can build - * sub-aggregations without trouble even though we haven't - * hit any documents for them. This is wasteful, but - * settings minDocCount == 0 is wasteful in general..... - */ - long bucketOrd = bucketOrds.add(owningBucketOrd, globalOrd); - long docCount; - if (bucketOrd < 0) { - bucketOrd = -1 - bucketOrd; - docCount = bucketDocCount(bucketOrd); - } else { - docCount = 0; - } - assert globalOrd >= 0; - consumer.accept(globalOrd, bucketOrd, docCount); - if (accepted != null) { - accepted.add(globalOrd); - } - } - @Override public void close() { bucketOrds.close(); } - } - - /** - * Strategy for building results. - */ - abstract class ResultStrategy< - R extends InternalAggregation, - B extends InternalMultiBucketAggregation.InternalBucket, - TB extends InternalMultiBucketAggregation.InternalBucket> implements Releasable { - - private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throws IOException { - if (valueCount == 0) { // no context in this reader - return GlobalOrdinalsStringTermsAggregator.this.buildAggregations( - Math.toIntExact(owningBucketOrds.size()), - ordIdx -> buildNoValuesResult(owningBucketOrds.get(ordIdx)) - ); - } + @Override + InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throws IOException { try ( LongArray otherDocCount = bigArrays().newLongArray(owningBucketOrds.size(), true); - ObjectArray topBucketsPreOrd = buildTopBucketsPerOrd(owningBucketOrds.size()) + ObjectArray topBucketsPreOrd = collectionStrategy.buildTopBucketsPerOrd(owningBucketOrds.size()) ) { - GlobalOrdLookupFunction lookupGlobalOrd = valuesSupplier.get()::lookupOrd; - for (long ordIdx = 0; ordIdx < topBucketsPreOrd.size(); ordIdx++) { - final int size; - if (bucketCountThresholds.getMinDocCount() == 0) { - // if minDocCount == 0 then we can end up with more buckets then maxBucketOrd() returns - size = (int) Math.min(valueCount, bucketCountThresholds.getShardSize()); - } else { - size = (int) Math.min(maxBucketOrd(), bucketCountThresholds.getShardSize()); - } - try (ObjectArrayPriorityQueue ordered = buildPriorityQueue(size)) { - final long finalOrdIdx = ordIdx; + try (IntArray bucketsToCollect = bigArrays().newIntArray(owningBucketOrds.size())) { + long ordsToCollect = 0; + for (long ordIdx = 0; ordIdx < owningBucketOrds.size(); ordIdx++) { final long owningBucketOrd = owningBucketOrds.get(ordIdx); - BucketUpdater updater = bucketUpdater(owningBucketOrd, lookupGlobalOrd); - collectionStrategy.forEach(owningBucketOrd, new BucketInfoConsumer() { - TB spare = null; - - @Override - public void accept(long globalOrd, long bucketOrd, long docCount) throws IOException { - otherDocCount.increment(finalOrdIdx, docCount); - if (docCount >= bucketCountThresholds.getShardMinDocCount()) { + collectZeroDocEntriesIfNeeded(owningBucketOrd); + final int size = (int) Math.min(bucketOrds.bucketsInOrd(owningBucketOrd), bucketCountThresholds.getShardSize()); + ordsToCollect += size; + bucketsToCollect.set(ordIdx, size); + } + try (LongArray ordsArray = bigArrays().newLongArray(ordsToCollect)) { + long ordsCollected = 0; + GlobalOrdLookupFunction lookupGlobalOrd = valuesSupplier.get()::lookupOrd; + for (long ordIdx = 0; ordIdx < topBucketsPreOrd.size(); ordIdx++) { + long owningBucketOrd = owningBucketOrds.get(ordIdx); + try ( + ObjectArrayPriorityQueue> ordered = collectionStrategy.buildPriorityQueue( + bucketsToCollect.get(ordIdx) + ) + ) { + BucketUpdater updater = collectionStrategy.bucketUpdater(owningBucketOrd, lookupGlobalOrd); + LongKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrd); + BucketAndOrd spare = null; + while (ordsEnum.next()) { + long docCount = bucketDocCount(ordsEnum.ord()); + otherDocCount.increment(ordIdx, docCount); + if (docCount < bucketCountThresholds.getShardMinDocCount()) { + continue; + } if (spare == null) { checkRealMemoryCBForInternalBucket(); - spare = buildEmptyTemporaryBucket(); + spare = new BucketAndOrd<>(collectionStrategy.buildEmptyTemporaryBucket()); } - updater.updateBucket(spare, globalOrd, bucketOrd, docCount); + updater.updateBucket(spare.bucket, ordsEnum.value(), docCount); + spare.ord = ordsEnum.ord(); spare = ordered.insertWithOverflow(spare); } + // Get the top buckets + int orderedSize = (int) ordered.size(); + B[] buckets = collectionStrategy.buildBuckets(orderedSize); + for (int i = orderedSize - 1; i >= 0; --i) { + checkRealMemoryCBForInternalBucket(); + BucketAndOrd bucketAndOrd = ordered.pop(); + B bucket = collectionStrategy.convertTempBucketToRealBucket(bucketAndOrd.bucket, lookupGlobalOrd); + ordsArray.set(ordsCollected + i, bucketAndOrd.ord); + buckets[i] = bucket; + otherDocCount.increment(ordIdx, -bucket.getDocCount()); + } + topBucketsPreOrd.set(ordIdx, buckets); + ordsCollected += orderedSize; } - }); - - // Get the top buckets - topBucketsPreOrd.set(ordIdx, buildBuckets((int) ordered.size())); - for (int i = (int) ordered.size() - 1; i >= 0; --i) { - checkRealMemoryCBForInternalBucket(); - B bucket = convertTempBucketToRealBucket(ordered.pop(), lookupGlobalOrd); - topBucketsPreOrd.get(ordIdx)[i] = bucket; - otherDocCount.increment(ordIdx, -bucket.getDocCount()); } + assert ordsCollected == ordsArray.size(); + collectionStrategy.buildSubAggs(topBucketsPreOrd, ordsArray); } } - - buildSubAggs(topBucketsPreOrd); - return GlobalOrdinalsStringTermsAggregator.this.buildAggregations( Math.toIntExact(owningBucketOrds.size()), - ordIdx -> buildResult(owningBucketOrds.get(ordIdx), otherDocCount.get(ordIdx), topBucketsPreOrd.get(ordIdx)) + ordIdx -> collectionStrategy.buildResult( + owningBucketOrds.get(ordIdx), + otherDocCount.get(ordIdx), + topBucketsPreOrd.get(ordIdx) + ) ); } } + } + + /** + * Strategy for building results. + */ + abstract class ResultStrategy< + R extends InternalAggregation, + B extends InternalMultiBucketAggregation.InternalBucket, + TB extends InternalMultiBucketAggregation.InternalBucket> implements Releasable { + /** * Short description of the collection mechanism added to the profile * output to help with debugging. @@ -780,13 +815,13 @@ public void accept(long globalOrd, long bucketOrd, long docCount) throws IOExcep * Update fields in {@code spare} to reflect information collected for * this bucket ordinal. */ - abstract BucketUpdater bucketUpdater(long owningBucketOrd, GlobalOrdLookupFunction lookupGlobalOrd) throws IOException; + abstract BucketUpdater bucketUpdater(long owningBucketOrd, GlobalOrdLookupFunction lookupGlobalOrd); /** * Build a {@link PriorityQueue} to sort the buckets. After we've * collected all of the buckets we'll collect all entries in the queue. */ - abstract ObjectArrayPriorityQueue buildPriorityQueue(int size); + abstract ObjectArrayPriorityQueue> buildPriorityQueue(int size); /** * Build an array to hold the "top" buckets for each ordinal. @@ -808,7 +843,7 @@ public void accept(long globalOrd, long bucketOrd, long docCount) throws IOExcep * Build the sub-aggregations into the buckets. This will usually * delegate to {@link #buildSubAggsForAllBuckets}. */ - abstract void buildSubAggs(ObjectArray topBucketsPreOrd) throws IOException; + abstract void buildSubAggs(ObjectArray topBucketsPreOrd, LongArray ordsArray) throws IOException; /** * Turn the buckets into an aggregation result. @@ -829,7 +864,7 @@ public void accept(long globalOrd, long bucketOrd, long docCount) throws IOExcep } interface BucketUpdater { - void updateBucket(TB spare, long globalOrd, long bucketOrd, long docCount) throws IOException; + void updateBucket(TB spare, long globalOrd, long docCount) throws IOException; } /** @@ -862,30 +897,31 @@ OrdBucket buildEmptyTemporaryBucket() { } @Override - BucketUpdater bucketUpdater(long owningBucketOrd, GlobalOrdLookupFunction lookupGlobalOrd) throws IOException { - return (spare, globalOrd, bucketOrd, docCount) -> { + BucketUpdater bucketUpdater(long owningBucketOrd, GlobalOrdLookupFunction lookupGlobalOrd) { + return (spare, globalOrd, docCount) -> { spare.globalOrd = globalOrd; - spare.bucketOrd = bucketOrd; spare.docCount = docCount; }; } @Override - ObjectArrayPriorityQueue buildPriorityQueue(int size) { - return new BucketPriorityQueue<>(size, bigArrays(), partiallyBuiltBucketComparator); + ObjectArrayPriorityQueue> buildPriorityQueue(int size) { + return new BucketPriorityQueue<>( + size, + bigArrays(), + order.partiallyBuiltBucketComparator(GlobalOrdinalsStringTermsAggregator.this) + ); } @Override StringTerms.Bucket convertTempBucketToRealBucket(OrdBucket temp, GlobalOrdLookupFunction lookupGlobalOrd) throws IOException { BytesRef term = BytesRef.deepCopyOf(lookupGlobalOrd.apply(temp.globalOrd)); - StringTerms.Bucket result = new StringTerms.Bucket(term, temp.docCount, null, showTermDocCountError, 0, format); - result.bucketOrd = temp.bucketOrd; - return result; + return new StringTerms.Bucket(term, temp.docCount, null, showTermDocCountError, 0, format); } @Override - void buildSubAggs(ObjectArray topBucketsPreOrd) throws IOException { - buildSubAggsForAllBuckets(topBucketsPreOrd, b -> b.bucketOrd, (b, aggs) -> b.aggregations = aggs); + void buildSubAggs(ObjectArray topBucketsPreOrd, LongArray ordsArray) throws IOException { + buildSubAggsForAllBuckets(topBucketsPreOrd, ordsArray, (b, aggs) -> b.aggregations = aggs); } @Override @@ -1000,8 +1036,7 @@ private long subsetSize(long owningBucketOrd) { @Override BucketUpdater bucketUpdater(long owningBucketOrd, GlobalOrdLookupFunction lookupGlobalOrd) { long subsetSize = subsetSize(owningBucketOrd); - return (spare, globalOrd, bucketOrd, docCount) -> { - spare.bucketOrd = bucketOrd; + return (spare, globalOrd, docCount) -> { oversizedCopy(lookupGlobalOrd.apply(globalOrd), spare.termBytes); spare.subsetDf = docCount; spare.supersetDf = backgroundFrequencies.freq(spare.termBytes); @@ -1015,7 +1050,7 @@ BucketUpdater bucketUpdater(long owningBucketOrd, } @Override - ObjectArrayPriorityQueue buildPriorityQueue(int size) { + ObjectArrayPriorityQueue> buildPriorityQueue(int size) { return new BucketSignificancePriorityQueue<>(size, bigArrays()); } @@ -1028,8 +1063,8 @@ SignificantStringTerms.Bucket convertTempBucketToRealBucket( } @Override - void buildSubAggs(ObjectArray topBucketsPreOrd) throws IOException { - buildSubAggsForAllBuckets(topBucketsPreOrd, b -> b.bucketOrd, (b, aggs) -> b.aggregations = aggs); + void buildSubAggs(ObjectArray topBucketsPreOrd, LongArray ordsArray) throws IOException { + buildSubAggsForAllBuckets(topBucketsPreOrd, ordsArray, (b, aggs) -> b.aggregations = aggs); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java index 78ae2481f5d99..5108793b8a809 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalSignificantTerms.java @@ -10,12 +10,12 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.ObjectArrayPriorityQueue; import org.elasticsearch.common.util.ObjectObjectPagedHashMap; import org.elasticsearch.core.Releasables; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.aggregations.AggregationErrors; import org.elasticsearch.search.aggregations.AggregationReduceContext; -import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.AggregatorReducer; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.InternalAggregations; @@ -58,12 +58,6 @@ public interface Reader> { long subsetDf; long supersetDf; - /** - * Ordinal of the bucket while it is being built. Not used after it is - * returned from {@link Aggregator#buildAggregations(org.elasticsearch.common.util.LongArray)} and not - * serialized. - */ - transient long bucketOrd; double score; protected InternalAggregations aggregations; final transient DocValueFormat format; @@ -235,7 +229,12 @@ canLeadReduction here is essentially checking if this shard returned data. Unma public InternalAggregation get() { final SignificanceHeuristic heuristic = getSignificanceHeuristic().rewrite(reduceContext); final int size = (int) (reduceContext.isFinalReduce() == false ? buckets.size() : Math.min(requiredSize, buckets.size())); - try (BucketSignificancePriorityQueue ordered = new BucketSignificancePriorityQueue<>(size, reduceContext.bigArrays())) { + try (ObjectArrayPriorityQueue ordered = new ObjectArrayPriorityQueue(size, reduceContext.bigArrays()) { + @Override + protected boolean lessThan(B a, B b) { + return a.getSignificanceScore() < b.getSignificanceScore(); + } + }) { buckets.forEach(entry -> { final B b = createBucket( entry.value.subsetDf[0], diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalTerms.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalTerms.java index 739f0b923eaab..de35046691b34 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalTerms.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/InternalTerms.java @@ -38,8 +38,6 @@ public interface Reader> { B read(StreamInput in, DocValueFormat format, boolean showDocCountError) throws IOException; } - long bucketOrd; - protected long docCount; private long docCountError; protected InternalAggregations aggregations; @@ -88,14 +86,6 @@ public void setDocCount(long docCount) { this.docCount = docCount; } - public long getBucketOrd() { - return bucketOrd; - } - - public void setBucketOrd(long bucketOrd) { - this.bucketOrd = bucketOrd; - } - @Override public long getDocCountError() { return docCountError; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java index b96c495d37489..026912a583ef3 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/MapStringTermsAggregator.java @@ -17,6 +17,7 @@ import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.PriorityQueue; +import org.elasticsearch.common.util.IntArray; import org.elasticsearch.common.util.LongArray; import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.common.util.ObjectArrayPriorityQueue; @@ -43,6 +44,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Comparator; import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Function; @@ -287,40 +289,55 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro LongArray otherDocCounts = bigArrays().newLongArray(owningBucketOrds.size(), true); ObjectArray topBucketsPerOrd = buildTopBucketsPerOrd(Math.toIntExact(owningBucketOrds.size())) ) { - for (long ordIdx = 0; ordIdx < topBucketsPerOrd.size(); ordIdx++) { - long owningOrd = owningBucketOrds.get(ordIdx); - collectZeroDocEntriesIfNeeded(owningOrd, excludeDeletedDocs); - int size = (int) Math.min(bucketOrds.size(), bucketCountThresholds.getShardSize()); - - try (ObjectArrayPriorityQueue ordered = buildPriorityQueue(size)) { - B spare = null; - BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningOrd); - BucketUpdater bucketUpdater = bucketUpdater(owningOrd); - while (ordsEnum.next()) { - long docCount = bucketDocCount(ordsEnum.ord()); - otherDocCounts.increment(ordIdx, docCount); - if (docCount < bucketCountThresholds.getShardMinDocCount()) { - continue; - } - if (spare == null) { - checkRealMemoryCBForInternalBucket(); - spare = buildEmptyBucket(); + try (IntArray bucketsToCollect = bigArrays().newIntArray(owningBucketOrds.size())) { + long ordsToCollect = 0; + for (long ordIdx = 0; ordIdx < owningBucketOrds.size(); ordIdx++) { + final long owningBucketOrd = owningBucketOrds.get(ordIdx); + collectZeroDocEntriesIfNeeded(owningBucketOrd, excludeDeletedDocs); + final int size = (int) Math.min(bucketOrds.bucketsInOrd(owningBucketOrd), bucketCountThresholds.getShardSize()); + ordsToCollect += size; + bucketsToCollect.set(ordIdx, size); + } + try (LongArray ordsArray = bigArrays().newLongArray(ordsToCollect)) { + long ordsCollected = 0; + for (long ordIdx = 0; ordIdx < topBucketsPerOrd.size(); ordIdx++) { + long owningOrd = owningBucketOrds.get(ordIdx); + try (ObjectArrayPriorityQueue> ordered = buildPriorityQueue(bucketsToCollect.get(ordIdx))) { + BucketAndOrd spare = null; + BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningOrd); + BucketUpdater bucketUpdater = bucketUpdater(owningOrd); + while (ordsEnum.next()) { + long docCount = bucketDocCount(ordsEnum.ord()); + otherDocCounts.increment(ordIdx, docCount); + if (docCount < bucketCountThresholds.getShardMinDocCount()) { + continue; + } + if (spare == null) { + checkRealMemoryCBForInternalBucket(); + spare = new BucketAndOrd<>(buildEmptyBucket()); + } + bucketUpdater.updateBucket(spare.bucket, ordsEnum, docCount); + spare.ord = ordsEnum.ord(); + spare = ordered.insertWithOverflow(spare); + } + + final int orderedSize = (int) ordered.size(); + final B[] buckets = buildBuckets(orderedSize); + for (int i = orderedSize - 1; i >= 0; --i) { + BucketAndOrd bucketAndOrd = ordered.pop(); + finalizeBucket(bucketAndOrd.bucket); + buckets[i] = bucketAndOrd.bucket; + ordsArray.set(ordsCollected + i, bucketAndOrd.ord); + otherDocCounts.increment(ordIdx, -bucketAndOrd.bucket.getDocCount()); + } + topBucketsPerOrd.set(ordIdx, buckets); + ordsCollected += orderedSize; } - bucketUpdater.updateBucket(spare, ordsEnum, docCount); - spare = ordered.insertWithOverflow(spare); - } - - topBucketsPerOrd.set(ordIdx, buildBuckets((int) ordered.size())); - for (int i = (int) ordered.size() - 1; i >= 0; --i) { - topBucketsPerOrd.get(ordIdx)[i] = ordered.pop(); - otherDocCounts.increment(ordIdx, -topBucketsPerOrd.get(ordIdx)[i].getDocCount()); - finalizeBucket(topBucketsPerOrd.get(ordIdx)[i]); } + assert ordsCollected == ordsArray.size(); + buildSubAggs(topBucketsPerOrd, ordsArray); } } - - buildSubAggs(topBucketsPerOrd); - return MapStringTermsAggregator.this.buildAggregations( Math.toIntExact(owningBucketOrds.size()), ordIdx -> buildResult(owningBucketOrds.get(ordIdx), otherDocCounts.get(ordIdx), topBucketsPerOrd.get(ordIdx)) @@ -355,7 +372,7 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro * Build a {@link PriorityQueue} to sort the buckets. After we've * collected all of the buckets we'll collect all entries in the queue. */ - abstract ObjectArrayPriorityQueue buildPriorityQueue(int size); + abstract ObjectArrayPriorityQueue> buildPriorityQueue(int size); /** * Update fields in {@code spare} to reflect information collected for @@ -382,9 +399,9 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro /** * Build the sub-aggregations into the buckets. This will usually - * delegate to {@link #buildSubAggsForAllBuckets}. + * delegate to {@link #buildSubAggsForAllBuckets(ObjectArray, LongArray, BiConsumer)}. */ - abstract void buildSubAggs(ObjectArray topBucketsPerOrd) throws IOException; + abstract void buildSubAggs(ObjectArray topBucketsPerOrd, LongArray ordsArray) throws IOException; /** * Turn the buckets into an aggregation result. @@ -407,9 +424,11 @@ interface BucketUpdater */ class StandardTermsResults extends ResultStrategy { private final ValuesSource valuesSource; + private final Comparator> comparator; - StandardTermsResults(ValuesSource valuesSource) { + StandardTermsResults(ValuesSource valuesSource, Aggregator aggregator) { this.valuesSource = valuesSource; + this.comparator = order.partiallyBuiltBucketComparator(aggregator); } @Override @@ -498,8 +517,8 @@ StringTerms.Bucket buildEmptyBucket() { } @Override - ObjectArrayPriorityQueue buildPriorityQueue(int size) { - return new BucketPriorityQueue<>(size, bigArrays(), partiallyBuiltBucketComparator); + ObjectArrayPriorityQueue> buildPriorityQueue(int size) { + return new BucketPriorityQueue<>(size, bigArrays(), comparator); } @Override @@ -507,7 +526,6 @@ BucketUpdater bucketUpdater(long owningBucketOrd) { return (spare, ordsEnum, docCount) -> { ordsEnum.readValue(spare.termBytes); spare.docCount = docCount; - spare.bucketOrd = ordsEnum.ord(); }; } @@ -532,8 +550,8 @@ void finalizeBucket(StringTerms.Bucket bucket) { } @Override - void buildSubAggs(ObjectArray topBucketsPerOrd) throws IOException { - buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, a) -> b.aggregations = a); + void buildSubAggs(ObjectArray topBucketsPerOrd, LongArray ordArray) throws IOException { + buildSubAggsForAllBuckets(topBucketsPerOrd, ordArray, (b, a) -> b.aggregations = a); } @Override @@ -625,7 +643,7 @@ SignificantStringTerms.Bucket buildEmptyBucket() { } @Override - ObjectArrayPriorityQueue buildPriorityQueue(int size) { + ObjectArrayPriorityQueue> buildPriorityQueue(int size) { return new BucketSignificancePriorityQueue<>(size, bigArrays()); } @@ -634,7 +652,6 @@ BucketUpdater bucketUpdater(long owningBucketOrd) long subsetSize = subsetSizes.get(owningBucketOrd); return (spare, ordsEnum, docCount) -> { ordsEnum.readValue(spare.termBytes); - spare.bucketOrd = ordsEnum.ord(); spare.subsetDf = docCount; spare.supersetDf = backgroundFrequencies.freq(spare.termBytes); /* @@ -667,8 +684,8 @@ void finalizeBucket(SignificantStringTerms.Bucket bucket) { } @Override - void buildSubAggs(ObjectArray topBucketsPerOrd) throws IOException { - buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, a) -> b.aggregations = a); + void buildSubAggs(ObjectArray topBucketsPerOrd, LongArray ordsArray) throws IOException { + buildSubAggsForAllBuckets(topBucketsPerOrd, ordsArray, (b, a) -> b.aggregations = a); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/NumericTermsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/NumericTermsAggregator.java index 5d4c15d8a3b80..a54053f712f8d 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/NumericTermsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/NumericTermsAggregator.java @@ -14,6 +14,7 @@ import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.util.NumericUtils; +import org.elasticsearch.common.util.IntArray; import org.elasticsearch.common.util.LongArray; import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.common.util.ObjectArrayPriorityQueue; @@ -40,6 +41,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Comparator; import java.util.Map; import java.util.function.BiConsumer; import java.util.function.Function; @@ -167,42 +169,56 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro LongArray otherDocCounts = bigArrays().newLongArray(owningBucketOrds.size(), true); ObjectArray topBucketsPerOrd = buildTopBucketsPerOrd(owningBucketOrds.size()) ) { - for (long ordIdx = 0; ordIdx < topBucketsPerOrd.size(); ordIdx++) { - final long owningBucketOrd = owningBucketOrds.get(ordIdx); - collectZeroDocEntriesIfNeeded(owningBucketOrd, excludeDeletedDocs); - long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrd); - - int size = (int) Math.min(bucketsInOrd, bucketCountThresholds.getShardSize()); - try (ObjectArrayPriorityQueue ordered = buildPriorityQueue(size)) { - B spare = null; - BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrd); - BucketUpdater bucketUpdater = bucketUpdater(owningBucketOrd); - while (ordsEnum.next()) { - long docCount = bucketDocCount(ordsEnum.ord()); - otherDocCounts.increment(ordIdx, docCount); - if (docCount < bucketCountThresholds.getShardMinDocCount()) { - continue; - } - if (spare == null) { - checkRealMemoryCBForInternalBucket(); - spare = buildEmptyBucket(); - } - bucketUpdater.updateBucket(spare, ordsEnum, docCount); - spare = ordered.insertWithOverflow(spare); - } + try (IntArray bucketsToCollect = bigArrays().newIntArray(owningBucketOrds.size())) { + long ordsToCollect = 0; + for (long ordIdx = 0; ordIdx < owningBucketOrds.size(); ordIdx++) { + final long owningBucketOrd = owningBucketOrds.get(ordIdx); + collectZeroDocEntriesIfNeeded(owningBucketOrd, excludeDeletedDocs); + int size = (int) Math.min(bucketOrds.bucketsInOrd(owningBucketOrd), bucketCountThresholds.getShardSize()); + bucketsToCollect.set(ordIdx, size); + ordsToCollect += size; + } + try (LongArray ordsArray = bigArrays().newLongArray(ordsToCollect)) { + long ordsCollected = 0; + for (long ordIdx = 0; ordIdx < topBucketsPerOrd.size(); ordIdx++) { + final long owningBucketOrd = owningBucketOrds.get(ordIdx); + try (ObjectArrayPriorityQueue> ordered = buildPriorityQueue(bucketsToCollect.get(ordIdx))) { + BucketAndOrd spare = null; + BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrd); + BucketUpdater bucketUpdater = bucketUpdater(owningBucketOrd); + while (ordsEnum.next()) { + long docCount = bucketDocCount(ordsEnum.ord()); + otherDocCounts.increment(ordIdx, docCount); + if (docCount < bucketCountThresholds.getShardMinDocCount()) { + continue; + } + if (spare == null) { + checkRealMemoryCBForInternalBucket(); + spare = new BucketAndOrd<>(buildEmptyBucket()); + } + bucketUpdater.updateBucket(spare.bucket, ordsEnum, docCount); + spare.ord = ordsEnum.ord(); + spare = ordered.insertWithOverflow(spare); + } + + // Get the top buckets + final int orderedSize = (int) ordered.size(); + final B[] bucketsForOrd = buildBuckets(orderedSize); + for (int b = orderedSize - 1; b >= 0; --b) { + BucketAndOrd bucketAndOrd = ordered.pop(); + bucketsForOrd[b] = bucketAndOrd.bucket; + ordsArray.set(ordsCollected + b, bucketAndOrd.ord); + otherDocCounts.increment(ordIdx, -bucketAndOrd.bucket.getDocCount()); + } + topBucketsPerOrd.set(ordIdx, bucketsForOrd); + ordsCollected += orderedSize; - // Get the top buckets - B[] bucketsForOrd = buildBuckets((int) ordered.size()); - topBucketsPerOrd.set(ordIdx, bucketsForOrd); - for (int b = (int) ordered.size() - 1; b >= 0; --b) { - topBucketsPerOrd.get(ordIdx)[b] = ordered.pop(); - otherDocCounts.increment(ordIdx, -topBucketsPerOrd.get(ordIdx)[b].getDocCount()); + } } + assert ordsCollected == ordsArray.size(); + buildSubAggs(topBucketsPerOrd, ordsArray); } } - - buildSubAggs(topBucketsPerOrd); - return NumericTermsAggregator.this.buildAggregations( Math.toIntExact(owningBucketOrds.size()), ordIdx -> buildResult(owningBucketOrds.get(ordIdx), otherDocCounts.get(ordIdx), topBucketsPerOrd.get(ordIdx)) @@ -254,13 +270,13 @@ private InternalAggregation[] buildAggregations(LongArray owningBucketOrds) thro * Build a {@link ObjectArrayPriorityQueue} to sort the buckets. After we've * collected all of the buckets we'll collect all entries in the queue. */ - abstract ObjectArrayPriorityQueue buildPriorityQueue(int size); + abstract ObjectArrayPriorityQueue> buildPriorityQueue(int size); /** * Build the sub-aggregations into the buckets. This will usually - * delegate to {@link #buildSubAggsForAllBuckets}. + * delegate to {@link #buildSubAggsForAllBuckets(ObjectArray, LongArray, BiConsumer)}. */ - abstract void buildSubAggs(ObjectArray topBucketsPerOrd) throws IOException; + abstract void buildSubAggs(ObjectArray topBucketsPerOrd, LongArray ordsArray) throws IOException; /** * Collect extra entries for "zero" hit documents if they were requested @@ -287,9 +303,11 @@ interface BucketUpdater abstract class StandardTermsResultStrategy, B extends InternalTerms.Bucket> extends ResultStrategy { protected final boolean showTermDocCountError; + private final Comparator> comparator; - StandardTermsResultStrategy(boolean showTermDocCountError) { + StandardTermsResultStrategy(boolean showTermDocCountError, Aggregator aggregator) { this.showTermDocCountError = showTermDocCountError; + this.comparator = order.partiallyBuiltBucketComparator(aggregator); } @Override @@ -298,13 +316,13 @@ final LeafBucketCollector wrapCollector(LeafBucketCollector primary) { } @Override - final ObjectArrayPriorityQueue buildPriorityQueue(int size) { - return new BucketPriorityQueue<>(size, bigArrays(), partiallyBuiltBucketComparator); + final ObjectArrayPriorityQueue> buildPriorityQueue(int size) { + return new BucketPriorityQueue<>(size, bigArrays(), comparator); } @Override - final void buildSubAggs(ObjectArray topBucketsPerOrd) throws IOException { - buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, aggs) -> b.aggregations = aggs); + final void buildSubAggs(ObjectArray topBucketsPerOrd, LongArray ordsArray) throws IOException { + buildSubAggsForAllBuckets(topBucketsPerOrd, ordsArray, (b, aggs) -> b.aggregations = aggs); } @Override @@ -340,8 +358,8 @@ public final void close() {} } class LongTermsResults extends StandardTermsResultStrategy { - LongTermsResults(boolean showTermDocCountError) { - super(showTermDocCountError); + LongTermsResults(boolean showTermDocCountError, Aggregator aggregator) { + super(showTermDocCountError, aggregator); } @Override @@ -374,7 +392,6 @@ BucketUpdater bucketUpdater(long owningBucketOrd) { return (LongTerms.Bucket spare, BucketOrdsEnum ordsEnum, long docCount) -> { spare.term = ordsEnum.value(); spare.docCount = docCount; - spare.bucketOrd = ordsEnum.ord(); }; } @@ -424,8 +441,8 @@ LongTerms buildEmptyResult() { class DoubleTermsResults extends StandardTermsResultStrategy { - DoubleTermsResults(boolean showTermDocCountError) { - super(showTermDocCountError); + DoubleTermsResults(boolean showTermDocCountError, Aggregator aggregator) { + super(showTermDocCountError, aggregator); } @Override @@ -458,7 +475,6 @@ BucketUpdater bucketUpdater(long owningBucketOrd) { return (DoubleTerms.Bucket spare, BucketOrdsEnum ordsEnum, long docCount) -> { spare.term = NumericUtils.sortableLongToDouble(ordsEnum.value()); spare.docCount = docCount; - spare.bucketOrd = ordsEnum.ord(); }; } @@ -575,7 +591,6 @@ BucketUpdater bucketUpdater(long owningBucketOrd) { spare.term = ordsEnum.value(); spare.subsetDf = docCount; spare.supersetDf = backgroundFrequencies.freq(spare.term); - spare.bucketOrd = ordsEnum.ord(); // During shard-local down-selection we use subset/superset stats that are for this shard only // Back at the central reducer these properties will be updated with global stats spare.updateScore(significanceHeuristic, subsetSize, supersetSize); @@ -583,13 +598,13 @@ BucketUpdater bucketUpdater(long owningBucketOrd) { } @Override - ObjectArrayPriorityQueue buildPriorityQueue(int size) { + ObjectArrayPriorityQueue> buildPriorityQueue(int size) { return new BucketSignificancePriorityQueue<>(size, bigArrays()); } @Override - void buildSubAggs(ObjectArray topBucketsPerOrd) throws IOException { - buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, aggs) -> b.aggregations = aggs); + void buildSubAggs(ObjectArray topBucketsPerOrd, LongArray ordsArray) throws IOException { + buildSubAggsForAllBuckets(topBucketsPerOrd, ordsArray, (b, aggs) -> b.aggregations = aggs); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregator.java index 4922be7cec1ba..c07c0726a4ae1 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregator.java @@ -27,7 +27,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; -import java.util.Comparator; import java.util.HashSet; import java.util.Map; import java.util.Objects; @@ -190,7 +189,6 @@ public boolean equals(Object obj) { protected final DocValueFormat format; protected final BucketCountThresholds bucketCountThresholds; protected final BucketOrder order; - protected final Comparator> partiallyBuiltBucketComparator; protected final Set aggsUsedForSorting; protected final SubAggCollectionMode collectMode; @@ -209,7 +207,9 @@ public TermsAggregator( super(name, factories, context, parent, metadata); this.bucketCountThresholds = bucketCountThresholds; this.order = order; - partiallyBuiltBucketComparator = order == null ? null : order.partiallyBuiltBucketComparator(b -> b.bucketOrd, this); + if (order != null) { + order.validate(this); + } this.format = format; if ((subAggsNeedScore() && descendsFromNestedAggregator(parent)) || context.isInSortOrderExecutionRequired()) { /** diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java index 2c7b768fcdbb3..da5ae37b08228 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/terms/TermsAggregatorFactory.java @@ -195,12 +195,12 @@ private static TermsAggregatorSupplier numericSupplier() { if (includeExclude != null) { longFilter = includeExclude.convertToDoubleFilter(); } - resultStrategy = agg -> agg.new DoubleTermsResults(showTermDocCountError); + resultStrategy = agg -> agg.new DoubleTermsResults(showTermDocCountError, agg); } else { if (includeExclude != null) { longFilter = includeExclude.convertToLongFilter(valuesSourceConfig.format()); } - resultStrategy = agg -> agg.new LongTermsResults(showTermDocCountError); + resultStrategy = agg -> agg.new LongTermsResults(showTermDocCountError, agg); } return new NumericTermsAggregator( name, @@ -403,7 +403,7 @@ Aggregator create( name, factories, new MapStringTermsAggregator.ValuesSourceCollectorSource(valuesSourceConfig), - a -> a.new StandardTermsResults(valuesSourceConfig.getValuesSource()), + a -> a.new StandardTermsResults(valuesSourceConfig.getValuesSource(), a), order, valuesSourceConfig.format(), bucketCountThresholds, diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index 098a2b2f45d2f..3554a6dc08b90 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -19,7 +19,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Booleans; @@ -92,7 +91,6 @@ * @see SearchRequest#source(SearchSourceBuilder) */ public final class SearchSourceBuilder implements Writeable, ToXContentObject, Rewriteable { - private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(SearchSourceBuilder.class); public static final ParseField FROM_FIELD = new ParseField("from"); public static final ParseField SIZE_FIELD = new ParseField("size"); diff --git a/server/src/main/java/org/elasticsearch/search/lookup/SearchLookup.java b/server/src/main/java/org/elasticsearch/search/lookup/SearchLookup.java index f7f8cee30ee15..9eb0170af5efb 100644 --- a/server/src/main/java/org/elasticsearch/search/lookup/SearchLookup.java +++ b/server/src/main/java/org/elasticsearch/search/lookup/SearchLookup.java @@ -102,6 +102,14 @@ private SearchLookup(SearchLookup searchLookup, Set fieldChain) { this.fieldLookupProvider = searchLookup.fieldLookupProvider; } + private SearchLookup(SearchLookup searchLookup, SourceProvider sourceProvider, Set fieldChain) { + this.fieldChain = Collections.unmodifiableSet(fieldChain); + this.sourceProvider = sourceProvider; + this.fieldTypeLookup = searchLookup.fieldTypeLookup; + this.fieldDataLookup = searchLookup.fieldDataLookup; + this.fieldLookupProvider = searchLookup.fieldLookupProvider; + } + /** * Creates a copy of the current {@link SearchLookup} that looks fields up in the same way, but also tracks field references * in order to detect cycles and prevent resolving fields that depend on more than {@link #MAX_FIELD_CHAIN_DEPTH} other fields. @@ -144,4 +152,8 @@ public IndexFieldData getForField(MappedFieldType fieldType, MappedFieldType. public Source getSource(LeafReaderContext ctx, int doc) throws IOException { return sourceProvider.getSource(ctx, doc); } + + public SearchLookup swapSourceProvider(SourceProvider sourceProvider) { + return new SearchLookup(this, sourceProvider, fieldChain); + } } diff --git a/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java b/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java index 9ab14aa9362b5..d4127836a4e4a 100644 --- a/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java +++ b/server/src/main/java/org/elasticsearch/search/rank/RankDoc.java @@ -44,7 +44,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.RANK_DOCS_RETRIEVER; + return TransportVersions.V_8_16_0; } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index db839de9f573a..2ab6395db73b5 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.TransportMultiSearchAction; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.rest.RestStatus; @@ -46,6 +47,8 @@ */ public abstract class CompoundRetrieverBuilder> extends RetrieverBuilder { + public static final NodeFeature INNER_RETRIEVERS_FILTER_SUPPORT = new NodeFeature("inner_retrievers_filter_support"); + public record RetrieverSource(RetrieverBuilder retriever, SearchSourceBuilder source) {} protected final int rankWindowSize; @@ -64,9 +67,9 @@ public T addChild(RetrieverBuilder retrieverBuilder) { /** * Returns a clone of the original retriever, replacing the sub-retrievers with - * the provided {@code newChildRetrievers}. + * the provided {@code newChildRetrievers} and the filters with the {@code newPreFilterQueryBuilders}. */ - protected abstract T clone(List newChildRetrievers); + protected abstract T clone(List newChildRetrievers, List newPreFilterQueryBuilders); /** * Combines the provided {@code rankResults} to return the final top documents. @@ -85,13 +88,25 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio } // Rewrite prefilters - boolean hasChanged = false; + // We eagerly rewrite prefilters, because some of the innerRetrievers + // could be compound too, so we want to propagate all the necessary filter information to them + // and have it available as part of their own rewrite step var newPreFilters = rewritePreFilters(ctx); - hasChanged |= newPreFilters != preFilterQueryBuilders; + if (newPreFilters != preFilterQueryBuilders) { + return clone(innerRetrievers, newPreFilters); + } + boolean hasChanged = false; // Rewrite retriever sources List newRetrievers = new ArrayList<>(); for (var entry : innerRetrievers) { + // we propagate the filters only for compound retrievers as they won't be attached through + // the createSearchSourceBuilder. + // We could remove this check, but we would end up adding the same filters + // multiple times in case an inner retriever rewrites itself, when we re-enter to rewrite + if (entry.retriever.isCompound() && false == preFilterQueryBuilders.isEmpty()) { + entry.retriever.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders); + } RetrieverBuilder newRetriever = entry.retriever.rewrite(ctx); if (newRetriever != entry.retriever) { newRetrievers.add(new RetrieverSource(newRetriever, null)); @@ -106,7 +121,7 @@ public final RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOExceptio } } if (hasChanged) { - return clone(newRetrievers); + return clone(newRetrievers, newPreFilters); } // execute searches @@ -166,12 +181,7 @@ public void onFailure(Exception e) { }); }); - return new RankDocsRetrieverBuilder( - rankWindowSize, - newRetrievers.stream().map(s -> s.retriever).toList(), - results::get, - newPreFilters - ); + return new RankDocsRetrieverBuilder(rankWindowSize, newRetrievers.stream().map(s -> s.retriever).toList(), results::get); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java index 8be9a78dae154..f1464c41ca3be 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java @@ -184,8 +184,7 @@ public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { ll.onResponse(null); })); }); - var rewritten = new KnnRetrieverBuilder(this, () -> toSet.get(), null); - return rewritten; + return new KnnRetrieverBuilder(this, () -> toSet.get(), null); } return super.rewrite(ctx); } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java index 02f890f51d011..4d3f3fefd4462 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilder.java @@ -33,19 +33,13 @@ public class RankDocsRetrieverBuilder extends RetrieverBuilder { final List sources; final Supplier rankDocs; - public RankDocsRetrieverBuilder( - int rankWindowSize, - List sources, - Supplier rankDocs, - List preFilterQueryBuilders - ) { + public RankDocsRetrieverBuilder(int rankWindowSize, List sources, Supplier rankDocs) { this.rankWindowSize = rankWindowSize; this.rankDocs = rankDocs; if (sources == null || sources.isEmpty()) { throw new IllegalArgumentException("sources must not be null or empty"); } this.sources = sources; - this.preFilterQueryBuilders = preFilterQueryBuilders; } @Override @@ -73,10 +67,6 @@ private boolean sourceShouldRewrite(QueryRewriteContext ctx) throws IOException @Override public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { assert false == sourceShouldRewrite(ctx) : "retriever sources should be rewritten first"; - var rewrittenFilters = rewritePreFilters(ctx); - if (rewrittenFilters != preFilterQueryBuilders) { - return new RankDocsRetrieverBuilder(rankWindowSize, sources, rankDocs, rewrittenFilters); - } return this; } @@ -94,7 +84,7 @@ public QueryBuilder topDocsQuery() { boolQuery.should(query); } } - // ignore prefilters of this level, they are already propagated to children + // ignore prefilters of this level, they were already propagated to children return boolQuery; } @@ -133,7 +123,7 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder } else { rankQuery = new RankDocsQueryBuilder(rankDocResults, null, false); } - // ignore prefilters of this level, they are already propagated to children + // ignore prefilters of this level, they were already propagated to children searchSourceBuilder.query(rankQuery); if (sourceHasMinScore()) { searchSourceBuilder.minScore(this.minScore() == null ? Float.MIN_VALUE : this.minScore()); diff --git a/server/src/main/java/org/elasticsearch/search/sort/GeoDistanceSortBuilder.java b/server/src/main/java/org/elasticsearch/search/sort/GeoDistanceSortBuilder.java index 6640f0f858404..2aaade35fb8f3 100644 --- a/server/src/main/java/org/elasticsearch/search/sort/GeoDistanceSortBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/sort/GeoDistanceSortBuilder.java @@ -28,7 +28,6 @@ import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.unit.DistanceUnit; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.index.fielddata.FieldData; @@ -67,7 +66,6 @@ * A geo distance based sorting on a geo point like field. */ public class GeoDistanceSortBuilder extends SortBuilder { - private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(GeoDistanceSortBuilder.class); public static final String NAME = "_geo_distance"; public static final String ALTERNATIVE_NAME = "_geoDistance"; diff --git a/server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java index c8670a8dfeec2..77d708432cf26 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java @@ -55,8 +55,7 @@ public ExactKnnQueryBuilder(StreamInput in) throws IOException { this.query = VectorData.fromFloats(in.readFloatArray()); } this.field = in.readString(); - if (in.getTransportVersion().onOrAfter(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS) - || in.getTransportVersion().isPatchFrom(TransportVersions.V_8_15_0)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { this.vectorSimilarity = in.readOptionalFloat(); } else { this.vectorSimilarity = null; @@ -88,8 +87,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeFloatArray(query.asFloatVector()); } out.writeString(field); - if (out.getTransportVersion().onOrAfter(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS) - || out.getTransportVersion().isPatchFrom(TransportVersions.V_8_15_0)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { out.writeOptionalFloat(vectorSimilarity); } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java index f52addefc8b1c..b5ba97906f0ec 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java @@ -71,8 +71,7 @@ public KnnScoreDocQueryBuilder(StreamInput in) throws IOException { this.fieldName = null; this.queryVector = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS) - || in.getTransportVersion().isPatchFrom(TransportVersions.V_8_15_0)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { this.vectorSimilarity = in.readOptionalFloat(); } else { this.vectorSimilarity = null; @@ -116,8 +115,7 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeBoolean(false); } } - if (out.getTransportVersion().onOrAfter(TransportVersions.FIX_VECTOR_SIMILARITY_INNER_HITS) - || out.getTransportVersion().isPatchFrom(TransportVersions.V_8_15_0)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { out.writeOptionalFloat(vectorSimilarity); } } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index deb7e6bd035b8..5dd2cbf32dd12 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -481,10 +481,9 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { } parentBitSet = context.bitsetFilter(parentFilter); if (filterQuery != null) { - NestedHelper nestedHelper = new NestedHelper(context.nestedLookup(), context::isFieldMapped); // We treat the provided filter as a filter over PARENT documents, so if it might match nested documents // we need to adjust it. - if (nestedHelper.mightMatchNestedDocs(filterQuery)) { + if (NestedHelper.mightMatchNestedDocs(filterQuery, context)) { // Ensure that the query only returns parent documents matching `filterQuery` filterQuery = Queries.filtered(filterQuery, parentFilter); } diff --git a/server/src/main/java/org/elasticsearch/snapshots/RegisteredPolicySnapshots.java b/server/src/main/java/org/elasticsearch/snapshots/RegisteredPolicySnapshots.java index f34b876697473..231894875b7fa 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/RegisteredPolicySnapshots.java +++ b/server/src/main/java/org/elasticsearch/snapshots/RegisteredPolicySnapshots.java @@ -101,7 +101,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.REGISTER_SLM_STATS; + return TransportVersions.V_8_16_0; } @Override @@ -171,7 +171,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.REGISTER_SLM_STATS; + return TransportVersions.V_8_16_0; } } diff --git a/server/src/test/java/org/elasticsearch/TransportVersionTests.java b/server/src/test/java/org/elasticsearch/TransportVersionTests.java index 6c2cc5c1f4cc0..08b12cec2e17e 100644 --- a/server/src/test/java/org/elasticsearch/TransportVersionTests.java +++ b/server/src/test/java/org/elasticsearch/TransportVersionTests.java @@ -211,7 +211,7 @@ public void testDenseTransportVersions() { Set missingVersions = new TreeSet<>(); TransportVersion previous = null; for (var tv : TransportVersions.getAllVersions()) { - if (tv.before(TransportVersions.V_8_15_2)) { + if (tv.before(TransportVersions.V_8_16_0)) { continue; } if (previous == null) { diff --git a/server/src/test/java/org/elasticsearch/VersionTests.java b/server/src/test/java/org/elasticsearch/VersionTests.java index 0b35a3cc23c16..5e10a7d37aea1 100644 --- a/server/src/test/java/org/elasticsearch/VersionTests.java +++ b/server/src/test/java/org/elasticsearch/VersionTests.java @@ -179,8 +179,7 @@ public void testParseVersion() { } public void testAllVersionsMatchId() throws Exception { - final Set releasedVersions = new HashSet<>(VersionUtils.allReleasedVersions()); - final Set unreleasedVersions = new HashSet<>(VersionUtils.allUnreleasedVersions()); + final Set versions = new HashSet<>(VersionUtils.allVersions()); Map maxBranchVersions = new HashMap<>(); for (java.lang.reflect.Field field : Version.class.getFields()) { if (field.getName().matches("_ID")) { @@ -195,43 +194,15 @@ public void testAllVersionsMatchId() throws Exception { Version v = (Version) versionConstant.get(null); logger.debug("Checking {}", v); - if (field.getName().endsWith("_UNRELEASED")) { - assertTrue(unreleasedVersions.contains(v)); - } else { - assertTrue(releasedVersions.contains(v)); - } + assertTrue(versions.contains(v)); assertEquals("Version id " + field.getName() + " does not point to " + constantName, v, Version.fromId(versionId)); assertEquals("Version " + constantName + " does not have correct id", versionId, v.id); String number = v.toString(); assertEquals("V_" + number.replace('.', '_'), constantName); - - // only the latest version for a branch should be a snapshot (ie unreleased) - String branchName = "" + v.major + "." + v.minor; - Version maxBranchVersion = maxBranchVersions.get(branchName); - if (maxBranchVersion == null) { - maxBranchVersions.put(branchName, v); - } else if (v.after(maxBranchVersion)) { - if (v == Version.CURRENT) { - // Current is weird - it counts as released even though it shouldn't. - continue; - } - assertFalse( - "Version " + maxBranchVersion + " cannot be a snapshot because version " + v + " exists", - VersionUtils.allUnreleasedVersions().contains(maxBranchVersion) - ); - maxBranchVersions.put(branchName, v); - } } } } - public static void assertUnknownVersion(Version version) { - assertFalse( - "Version " + version + " has been releaed don't use a new instance of this version", - VersionUtils.allReleasedVersions().contains(version) - ); - } - public void testIsCompatible() { assertTrue(isCompatible(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion())); assertFalse(isCompatible(Version.V_7_0_0, Version.V_8_0_0)); @@ -279,14 +250,6 @@ public boolean isCompatible(Version left, Version right) { return result; } - // This exists because 5.1.0 was never released due to a mistake in the release process. - // This verifies that we never declare the version as "released" accidentally. - // It would never pass qa tests later on, but those come very far in the build and this is quick to check now. - public void testUnreleasedVersion() { - Version VERSION_5_1_0_UNRELEASED = Version.fromString("5.1.0"); - VersionTests.assertUnknownVersion(VERSION_5_1_0_UNRELEASED); - } - public void testIllegalMinorAndPatchNumbers() { IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> Version.fromString("8.2.999")); assertThat( diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsRequestParametersTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsRequestParametersTests.java index f37b1d1b41712..cfdbfdfbfcf8c 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsRequestParametersTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodesStatsRequestParametersTests.java @@ -23,7 +23,7 @@ public class NodesStatsRequestParametersTests extends ESTestCase { public void testReadWriteMetricSet() { - for (var version : List.of(TransportVersions.V_8_15_0, TransportVersions.NODES_STATS_ENUM_SET)) { + for (var version : List.of(TransportVersions.V_8_15_0, TransportVersions.V_8_16_0)) { var randSet = randomSubsetOf(Metric.ALL); var metricsOut = randSet.isEmpty() ? EnumSet.noneOf(Metric.class) : EnumSet.copyOf(randSet); try { diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStatsTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStatsTests.java index 89ccd4ab63d7f..46b757407e6a9 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStatsTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/stats/SearchUsageStatsTests.java @@ -199,7 +199,7 @@ public void testSerializationBWC() throws IOException { randomQueryUsage(QUERY_TYPES.size()), version.onOrAfter(TransportVersions.V_8_12_0) ? randomRescorerUsage(RESCORER_TYPES.size()) : Map.of(), randomSectionsUsage(SECTIONS.size()), - version.onOrAfter(TransportVersions.RETRIEVERS_TELEMETRY_ADDED) ? randomRetrieversUsage(RETRIEVERS.size()) : Map.of(), + version.onOrAfter(TransportVersions.V_8_16_0) ? randomRetrieversUsage(RETRIEVERS.size()) : Map.of(), randomLongBetween(0, Long.MAX_VALUE) ); assertSerialization(testInstance, version); diff --git a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java index 484b3c6b386fd..7a38858d8477a 100644 --- a/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java +++ b/server/src/test/java/org/elasticsearch/action/search/MockSearchPhaseContext.java @@ -154,11 +154,6 @@ protected void executePhaseOnShard( }, shardIt); } - @Override - public void onFailure(Exception e) { - Assert.fail("should not be called"); - } - @Override public void sendReleaseSearchContext(ShardSearchContextId contextId, Transport.Connection connection, OriginalIndices originalIndices) { releasedSearchContexts.add(contextId); diff --git a/server/src/test/java/org/elasticsearch/bootstrap/PluginsResolverTests.java b/server/src/test/java/org/elasticsearch/bootstrap/PluginsResolverTests.java new file mode 100644 index 0000000000000..798b576500d72 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/bootstrap/PluginsResolverTests.java @@ -0,0 +1,257 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.bootstrap; + +import org.elasticsearch.plugins.PluginBundle; +import org.elasticsearch.plugins.PluginDescriptor; +import org.elasticsearch.plugins.PluginsLoader; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.compiler.InMemoryJavaCompiler; +import org.elasticsearch.test.jar.JarUtils; + +import java.io.IOException; +import java.lang.module.Configuration; +import java.lang.module.ModuleFinder; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; + +import static java.util.Map.entry; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@ESTestCase.WithoutSecurityManager +public class PluginsResolverTests extends ESTestCase { + + private record TestPluginLayer(PluginBundle pluginBundle, ClassLoader pluginClassLoader, ModuleLayer pluginModuleLayer) + implements + PluginsLoader.PluginLayer {} + + public void testResolveModularPlugin() throws IOException, ClassNotFoundException { + String moduleName = "modular.plugin"; + String pluginName = "modular-plugin"; + + final Path home = createTempDir(); + + Path jar = createModularPluginJar(home, pluginName, moduleName, "p", "A"); + + var layer = createModuleLayer(moduleName, jar); + var loader = layer.findLoader(moduleName); + + PluginBundle bundle = createMockBundle(pluginName, moduleName, "p.A"); + PluginsLoader mockPluginsLoader = mock(PluginsLoader.class); + + when(mockPluginsLoader.pluginLayers()).thenReturn(Stream.of(new TestPluginLayer(bundle, loader, layer))); + PluginsResolver pluginsResolver = PluginsResolver.create(mockPluginsLoader); + + var testClass = loader.loadClass("p.A"); + var resolvedPluginName = pluginsResolver.resolveClassToPluginName(testClass); + var unresolvedPluginName1 = pluginsResolver.resolveClassToPluginName(PluginsResolver.class); + var unresolvedPluginName2 = pluginsResolver.resolveClassToPluginName(String.class); + + assertEquals(pluginName, resolvedPluginName); + assertNull(unresolvedPluginName1); + assertNull(unresolvedPluginName2); + } + + public void testResolveMultipleModularPlugins() throws IOException, ClassNotFoundException { + final Path home = createTempDir(); + + Path jar1 = createModularPluginJar(home, "plugin1", "module.one", "p", "A"); + Path jar2 = createModularPluginJar(home, "plugin2", "module.two", "q", "B"); + + var layer1 = createModuleLayer("module.one", jar1); + var loader1 = layer1.findLoader("module.one"); + var layer2 = createModuleLayer("module.two", jar2); + var loader2 = layer2.findLoader("module.two"); + + PluginBundle bundle1 = createMockBundle("plugin1", "module.one", "p.A"); + PluginBundle bundle2 = createMockBundle("plugin2", "module.two", "q.B"); + PluginsLoader mockPluginsLoader = mock(PluginsLoader.class); + + when(mockPluginsLoader.pluginLayers()).thenReturn( + Stream.of(new TestPluginLayer(bundle1, loader1, layer1), new TestPluginLayer(bundle2, loader2, layer2)) + ); + PluginsResolver pluginsResolver = PluginsResolver.create(mockPluginsLoader); + + var testClass1 = loader1.loadClass("p.A"); + var testClass2 = loader2.loadClass("q.B"); + var resolvedPluginName1 = pluginsResolver.resolveClassToPluginName(testClass1); + var resolvedPluginName2 = pluginsResolver.resolveClassToPluginName(testClass2); + + assertEquals("plugin1", resolvedPluginName1); + assertEquals("plugin2", resolvedPluginName2); + } + + public void testResolveReferencedModulesInModularPlugins() throws IOException, ClassNotFoundException { + final Path home = createTempDir(); + + Path dependencyJar = createModularPluginJar(home, "plugin1", "module.one", "p", "A"); + Path pluginJar = home.resolve("plugin2.jar"); + + Map sources = Map.ofEntries( + entry("module-info", "module module.two { exports q; requires module.one; }"), + entry("q.B", "package q; public class B { public p.A a = null; }") + ); + + var classToBytes = InMemoryJavaCompiler.compile(sources, "--add-modules", "module.one", "-p", home.toString()); + JarUtils.createJarWithEntries( + pluginJar, + Map.ofEntries(entry("module-info.class", classToBytes.get("module-info")), entry("q/B.class", classToBytes.get("q.B"))) + ); + + var layer = createModuleLayer("module.two", pluginJar, dependencyJar); + var loader = layer.findLoader("module.two"); + + PluginBundle bundle = createMockBundle("plugin2", "module.two", "q.B"); + PluginsLoader mockPluginsLoader = mock(PluginsLoader.class); + + when(mockPluginsLoader.pluginLayers()).thenReturn(Stream.of(new TestPluginLayer(bundle, loader, layer))); + PluginsResolver pluginsResolver = PluginsResolver.create(mockPluginsLoader); + + var testClass1 = loader.loadClass("p.A"); + var testClass2 = loader.loadClass("q.B"); + var resolvedPluginName1 = pluginsResolver.resolveClassToPluginName(testClass1); + var resolvedPluginName2 = pluginsResolver.resolveClassToPluginName(testClass2); + + assertEquals("plugin2", resolvedPluginName1); + assertEquals("plugin2", resolvedPluginName2); + } + + public void testResolveMultipleNonModularPlugins() throws IOException, ClassNotFoundException { + final Path home = createTempDir(); + + Path jar1 = createNonModularPluginJar(home, "plugin1", "p", "A"); + Path jar2 = createNonModularPluginJar(home, "plugin2", "q", "B"); + + try (var loader1 = createClassLoader(jar1); var loader2 = createClassLoader(jar2)) { + + PluginBundle bundle1 = createMockBundle("plugin1", null, "p.A"); + PluginBundle bundle2 = createMockBundle("plugin2", null, "q.B"); + PluginsLoader mockPluginsLoader = mock(PluginsLoader.class); + + when(mockPluginsLoader.pluginLayers()).thenReturn( + Stream.of( + new TestPluginLayer(bundle1, loader1, ModuleLayer.boot()), + new TestPluginLayer(bundle2, loader2, ModuleLayer.boot()) + ) + ); + PluginsResolver pluginsResolver = PluginsResolver.create(mockPluginsLoader); + + var testClass1 = loader1.loadClass("p.A"); + var testClass2 = loader2.loadClass("q.B"); + var resolvedPluginName1 = pluginsResolver.resolveClassToPluginName(testClass1); + var resolvedPluginName2 = pluginsResolver.resolveClassToPluginName(testClass2); + + assertEquals("plugin1", resolvedPluginName1); + assertEquals("plugin2", resolvedPluginName2); + } + } + + public void testResolveNonModularPlugin() throws IOException, ClassNotFoundException { + String pluginName = "non-modular-plugin"; + + final Path home = createTempDir(); + + Path jar = createNonModularPluginJar(home, pluginName, "p", "A"); + + try (var loader = createClassLoader(jar)) { + PluginBundle bundle = createMockBundle(pluginName, null, "p.A"); + PluginsLoader mockPluginsLoader = mock(PluginsLoader.class); + + when(mockPluginsLoader.pluginLayers()).thenReturn(Stream.of(new TestPluginLayer(bundle, loader, ModuleLayer.boot()))); + PluginsResolver pluginsResolver = PluginsResolver.create(mockPluginsLoader); + + var testClass = loader.loadClass("p.A"); + var resolvedPluginName = pluginsResolver.resolveClassToPluginName(testClass); + var unresolvedPluginName1 = pluginsResolver.resolveClassToPluginName(PluginsResolver.class); + var unresolvedPluginName2 = pluginsResolver.resolveClassToPluginName(String.class); + + assertEquals(pluginName, resolvedPluginName); + assertNull(unresolvedPluginName1); + assertNull(unresolvedPluginName2); + } + } + + private static URLClassLoader createClassLoader(Path jar) throws MalformedURLException { + return new URLClassLoader(new URL[] { jar.toUri().toURL() }); + } + + private static ModuleLayer createModuleLayer(String moduleName, Path... jars) { + var finder = ModuleFinder.of(jars); + Configuration cf = ModuleLayer.boot().configuration().resolve(finder, ModuleFinder.of(), Set.of(moduleName)); + var moduleController = ModuleLayer.defineModulesWithOneLoader( + cf, + List.of(ModuleLayer.boot()), + ClassLoader.getPlatformClassLoader() + ); + return moduleController.layer(); + } + + private static PluginBundle createMockBundle(String pluginName, String moduleName, String fqClassName) { + PluginDescriptor pd = new PluginDescriptor( + pluginName, + null, + null, + null, + null, + fqClassName, + moduleName, + List.of(), + false, + false, + true, + false + ); + + PluginBundle bundle = mock(PluginBundle.class); + when(bundle.pluginDescriptor()).thenReturn(pd); + return bundle; + } + + private static Path createModularPluginJar(Path home, String pluginName, String moduleName, String packageName, String className) + throws IOException { + Path jar = home.resolve(pluginName + ".jar"); + String fqClassName = packageName + "." + className; + + Map sources = Map.ofEntries( + entry("module-info", "module " + moduleName + " { exports " + packageName + "; }"), + entry(fqClassName, "package " + packageName + "; public class " + className + " {}") + ); + + var classToBytes = InMemoryJavaCompiler.compile(sources); + JarUtils.createJarWithEntries( + jar, + Map.ofEntries( + entry("module-info.class", classToBytes.get("module-info")), + entry(packageName + "/" + className + ".class", classToBytes.get(fqClassName)) + ) + ); + return jar; + } + + private static Path createNonModularPluginJar(Path home, String pluginName, String packageName, String className) throws IOException { + Path jar = home.resolve(pluginName + ".jar"); + String fqClassName = packageName + "." + className; + + Map sources = Map.ofEntries( + entry(fqClassName, "package " + packageName + "; public class " + className + " {}") + ); + + var classToBytes = InMemoryJavaCompiler.compile(sources); + JarUtils.createJarWithEntries(jar, Map.ofEntries(entry(packageName + "/" + className + ".class", classToBytes.get(fqClassName)))); + return jar; + } +} diff --git a/server/src/test/java/org/elasticsearch/common/io/stream/AbstractStreamTests.java b/server/src/test/java/org/elasticsearch/common/io/stream/AbstractStreamTests.java index d2b6d0a6ec6d7..afaa7a9a32888 100644 --- a/server/src/test/java/org/elasticsearch/common/io/stream/AbstractStreamTests.java +++ b/server/src/test/java/org/elasticsearch/common/io/stream/AbstractStreamTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -53,8 +54,6 @@ import static java.time.Instant.ofEpochSecond; import static java.time.ZonedDateTime.ofInstant; -import static org.elasticsearch.TransportVersions.ZDT_NANOS_SUPPORT; -import static org.elasticsearch.TransportVersions.ZDT_NANOS_SUPPORT_BROKEN; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasToString; @@ -729,15 +728,11 @@ public void testReadAfterReachingEndOfStream() throws IOException { } public void testZonedDateTimeSerialization() throws IOException { - checkZonedDateTimeSerialization(ZDT_NANOS_SUPPORT); - } - - public void testZonedDateTimeMillisBwcSerializationV1() throws IOException { - checkZonedDateTimeSerialization(TransportVersionUtils.getPreviousVersion(ZDT_NANOS_SUPPORT_BROKEN)); + checkZonedDateTimeSerialization(TransportVersions.V_8_16_0); } public void testZonedDateTimeMillisBwcSerialization() throws IOException { - checkZonedDateTimeSerialization(TransportVersionUtils.getPreviousVersion(ZDT_NANOS_SUPPORT)); + checkZonedDateTimeSerialization(TransportVersionUtils.getPreviousVersion(TransportVersions.V_8_16_0)); } public void checkZonedDateTimeSerialization(TransportVersion tv) throws IOException { @@ -745,12 +740,12 @@ public void checkZonedDateTimeSerialization(TransportVersion tv) throws IOExcept assertGenericRoundtrip(ofInstant(ofEpochSecond(1), randomZone()), tv); // just want to test a large number that will use 5+ bytes long maxEpochSecond = Integer.MAX_VALUE; - long minEpochSecond = tv.between(ZDT_NANOS_SUPPORT_BROKEN, ZDT_NANOS_SUPPORT) ? 0 : Integer.MIN_VALUE; + long minEpochSecond = Integer.MIN_VALUE; assertGenericRoundtrip(ofInstant(ofEpochSecond(maxEpochSecond), randomZone()), tv); assertGenericRoundtrip(ofInstant(ofEpochSecond(randomLongBetween(minEpochSecond, maxEpochSecond)), randomZone()), tv); assertGenericRoundtrip(ofInstant(ofEpochSecond(randomLongBetween(minEpochSecond, maxEpochSecond), 1_000_000), randomZone()), tv); assertGenericRoundtrip(ofInstant(ofEpochSecond(randomLongBetween(minEpochSecond, maxEpochSecond), 999_000_000), randomZone()), tv); - if (tv.onOrAfter(ZDT_NANOS_SUPPORT)) { + if (tv.onOrAfter(TransportVersions.V_8_16_0)) { assertGenericRoundtrip( ofInstant(ofEpochSecond(randomLongBetween(minEpochSecond, maxEpochSecond), 999_999_999), randomZone()), tv diff --git a/server/src/test/java/org/elasticsearch/common/time/DateUtilsTests.java b/server/src/test/java/org/elasticsearch/common/time/DateUtilsTests.java index 2dd0a28013058..e15bbbf75a529 100644 --- a/server/src/test/java/org/elasticsearch/common/time/DateUtilsTests.java +++ b/server/src/test/java/org/elasticsearch/common/time/DateUtilsTests.java @@ -20,7 +20,11 @@ import java.time.ZonedDateTime; import java.time.temporal.ChronoField; +import static org.elasticsearch.common.time.DateUtils.MAX_MILLIS_BEFORE_MINUS_9999; +import static org.elasticsearch.common.time.DateUtils.MAX_NANOSECOND_INSTANT; +import static org.elasticsearch.common.time.DateUtils.MAX_NANOSECOND_IN_MILLIS; import static org.elasticsearch.common.time.DateUtils.clampToNanosRange; +import static org.elasticsearch.common.time.DateUtils.compareNanosToMillis; import static org.elasticsearch.common.time.DateUtils.toInstant; import static org.elasticsearch.common.time.DateUtils.toLong; import static org.elasticsearch.common.time.DateUtils.toMilliSeconds; @@ -28,9 +32,45 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; public class DateUtilsTests extends ESTestCase { + public void testCompareNanosToMillis() { + assertThat(MAX_NANOSECOND_IN_MILLIS * 1_000_000, lessThan(Long.MAX_VALUE)); + + assertThat(compareNanosToMillis(toLong(Instant.EPOCH), Instant.EPOCH.toEpochMilli()), is(0)); + + // This should be 1, because the millisecond version should truncate a bit + assertThat(compareNanosToMillis(toLong(MAX_NANOSECOND_INSTANT), MAX_NANOSECOND_INSTANT.toEpochMilli()), is(1)); + + assertThat(compareNanosToMillis(toLong(MAX_NANOSECOND_INSTANT), -1000), is(1)); + // millis before epoch + assertCompareInstants( + randomInstantBetween(Instant.EPOCH, MAX_NANOSECOND_INSTANT), + randomInstantBetween(Instant.ofEpochMilli(MAX_MILLIS_BEFORE_MINUS_9999), Instant.ofEpochMilli(-1L)) + ); + + // millis after nanos range + assertCompareInstants( + randomInstantBetween(Instant.EPOCH, MAX_NANOSECOND_INSTANT), + randomInstantBetween(MAX_NANOSECOND_INSTANT.plusMillis(1), Instant.ofEpochMilli(Long.MAX_VALUE)) + ); + + // both in range + Instant nanos = randomInstantBetween(Instant.EPOCH, MAX_NANOSECOND_INSTANT); + Instant millis = randomInstantBetween(Instant.EPOCH, MAX_NANOSECOND_INSTANT); + + assertCompareInstants(nanos, millis); + } + + /** + * check that compareNanosToMillis is consistent with Instant#compare. + */ + private void assertCompareInstants(Instant nanos, Instant millis) { + assertThat(compareNanosToMillis(toLong(nanos), millis.toEpochMilli()), equalTo(nanos.compareTo(millis))); + } + public void testInstantToLong() { assertThat(toLong(Instant.EPOCH), is(0L)); diff --git a/server/src/test/java/org/elasticsearch/common/xcontent/json/JsonXContentTests.java b/server/src/test/java/org/elasticsearch/common/xcontent/json/JsonXContentTests.java index 55f6cc5498d80..4135ead545e07 100644 --- a/server/src/test/java/org/elasticsearch/common/xcontent/json/JsonXContentTests.java +++ b/server/src/test/java/org/elasticsearch/common/xcontent/json/JsonXContentTests.java @@ -11,6 +11,9 @@ import org.elasticsearch.common.xcontent.BaseXContentTestCase; import org.elasticsearch.xcontent.XContentGenerator; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; @@ -28,4 +31,14 @@ public void testBigInteger() throws Exception { XContentGenerator generator = JsonXContent.jsonXContent.createGenerator(os); doTestBigInteger(generator, os); } + + public void testMalformedJsonFieldThrowsXContentException() throws Exception { + String json = "{\"test\":\"/*/}"; + try (XContentParser parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, json)) { + parser.nextToken(); + parser.nextToken(); + parser.nextToken(); + assertThrows(XContentParseException.class, () -> parser.text()); + } + } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java index ffa5bd339ae06..8e0cd97e518fa 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldMapperTests.java @@ -11,17 +11,24 @@ import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute; -import org.apache.lucene.document.FeatureField; +import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.mapper.DocumentMapper; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.MapperParsingException; +import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperTestCase; import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.search.lookup.Source; import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.hamcrest.Matchers; @@ -29,18 +36,25 @@ import java.io.IOException; import java.util.Arrays; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import static org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.NEW_SPARSE_VECTOR_INDEX_VERSION; import static org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.PREVIOUS_SPARSE_VECTOR_INDEX_VERSION; +import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; public class SparseVectorFieldMapperTests extends MapperTestCase { @Override protected Object getSampleValueForDocument() { - return Map.of("ten", 10, "twenty", 20); + Map map = new LinkedHashMap<>(); + map.put("ten", 10f); + map.put("twenty", 20f); + return map; } @Override @@ -88,14 +102,18 @@ public void testDefaults() throws Exception { List fields = doc1.rootDoc().getFields("field"); assertEquals(2, fields.size()); - assertThat(fields.get(0), Matchers.instanceOf(FeatureField.class)); - FeatureField featureField1 = null; - FeatureField featureField2 = null; + if (IndexVersion.current().luceneVersion().major == 10) { + // TODO: Update to use Lucene's FeatureField after upgrading to Lucene 10.1. + assertThat(IndexVersion.current().luceneVersion().minor, equalTo(0)); + } + assertThat(fields.get(0), Matchers.instanceOf(XFeatureField.class)); + XFeatureField featureField1 = null; + XFeatureField featureField2 = null; for (IndexableField field : fields) { if (field.stringValue().equals("ten")) { - featureField1 = (FeatureField) field; + featureField1 = (XFeatureField) field; } else if (field.stringValue().equals("twenty")) { - featureField2 = (FeatureField) field; + featureField2 = (XFeatureField) field; } else { throw new UnsupportedOperationException(); } @@ -112,14 +130,14 @@ public void testDotInFieldName() throws Exception { List fields = parsedDocument.rootDoc().getFields("field"); assertEquals(2, fields.size()); - assertThat(fields.get(0), Matchers.instanceOf(FeatureField.class)); - FeatureField featureField1 = null; - FeatureField featureField2 = null; + assertThat(fields.get(0), Matchers.instanceOf(XFeatureField.class)); + XFeatureField featureField1 = null; + XFeatureField featureField2 = null; for (IndexableField field : fields) { if (field.stringValue().equals("foo.bar")) { - featureField1 = (FeatureField) field; + featureField1 = (XFeatureField) field; } else if (field.stringValue().equals("foobar")) { - featureField2 = (FeatureField) field; + featureField2 = (XFeatureField) field; } else { throw new UnsupportedOperationException(); } @@ -167,13 +185,13 @@ public void testHandlesMultiValuedFields() throws MapperParsingException, IOExce })); // then validate that the generate document stored both values appropriately and we have only the max value stored - FeatureField barField = ((FeatureField) doc1.rootDoc().getByKey("foo.field\\.bar")); + XFeatureField barField = ((XFeatureField) doc1.rootDoc().getByKey("foo.field\\.bar")); assertEquals(20, barField.getFeatureValue(), 1); - FeatureField storedBarField = ((FeatureField) doc1.rootDoc().getFields("foo.field").get(1)); + XFeatureField storedBarField = ((XFeatureField) doc1.rootDoc().getFields("foo.field").get(1)); assertEquals(20, storedBarField.getFeatureValue(), 1); - assertEquals(3, doc1.rootDoc().getFields().stream().filter((f) -> f instanceof FeatureField).count()); + assertEquals(3, doc1.rootDoc().getFields().stream().filter((f) -> f instanceof XFeatureField).count()); } public void testCannotBeUsedInMultiFields() { @@ -188,6 +206,53 @@ public void testCannotBeUsedInMultiFields() { assertThat(e.getMessage(), containsString("Field [feature] of type [sparse_vector] can't be used in multifields")); } + public void testStoreIsNotUpdateable() throws IOException { + var mapperService = createMapperService(fieldMapping(this::minimalMapping)); + XContentBuilder mapping = jsonBuilder().startObject() + .startObject("_doc") + .startObject("properties") + .startObject("field") + .field("type", "sparse_vector") + .field("store", true) + .endObject() + .endObject() + .endObject() + .endObject(); + var exc = expectThrows( + Exception.class, + () -> mapperService.merge("_doc", new CompressedXContent(Strings.toString(mapping)), MapperService.MergeReason.MAPPING_UPDATE) + ); + assertThat(exc.getMessage(), containsString("Cannot update parameter [store]")); + } + + @SuppressWarnings("unchecked") + public void testValueFetcher() throws Exception { + for (boolean store : new boolean[] { true, false }) { + var mapperService = createMapperService(fieldMapping(store ? this::minimalStoreMapping : this::minimalMapping)); + var mapper = mapperService.documentMapper(); + try (Directory directory = newDirectory()) { + RandomIndexWriter iw = new RandomIndexWriter(random(), directory); + var sourceToParse = source(this::writeField); + ParsedDocument doc1 = mapper.parse(sourceToParse); + iw.addDocument(doc1.rootDoc()); + iw.close(); + try (DirectoryReader reader = wrapInMockESDirectoryReader(DirectoryReader.open(directory))) { + LeafReader leafReader = getOnlyLeafReader(reader); + var searchContext = createSearchExecutionContext(mapperService, new IndexSearcher(leafReader)); + var fieldType = mapper.mappers().getFieldType("field"); + var valueFetcher = fieldType.valueFetcher(searchContext, null); + valueFetcher.setNextReader(leafReader.getContext()); + + var source = Source.fromBytes(sourceToParse.source()); + var result = valueFetcher.fetchValues(source, 0, List.of()); + assertThat(result.size(), equalTo(1)); + assertThat(result.get(0), instanceOf(Map.class)); + assertThat(toFloats((Map) result.get(0)), equalTo(toFloats((Map) source.source().get("field")))); + } + } + } + } + @Override protected Object generateRandomInputValue(MappedFieldType ft) { assumeFalse("Test implemented in a follow up", true); @@ -201,7 +266,29 @@ protected boolean allowsNullValues() { @Override protected SyntheticSourceSupport syntheticSourceSupport(boolean syntheticSource) { - throw new AssumptionViolatedException("not supported"); + boolean withStore = randomBoolean(); + return new SyntheticSourceSupport() { + @Override + public boolean preservesExactSource() { + return withStore == false; + } + + @Override + public SyntheticSourceExample example(int maxValues) { + return new SyntheticSourceExample(getSampleValueForDocument(), getSampleValueForDocument(), b -> { + if (withStore) { + minimalStoreMapping(b); + } else { + minimalMapping(b); + } + }); + } + + @Override + public List invalidExample() { + return List.of(); + } + }; } @Override @@ -234,4 +321,20 @@ public void testSparseVectorUnsupportedIndex() throws Exception { }))); assertThat(e.getMessage(), containsString(SparseVectorFieldMapper.ERROR_MESSAGE_8X)); } + + /** + * Handles float/double conversion when reading/writing with xcontent by converting all numbers to floats. + */ + private Map toFloats(Map value) { + // preserve order + Map result = new LinkedHashMap<>(); + for (var entry : value.entrySet()) { + if (entry.getValue() instanceof Number num) { + result.put(entry.getKey(), num.floatValue()); + } else { + throw new IllegalArgumentException("Expected Number, got: " + value.getClass().getSimpleName()); + } + } + return result; + } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldTypeTests.java index 4627d4d871957..0dbe3817c3e87 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldTypeTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/vectors/SparseVectorFieldTypeTests.java @@ -18,13 +18,13 @@ public class SparseVectorFieldTypeTests extends FieldTypeTestCase { public void testDocValuesDisabled() { - MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Collections.emptyMap()); + MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", false, Collections.emptyMap()); assertFalse(fieldType.hasDocValues()); expectThrows(IllegalArgumentException.class, () -> fieldType.fielddataBuilder(FieldDataContext.noRuntimeFields("test"))); } public void testIsNotAggregatable() { - MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", Collections.emptyMap()); + MappedFieldType fieldType = new SparseVectorFieldMapper.SparseVectorFieldType("field", false, Collections.emptyMap()); assertFalse(fieldType.isAggregatable()); } } diff --git a/server/src/test/java/org/elasticsearch/index/search/NestedHelperTests.java b/server/src/test/java/org/elasticsearch/index/search/NestedHelperTests.java index a7a1d33badf25..b2583eb176deb 100644 --- a/server/src/test/java/org/elasticsearch/index/search/NestedHelperTests.java +++ b/server/src/test/java/org/elasticsearch/index/search/NestedHelperTests.java @@ -17,6 +17,7 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.join.ScoreMode; +import org.elasticsearch.index.mapper.MapperMetrics; import org.elasticsearch.index.mapper.MapperService; import org.elasticsearch.index.mapper.MapperServiceTestCase; import org.elasticsearch.index.query.MatchAllQueryBuilder; @@ -27,12 +28,15 @@ import java.io.IOException; import java.util.Collections; +import static java.util.Collections.emptyMap; import static org.mockito.Mockito.mock; public class NestedHelperTests extends MapperServiceTestCase { MapperService mapperService; + SearchExecutionContext searchExecutionContext; + @Override public void setUp() throws Exception { super.setUp(); @@ -68,167 +72,185 @@ public void setUp() throws Exception { } } """; mapperService = createMapperService(mapping); - } - - private static NestedHelper buildNestedHelper(MapperService mapperService) { - return new NestedHelper(mapperService.mappingLookup().nestedLookup(), field -> mapperService.fieldType(field) != null); + searchExecutionContext = new SearchExecutionContext( + 0, + 0, + mapperService.getIndexSettings(), + null, + null, + mapperService, + mapperService.mappingLookup(), + null, + null, + parserConfig(), + writableRegistry(), + null, + null, + System::currentTimeMillis, + null, + null, + () -> true, + null, + emptyMap(), + MapperMetrics.NOOP + ); } public void testMatchAll() { - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(new MatchAllDocsQuery())); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(new MatchAllDocsQuery(), "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(new MatchAllDocsQuery(), "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(new MatchAllDocsQuery(), "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(new MatchAllDocsQuery(), "nested_missing")); + assertTrue(NestedHelper.mightMatchNestedDocs(new MatchAllDocsQuery(), searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(new MatchAllDocsQuery(), "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(new MatchAllDocsQuery(), "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(new MatchAllDocsQuery(), "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(new MatchAllDocsQuery(), "nested_missing", searchExecutionContext)); } public void testMatchNo() { - assertFalse(buildNestedHelper(mapperService).mightMatchNestedDocs(new MatchNoDocsQuery())); - assertFalse(buildNestedHelper(mapperService).mightMatchNonNestedDocs(new MatchNoDocsQuery(), "nested1")); - assertFalse(buildNestedHelper(mapperService).mightMatchNonNestedDocs(new MatchNoDocsQuery(), "nested2")); - assertFalse(buildNestedHelper(mapperService).mightMatchNonNestedDocs(new MatchNoDocsQuery(), "nested3")); - assertFalse(buildNestedHelper(mapperService).mightMatchNonNestedDocs(new MatchNoDocsQuery(), "nested_missing")); + assertFalse(NestedHelper.mightMatchNestedDocs(new MatchNoDocsQuery(), searchExecutionContext)); + assertFalse(NestedHelper.mightMatchNonNestedDocs(new MatchNoDocsQuery(), "nested1", searchExecutionContext)); + assertFalse(NestedHelper.mightMatchNonNestedDocs(new MatchNoDocsQuery(), "nested2", searchExecutionContext)); + assertFalse(NestedHelper.mightMatchNonNestedDocs(new MatchNoDocsQuery(), "nested3", searchExecutionContext)); + assertFalse(NestedHelper.mightMatchNonNestedDocs(new MatchNoDocsQuery(), "nested_missing", searchExecutionContext)); } public void testTermsQuery() { Query termsQuery = mapperService.fieldType("foo").termsQuery(Collections.singletonList("bar"), null); - assertFalse(buildNestedHelper(mapperService).mightMatchNestedDocs(termsQuery)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested_missing")); + assertFalse(NestedHelper.mightMatchNestedDocs(termsQuery, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested_missing", searchExecutionContext)); termsQuery = mapperService.fieldType("nested1.foo").termsQuery(Collections.singletonList("bar"), null); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(termsQuery)); - assertFalse(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested_missing")); + assertTrue(NestedHelper.mightMatchNestedDocs(termsQuery, searchExecutionContext)); + assertFalse(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested_missing", searchExecutionContext)); termsQuery = mapperService.fieldType("nested2.foo").termsQuery(Collections.singletonList("bar"), null); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(termsQuery)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested_missing")); + assertTrue(NestedHelper.mightMatchNestedDocs(termsQuery, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested_missing", searchExecutionContext)); termsQuery = mapperService.fieldType("nested3.foo").termsQuery(Collections.singletonList("bar"), null); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(termsQuery)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termsQuery, "nested_missing")); + assertTrue(NestedHelper.mightMatchNestedDocs(termsQuery, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termsQuery, "nested_missing", searchExecutionContext)); } public void testTermQuery() { Query termQuery = mapperService.fieldType("foo").termQuery("bar", null); - assertFalse(buildNestedHelper(mapperService).mightMatchNestedDocs(termQuery)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested_missing")); + assertFalse(NestedHelper.mightMatchNestedDocs(termQuery, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested_missing", searchExecutionContext)); termQuery = mapperService.fieldType("nested1.foo").termQuery("bar", null); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(termQuery)); - assertFalse(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested_missing")); + assertTrue(NestedHelper.mightMatchNestedDocs(termQuery, searchExecutionContext)); + assertFalse(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested_missing", searchExecutionContext)); termQuery = mapperService.fieldType("nested2.foo").termQuery("bar", null); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(termQuery)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested_missing")); + assertTrue(NestedHelper.mightMatchNestedDocs(termQuery, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested_missing", searchExecutionContext)); termQuery = mapperService.fieldType("nested3.foo").termQuery("bar", null); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(termQuery)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(termQuery, "nested_missing")); + assertTrue(NestedHelper.mightMatchNestedDocs(termQuery, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(termQuery, "nested_missing", searchExecutionContext)); } public void testRangeQuery() { SearchExecutionContext context = mock(SearchExecutionContext.class); Query rangeQuery = mapperService.fieldType("foo2").rangeQuery(2, 5, true, true, null, null, null, context); - assertFalse(buildNestedHelper(mapperService).mightMatchNestedDocs(rangeQuery)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested_missing")); + assertFalse(NestedHelper.mightMatchNestedDocs(rangeQuery, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested_missing", searchExecutionContext)); rangeQuery = mapperService.fieldType("nested1.foo2").rangeQuery(2, 5, true, true, null, null, null, context); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(rangeQuery)); - assertFalse(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested_missing")); + assertTrue(NestedHelper.mightMatchNestedDocs(rangeQuery, searchExecutionContext)); + assertFalse(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested_missing", searchExecutionContext)); rangeQuery = mapperService.fieldType("nested2.foo2").rangeQuery(2, 5, true, true, null, null, null, context); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(rangeQuery)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested_missing")); + assertTrue(NestedHelper.mightMatchNestedDocs(rangeQuery, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested_missing", searchExecutionContext)); rangeQuery = mapperService.fieldType("nested3.foo2").rangeQuery(2, 5, true, true, null, null, null, context); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(rangeQuery)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(rangeQuery, "nested_missing")); + assertTrue(NestedHelper.mightMatchNestedDocs(rangeQuery, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(rangeQuery, "nested_missing", searchExecutionContext)); } public void testDisjunction() { BooleanQuery bq = new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) .add(new TermQuery(new Term("foo", "baz")), Occur.SHOULD) .build(); - assertFalse(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested1")); + assertFalse(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested1", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new TermQuery(new Term("nested1.foo", "bar")), Occur.SHOULD) .add(new TermQuery(new Term("nested1.foo", "baz")), Occur.SHOULD) .build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertFalse(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested1")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertFalse(NestedHelper.mightMatchNonNestedDocs(bq, "nested1", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new TermQuery(new Term("nested2.foo", "bar")), Occur.SHOULD) .add(new TermQuery(new Term("nested2.foo", "baz")), Occur.SHOULD) .build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested2")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested2", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new TermQuery(new Term("nested3.foo", "bar")), Occur.SHOULD) .add(new TermQuery(new Term("nested3.foo", "baz")), Occur.SHOULD) .build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested3")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested3", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.SHOULD) .add(new MatchAllDocsQuery(), Occur.SHOULD) .build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested1")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested1", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new TermQuery(new Term("nested1.foo", "bar")), Occur.SHOULD) .add(new MatchAllDocsQuery(), Occur.SHOULD) .build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested1")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested1", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new TermQuery(new Term("nested2.foo", "bar")), Occur.SHOULD) .add(new MatchAllDocsQuery(), Occur.SHOULD) .build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested2")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested2", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new TermQuery(new Term("nested3.foo", "bar")), Occur.SHOULD) .add(new MatchAllDocsQuery(), Occur.SHOULD) .build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested3")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested3", searchExecutionContext)); } private static Occur requiredOccur() { @@ -239,42 +261,42 @@ public void testConjunction() { BooleanQuery bq = new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), requiredOccur()) .add(new MatchAllDocsQuery(), requiredOccur()) .build(); - assertFalse(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested1")); + assertFalse(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested1", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new TermQuery(new Term("nested1.foo", "bar")), requiredOccur()) .add(new MatchAllDocsQuery(), requiredOccur()) .build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertFalse(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested1")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertFalse(NestedHelper.mightMatchNonNestedDocs(bq, "nested1", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new TermQuery(new Term("nested2.foo", "bar")), requiredOccur()) .add(new MatchAllDocsQuery(), requiredOccur()) .build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested2")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested2", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new TermQuery(new Term("nested3.foo", "bar")), requiredOccur()) .add(new MatchAllDocsQuery(), requiredOccur()) .build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested3")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested3", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new MatchAllDocsQuery(), requiredOccur()).add(new MatchAllDocsQuery(), requiredOccur()).build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested1")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested1", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new MatchAllDocsQuery(), requiredOccur()).add(new MatchAllDocsQuery(), requiredOccur()).build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested1")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested1", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new MatchAllDocsQuery(), requiredOccur()).add(new MatchAllDocsQuery(), requiredOccur()).build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested2")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested2", searchExecutionContext)); bq = new BooleanQuery.Builder().add(new MatchAllDocsQuery(), requiredOccur()).add(new MatchAllDocsQuery(), requiredOccur()).build(); - assertTrue(buildNestedHelper(mapperService).mightMatchNestedDocs(bq)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(bq, "nested3")); + assertTrue(NestedHelper.mightMatchNestedDocs(bq, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(bq, "nested3", searchExecutionContext)); } public void testNested() throws IOException { @@ -288,11 +310,11 @@ public void testNested() throws IOException { .build(); assertEquals(expectedChildQuery, query.getChildQuery()); - assertFalse(buildNestedHelper(mapperService).mightMatchNestedDocs(query)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested_missing")); + assertFalse(NestedHelper.mightMatchNestedDocs(query, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested_missing", searchExecutionContext)); queryBuilder = new NestedQueryBuilder("nested1", new TermQueryBuilder("nested1.foo", "bar"), ScoreMode.Avg); query = (ESToParentBlockJoinQuery) queryBuilder.toQuery(context); @@ -301,11 +323,11 @@ public void testNested() throws IOException { expectedChildQuery = new TermQuery(new Term("nested1.foo", "bar")); assertEquals(expectedChildQuery, query.getChildQuery()); - assertFalse(buildNestedHelper(mapperService).mightMatchNestedDocs(query)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested_missing")); + assertFalse(NestedHelper.mightMatchNestedDocs(query, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested_missing", searchExecutionContext)); queryBuilder = new NestedQueryBuilder("nested2", new TermQueryBuilder("nested2.foo", "bar"), ScoreMode.Avg); query = (ESToParentBlockJoinQuery) queryBuilder.toQuery(context); @@ -316,11 +338,11 @@ public void testNested() throws IOException { .build(); assertEquals(expectedChildQuery, query.getChildQuery()); - assertFalse(buildNestedHelper(mapperService).mightMatchNestedDocs(query)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested_missing")); + assertFalse(NestedHelper.mightMatchNestedDocs(query, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested_missing", searchExecutionContext)); queryBuilder = new NestedQueryBuilder("nested3", new TermQueryBuilder("nested3.foo", "bar"), ScoreMode.Avg); query = (ESToParentBlockJoinQuery) queryBuilder.toQuery(context); @@ -331,10 +353,10 @@ public void testNested() throws IOException { .build(); assertEquals(expectedChildQuery, query.getChildQuery()); - assertFalse(buildNestedHelper(mapperService).mightMatchNestedDocs(query)); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested1")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested2")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested3")); - assertTrue(buildNestedHelper(mapperService).mightMatchNonNestedDocs(query, "nested_missing")); + assertFalse(NestedHelper.mightMatchNestedDocs(query, searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested1", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested2", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested3", searchExecutionContext)); + assertTrue(NestedHelper.mightMatchNonNestedDocs(query, "nested_missing", searchExecutionContext)); } } diff --git a/server/src/test/java/org/elasticsearch/index/shard/IndexingFailuresDebugListenerTests.java b/server/src/test/java/org/elasticsearch/index/shard/IndexingFailuresDebugListenerTests.java new file mode 100644 index 0000000000000..43434a691bd90 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/shard/IndexingFailuresDebugListenerTests.java @@ -0,0 +1,138 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.shard; + +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.cluster.routing.ShardRoutingState; +import org.elasticsearch.cluster.routing.TestShardRouting; +import org.elasticsearch.common.logging.Loggers; +import org.elasticsearch.common.logging.MockAppender; +import org.elasticsearch.index.engine.Engine; +import org.elasticsearch.index.engine.EngineTestCase; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.Uid; +import org.elasticsearch.test.ESTestCase; +import org.junit.AfterClass; +import org.junit.BeforeClass; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class IndexingFailuresDebugListenerTests extends ESTestCase { + + static MockAppender appender; + static Logger testLogger1 = LogManager.getLogger(IndexingFailuresDebugListener.class); + static Level origLogLevel = testLogger1.getLevel(); + + @BeforeClass + public static void init() throws IllegalAccessException { + appender = new MockAppender("mock_appender"); + appender.start(); + Loggers.addAppender(testLogger1, appender); + Loggers.setLevel(testLogger1, randomBoolean() ? Level.DEBUG : Level.TRACE); + } + + @AfterClass + public static void cleanup() { + Loggers.removeAppender(testLogger1, appender); + appender.stop(); + + Loggers.setLevel(testLogger1, origLogLevel); + } + + public void testPostIndexException() { + var shardId = ShardId.fromString("[index][123]"); + var mockShard = mock(IndexShard.class); + var shardRouting = TestShardRouting.newShardRouting(shardId, "node-id", true, ShardRoutingState.STARTED); + when(mockShard.routingEntry()).thenReturn(shardRouting); + when(mockShard.getOperationPrimaryTerm()).thenReturn(1L); + IndexingFailuresDebugListener indexingFailuresDebugListener = new IndexingFailuresDebugListener(mockShard); + + ParsedDocument doc = EngineTestCase.createParsedDoc("1", null); + Engine.Index index = new Engine.Index(Uid.encodeId("doc_id"), 1, doc); + indexingFailuresDebugListener.postIndex(shardId, index, new RuntimeException("test exception")); + String message = appender.getLastEventAndReset().getMessage().getFormattedMessage(); + assertThat( + message, + equalTo( + "index-fail [1] seq# [-2] allocation-id [" + + shardRouting.allocationId() + + "] primaryTerm [1] operationPrimaryTerm [1] origin [PRIMARY]" + ) + ); + } + + public void testPostIndexExceptionInfoLevel() { + var previousLevel = testLogger1.getLevel(); + try { + Loggers.setLevel(testLogger1, randomBoolean() ? Level.INFO : Level.WARN); + var shardId = ShardId.fromString("[index][123]"); + var mockShard = mock(IndexShard.class); + var shardRouting = TestShardRouting.newShardRouting(shardId, "node-id", true, ShardRoutingState.STARTED); + when(mockShard.routingEntry()).thenReturn(shardRouting); + when(mockShard.getOperationPrimaryTerm()).thenReturn(1L); + IndexingFailuresDebugListener indexingFailuresDebugListener = new IndexingFailuresDebugListener(mockShard); + + ParsedDocument doc = EngineTestCase.createParsedDoc("1", null); + Engine.Index index = new Engine.Index(Uid.encodeId("doc_id"), 1, doc); + indexingFailuresDebugListener.postIndex(shardId, index, new RuntimeException("test exception")); + assertThat(appender.getLastEventAndReset(), nullValue()); + } finally { + Loggers.setLevel(testLogger1, previousLevel); + } + } + + public void testPostIndexFailure() { + var shardId = ShardId.fromString("[index][123]"); + var mockShard = mock(IndexShard.class); + var shardRouting = TestShardRouting.newShardRouting(shardId, "node-id", true, ShardRoutingState.STARTED); + when(mockShard.routingEntry()).thenReturn(shardRouting); + when(mockShard.getOperationPrimaryTerm()).thenReturn(1L); + IndexingFailuresDebugListener indexingFailuresDebugListener = new IndexingFailuresDebugListener(mockShard); + + ParsedDocument doc = EngineTestCase.createParsedDoc("1", null); + Engine.Index index = new Engine.Index(Uid.encodeId("doc_id"), 1, doc); + Engine.IndexResult indexResult = mock(Engine.IndexResult.class); + when(indexResult.getResultType()).thenReturn(Engine.Result.Type.FAILURE); + when(indexResult.getFailure()).thenReturn(new RuntimeException("test exception")); + indexingFailuresDebugListener.postIndex(shardId, index, indexResult); + String message = appender.getLastEventAndReset().getMessage().getFormattedMessage(); + assertThat( + message, + equalTo( + "index-fail [1] seq# [-2] allocation-id [" + + shardRouting.allocationId() + + "] primaryTerm [1] operationPrimaryTerm [1] origin [PRIMARY]" + ) + ); + } + + public void testPostIndex() { + var shardId = ShardId.fromString("[index][123]"); + var mockShard = mock(IndexShard.class); + var shardRouting = TestShardRouting.newShardRouting(shardId, "node-id", true, ShardRoutingState.STARTED); + when(mockShard.routingEntry()).thenReturn(shardRouting); + when(mockShard.getOperationPrimaryTerm()).thenReturn(1L); + IndexingFailuresDebugListener indexingFailuresDebugListener = new IndexingFailuresDebugListener(mockShard); + + ParsedDocument doc = EngineTestCase.createParsedDoc("1", null); + Engine.Index index = new Engine.Index(Uid.encodeId("doc_id"), 1, doc); + Engine.IndexResult indexResult = mock(Engine.IndexResult.class); + when(indexResult.getResultType()).thenReturn(Engine.Result.Type.SUCCESS); + when(indexResult.getFailure()).thenReturn(new RuntimeException("test exception")); + indexingFailuresDebugListener.postIndex(shardId, index, indexResult); + assertThat(appender.getLastEventAndReset(), nullValue()); + } + +} diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java index d1ccfcbe78732..89fd25f638e1c 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java @@ -95,7 +95,6 @@ import org.elasticsearch.search.aggregations.support.ValueType; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; -import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.dfs.AggregatedDfs; import org.elasticsearch.search.fetch.FetchSearchResult; import org.elasticsearch.search.fetch.ShardFetchRequest; @@ -124,7 +123,6 @@ import org.elasticsearch.search.rank.feature.RankFeatureResult; import org.elasticsearch.search.rank.feature.RankFeatureShardRequest; import org.elasticsearch.search.rank.feature.RankFeatureShardResult; -import org.elasticsearch.search.slice.SliceBuilder; import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.tasks.TaskCancelHelper; import org.elasticsearch.tasks.TaskCancelledException; @@ -2930,119 +2928,6 @@ public void testSlicingBehaviourForParallelCollection() throws Exception { } } - /** - * This method tests validation that happens on the data nodes, which is now performed on the coordinating node. - * We still need the validation to cover for mixed cluster scenarios where the coordinating node does not perform the check yet. - */ - public void testParseSourceValidation() { - String index = randomAlphaOfLengthBetween(5, 10).toLowerCase(Locale.ROOT); - IndexService indexService = createIndex(index); - final SearchService service = getInstanceFromNode(SearchService.class); - { - // scroll and search_after - SearchRequest searchRequest = new SearchRequest().source(new SearchSourceBuilder()); - searchRequest.scroll(new TimeValue(1000)); - searchRequest.source().searchAfter(new String[] { "value" }); - assertCreateContextValidation(searchRequest, "`search_after` cannot be used in a scroll context.", indexService, service); - } - { - // scroll and collapse - SearchRequest searchRequest = new SearchRequest().source(new SearchSourceBuilder()); - searchRequest.scroll(new TimeValue(1000)); - searchRequest.source().collapse(new CollapseBuilder("field")); - assertCreateContextValidation(searchRequest, "cannot use `collapse` in a scroll context", indexService, service); - } - { - // search_after and `from` isn't valid - SearchRequest searchRequest = new SearchRequest().source(new SearchSourceBuilder()); - searchRequest.source().searchAfter(new String[] { "value" }); - searchRequest.source().from(10); - assertCreateContextValidation( - searchRequest, - "`from` parameter must be set to 0 when `search_after` is used", - indexService, - service - ); - } - { - // slice without scroll or pit - SearchRequest searchRequest = new SearchRequest().source(new SearchSourceBuilder()); - searchRequest.source().slice(new SliceBuilder(1, 10)); - assertCreateContextValidation( - searchRequest, - "[slice] can only be used with [scroll] or [point-in-time] requests", - indexService, - service - ); - } - { - // stored fields disabled with _source requested - SearchRequest searchRequest = new SearchRequest().source(new SearchSourceBuilder()); - searchRequest.source().storedField("_none_"); - searchRequest.source().fetchSource(true); - assertCreateContextValidation( - searchRequest, - "[stored_fields] cannot be disabled if [_source] is requested", - indexService, - service - ); - } - { - // stored fields disabled with fetch fields requested - SearchRequest searchRequest = new SearchRequest().source(new SearchSourceBuilder()); - searchRequest.source().storedField("_none_"); - searchRequest.source().fetchSource(false); - searchRequest.source().fetchField("field"); - assertCreateContextValidation( - searchRequest, - "[stored_fields] cannot be disabled when using the [fields] option", - indexService, - service - ); - } - } - - private static void assertCreateContextValidation( - SearchRequest searchRequest, - String errorMessage, - IndexService indexService, - SearchService searchService - ) { - ShardId shardId = new ShardId(indexService.index(), 0); - long nowInMillis = System.currentTimeMillis(); - String clusterAlias = randomBoolean() ? null : randomAlphaOfLengthBetween(3, 10); - searchRequest.allowPartialSearchResults(randomBoolean()); - ShardSearchRequest request = new ShardSearchRequest( - OriginalIndices.NONE, - searchRequest, - shardId, - 0, - indexService.numberOfShards(), - AliasFilter.EMPTY, - 1f, - nowInMillis, - clusterAlias - ); - - SearchShardTask task = new SearchShardTask(1, "type", "action", "description", null, emptyMap()); - - ReaderContext readerContext = null; - try { - ReaderContext createOrGetReaderContext = searchService.createOrGetReaderContext(request); - readerContext = createOrGetReaderContext; - IllegalArgumentException exception = expectThrows( - IllegalArgumentException.class, - () -> searchService.createContext(createOrGetReaderContext, request, task, ResultsType.QUERY, randomBoolean()) - ); - assertThat(exception.getMessage(), containsString(errorMessage)); - } finally { - if (readerContext != null) { - readerContext.close(); - searchService.freeReaderContext(readerContext.id()); - } - } - } - private static ReaderContext createReaderContext(IndexService indexService, IndexShard indexShard) { return new ReaderContext( new ShardSearchContextId(UUIDs.randomBase64UUID(), randomNonNegativeLong()), diff --git a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java index af6782c45dce8..ccf33c0b71b6b 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/RankDocsRetrieverBuilderTests.java @@ -95,12 +95,7 @@ private List preFilters(QueryRewriteContext queryRewriteContext) t } private RankDocsRetrieverBuilder createRandomRankDocsRetrieverBuilder(QueryRewriteContext queryRewriteContext) throws IOException { - return new RankDocsRetrieverBuilder( - randomIntBetween(1, 100), - innerRetrievers(queryRewriteContext), - rankDocsSupplier(), - preFilters(queryRewriteContext) - ); + return new RankDocsRetrieverBuilder(randomIntBetween(1, 100), innerRetrievers(queryRewriteContext), rankDocsSupplier()); } public void testExtractToSearchSourceBuilder() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java b/server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java index c47c8c16f6a2f..5733a51bb7e9c 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/TestQueryVectorBuilderPlugin.java @@ -27,9 +27,9 @@ /** * A SearchPlugin to exercise query vector builder */ -class TestQueryVectorBuilderPlugin implements SearchPlugin { +public class TestQueryVectorBuilderPlugin implements SearchPlugin { - static class TestQueryVectorBuilder implements QueryVectorBuilder { + public static class TestQueryVectorBuilder implements QueryVectorBuilder { private static final String NAME = "test_query_vector_builder"; private static final ParseField QUERY_VECTOR = new ParseField("query_vector"); @@ -47,11 +47,11 @@ static class TestQueryVectorBuilder implements QueryVectorBuilder { private List vectorToBuild; - TestQueryVectorBuilder(List vectorToBuild) { + public TestQueryVectorBuilder(List vectorToBuild) { this.vectorToBuild = vectorToBuild; } - TestQueryVectorBuilder(float[] expected) { + public TestQueryVectorBuilder(float[] expected) { this.vectorToBuild = new ArrayList<>(expected.length); for (float f : expected) { vectorToBuild.add(f); diff --git a/settings.gradle b/settings.gradle index 4722fc311480a..747fbb3e439fe 100644 --- a/settings.gradle +++ b/settings.gradle @@ -73,6 +73,7 @@ List projects = [ 'distribution:packages:aarch64-rpm', 'distribution:packages:rpm', 'distribution:bwc:bugfix', + 'distribution:bwc:bugfix2', 'distribution:bwc:maintenance', 'distribution:bwc:minor', 'distribution:bwc:staged', diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java index 8b9176a346e30..ace3db377664c 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java @@ -194,6 +194,13 @@ private void assertCircuitBreaks(ThrowingRunnable r) throws IOException { ); } + private void assertParseFailure(ThrowingRunnable r) throws IOException { + ResponseException e = expectThrows(ResponseException.class, r); + Map map = responseAsMap(e.getResponse()); + logger.info("expected parse failure {}", map); + assertMap(map, matchesMap().entry("status", 400).entry("error", matchesMap().extraOk().entry("type", "parsing_exception"))); + } + private Response sortByManyLongs(int count) throws IOException { logger.info("sorting by {} longs", count); return query(makeSortByManyLongs(count).toString(), null); @@ -318,6 +325,13 @@ public void testManyConcatFromRow() throws IOException { assertManyStrings(resp, strings); } + /** + * Fails to parse a huge huge query. + */ + public void testHugeHugeManyConcatFromRow() throws IOException { + assertParseFailure(() -> manyConcat("ROW a=9999, b=9999, c=9999, d=9999, e=9999", 50000)); + } + /** * Tests that generate many moderately long strings. */ @@ -378,6 +392,13 @@ public void testManyRepeatFromRow() throws IOException { assertManyStrings(resp, strings); } + /** + * Fails to parse a huge huge query. + */ + public void testHugeHugeManyRepeatFromRow() throws IOException { + assertParseFailure(() -> manyRepeat("ROW a = 99", 100000)); + } + /** * Tests that generate many moderately long strings. */ diff --git a/test/framework/build.gradle b/test/framework/build.gradle index 126b95041da11..c7e08eb3cdfa9 100644 --- a/test/framework/build.gradle +++ b/test/framework/build.gradle @@ -86,7 +86,6 @@ tasks.named("thirdPartyAudit").configure { tasks.named("test").configure { systemProperty 'tests.gradle_index_compat_versions', buildParams.bwcVersions.indexCompatible.join(',') systemProperty 'tests.gradle_wire_compat_versions', buildParams.bwcVersions.wireCompatible.join(',') - systemProperty 'tests.gradle_unreleased_versions', buildParams.bwcVersions.unreleased.join(',') } tasks.register("integTest", Test) { diff --git a/test/framework/src/main/java/org/elasticsearch/plugins/MockPluginsService.java b/test/framework/src/main/java/org/elasticsearch/plugins/MockPluginsService.java index a9a825af3b865..91875600ec000 100644 --- a/test/framework/src/main/java/org/elasticsearch/plugins/MockPluginsService.java +++ b/test/framework/src/main/java/org/elasticsearch/plugins/MockPluginsService.java @@ -45,7 +45,7 @@ public MockPluginsService(Settings settings, Environment environment, Collection super( settings, environment.configFile(), - new PluginsLoader(Collections.emptyList(), Collections.emptyList(), Collections.emptyMap()) + new PluginsLoader(Collections.emptyList(), Collections.emptyList(), Collections.emptyMap(), Collections.emptySet()) ); List pluginsLoaded = new ArrayList<>(); diff --git a/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java b/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java index 9f199aa7f3ef8..4a5f280c10a99 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java +++ b/test/framework/src/main/java/org/elasticsearch/search/retriever/TestCompoundRetrieverBuilder.java @@ -10,6 +10,7 @@ package org.elasticsearch.search.retriever; import org.apache.lucene.search.ScoreDoc; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.xcontent.XContentBuilder; @@ -23,16 +24,17 @@ public class TestCompoundRetrieverBuilder extends CompoundRetrieverBuilder(), rankWindowSize); + this(new ArrayList<>(), rankWindowSize, new ArrayList<>()); } - TestCompoundRetrieverBuilder(List childRetrievers, int rankWindowSize) { + TestCompoundRetrieverBuilder(List childRetrievers, int rankWindowSize, List preFilterQueryBuilders) { super(childRetrievers, rankWindowSize); + this.preFilterQueryBuilders = preFilterQueryBuilders; } @Override - protected TestCompoundRetrieverBuilder clone(List newChildRetrievers) { - return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize); + protected TestCompoundRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { + return new TestCompoundRetrieverBuilder(newChildRetrievers, rankWindowSize, newPreFilterQueryBuilders); } @Override diff --git a/test/framework/src/main/java/org/elasticsearch/snapshots/AbstractSnapshotIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/snapshots/AbstractSnapshotIntegTestCase.java index 8bc81fef2157d..a2bf70bf6e087 100644 --- a/test/framework/src/main/java/org/elasticsearch/snapshots/AbstractSnapshotIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/snapshots/AbstractSnapshotIntegTestCase.java @@ -128,6 +128,7 @@ protected Collection> nodePlugins() { @After public void assertConsistentHistoryInLuceneIndex() throws Exception { + internalCluster().beforeIndexDeletion(); internalCluster().assertConsistentHistoryBetweenTranslogAndLuceneIndex(); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java index ea82c9d21ab89..7cd7bce4db187 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java @@ -9,6 +9,8 @@ package org.elasticsearch.test; +import com.carrotsearch.randomizedtesting.RandomizedTest; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.admin.cluster.remote.RemoteInfoRequest; @@ -36,10 +38,10 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.stream.Collectors; @@ -58,7 +60,7 @@ public abstract class AbstractMultiClustersTestCase extends ESTestCase { private static volatile ClusterGroup clusterGroup; - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return randomSubsetOf(List.of("cluster-a", "cluster-b")); } @@ -100,17 +102,23 @@ public final void startClusters() throws Exception { return; } stopClusters(); - final Map clusters = new HashMap<>(); + final Map clusters = new ConcurrentHashMap<>(); final List clusterAliases = new ArrayList<>(remoteClusterAlias()); clusterAliases.add(LOCAL_CLUSTER); - for (String clusterAlias : clusterAliases) { + final List> mockPlugins = List.of( + MockHttpTransport.TestPlugin.class, + MockTransportService.TestPlugin.class, + getTestTransportPlugin() + ); + // We are going to initialize multiple clusters concurrently, but there is a race condition around the lazy initialization of test + // groups in GroupEvaluator across multiple threads. See https://github.com/randomizedtesting/randomizedtesting/issues/311. + // Calling isNightly before parallelizing is enough to work around that issue. + @SuppressWarnings("unused") + boolean nightly = RandomizedTest.isNightly(); + runInParallel(clusterAliases.size(), i -> { + String clusterAlias = clusterAliases.get(i); final String clusterName = clusterAlias.equals(LOCAL_CLUSTER) ? "main-cluster" : clusterAlias; final int numberOfNodes = randomIntBetween(1, 3); - final List> mockPlugins = List.of( - MockHttpTransport.TestPlugin.class, - MockTransportService.TestPlugin.class, - getTestTransportPlugin() - ); final Collection> nodePlugins = nodePlugins(clusterAlias); final NodeConfigurationSource nodeConfigurationSource = nodeConfigurationSource(nodeSettings(), nodePlugins); @@ -128,10 +136,14 @@ public final void startClusters() throws Exception { mockPlugins, Function.identity() ); - cluster.beforeTest(random()); + try { + cluster.beforeTest(random()); + } catch (Exception e) { + throw new RuntimeException(e); + } clusters.put(clusterAlias, cluster); - } - clusterGroup = new ClusterGroup(clusters); + }); + clusterGroup = new ClusterGroup(Map.copyOf(clusters)); configureAndConnectsToRemoteClusters(); } diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index d983fc854bdfd..a71f61740e17b 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -1205,10 +1205,30 @@ public static SecureString randomSecureStringOfLength(int codeUnits) { return new SecureString(randomAlpha.toCharArray()); } - public static String randomNullOrAlphaOfLength(int codeUnits) { + public static String randomAlphaOfLengthOrNull(int codeUnits) { return randomBoolean() ? null : randomAlphaOfLength(codeUnits); } + public static Long randomLongOrNull() { + return randomBoolean() ? null : randomLong(); + } + + public static Long randomPositiveLongOrNull() { + return randomBoolean() ? null : randomNonNegativeLong(); + } + + public static Integer randomIntOrNull() { + return randomBoolean() ? null : randomInt(); + } + + public static Integer randomPositiveIntOrNull() { + return randomBoolean() ? null : randomNonNegativeInt(); + } + + public static Float randomFloatOrNull() { + return randomBoolean() ? null : randomFloat(); + } + /** * Creates a valid random identifier such as node id or index name */ diff --git a/test/framework/src/main/java/org/elasticsearch/test/VersionUtils.java b/test/framework/src/main/java/org/elasticsearch/test/VersionUtils.java index d561c5512b614..8b7ab620774b9 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/VersionUtils.java +++ b/test/framework/src/main/java/org/elasticsearch/test/VersionUtils.java @@ -12,132 +12,15 @@ import org.elasticsearch.Build; import org.elasticsearch.Version; import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.Tuple; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Random; -import java.util.stream.Collectors; -import java.util.stream.Stream; /** Utilities for selecting versions in tests */ public class VersionUtils { - /** - * Sort versions that have backwards compatibility guarantees from - * those that don't. Doesn't actually check whether or not the versions - * are released, instead it relies on gradle to have already checked - * this which it does in {@code :core:verifyVersions}. So long as the - * rules here match up with the rules in gradle then this should - * produce sensible results. - * @return a tuple containing versions with backwards compatibility - * guarantees in v1 and versions without the guranteees in v2 - */ - static Tuple, List> resolveReleasedVersions(Version current, Class versionClass) { - // group versions into major version - Map> majorVersions = Version.getDeclaredVersions(versionClass) - .stream() - .collect(Collectors.groupingBy(v -> (int) v.major)); - // this breaks b/c 5.x is still in version list but master doesn't care about it! - // assert majorVersions.size() == 2; - // TODO: remove oldVersions, we should only ever have 2 majors in Version - List> oldVersions = splitByMinor(majorVersions.getOrDefault((int) current.major - 2, Collections.emptyList())); - List> previousMajor = splitByMinor(majorVersions.get((int) current.major - 1)); - List> currentMajor = splitByMinor(majorVersions.get((int) current.major)); - - List unreleasedVersions = new ArrayList<>(); - final List> stableVersions; - if (currentMajor.size() == 1) { - // on master branch - stableVersions = previousMajor; - // remove current - moveLastToUnreleased(currentMajor, unreleasedVersions); - } else { - // on a stable or release branch, ie N.x - stableVersions = currentMajor; - // remove the next maintenance bugfix - moveLastToUnreleased(previousMajor, unreleasedVersions); - } - - // remove next minor - Version lastMinor = moveLastToUnreleased(stableVersions, unreleasedVersions); - if (lastMinor.revision == 0) { - if (stableVersions.get(stableVersions.size() - 1).size() == 1) { - // a minor is being staged, which is also unreleased - moveLastToUnreleased(stableVersions, unreleasedVersions); - } - // remove the next bugfix - if (stableVersions.isEmpty() == false) { - moveLastToUnreleased(stableVersions, unreleasedVersions); - } - } - - // If none of the previous major was released, then the last minor and bugfix of the old version was not released either. - if (previousMajor.isEmpty()) { - assert currentMajor.isEmpty() : currentMajor; - // minor of the old version is being staged - moveLastToUnreleased(oldVersions, unreleasedVersions); - // bugix of the old version is also being staged - moveLastToUnreleased(oldVersions, unreleasedVersions); - } - List releasedVersions = Stream.of(oldVersions, previousMajor, currentMajor) - .flatMap(List::stream) - .flatMap(List::stream) - .collect(Collectors.toList()); - Collections.sort(unreleasedVersions); // we add unreleased out of order, so need to sort here - return new Tuple<>(Collections.unmodifiableList(releasedVersions), Collections.unmodifiableList(unreleasedVersions)); - } - - // split the given versions into sub lists grouped by minor version - private static List> splitByMinor(List versions) { - Map> byMinor = versions.stream().collect(Collectors.groupingBy(v -> (int) v.minor)); - return byMinor.entrySet().stream().sorted(Map.Entry.comparingByKey()).map(Map.Entry::getValue).collect(Collectors.toList()); - } - - // move the last version of the last minor in versions to the unreleased versions - private static Version moveLastToUnreleased(List> versions, List unreleasedVersions) { - List lastMinor = new ArrayList<>(versions.get(versions.size() - 1)); - Version lastVersion = lastMinor.remove(lastMinor.size() - 1); - if (lastMinor.isEmpty()) { - versions.remove(versions.size() - 1); - } else { - versions.set(versions.size() - 1, lastMinor); - } - unreleasedVersions.add(lastVersion); - return lastVersion; - } - - private static final List RELEASED_VERSIONS; - private static final List UNRELEASED_VERSIONS; - private static final List ALL_VERSIONS; - - static { - Tuple, List> versions = resolveReleasedVersions(Version.CURRENT, Version.class); - RELEASED_VERSIONS = versions.v1(); - UNRELEASED_VERSIONS = versions.v2(); - List allVersions = new ArrayList<>(RELEASED_VERSIONS.size() + UNRELEASED_VERSIONS.size()); - allVersions.addAll(RELEASED_VERSIONS); - allVersions.addAll(UNRELEASED_VERSIONS); - Collections.sort(allVersions); - ALL_VERSIONS = Collections.unmodifiableList(allVersions); - } - - /** - * Returns an immutable, sorted list containing all released versions. - */ - public static List allReleasedVersions() { - return RELEASED_VERSIONS; - } - - /** - * Returns an immutable, sorted list containing all unreleased versions. - */ - public static List allUnreleasedVersions() { - return UNRELEASED_VERSIONS; - } + private static final List ALL_VERSIONS = Version.getDeclaredVersions(Version.class); /** * Returns an immutable, sorted list containing all versions, both released and unreleased. @@ -147,16 +30,16 @@ public static List allVersions() { } /** - * Get the released version before {@code version}. + * Get the version before {@code version}. */ public static Version getPreviousVersion(Version version) { - for (int i = RELEASED_VERSIONS.size() - 1; i >= 0; i--) { - Version v = RELEASED_VERSIONS.get(i); + for (int i = ALL_VERSIONS.size() - 1; i >= 0; i--) { + Version v = ALL_VERSIONS.get(i); if (v.before(version)) { return v; } } - throw new IllegalArgumentException("couldn't find any released versions before [" + version + "]"); + throw new IllegalArgumentException("couldn't find any versions before [" + version + "]"); } /** @@ -169,22 +52,22 @@ public static Version getPreviousVersion() { } /** - * Returns the released {@link Version} before the {@link Version#CURRENT} + * Returns the {@link Version} before the {@link Version#CURRENT} * where the minor version is less than the currents minor version. */ public static Version getPreviousMinorVersion() { - for (int i = RELEASED_VERSIONS.size() - 1; i >= 0; i--) { - Version v = RELEASED_VERSIONS.get(i); + for (int i = ALL_VERSIONS.size() - 1; i >= 0; i--) { + Version v = ALL_VERSIONS.get(i); if (v.minor < Version.CURRENT.minor || v.major < Version.CURRENT.major) { return v; } } - throw new IllegalArgumentException("couldn't find any released versions of the minor before [" + Build.current().version() + "]"); + throw new IllegalArgumentException("couldn't find any versions of the minor before [" + Build.current().version() + "]"); } - /** Returns the oldest released {@link Version} */ + /** Returns the oldest {@link Version} */ public static Version getFirstVersion() { - return RELEASED_VERSIONS.get(0); + return ALL_VERSIONS.get(0); } /** Returns a random {@link Version} from all available versions. */ diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java index b4f4243fb90fd..4428afaaeabe5 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestCase.java @@ -333,8 +333,11 @@ public void initClient() throws IOException { assert testFeatureServiceInitialized() == false; clusterHosts = parseClusterHosts(getTestRestCluster()); logger.info("initializing REST clients against {}", clusterHosts); - client = buildClient(restClientSettings(), clusterHosts.toArray(new HttpHost[clusterHosts.size()])); - adminClient = buildClient(restAdminSettings(), clusterHosts.toArray(new HttpHost[clusterHosts.size()])); + var clientSettings = restClientSettings(); + var adminSettings = restAdminSettings(); + var hosts = clusterHosts.toArray(new HttpHost[0]); + client = buildClient(clientSettings, hosts); + adminClient = clientSettings.equals(adminSettings) ? client : buildClient(adminSettings, hosts); availableFeatures = EnumSet.of(ProductFeature.LEGACY_TEMPLATES); Set versions = new HashSet<>(); diff --git a/test/framework/src/test/java/org/elasticsearch/test/VersionUtilsTests.java b/test/framework/src/test/java/org/elasticsearch/test/VersionUtilsTests.java index e0013e06f3248..5ae7e5640fc91 100644 --- a/test/framework/src/test/java/org/elasticsearch/test/VersionUtilsTests.java +++ b/test/framework/src/test/java/org/elasticsearch/test/VersionUtilsTests.java @@ -9,19 +9,11 @@ package org.elasticsearch.test; import org.elasticsearch.Version; -import org.elasticsearch.core.Booleans; -import org.elasticsearch.core.Tuple; import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; import java.util.List; -import java.util.Set; import static org.elasticsearch.Version.fromId; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.lessThanOrEqualTo; /** * Tests VersionUtils. Note: this test should remain unchanged across major versions @@ -30,7 +22,7 @@ public class VersionUtilsTests extends ESTestCase { public void testAllVersionsSorted() { - List allVersions = VersionUtils.allReleasedVersions(); + List allVersions = VersionUtils.allVersions(); for (int i = 0, j = 1; j < allVersions.size(); ++i, ++j) { assertTrue(allVersions.get(i).before(allVersions.get(j))); } @@ -58,9 +50,9 @@ public void testRandomVersionBetween() { got = VersionUtils.randomVersionBetween(random(), null, fromId(7000099)); assertTrue(got.onOrAfter(VersionUtils.getFirstVersion())); assertTrue(got.onOrBefore(fromId(7000099))); - got = VersionUtils.randomVersionBetween(random(), null, VersionUtils.allReleasedVersions().get(0)); + got = VersionUtils.randomVersionBetween(random(), null, VersionUtils.allVersions().get(0)); assertTrue(got.onOrAfter(VersionUtils.getFirstVersion())); - assertTrue(got.onOrBefore(VersionUtils.allReleasedVersions().get(0))); + assertTrue(got.onOrBefore(VersionUtils.allVersions().get(0))); // unbounded upper got = VersionUtils.randomVersionBetween(random(), VersionUtils.getFirstVersion(), null); @@ -83,265 +75,34 @@ public void testRandomVersionBetween() { assertEquals(got, VersionUtils.getFirstVersion()); got = VersionUtils.randomVersionBetween(random(), Version.CURRENT, null); assertEquals(got, Version.CURRENT); - - if (Booleans.parseBoolean(System.getProperty("build.snapshot", "true"))) { - // max or min can be an unreleased version - final Version unreleased = randomFrom(VersionUtils.allUnreleasedVersions()); - assertThat(VersionUtils.randomVersionBetween(random(), null, unreleased), lessThanOrEqualTo(unreleased)); - assertThat(VersionUtils.randomVersionBetween(random(), unreleased, null), greaterThanOrEqualTo(unreleased)); - assertEquals(unreleased, VersionUtils.randomVersionBetween(random(), unreleased, unreleased)); - } - } - - public static class TestReleaseBranch { - public static final Version V_4_0_0 = Version.fromString("4.0.0"); - public static final Version V_4_0_1 = Version.fromString("4.0.1"); - public static final Version V_5_3_0 = Version.fromString("5.3.0"); - public static final Version V_5_3_1 = Version.fromString("5.3.1"); - public static final Version V_5_3_2 = Version.fromString("5.3.2"); - public static final Version V_5_4_0 = Version.fromString("5.4.0"); - public static final Version V_5_4_1 = Version.fromString("5.4.1"); - public static final Version CURRENT = V_5_4_1; - } - - public void testResolveReleasedVersionsForReleaseBranch() { - Tuple, List> t = VersionUtils.resolveReleasedVersions(TestReleaseBranch.CURRENT, TestReleaseBranch.class); - List released = t.v1(); - List unreleased = t.v2(); - - assertThat( - released, - equalTo( - Arrays.asList( - TestReleaseBranch.V_4_0_0, - TestReleaseBranch.V_5_3_0, - TestReleaseBranch.V_5_3_1, - TestReleaseBranch.V_5_3_2, - TestReleaseBranch.V_5_4_0 - ) - ) - ); - assertThat(unreleased, equalTo(Arrays.asList(TestReleaseBranch.V_4_0_1, TestReleaseBranch.V_5_4_1))); - } - - public static class TestStableBranch { - public static final Version V_4_0_0 = Version.fromString("4.0.0"); - public static final Version V_4_0_1 = Version.fromString("4.0.1"); - public static final Version V_5_0_0 = Version.fromString("5.0.0"); - public static final Version V_5_0_1 = Version.fromString("5.0.1"); - public static final Version V_5_0_2 = Version.fromString("5.0.2"); - public static final Version V_5_1_0 = Version.fromString("5.1.0"); - public static final Version CURRENT = V_5_1_0; - } - - public void testResolveReleasedVersionsForUnreleasedStableBranch() { - Tuple, List> t = VersionUtils.resolveReleasedVersions(TestStableBranch.CURRENT, TestStableBranch.class); - List released = t.v1(); - List unreleased = t.v2(); - - assertThat(released, equalTo(Arrays.asList(TestStableBranch.V_4_0_0, TestStableBranch.V_5_0_0, TestStableBranch.V_5_0_1))); - assertThat(unreleased, equalTo(Arrays.asList(TestStableBranch.V_4_0_1, TestStableBranch.V_5_0_2, TestStableBranch.V_5_1_0))); - } - - public static class TestStableBranchBehindStableBranch { - public static final Version V_4_0_0 = Version.fromString("4.0.0"); - public static final Version V_4_0_1 = Version.fromString("4.0.1"); - public static final Version V_5_3_0 = Version.fromString("5.3.0"); - public static final Version V_5_3_1 = Version.fromString("5.3.1"); - public static final Version V_5_3_2 = Version.fromString("5.3.2"); - public static final Version V_5_4_0 = Version.fromString("5.4.0"); - public static final Version V_5_5_0 = Version.fromString("5.5.0"); - public static final Version CURRENT = V_5_5_0; - } - - public void testResolveReleasedVersionsForStableBranchBehindStableBranch() { - Tuple, List> t = VersionUtils.resolveReleasedVersions( - TestStableBranchBehindStableBranch.CURRENT, - TestStableBranchBehindStableBranch.class - ); - List released = t.v1(); - List unreleased = t.v2(); - - assertThat( - released, - equalTo( - Arrays.asList( - TestStableBranchBehindStableBranch.V_4_0_0, - TestStableBranchBehindStableBranch.V_5_3_0, - TestStableBranchBehindStableBranch.V_5_3_1 - ) - ) - ); - assertThat( - unreleased, - equalTo( - Arrays.asList( - TestStableBranchBehindStableBranch.V_4_0_1, - TestStableBranchBehindStableBranch.V_5_3_2, - TestStableBranchBehindStableBranch.V_5_4_0, - TestStableBranchBehindStableBranch.V_5_5_0 - ) - ) - ); - } - - public static class TestUnstableBranch { - public static final Version V_5_3_0 = Version.fromString("5.3.0"); - public static final Version V_5_3_1 = Version.fromString("5.3.1"); - public static final Version V_5_3_2 = Version.fromString("5.3.2"); - public static final Version V_5_4_0 = Version.fromString("5.4.0"); - public static final Version V_6_0_0 = Version.fromString("6.0.0"); - public static final Version CURRENT = V_6_0_0; - } - - public void testResolveReleasedVersionsForUnstableBranch() { - Tuple, List> t = VersionUtils.resolveReleasedVersions(TestUnstableBranch.CURRENT, TestUnstableBranch.class); - List released = t.v1(); - List unreleased = t.v2(); - - assertThat(released, equalTo(Arrays.asList(TestUnstableBranch.V_5_3_0, TestUnstableBranch.V_5_3_1))); - assertThat(unreleased, equalTo(Arrays.asList(TestUnstableBranch.V_5_3_2, TestUnstableBranch.V_5_4_0, TestUnstableBranch.V_6_0_0))); - } - - public static class TestNewMajorRelease { - public static final Version V_5_6_0 = Version.fromString("5.6.0"); - public static final Version V_5_6_1 = Version.fromString("5.6.1"); - public static final Version V_5_6_2 = Version.fromString("5.6.2"); - public static final Version V_6_0_0 = Version.fromString("6.0.0"); - public static final Version V_6_0_1 = Version.fromString("6.0.1"); - public static final Version CURRENT = V_6_0_1; - } - - public void testResolveReleasedVersionsAtNewMajorRelease() { - Tuple, List> t = VersionUtils.resolveReleasedVersions( - TestNewMajorRelease.CURRENT, - TestNewMajorRelease.class - ); - List released = t.v1(); - List unreleased = t.v2(); - - assertThat(released, equalTo(Arrays.asList(TestNewMajorRelease.V_5_6_0, TestNewMajorRelease.V_5_6_1, TestNewMajorRelease.V_6_0_0))); - assertThat(unreleased, equalTo(Arrays.asList(TestNewMajorRelease.V_5_6_2, TestNewMajorRelease.V_6_0_1))); - } - - public static class TestVersionBumpIn6x { - public static final Version V_5_6_0 = Version.fromString("5.6.0"); - public static final Version V_5_6_1 = Version.fromString("5.6.1"); - public static final Version V_5_6_2 = Version.fromString("5.6.2"); - public static final Version V_6_0_0 = Version.fromString("6.0.0"); - public static final Version V_6_0_1 = Version.fromString("6.0.1"); - public static final Version V_6_1_0 = Version.fromString("6.1.0"); - public static final Version CURRENT = V_6_1_0; - } - - public void testResolveReleasedVersionsAtVersionBumpIn6x() { - Tuple, List> t = VersionUtils.resolveReleasedVersions( - TestVersionBumpIn6x.CURRENT, - TestVersionBumpIn6x.class - ); - List released = t.v1(); - List unreleased = t.v2(); - - assertThat(released, equalTo(Arrays.asList(TestVersionBumpIn6x.V_5_6_0, TestVersionBumpIn6x.V_5_6_1, TestVersionBumpIn6x.V_6_0_0))); - assertThat( - unreleased, - equalTo(Arrays.asList(TestVersionBumpIn6x.V_5_6_2, TestVersionBumpIn6x.V_6_0_1, TestVersionBumpIn6x.V_6_1_0)) - ); - } - - public static class TestNewMinorBranchIn6x { - public static final Version V_5_6_0 = Version.fromString("5.6.0"); - public static final Version V_5_6_1 = Version.fromString("5.6.1"); - public static final Version V_5_6_2 = Version.fromString("5.6.2"); - public static final Version V_6_0_0 = Version.fromString("6.0.0"); - public static final Version V_6_0_1 = Version.fromString("6.0.1"); - public static final Version V_6_1_0 = Version.fromString("6.1.0"); - public static final Version V_6_1_1 = Version.fromString("6.1.1"); - public static final Version V_6_1_2 = Version.fromString("6.1.2"); - public static final Version V_6_2_0 = Version.fromString("6.2.0"); - public static final Version CURRENT = V_6_2_0; - } - - public void testResolveReleasedVersionsAtNewMinorBranchIn6x() { - Tuple, List> t = VersionUtils.resolveReleasedVersions( - TestNewMinorBranchIn6x.CURRENT, - TestNewMinorBranchIn6x.class - ); - List released = t.v1(); - List unreleased = t.v2(); - - assertThat( - released, - equalTo( - Arrays.asList( - TestNewMinorBranchIn6x.V_5_6_0, - TestNewMinorBranchIn6x.V_5_6_1, - TestNewMinorBranchIn6x.V_6_0_0, - TestNewMinorBranchIn6x.V_6_0_1, - TestNewMinorBranchIn6x.V_6_1_0, - TestNewMinorBranchIn6x.V_6_1_1 - ) - ) - ); - assertThat( - unreleased, - equalTo(Arrays.asList(TestNewMinorBranchIn6x.V_5_6_2, TestNewMinorBranchIn6x.V_6_1_2, TestNewMinorBranchIn6x.V_6_2_0)) - ); } /** - * Tests that {@link Version#minimumCompatibilityVersion()} and {@link VersionUtils#allReleasedVersions()} + * Tests that {@link Version#minimumCompatibilityVersion()} and {@link VersionUtils#allVersions()} * agree with the list of wire compatible versions we build in gradle. */ public void testGradleVersionsMatchVersionUtils() { // First check the index compatible versions - List released = VersionUtils.allReleasedVersions() + List versions = VersionUtils.allVersions() .stream() /* Java lists all versions from the 5.x series onwards, but we only want to consider * ones that we're supposed to be compatible with. */ .filter(v -> v.onOrAfter(Version.CURRENT.minimumCompatibilityVersion())) + .map(Version::toString) .toList(); - VersionsFromProperty wireCompatible = new VersionsFromProperty("tests.gradle_wire_compat_versions"); - - Version minimumCompatibleVersion = Version.CURRENT.minimumCompatibilityVersion(); - List releasedWireCompatible = released.stream() - .filter(v -> Version.CURRENT.equals(v) == false) - .filter(v -> v.onOrAfter(minimumCompatibleVersion)) - .map(Object::toString) - .toList(); - assertEquals(releasedWireCompatible, wireCompatible.released); - - List unreleasedWireCompatible = VersionUtils.allUnreleasedVersions() - .stream() - .filter(v -> v.onOrAfter(minimumCompatibleVersion)) - .map(Object::toString) - .toList(); - assertEquals(unreleasedWireCompatible, wireCompatible.unreleased); + List gradleVersions = versionFromProperty("tests.gradle_wire_compat_versions"); + assertEquals(versions, gradleVersions); } - /** - * Read a versions system property as set by gradle into a tuple of {@code (releasedVersion, unreleasedVersion)}. - */ - private class VersionsFromProperty { - private final List released = new ArrayList<>(); - private final List unreleased = new ArrayList<>(); - - private VersionsFromProperty(String property) { - Set allUnreleased = new HashSet<>(Arrays.asList(System.getProperty("tests.gradle_unreleased_versions", "").split(","))); - if (allUnreleased.isEmpty()) { - fail("[tests.gradle_unreleased_versions] not set or empty. Gradle should set this before running."); - } - String versions = System.getProperty(property); - assertNotNull("Couldn't find [" + property + "]. Gradle should set this before running the tests.", versions); - logger.info("Looked up versions [{}={}]", property, versions); - - for (String version : versions.split(",")) { - if (allUnreleased.contains(version)) { - unreleased.add(version); - } else { - released.add(version); - } - } + private List versionFromProperty(String property) { + List versions = new ArrayList<>(); + String versionsString = System.getProperty(property); + assertNotNull("Couldn't find [" + property + "]. Gradle should set this before running the tests.", versionsString); + logger.info("Looked up versions [{}={}]", property, versionsString); + for (String version : versionsString.split(",")) { + versions.add(version); } + + return versions; } } diff --git a/test/immutable-collections-patch/build.gradle b/test/immutable-collections-patch/build.gradle index 85a199af2d477..852a19116fb71 100644 --- a/test/immutable-collections-patch/build.gradle +++ b/test/immutable-collections-patch/build.gradle @@ -17,8 +17,8 @@ configurations { } dependencies { - implementation 'org.ow2.asm:asm:9.7' - implementation 'org.ow2.asm:asm-tree:9.7' + implementation 'org.ow2.asm:asm:9.7.1' + implementation 'org.ow2.asm:asm-tree:9.7.1' } def outputDir = layout.buildDirectory.dir("jdk-patches") diff --git a/test/logger-usage/build.gradle b/test/logger-usage/build.gradle index 8677b1404a727..6d6c5ff889a45 100644 --- a/test/logger-usage/build.gradle +++ b/test/logger-usage/build.gradle @@ -10,9 +10,9 @@ apply plugin: 'elasticsearch.java' dependencies { - api 'org.ow2.asm:asm:9.7' - api 'org.ow2.asm:asm-tree:9.7' - api 'org.ow2.asm:asm-analysis:9.7' + api 'org.ow2.asm:asm:9.7.1' + api 'org.ow2.asm:asm-tree:9.7.1' + api 'org.ow2.asm:asm-analysis:9.7.1' api "org.apache.logging.log4j:log4j-api:${versions.log4j}" testImplementation project(":test:framework") } diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/local/DefaultLocalClusterHandle.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/local/DefaultLocalClusterHandle.java index eb45aacda68da..13adde1da8a69 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/local/DefaultLocalClusterHandle.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/local/DefaultLocalClusterHandle.java @@ -176,8 +176,9 @@ public long getPid(int index) { return nodes.get(index).getPid(); } + @Override public void stopNode(int index, boolean forcibly) { - nodes.get(index).stop(false); + nodes.get(index).stop(forcibly); } @Override @@ -252,9 +253,8 @@ private void writeUnicastHostsFile() { execute(() -> nodes.parallelStream().forEach(node -> { try { Path hostsFile = node.getWorkingDir().resolve("config").resolve("unicast_hosts.txt"); - if (Files.notExists(hostsFile)) { - Files.writeString(hostsFile, transportUris); - } + LOGGER.info("Writing unicast hosts file {} for node {}", hostsFile, node.getName()); + Files.writeString(hostsFile, transportUris); } catch (IOException e) { throw new UncheckedIOException("Failed to write unicast_hosts for: " + node, e); } diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/InternalMultiTerms.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/InternalMultiTerms.java index 0d42a2856a10e..85510c8a989c0 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/InternalMultiTerms.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/InternalMultiTerms.java @@ -37,9 +37,6 @@ public class InternalMultiTerms extends AbstractInternalTerms { - - long bucketOrd; - protected long docCount; protected InternalAggregations aggregations; private long docCountError; diff --git a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/MultiTermsAggregator.java b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/MultiTermsAggregator.java index 1691aedf543f4..5c10e2c8feeb1 100644 --- a/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/MultiTermsAggregator.java +++ b/x-pack/plugin/analytics/src/main/java/org/elasticsearch/xpack/analytics/multiterms/MultiTermsAggregator.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.IntArray; import org.elasticsearch.common.util.LongArray; import org.elasticsearch.common.util.ObjectArray; import org.elasticsearch.common.util.ObjectArrayPriorityQueue; @@ -40,6 +41,7 @@ import org.elasticsearch.search.aggregations.LeafBucketCollector; import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; import org.elasticsearch.search.aggregations.bucket.DeferableBucketAggregator; +import org.elasticsearch.search.aggregations.bucket.terms.BucketAndOrd; import org.elasticsearch.search.aggregations.bucket.terms.BucketPriorityQueue; import org.elasticsearch.search.aggregations.bucket.terms.BytesKeyedBucketOrds; import org.elasticsearch.search.aggregations.bucket.terms.TermsAggregator; @@ -72,7 +74,7 @@ class MultiTermsAggregator extends DeferableBucketAggregator { protected final List formats; protected final TermsAggregator.BucketCountThresholds bucketCountThresholds; protected final BucketOrder order; - protected final Comparator partiallyBuiltBucketComparator; + protected final Comparator> partiallyBuiltBucketComparator; protected final Set aggsUsedForSorting; protected final SubAggCollectionMode collectMode; private final List values; @@ -99,7 +101,7 @@ protected MultiTermsAggregator( super(name, factories, context, parent, metadata); this.bucketCountThresholds = bucketCountThresholds; this.order = order; - partiallyBuiltBucketComparator = order == null ? null : order.partiallyBuiltBucketComparator(b -> b.bucketOrd, this); + partiallyBuiltBucketComparator = order == null ? null : order.partiallyBuiltBucketComparator(this); this.formats = formats; this.showTermDocCountError = showTermDocCountError; if (subAggsNeedScore() && descendsFromNestedAggregator(parent) || context.isInSortOrderExecutionRequired()) { @@ -242,52 +244,67 @@ public InternalAggregation[] buildAggregations(LongArray owningBucketOrds) throw LongArray otherDocCounts = bigArrays().newLongArray(owningBucketOrds.size(), true); ObjectArray topBucketsPerOrd = bigArrays().newObjectArray(owningBucketOrds.size()) ) { - for (long ordIdx = 0; ordIdx < owningBucketOrds.size(); ordIdx++) { - final long owningBucketOrd = owningBucketOrds.get(ordIdx); - long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrd); - - int size = (int) Math.min(bucketsInOrd, bucketCountThresholds.getShardSize()); - try ( - ObjectArrayPriorityQueue ordered = new BucketPriorityQueue<>( - size, - bigArrays(), - partiallyBuiltBucketComparator - ) - ) { - InternalMultiTerms.Bucket spare = null; - BytesRef spareKey = null; - BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrd); - while (ordsEnum.next()) { - long docCount = bucketDocCount(ordsEnum.ord()); - otherDocCounts.increment(ordIdx, docCount); - if (docCount < bucketCountThresholds.getShardMinDocCount()) { - continue; - } - if (spare == null) { - checkRealMemoryCBForInternalBucket(); - spare = new InternalMultiTerms.Bucket(null, 0, null, showTermDocCountError, 0, formats, keyConverters); - spareKey = new BytesRef(); - } - ordsEnum.readValue(spareKey); - spare.terms = unpackTerms(spareKey); - spare.docCount = docCount; - spare.bucketOrd = ordsEnum.ord(); - spare = ordered.insertWithOverflow(spare); - } + try (IntArray bucketsToCollect = bigArrays().newIntArray(owningBucketOrds.size())) { + long ordsToCollect = 0; + for (long ordIdx = 0; ordIdx < owningBucketOrds.size(); ordIdx++) { + int size = (int) Math.min(bucketOrds.bucketsInOrd(owningBucketOrds.get(ordIdx)), bucketCountThresholds.getShardSize()); + ordsToCollect += size; + bucketsToCollect.set(ordIdx, size); + } + try (LongArray ordsArray = bigArrays().newLongArray(ordsToCollect)) { + long ordsCollected = 0; + for (long ordIdx = 0; ordIdx < owningBucketOrds.size(); ordIdx++) { + final long owningBucketOrd = owningBucketOrds.get(ordIdx); + long bucketsInOrd = bucketOrds.bucketsInOrd(owningBucketOrd); + + int size = (int) Math.min(bucketsInOrd, bucketCountThresholds.getShardSize()); + try ( + ObjectArrayPriorityQueue> ordered = new BucketPriorityQueue<>( + size, + bigArrays(), + partiallyBuiltBucketComparator + ) + ) { + BucketAndOrd spare = null; + BytesRef spareKey = null; + BytesKeyedBucketOrds.BucketOrdsEnum ordsEnum = bucketOrds.ordsEnum(owningBucketOrd); + while (ordsEnum.next()) { + long docCount = bucketDocCount(ordsEnum.ord()); + otherDocCounts.increment(ordIdx, docCount); + if (docCount < bucketCountThresholds.getShardMinDocCount()) { + continue; + } + if (spare == null) { + checkRealMemoryCBForInternalBucket(); + spare = new BucketAndOrd<>( + new InternalMultiTerms.Bucket(null, 0, null, showTermDocCountError, 0, formats, keyConverters) + ); + spareKey = new BytesRef(); + } + ordsEnum.readValue(spareKey); + spare.bucket.terms = unpackTerms(spareKey); + spare.bucket.docCount = docCount; + spare.ord = ordsEnum.ord(); + spare = ordered.insertWithOverflow(spare); + } - // Get the top buckets - InternalMultiTerms.Bucket[] bucketsForOrd = new InternalMultiTerms.Bucket[(int) ordered.size()]; - topBucketsPerOrd.set(ordIdx, bucketsForOrd); - for (int b = (int) ordered.size() - 1; b >= 0; --b) { - InternalMultiTerms.Bucket[] buckets = topBucketsPerOrd.get(ordIdx); - buckets[b] = ordered.pop(); - otherDocCounts.increment(ordIdx, -buckets[b].getDocCount()); + // Get the top buckets + int orderedSize = (int) ordered.size(); + InternalMultiTerms.Bucket[] buckets = new InternalMultiTerms.Bucket[orderedSize]; + for (int i = orderedSize - 1; i >= 0; --i) { + BucketAndOrd bucketAndOrd = ordered.pop(); + buckets[i] = bucketAndOrd.bucket; + ordsArray.set(ordsCollected + i, bucketAndOrd.ord); + otherDocCounts.increment(ordIdx, -buckets[i].getDocCount()); + } + topBucketsPerOrd.set(ordIdx, buckets); + ordsCollected += orderedSize; + } } + buildSubAggsForAllBuckets(topBucketsPerOrd, ordsArray, (b, a) -> b.aggregations = a); } } - buildSubAggsForAllBuckets(topBucketsPerOrd, b -> b.bucketOrd, (b, a) -> b.aggregations = a); - return buildAggregations( Math.toIntExact(owningBucketOrds.size()), ordIdx -> buildResult(otherDocCounts.get(ordIdx), topBucketsPerOrd.get(ordIdx)) diff --git a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CCSUsageTelemetryAsyncSearchIT.java b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CCSUsageTelemetryAsyncSearchIT.java index 65f9f13846126..1b19f6f04693b 100644 --- a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CCSUsageTelemetryAsyncSearchIT.java +++ b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CCSUsageTelemetryAsyncSearchIT.java @@ -60,7 +60,7 @@ protected boolean reuseClusters() { } @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE1, REMOTE2); } diff --git a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CrossClusterAsyncSearchIT.java b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CrossClusterAsyncSearchIT.java index 3b5647da1399f..2a8daf8bfe12c 100644 --- a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CrossClusterAsyncSearchIT.java +++ b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/CrossClusterAsyncSearchIT.java @@ -88,7 +88,7 @@ public class CrossClusterAsyncSearchIT extends AbstractMultiClustersTestCase { private static final long LATEST_TIMESTAMP = 1691348820000L; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); } diff --git a/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/termsenum/CCSTermsEnumIT.java b/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/termsenum/CCSTermsEnumIT.java index 157628be9fbc9..f5c070073d9b5 100644 --- a/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/termsenum/CCSTermsEnumIT.java +++ b/x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/termsenum/CCSTermsEnumIT.java @@ -26,7 +26,7 @@ public class CCSTermsEnumIT extends AbstractMultiClustersTestCase { @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of("remote_cluster"); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/datastreams/DataStreamLifecycleFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/datastreams/DataStreamLifecycleFeatureSetUsage.java index 7a31888a440c3..a61a86eea7104 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/datastreams/DataStreamLifecycleFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/datastreams/DataStreamLifecycleFeatureSetUsage.java @@ -111,7 +111,7 @@ public LifecycleStats( } public static LifecycleStats read(StreamInput in) throws IOException { - if (in.getTransportVersion().onOrAfter(TransportVersions.GLOBAL_RETENTION_TELEMETRY)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { return new LifecycleStats( in.readVLong(), in.readBoolean(), @@ -139,7 +139,7 @@ public static LifecycleStats read(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { - if (out.getTransportVersion().onOrAfter(TransportVersions.GLOBAL_RETENTION_TELEMETRY)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVLong(dataStreamsWithLifecyclesCount); out.writeBoolean(defaultRolloverUsed); dataRetentionStats.writeTo(out); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/enrich/action/EnrichStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/enrich/action/EnrichStatsAction.java index 0457de6edcc9f..36322ed6c6cbd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/enrich/action/EnrichStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/enrich/action/EnrichStatsAction.java @@ -209,7 +209,7 @@ public CacheStats(StreamInput in) throws IOException { in.readVLong(), in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0) ? in.readLong() : -1, in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0) ? in.readLong() : -1, - in.getTransportVersion().onOrAfter(TransportVersions.ENRICH_CACHE_STATS_SIZE_ADDED) ? in.readLong() : -1 + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readLong() : -1 ); } @@ -237,7 +237,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeLong(hitsTimeInMillis); out.writeLong(missesTimeInMillis); } - if (out.getTransportVersion().onOrAfter(TransportVersions.ENRICH_CACHE_STATS_SIZE_ADDED)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeLong(cacheSizeInBytes); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/IndexLifecycleExplainResponse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/IndexLifecycleExplainResponse.java index 33402671a2236..5d635c97d9c8c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/IndexLifecycleExplainResponse.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/IndexLifecycleExplainResponse.java @@ -328,7 +328,7 @@ public IndexLifecycleExplainResponse(StreamInput in) throws IOException { } else { indexCreationDate = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.RETAIN_ILM_STEP_INFO)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { previousStepInfo = in.readOptionalBytesReference(); } else { previousStepInfo = null; @@ -379,7 +379,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_1_0)) { out.writeOptionalLong(indexCreationDate); } - if (out.getTransportVersion().onOrAfter(TransportVersions.RETAIN_ILM_STEP_INFO)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalBytesReference(previousStepInfo); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/SearchableSnapshotAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/SearchableSnapshotAction.java index c06dcc0f083d1..da64df2672bdb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/SearchableSnapshotAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ilm/SearchableSnapshotAction.java @@ -8,6 +8,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.elasticsearch.TransportVersions; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.health.ClusterHealthStatus; import org.elasticsearch.cluster.metadata.IndexAbstraction; @@ -32,7 +33,6 @@ import java.util.List; import java.util.Objects; -import static org.elasticsearch.TransportVersions.ILM_ADD_SEARCHABLE_SNAPSHOT_TOTAL_SHARDS_PER_NODE; import static org.elasticsearch.snapshots.SearchableSnapshotsSettings.SEARCHABLE_SNAPSHOTS_REPOSITORY_NAME_SETTING_KEY; import static org.elasticsearch.snapshots.SearchableSnapshotsSettings.SEARCHABLE_SNAPSHOTS_SNAPSHOT_NAME_SETTING_KEY; import static org.elasticsearch.snapshots.SearchableSnapshotsSettings.SEARCHABLE_SNAPSHOT_PARTIAL_SETTING_KEY; @@ -102,9 +102,7 @@ public SearchableSnapshotAction(String snapshotRepository) { public SearchableSnapshotAction(StreamInput in) throws IOException { this.snapshotRepository = in.readString(); this.forceMergeIndex = in.readBoolean(); - this.totalShardsPerNode = in.getTransportVersion().onOrAfter(ILM_ADD_SEARCHABLE_SNAPSHOT_TOTAL_SHARDS_PER_NODE) - ? in.readOptionalInt() - : null; + this.totalShardsPerNode = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalInt() : null; } boolean isForceMergeIndex() { @@ -424,7 +422,7 @@ public String getWriteableName() { public void writeTo(StreamOutput out) throws IOException { out.writeString(snapshotRepository); out.writeBoolean(forceMergeIndex); - if (out.getTransportVersion().onOrAfter(ILM_ADD_SEARCHABLE_SNAPSHOT_TOTAL_SHARDS_PER_NODE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalInt(totalShardsPerNode); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java new file mode 100644 index 0000000000000..e426574c52ce6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.inference.TaskType; + +import java.io.IOException; + +public abstract class BaseInferenceActionRequest extends ActionRequest { + + public BaseInferenceActionRequest() { + super(); + } + + public BaseInferenceActionRequest(StreamInput in) throws IOException { + super(in); + } + + public abstract boolean isStreaming(); + + public abstract TaskType getTaskType(); + + public abstract String getInferenceEntityId(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java index 226fe3630b387..c3f991a8b4e1e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java @@ -127,7 +127,7 @@ public Response(StreamInput in) throws IOException { pipelineIds = Set.of(); } - if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { indexes = in.readCollectionAsSet(StreamInput::readString); dryRunMessage = in.readOptionalString(); } else { @@ -143,7 +143,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { out.writeCollection(pipelineIds, StreamOutput::writeString); } - if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeCollection(indexes, StreamOutput::writeString); out.writeOptionalString(dryRunMessage); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java index ea0462d0f103e..ba3d417d02672 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/GetInferenceModelAction.java @@ -63,7 +63,7 @@ public Request(StreamInput in) throws IOException { this.inferenceEntityId = in.readString(); this.taskType = TaskType.fromStream(in); if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ) - || in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ_BACKPORT_8_16)) { + || in.getTransportVersion().isPatchFrom(TransportVersions.V_8_16_0)) { this.persistDefaultConfig = in.readBoolean(); } else { this.persistDefaultConfig = PERSIST_DEFAULT_CONFIGS; @@ -89,7 +89,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(inferenceEntityId); taskType.writeTo(out); if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ) - || out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_DONT_PERSIST_ON_READ_BACKPORT_8_16)) { + || out.getTransportVersion().isPatchFrom(TransportVersions.V_8_16_0)) { out.writeBoolean(this.persistDefaultConfig); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index a19edd5a08162..f88909ba4208e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -54,7 +53,7 @@ public InferenceAction() { super(NAME); } - public static class Request extends ActionRequest { + public static class Request extends BaseInferenceActionRequest { public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30); public static final ParseField INPUT = new ParseField("input"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java new file mode 100644 index 0000000000000..8d121463fb465 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class UnifiedCompletionAction extends ActionType { + public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction(); + public static final String NAME = "cluster:monitor/xpack/inference/unified"; + + public UnifiedCompletionAction() { + super(NAME); + } + + public static class Request extends BaseInferenceActionRequest { + public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser) + throws IOException { + var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null); + return new Request(inferenceEntityId, taskType, unifiedRequest, timeout); + } + + private final String inferenceEntityId; + private final TaskType taskType; + private final UnifiedCompletionRequest unifiedCompletionRequest; + private final TimeValue timeout; + + public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest, TimeValue timeout) { + this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId); + this.taskType = Objects.requireNonNull(taskType); + this.unifiedCompletionRequest = Objects.requireNonNull(unifiedCompletionRequest); + this.timeout = Objects.requireNonNull(timeout); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.inferenceEntityId = in.readString(); + this.taskType = TaskType.fromStream(in); + this.unifiedCompletionRequest = new UnifiedCompletionRequest(in); + this.timeout = in.readTimeValue(); + } + + public TaskType getTaskType() { + return taskType; + } + + public String getInferenceEntityId() { + return inferenceEntityId; + } + + public UnifiedCompletionRequest getUnifiedCompletionRequest() { + return unifiedCompletionRequest; + } + + /** + * The Unified API only supports streaming so we always return true here. + * @return true + */ + public boolean isStreaming() { + return true; + } + + public TimeValue getTimeout() { + return timeout; + } + + @Override + public ActionRequestValidationException validate() { + if (unifiedCompletionRequest == null || unifiedCompletionRequest.messages() == null) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [messages] cannot be null"); + return e; + } + + if (unifiedCompletionRequest.messages().isEmpty()) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [messages] cannot be an empty array"); + return e; + } + + if (taskType.isAnyOrSame(TaskType.COMPLETION) == false) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [taskType] must be [completion]"); + return e; + } + + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(inferenceEntityId); + taskType.writeTo(out); + unifiedCompletionRequest.writeTo(out); + out.writeTimeValue(timeout); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(inferenceEntityId, request.inferenceEntityId) + && taskType == request.taskType + && Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest) + && Objects.equals(timeout, request.timeout); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceEntityId, taskType, unifiedCompletionRequest, timeout); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java new file mode 100644 index 0000000000000..90038c67036c4 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -0,0 +1,329 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ToXContent; + +import java.io.IOException; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Flow; + +/** + * Chat Completion results that only contain a Flow.Publisher. + */ +public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) + implements + InferenceServiceResults { + + public static final String NAME = "chat_completion_chunk"; + public static final String MODEL_FIELD = "model"; + public static final String OBJECT_FIELD = "object"; + public static final String USAGE_FIELD = "usage"; + public static final String INDEX_FIELD = "index"; + public static final String ID_FIELD = "id"; + public static final String FUNCTION_NAME_FIELD = "name"; + public static final String FUNCTION_ARGUMENTS_FIELD = "arguments"; + public static final String FUNCTION_FIELD = "function"; + public static final String CHOICES_FIELD = "choices"; + public static final String DELTA_FIELD = "delta"; + public static final String CONTENT_FIELD = "content"; + public static final String REFUSAL_FIELD = "refusal"; + public static final String ROLE_FIELD = "role"; + private static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String FINISH_REASON_FIELD = "finish_reason"; + public static final String COMPLETION_TOKENS_FIELD = "completion_tokens"; + public static final String TOTAL_TOKENS_FIELD = "total_tokens"; + public static final String PROMPT_TOKENS_FIELD = "prompt_tokens"; + public static final String TYPE_FIELD = "type"; + + @Override + public boolean isStreaming() { + return true; + } + + @Override + public List transformToCoordinationFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public List transformToLegacyFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Map asMap() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + throw new UnsupportedOperationException("Not implemented"); + } + + public record Results(Deque chunks) implements ChunkedToXContent { + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat(Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params))); + } + } + + public static class ChatCompletionChunk implements ChunkedToXContent { + private final String id; + + public String getId() { + return id; + } + + public List getChoices() { + return choices; + } + + public String getModel() { + return model; + } + + public String getObject() { + return object; + } + + public Usage getUsage() { + return usage; + } + + private final List choices; + private final String model; + private final String object; + private final ChatCompletionChunk.Usage usage; + + public ChatCompletionChunk(String id, List choices, String model, String object, ChatCompletionChunk.Usage usage) { + this.id = id; + this.choices = choices; + this.model = model; + this.object = object; + this.usage = usage; + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + + Iterator choicesIterator = Collections.emptyIterator(); + if (choices != null) { + choicesIterator = Iterators.concat( + ChunkedToXContentHelper.startArray(CHOICES_FIELD), + Iterators.flatMap(choices.iterator(), c -> c.toXContentChunked(params)), + ChunkedToXContentHelper.endArray() + ); + } + + Iterator usageIterator = Collections.emptyIterator(); + if (usage != null) { + usageIterator = Iterators.concat( + ChunkedToXContentHelper.startObject(USAGE_FIELD), + ChunkedToXContentHelper.field(COMPLETION_TOKENS_FIELD, usage.completionTokens()), + ChunkedToXContentHelper.field(PROMPT_TOKENS_FIELD, usage.promptTokens()), + ChunkedToXContentHelper.field(TOTAL_TOKENS_FIELD, usage.totalTokens()), + ChunkedToXContentHelper.endObject() + ); + } + + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field(ID_FIELD, id), + choicesIterator, + ChunkedToXContentHelper.field(MODEL_FIELD, model), + ChunkedToXContentHelper.field(OBJECT_FIELD, object), + usageIterator, + ChunkedToXContentHelper.endObject() + ); + } + + public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) { + + /* + choices: Array<{ + delta: { ... }; + finish_reason: string | null; + index: number; + }>; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + delta.toXContentChunked(params), + ChunkedToXContentHelper.optionalField(FINISH_REASON_FIELD, finishReason), + ChunkedToXContentHelper.field(INDEX_FIELD, index), + ChunkedToXContentHelper.endObject() + ); + } + + public static class Delta { + private final String content; + private final String refusal; + private final String role; + private List toolCalls; + + public Delta(String content, String refusal, String role, List toolCalls) { + this.content = content; + this.refusal = refusal; + this.role = role; + this.toolCalls = toolCalls; + } + + /* + delta: { + content?: string | null; + refusal?: string | null; + role?: 'system' | 'user' | 'assistant' | 'tool'; + tool_calls?: Array<{ ... }>; + }; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + var xContent = Iterators.concat( + ChunkedToXContentHelper.startObject(DELTA_FIELD), + ChunkedToXContentHelper.optionalField(CONTENT_FIELD, content), + ChunkedToXContentHelper.optionalField(REFUSAL_FIELD, refusal), + ChunkedToXContentHelper.optionalField(ROLE_FIELD, role) + ); + + if (toolCalls != null && toolCalls.isEmpty() == false) { + xContent = Iterators.concat( + xContent, + ChunkedToXContentHelper.startArray(TOOL_CALLS_FIELD), + Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)), + ChunkedToXContentHelper.endArray() + ); + } + xContent = Iterators.concat(xContent, ChunkedToXContentHelper.endObject()); + return xContent; + + } + + public String getContent() { + return content; + } + + public String getRefusal() { + return refusal; + } + + public String getRole() { + return role; + } + + public List getToolCalls() { + return toolCalls; + } + + public static class ToolCall { + private final int index; + private final String id; + public ChatCompletionChunk.Choice.Delta.ToolCall.Function function; + private final String type; + + public ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type) { + this.index = index; + this.id = id; + this.function = function; + this.type = type; + } + + public int getIndex() { + return index; + } + + public String getId() { + return id; + } + + public ChatCompletionChunk.Choice.Delta.ToolCall.Function getFunction() { + return function; + } + + public String getType() { + return type; + } + + /* + index: number; + id?: string; + function?: { + arguments?: string; + name?: string; + }; + type?: 'function'; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + var content = Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field(INDEX_FIELD, index), + ChunkedToXContentHelper.optionalField(ID_FIELD, id) + ); + + if (function != null) { + content = Iterators.concat( + content, + ChunkedToXContentHelper.startObject(FUNCTION_FIELD), + ChunkedToXContentHelper.optionalField(FUNCTION_ARGUMENTS_FIELD, function.getArguments()), + ChunkedToXContentHelper.optionalField(FUNCTION_NAME_FIELD, function.getName()), + ChunkedToXContentHelper.endObject() + ); + } + + content = Iterators.concat( + content, + ChunkedToXContentHelper.field(TYPE_FIELD, type), + ChunkedToXContentHelper.endObject() + ); + return content; + } + + public static class Function { + private final String arguments; + private final String name; + + public Function(String arguments, String name) { + this.arguments = arguments; + this.name = name; + } + + public String getArguments() { + return arguments; + } + + public String getName() { + return name; + } + } + } + } + } + + public record Usage(int completionTokens, int promptTokens, int totalTokens) {} + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java index 0645299dfc30e..8c4611f05e72a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java @@ -66,7 +66,7 @@ public MachineLearningFeatureSetUsage(StreamInput in) throws IOException { this.analyticsUsage = in.readGenericMap(); this.inferenceUsage = in.readGenericMap(); this.nodeCount = in.readInt(); - if (in.getTransportVersion().onOrAfter(TransportVersions.ML_TELEMETRY_MEMORY_ADDED)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.memoryUsage = in.readGenericMap(); } else { this.memoryUsage = Map.of(); @@ -86,7 +86,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeGenericMap(analyticsUsage); out.writeGenericMap(inferenceUsage); out.writeInt(nodeCount); - if (out.getTransportVersion().onOrAfter(TransportVersions.ML_TELEMETRY_MEMORY_ADDED)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeGenericMap(memoryUsage); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java index c6976ab4b513e..2aedb46347534 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java @@ -47,7 +47,7 @@ public Request(StartTrainedModelDeploymentAction.TaskParams taskParams, Adaptive public Request(StreamInput in) throws IOException { super(in); this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in); - if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); } else { this.adaptiveAllocationsSettings = null; @@ -63,7 +63,7 @@ public ActionRequestValidationException validate() { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); taskParams.writeTo(out); - if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalWriteable(adaptiveAllocationsSettings); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index b298d486c9e03..1bf92262b30fb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -169,7 +169,7 @@ public Request(StreamInput in) throws IOException { modelId = in.readString(); timeout = in.readTimeValue(); waitForState = in.readEnum(AllocationStatus.State.class); - if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { numberOfAllocations = in.readOptionalVInt(); } else { numberOfAllocations = in.readVInt(); @@ -189,7 +189,7 @@ public Request(StreamInput in) throws IOException { } else { this.deploymentId = modelId; } - if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); } else { this.adaptiveAllocationsSettings = null; @@ -297,7 +297,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); out.writeTimeValue(timeout); out.writeEnum(waitForState); - if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalVInt(numberOfAllocations); } else { out.writeVInt(numberOfAllocations); @@ -313,7 +313,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { out.writeString(deploymentId); } - if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalWriteable(adaptiveAllocationsSettings); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java index cb578fdb157de..2018c9526ec83 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java @@ -87,7 +87,7 @@ public Request(String deploymentId) { public Request(StreamInput in) throws IOException { super(in); deploymentId = in.readString(); - if (in.getTransportVersion().before(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (in.getTransportVersion().before(TransportVersions.V_8_16_0)) { numberOfAllocations = in.readVInt(); adaptiveAllocationsSettings = null; isInternal = false; @@ -134,7 +134,7 @@ public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(deploymentId); - if (out.getTransportVersion().before(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (out.getTransportVersion().before(TransportVersions.V_8_16_0)) { out.writeVInt(numberOfAllocations); } else { out.writeOptionalVInt(numberOfAllocations); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEvent.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEvent.java index b007c1da451f5..742daa1bf6137 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEvent.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/calendars/ScheduledEvent.java @@ -115,7 +115,7 @@ public ScheduledEvent(StreamInput in) throws IOException { description = in.readString(); startTime = in.readInstant(); endTime = in.readInstant(); - if (in.getTransportVersion().onOrAfter(TransportVersions.ML_SCHEDULED_EVENT_TIME_SHIFT_CONFIGURATION)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { skipResult = in.readBoolean(); skipModelUpdate = in.readBoolean(); forceTimeShift = in.readOptionalInt(); @@ -204,7 +204,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(description); out.writeInstant(startTime); out.writeInstant(endTime); - if (out.getTransportVersion().onOrAfter(TransportVersions.ML_SCHEDULED_EVENT_TIME_SHIFT_CONFIGURATION)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeBoolean(skipResult); out.writeBoolean(skipModelUpdate); out.writeOptionalInt(forceTimeShift); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java index 858d97bf6f956..31b513eea161e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java @@ -483,7 +483,7 @@ public AssignmentStats(StreamInput in) throws IOException { } else { deploymentId = modelId; } - if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); } else { adaptiveAllocationsSettings = null; @@ -666,7 +666,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { out.writeString(deploymentId); } - if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalWriteable(adaptiveAllocationsSettings); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index efd07cceae09b..249e27d6f25e0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -178,7 +178,7 @@ public TrainedModelAssignment(StreamInput in) throws IOException { } else { this.maxAssignedAllocations = totalCurrentAllocations(); } - if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); } else { this.adaptiveAllocationsSettings = null; @@ -382,7 +382,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_4_0)) { out.writeVInt(maxAssignedAllocations); } - if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalWriteable(adaptiveAllocationsSettings); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearningToRankConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearningToRankConfig.java index 9929e59a9c803..a4d7c9c7fa08f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearningToRankConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/LearningToRankConfig.java @@ -41,7 +41,6 @@ public class LearningToRankConfig extends RegressionConfig implements Rewriteable { public static final ParseField NAME = new ParseField("learning_to_rank"); - static final TransportVersion MIN_SUPPORTED_TRANSPORT_VERSION = TransportVersions.LTR_SERVERLESS_RELEASE; public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); public static final ParseField FEATURE_EXTRACTORS = new ParseField("feature_extractors"); public static final ParseField DEFAULT_PARAMS = new ParseField("default_params"); @@ -226,7 +225,7 @@ public MlConfigVersion getMinimalSupportedMlConfigVersion() { @Override public TransportVersion getMinimalSupportedTransportVersion() { - return MIN_SUPPORTED_TRANSPORT_VERSION; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/DetectionRule.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/DetectionRule.java index eb952a7dc7e5c..4bdced325311f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/DetectionRule.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/DetectionRule.java @@ -68,7 +68,7 @@ public DetectionRule(StreamInput in) throws IOException { actions = in.readEnumSet(RuleAction.class); scope = new RuleScope(in); conditions = in.readCollectionAsList(RuleCondition::new); - if (in.getTransportVersion().onOrAfter(TransportVersions.ML_ADD_DETECTION_RULE_PARAMS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { params = new RuleParams(in); } else { params = new RuleParams(); @@ -80,7 +80,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeEnumSet(actions); scope.writeTo(out); out.writeCollection(conditions); - if (out.getTransportVersion().onOrAfter(TransportVersions.ML_ADD_DETECTION_RULE_PARAMS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { params.writeTo(out); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityFeatureSetUsage.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityFeatureSetUsage.java index 2793ddea3bd06..33f1a9a469b69 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityFeatureSetUsage.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/SecurityFeatureSetUsage.java @@ -55,10 +55,8 @@ public SecurityFeatureSetUsage(StreamInput in) throws IOException { realmsUsage = in.readGenericMap(); rolesStoreUsage = in.readGenericMap(); sslUsage = in.readGenericMap(); - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_2_0)) { - tokenServiceUsage = in.readGenericMap(); - apiKeyServiceUsage = in.readGenericMap(); - } + tokenServiceUsage = in.readGenericMap(); + apiKeyServiceUsage = in.readGenericMap(); auditUsage = in.readGenericMap(); ipFilterUsage = in.readGenericMap(); anonymousUsage = in.readGenericMap(); @@ -125,10 +123,8 @@ public void writeTo(StreamOutput out) throws IOException { out.writeGenericMap(realmsUsage); out.writeGenericMap(rolesStoreUsage); out.writeGenericMap(sslUsage); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_2_0)) { - out.writeGenericMap(tokenServiceUsage); - out.writeGenericMap(apiKeyServiceUsage); - } + out.writeGenericMap(tokenServiceUsage); + out.writeGenericMap(apiKeyServiceUsage); out.writeGenericMap(auditUsage); out.writeGenericMap(ipFilterUsage); out.writeGenericMap(anonymousUsage); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java index 8fe018a825468..59c16fc8a7a72 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java @@ -59,9 +59,6 @@ public TokensInvalidationResult(StreamInput in) throws IOException { this.invalidatedTokens = in.readStringCollectionAsList(); this.previouslyInvalidatedTokens = in.readStringCollectionAsList(); this.errors = in.readCollectionAsList(StreamInput::readException); - if (in.getTransportVersion().before(TransportVersions.V_7_2_0)) { - in.readVInt(); - } if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_0_0)) { this.restStatus = RestStatus.readFrom(in); } @@ -111,9 +108,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeStringCollection(invalidatedTokens); out.writeStringCollection(previouslyInvalidatedTokens); out.writeCollection(errors, StreamOutput::writeException); - if (out.getTransportVersion().before(TransportVersions.V_7_2_0)) { - out.writeVInt(5); - } if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_0_0)) { RestStatus.writeTo(out, restStatus); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/DocumentPermissions.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/DocumentPermissions.java index 14ecf4cb0d6e9..24f0a52436203 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/DocumentPermissions.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/permission/DocumentPermissions.java @@ -160,10 +160,8 @@ private static void buildRoleQuery( failIfQueryUsesClient(queryBuilder, context); Query roleQuery = context.toQuery(queryBuilder).query(); filter.add(roleQuery, SHOULD); - NestedLookup nestedLookup = context.nestedLookup(); - if (nestedLookup != NestedLookup.EMPTY) { - NestedHelper nestedHelper = new NestedHelper(nestedLookup, context::isFieldMapped); - if (nestedHelper.mightMatchNestedDocs(roleQuery)) { + if (context.nestedLookup() != NestedLookup.EMPTY) { + if (NestedHelper.mightMatchNestedDocs(roleQuery, context)) { roleQuery = new BooleanQuery.Builder().add(roleQuery, FILTER) .add(Queries.newNonNestedFilter(context.indexVersionCreated()), FILTER) .build(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivileges.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivileges.java index b93aa079a28d2..148fdf21fd2df 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivileges.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ConfigurableClusterPrivileges.java @@ -82,7 +82,7 @@ public static ConfigurableClusterPrivilege[] readArray(StreamInput in) throws IO * Utility method to write an array of {@link ConfigurableClusterPrivilege} objects to a {@link StreamOutput} */ public static void writeArray(StreamOutput out, ConfigurableClusterPrivilege[] privileges) throws IOException { - if (out.getTransportVersion().onOrAfter(TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeArray(WRITER, privileges); } else { out.writeArray( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java index cc589b53eaa1a..5e19b26b8f4de 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/store/KibanaOwnedReservedRoleDescriptors.java @@ -331,6 +331,8 @@ static RoleDescriptor kibanaSystem(String name) { ".logs-endpoint.diagnostic.collection-*", "logs-apm-*", "logs-apm.*-*", + "logs-cloud_security_posture.findings-*", + "logs-cloud_security_posture.vulnerabilities-*", "metrics-apm-*", "metrics-apm.*-*", "traces-apm-*", diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/support/WatcherIndexTemplateRegistryField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/support/WatcherIndexTemplateRegistryField.java index 20dcb84dffe3f..098549029e0ce 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/support/WatcherIndexTemplateRegistryField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/watcher/support/WatcherIndexTemplateRegistryField.java @@ -22,8 +22,9 @@ public final class WatcherIndexTemplateRegistryField { // version 14: move watch history to data stream // version 15: remove watches and triggered watches, these are now system indices // version 16: change watch history ILM policy + // version 17: exclude input chain from indexing // Note: if you change this, also inform the kibana team around the watcher-ui - public static final int INDEX_TEMPLATE_VERSION = 16; + public static final int INDEX_TEMPLATE_VERSION = 17; public static final String HISTORY_TEMPLATE_NAME = ".watch-history-" + INDEX_TEMPLATE_VERSION; public static final String HISTORY_TEMPLATE_NAME_NO_ILM = ".watch-history-no-ilm-" + INDEX_TEMPLATE_VERSION; public static final String[] TEMPLATE_NAMES = new String[] { HISTORY_TEMPLATE_NAME }; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index a9ca5e6da8720..01c0ff88be222 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -41,8 +41,7 @@ protected InferenceAction.Request createTestInstance() { return new InferenceAction.Request( randomFrom(TaskType.values()), randomAlphaOfLength(6), - // null, - randomNullOrAlphaOfLength(10), + randomAlphaOfLengthOrNull(10), randomList(1, 5, () -> randomAlphaOfLength(8)), randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java new file mode 100644 index 0000000000000..1872ac3caa230 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class UnifiedCompletionActionRequestTests extends AbstractBWCWireSerializationTestCase { + + public void testValidation_ReturnsException_When_UnifiedCompletionRequestMessage_Is_Null() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.COMPLETION, + UnifiedCompletionRequest.of(null), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [messages] cannot be null;")); + } + + public void testValidation_ReturnsException_When_UnifiedCompletionRequest_Is_EmptyArray() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.COMPLETION, + UnifiedCompletionRequest.of(List.of()), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [messages] cannot be an empty array;")); + } + + public void testValidation_ReturnsException_When_TaskType_IsNot_Completion() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.SPARSE_EMBEDDING, + UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [completion];")); + } + + public void testValidation_ReturnsNull_When_TaskType_IsAny() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.ANY, + UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), + TimeValue.timeValueSeconds(10) + ); + assertNull(request.validate()); + } + + @Override + protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return UnifiedCompletionAction.Request::new; + } + + @Override + protected UnifiedCompletionAction.Request createTestInstance() { + return new UnifiedCompletionAction.Request( + randomAlphaOfLength(10), + randomFrom(TaskType.values()), + UnifiedCompletionRequestTests.randomUnifiedCompletionRequest(), + TimeValue.timeValueMillis(randomLongBetween(1, 2048)) + ); + } + + @Override + protected UnifiedCompletionAction.Request mutateInstance(UnifiedCompletionAction.Request instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java new file mode 100644 index 0000000000000..47a0814a584b7 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -0,0 +1,293 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class UnifiedCompletionRequestTests extends AbstractBWCWireSerializationTestCase { + + public void testParseAllFields() throws IOException { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "text": "some text", + "type": "string" + } + ], + "name": "a name", + "tool_call_id": "100", + "tool_calls": [ + { + "id": "call_62136354", + "type": "function", + "function": { + "arguments": "{'order_id': 'order_12345'}", + "name": "get_delivery_date" + } + } + ] + } + ], + "max_completion_tokens": 100, + "stop": ["stop"], + "temperature": 0.1, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + } + } + ], + "tool_choice": { + "type": "function", + "function": { + "name": "some function" + } + }, + "top_p": 0.2 + } + """; + + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = UnifiedCompletionRequest.PARSER.apply(parser, null); + var expected = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects( + List.of(new UnifiedCompletionRequest.ContentObject("some text", "string")) + ), + "user", + "a name", + "100", + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{'order_id': 'order_12345'}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gpt-4o", + 100L, + List.of("stop"), + 0.1F, + new UnifiedCompletionRequest.ToolChoiceObject( + "function", + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField("some function") + ), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ) + ), + 0.2F + ); + + assertThat(request, is(expected)); + } + } + + public void testParsing() throws IOException { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "What is the weather like in Boston today?" + } + ], + "stop": "none", + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + } + } + ], + "tool_choice": "auto" + } + """; + + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = UnifiedCompletionRequest.PARSER.apply(parser, null); + var expected = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("What is the weather like in Boston today?"), + "user", + null, + null, + null + ) + ), + "gpt-4o", + null, + List.of("none"), + null, + new UnifiedCompletionRequest.ToolChoiceString("auto"), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ) + ), + null + ); + + assertThat(request, is(expected)); + } + } + + public static UnifiedCompletionRequest randomUnifiedCompletionRequest() { + return new UnifiedCompletionRequest( + randomList(5, UnifiedCompletionRequestTests::randomMessage), + randomAlphaOfLengthOrNull(10), + randomPositiveLongOrNull(), + randomStopOrNull(), + randomFloatOrNull(), + randomToolChoiceOrNull(), + randomToolListOrNull(), + randomFloatOrNull() + ); + } + + public static UnifiedCompletionRequest.Message randomMessage() { + return new UnifiedCompletionRequest.Message( + randomContent(), + randomAlphaOfLength(10), + randomAlphaOfLengthOrNull(10), + randomAlphaOfLengthOrNull(10), + randomToolCallListOrNull() + ); + } + + public static UnifiedCompletionRequest.Content randomContent() { + return randomBoolean() + ? new UnifiedCompletionRequest.ContentString(randomAlphaOfLength(10)) + : new UnifiedCompletionRequest.ContentObjects(randomList(10, UnifiedCompletionRequestTests::randomContentObject)); + } + + public static UnifiedCompletionRequest.ContentObject randomContentObject() { + return new UnifiedCompletionRequest.ContentObject(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + + public static List randomToolCallListOrNull() { + return randomBoolean() ? randomList(10, UnifiedCompletionRequestTests::randomToolCall) : null; + } + + public static UnifiedCompletionRequest.ToolCall randomToolCall() { + return new UnifiedCompletionRequest.ToolCall(randomAlphaOfLength(10), randomToolCallFunctionField(), randomAlphaOfLength(10)); + } + + public static UnifiedCompletionRequest.ToolCall.FunctionField randomToolCallFunctionField() { + return new UnifiedCompletionRequest.ToolCall.FunctionField(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + + public static List randomStopOrNull() { + return randomBoolean() ? randomStop() : null; + } + + public static List randomStop() { + return randomList(5, () -> randomAlphaOfLength(10)); + } + + public static UnifiedCompletionRequest.ToolChoice randomToolChoiceOrNull() { + return randomBoolean() ? randomToolChoice() : null; + } + + public static UnifiedCompletionRequest.ToolChoice randomToolChoice() { + return randomBoolean() + ? new UnifiedCompletionRequest.ToolChoiceString(randomAlphaOfLength(10)) + : new UnifiedCompletionRequest.ToolChoiceObject(randomAlphaOfLength(10), randomToolChoiceObjectFunctionField()); + } + + public static UnifiedCompletionRequest.ToolChoiceObject.FunctionField randomToolChoiceObjectFunctionField() { + return new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomAlphaOfLength(10)); + } + + public static List randomToolListOrNull() { + return randomBoolean() ? randomList(10, UnifiedCompletionRequestTests::randomTool) : null; + } + + public static UnifiedCompletionRequest.Tool randomTool() { + return new UnifiedCompletionRequest.Tool(randomAlphaOfLength(10), randomToolFunctionField()); + } + + public static UnifiedCompletionRequest.Tool.FunctionField randomToolFunctionField() { + return new UnifiedCompletionRequest.Tool.FunctionField( + randomAlphaOfLengthOrNull(10), + randomAlphaOfLength(10), + null, + randomOptionalBoolean() + ); + } + + @Override + protected UnifiedCompletionRequest mutateInstanceForVersion(UnifiedCompletionRequest instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return UnifiedCompletionRequest::new; + } + + @Override + protected UnifiedCompletionRequest createTestInstance() { + return randomUnifiedCompletionRequest(); + } + + @Override + protected UnifiedCompletionRequest mutateInstance(UnifiedCompletionRequest instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java new file mode 100644 index 0000000000000..a8f569dbef9d1 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; + +public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase { + + public void testResults_toXContentChunked() throws IOException { + String expected = """ + { + "id": "chunk1", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + ], + "model": "example_model", + "object": "example_object", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15 + } + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + "chunk1", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + "example_content", + "example_refusal", + "assistant", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ) + ) + ), + "example_reason", + 0 + ) + ), + "example_model", + "example_object", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(10, 5, 15) + ); + + Deque deque = new ArrayDeque<>(); + deque.add(chunk); + StreamingUnifiedChatCompletionResults.Results results = new StreamingUnifiedChatCompletionResults.Results(deque); + XContentBuilder builder = JsonXContent.contentBuilder(); + results.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + + public void testChoiceToXContentChunked() throws IOException { + String expected = """ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + "example_content", + "example_refusal", + "assistant", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ) + ) + ), + "example_reason", + 0 + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + choice.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + + public void testToolCallToXContentChunked() throws IOException { + String expected = """ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + toolCall.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsageTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsageTests.java index 87d658c6f983c..e9ec8dfe8ee52 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsageTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsageTests.java @@ -57,7 +57,7 @@ protected MachineLearningFeatureSetUsage mutateInstance(MachineLearningFeatureSe @Override protected MachineLearningFeatureSetUsage mutateInstanceForVersion(MachineLearningFeatureSetUsage instance, TransportVersion version) { - if (version.before(TransportVersions.ML_TELEMETRY_MEMORY_ADDED)) { + if (version.before(TransportVersions.V_8_16_0)) { return new MachineLearningFeatureSetUsage( instance.available(), instance.enabled(), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java index 9872d95de024a..a5c1ba45d90b7 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilderTests.java @@ -232,6 +232,12 @@ public void testToQuery() throws IOException { private void testDoToQuery(SparseVectorQueryBuilder queryBuilder, SearchExecutionContext context) throws IOException { Query query = queryBuilder.doToQuery(context); + + // test query builder can randomly have no vectors, which rewrites to a MatchNoneQuery - nothing more to do in this case. + if (query instanceof MatchNoDocsQuery) { + return; + } + assertTrue(query instanceof SparseVectorQueryWrapper); var sparseQuery = (SparseVectorQueryWrapper) query; if (queryBuilder.shouldPruneTokens()) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java index 17579fd6368ce..b69b0ece89960 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java @@ -1586,10 +1586,8 @@ public void testKibanaSystemRole() { final IndexAbstraction indexAbstraction = mockIndexAbstraction(cspIndex); assertThat(kibanaRole.indices().allowedIndicesMatcher("indices:foo").test(indexAbstraction), is(false)); assertThat(kibanaRole.indices().allowedIndicesMatcher("indices:bar").test(indexAbstraction), is(false)); - assertThat( - kibanaRole.indices().allowedIndicesMatcher(TransportDeleteIndexAction.TYPE.name()).test(indexAbstraction), - is(false) - ); + // Ensure privileges necessary for ILM policies in Cloud Security Posture Package + assertThat(kibanaRole.indices().allowedIndicesMatcher(TransportDeleteIndexAction.TYPE.name()).test(indexAbstraction), is(true)); assertThat(kibanaRole.indices().allowedIndicesMatcher(GetIndexAction.NAME).test(indexAbstraction), is(true)); assertThat( kibanaRole.indices().allowedIndicesMatcher(TransportCreateIndexAction.TYPE.name()).test(indexAbstraction), @@ -1613,10 +1611,9 @@ public void testKibanaSystemRole() { final IndexAbstraction indexAbstraction = mockIndexAbstraction(cspIndex); assertThat(kibanaRole.indices().allowedIndicesMatcher("indices:foo").test(indexAbstraction), is(false)); assertThat(kibanaRole.indices().allowedIndicesMatcher("indices:bar").test(indexAbstraction), is(false)); - assertThat( - kibanaRole.indices().allowedIndicesMatcher(TransportDeleteIndexAction.TYPE.name()).test(indexAbstraction), - is(false) - ); + // Ensure privileges necessary for ILM policies in Cloud Security Posture Package + assertThat(kibanaRole.indices().allowedIndicesMatcher(TransportDeleteIndexAction.TYPE.name()).test(indexAbstraction), is(true)); + assertThat(kibanaRole.indices().allowedIndicesMatcher(TransportDeleteIndexAction.TYPE.name()).test(indexAbstraction), is(true)); assertThat(kibanaRole.indices().allowedIndicesMatcher(GetIndexAction.NAME).test(indexAbstraction), is(true)); assertThat( kibanaRole.indices().allowedIndicesMatcher(TransportCreateIndexAction.TYPE.name()).test(indexAbstraction), @@ -1710,6 +1707,7 @@ public void testKibanaSystemRole() { kibanaRole.indices().allowedIndicesMatcher("indices:monitor/" + randomAlphaOfLengthBetween(3, 8)).test(indexAbstraction), is(true) ); + }); // cloud_defend @@ -4175,6 +4173,7 @@ public void testInferenceUserRole() { assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication)); + assertTrue(role.cluster().check("cluster:monitor/xpack/inference/unified", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/ml/trained_models/deployment/infer", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/ml/trained_models/deployment/start", request, authentication)); diff --git a/x-pack/plugin/core/template-resources/src/main/resources/watch-history-no-ilm.json b/x-pack/plugin/core/template-resources/src/main/resources/watch-history-no-ilm.json index 2eed69c7c58e6..da459cda13463 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/watch-history-no-ilm.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/watch-history-no-ilm.json @@ -54,6 +54,15 @@ "enabled": false } } + }, + { + "disabled_result_input_chain_fields": { + "path_match": "result.input.chain", + "mapping": { + "type": "object", + "enabled": false + } + } } ], "dynamic": false, diff --git a/x-pack/plugin/core/template-resources/src/main/resources/watch-history.json b/x-pack/plugin/core/template-resources/src/main/resources/watch-history.json index 19e4dc022daa1..2abf6570d1f8e 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/watch-history.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/watch-history.json @@ -55,6 +55,15 @@ "enabled": false } } + }, + { + "disabled_result_input_chain_fields": { + "path_match": "result.input.chain", + "mapping": { + "type": "object", + "enabled": false + } + } } ], "dynamic": false, diff --git a/x-pack/plugin/deprecation/qa/rest/src/javaRestTest/java/org/elasticsearch/xpack/deprecation/DeprecationHttpIT.java b/x-pack/plugin/deprecation/qa/rest/src/javaRestTest/java/org/elasticsearch/xpack/deprecation/DeprecationHttpIT.java index 4a17c2abbd797..2136129a671c8 100644 --- a/x-pack/plugin/deprecation/qa/rest/src/javaRestTest/java/org/elasticsearch/xpack/deprecation/DeprecationHttpIT.java +++ b/x-pack/plugin/deprecation/qa/rest/src/javaRestTest/java/org/elasticsearch/xpack/deprecation/DeprecationHttpIT.java @@ -339,12 +339,12 @@ public void testDeprecationMessagesCanBeIndexed() throws Exception { hasEntry("elasticsearch.event.category", "settings"), hasKey("elasticsearch.node.id"), hasKey("elasticsearch.node.name"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasKey("ecs.version"), hasEntry(KEY_FIELD_NAME, "deprecated_settings"), - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "WARN"), hasKey("log.logger"), hasEntry("message", "[deprecated_settings] usage is deprecated. use [settings] instead") @@ -357,12 +357,12 @@ public void testDeprecationMessagesCanBeIndexed() throws Exception { hasEntry("elasticsearch.event.category", "api"), hasKey("elasticsearch.node.id"), hasKey("elasticsearch.node.name"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasKey("ecs.version"), hasEntry(KEY_FIELD_NAME, "deprecated_route_GET_/_test_cluster/deprecated_settings"), - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "WARN"), hasKey("log.logger"), hasEntry("message", "[/_test_cluster/deprecated_settings] exists for deprecated tests") @@ -402,12 +402,12 @@ public void testDeprecationCriticalWarnMessagesCanBeIndexed() throws Exception { hasEntry("elasticsearch.event.category", "settings"), hasKey("elasticsearch.node.id"), hasKey("elasticsearch.node.name"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasKey("ecs.version"), hasEntry(KEY_FIELD_NAME, "deprecated_critical_settings"), - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "CRITICAL"), hasKey("log.logger"), hasEntry("message", "[deprecated_settings] usage is deprecated. use [settings] instead") @@ -443,12 +443,12 @@ public void testDeprecationWarnMessagesCanBeIndexed() throws Exception { hasEntry("elasticsearch.event.category", "settings"), hasKey("elasticsearch.node.id"), hasKey("elasticsearch.node.name"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasKey("ecs.version"), hasEntry(KEY_FIELD_NAME, "deprecated_warn_settings"), - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "WARN"), hasKey("log.logger"), hasEntry("message", "[deprecated_warn_settings] usage is deprecated but won't be breaking in next version") @@ -461,12 +461,12 @@ public void testDeprecationWarnMessagesCanBeIndexed() throws Exception { hasEntry("elasticsearch.event.category", "api"), hasKey("elasticsearch.node.id"), hasKey("elasticsearch.node.name"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasKey("ecs.version"), hasEntry(KEY_FIELD_NAME, "deprecated_route_GET_/_test_cluster/deprecated_settings"), - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "WARN"), hasKey("log.logger"), hasEntry("message", "[/_test_cluster/deprecated_settings] exists for deprecated tests") @@ -619,12 +619,12 @@ public void testCompatibleMessagesCanBeIndexed() throws Exception { hasEntry("elasticsearch.event.category", "compatible_api"), hasKey("elasticsearch.node.id"), hasKey("elasticsearch.node.name"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasKey("ecs.version"), hasEntry(KEY_FIELD_NAME, "compatible_key"), - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "CRITICAL"), hasKey("log.logger"), hasEntry("message", "You are using a compatible API for this request") @@ -637,12 +637,12 @@ public void testCompatibleMessagesCanBeIndexed() throws Exception { hasEntry("elasticsearch.event.category", "compatible_api"), hasKey("elasticsearch.node.id"), hasKey("elasticsearch.node.name"), - hasEntry("data_stream.dataset", "deprecation.elasticsearch"), + hasEntry("data_stream.dataset", "elasticsearch.deprecation"), hasEntry("data_stream.namespace", "default"), hasEntry("data_stream.type", "logs"), hasKey("ecs.version"), hasEntry(KEY_FIELD_NAME, "deprecated_route_GET_/_test_cluster/compat_only"), - hasEntry("event.dataset", "deprecation.elasticsearch"), + hasEntry("event.dataset", "elasticsearch.deprecation"), hasEntry("log.level", "CRITICAL"), hasKey("log.logger"), hasEntry("message", "[/_test_cluster/deprecated_settings] exists for deprecated tests") diff --git a/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/logging/DeprecationIndexingComponent.java b/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/logging/DeprecationIndexingComponent.java index 29041b0c58434..507f4b18c79e9 100644 --- a/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/logging/DeprecationIndexingComponent.java +++ b/x-pack/plugin/deprecation/src/main/java/org/elasticsearch/xpack/deprecation/logging/DeprecationIndexingComponent.java @@ -91,7 +91,7 @@ private DeprecationIndexingComponent( final Configuration configuration = context.getConfiguration(); final EcsLayout ecsLayout = ECSJsonLayout.newBuilder() - .setDataset("deprecation.elasticsearch") + .setDataset("elasticsearch.deprecation") .setConfiguration(configuration) .build(); diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/QueryRulesetListItem.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/QueryRulesetListItem.java index 3a61c848d3813..d694b2681ee88 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/QueryRulesetListItem.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/QueryRulesetListItem.java @@ -68,8 +68,7 @@ public QueryRulesetListItem(StreamInput in) throws IOException { this.criteriaTypeToCountMap = Map.of(); } TransportVersion streamTransportVersion = in.getTransportVersion(); - if (streamTransportVersion.isPatchFrom(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_15) - || streamTransportVersion.isPatchFrom(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_16) + if (streamTransportVersion.isPatchFrom(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_16) || streamTransportVersion.onOrAfter(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES)) { this.ruleTypeToCountMap = in.readMap(m -> in.readEnum(QueryRule.QueryRuleType.class), StreamInput::readInt); } else { @@ -104,8 +103,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeMap(criteriaTypeToCountMap, StreamOutput::writeEnum, StreamOutput::writeInt); } TransportVersion streamTransportVersion = out.getTransportVersion(); - if (streamTransportVersion.isPatchFrom(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_15) - || streamTransportVersion.isPatchFrom(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_16) + if (streamTransportVersion.isPatchFrom(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_16) || streamTransportVersion.onOrAfter(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES)) { out.writeMap(ruleTypeToCountMap, StreamOutput::writeEnum, StreamOutput::writeInt); } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java index 54a89d061de35..5b27cc7a3e05a 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java @@ -110,12 +110,14 @@ public QueryRuleRetrieverBuilder( Map matchCriteria, List retrieverSource, int rankWindowSize, - String retrieverName + String retrieverName, + List preFilterQueryBuilders ) { super(retrieverSource, rankWindowSize); this.rulesetIds = rulesetIds; this.matchCriteria = matchCriteria; this.retrieverName = retrieverName; + this.preFilterQueryBuilders = preFilterQueryBuilders; } @Override @@ -156,8 +158,15 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept } @Override - protected QueryRuleRetrieverBuilder clone(List newChildRetrievers) { - return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, newChildRetrievers, rankWindowSize, retrieverName); + protected QueryRuleRetrieverBuilder clone(List newChildRetrievers, List newPreFilterQueryBuilders) { + return new QueryRuleRetrieverBuilder( + rulesetIds, + matchCriteria, + newChildRetrievers, + rankWindowSize, + retrieverName, + newPreFilterQueryBuilders + ); } @Override diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/ListQueryRulesetsActionResponseBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/ListQueryRulesetsActionResponseBWCSerializingTests.java index 27d5e240534b2..c822dd123d3f8 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/ListQueryRulesetsActionResponseBWCSerializingTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/ListQueryRulesetsActionResponseBWCSerializingTests.java @@ -59,8 +59,7 @@ protected ListQueryRulesetsAction.Response mutateInstanceForVersion( ListQueryRulesetsAction.Response instance, TransportVersion version ) { - if (version.isPatchFrom(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_15) - || version.isPatchFrom(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_16) + if (version.isPatchFrom(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES_BACKPORT_8_16) || version.onOrAfter(TransportVersions.QUERY_RULES_LIST_INCLUDES_TYPES)) { return instance; } else if (version.onOrAfter(QueryRulesetListItem.EXPANDED_RULESET_COUNT_TRANSPORT_VERSION)) { diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/TestQueryRulesetActionRequestBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/TestQueryRulesetActionRequestBWCSerializingTests.java index 7041de1106b50..8582ee1bd8d24 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/TestQueryRulesetActionRequestBWCSerializingTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/TestQueryRulesetActionRequestBWCSerializingTests.java @@ -51,6 +51,6 @@ protected TestQueryRulesetAction.Request mutateInstanceForVersion(TestQueryRules @Override protected List bwcVersions() { - return getAllBWCVersions().stream().filter(v -> v.onOrAfter(TransportVersions.QUERY_RULE_TEST_API)).collect(Collectors.toList()); + return getAllBWCVersions().stream().filter(v -> v.onOrAfter(TransportVersions.V_8_16_0)).collect(Collectors.toList()); } } diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/TestQueryRulesetActionResponseBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/TestQueryRulesetActionResponseBWCSerializingTests.java index a6562fb7b52af..142310ac40332 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/TestQueryRulesetActionResponseBWCSerializingTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/rules/action/TestQueryRulesetActionResponseBWCSerializingTests.java @@ -47,6 +47,6 @@ protected TestQueryRulesetAction.Response mutateInstanceForVersion(TestQueryRule @Override protected List bwcVersions() { - return getAllBWCVersions().stream().filter(v -> v.onOrAfter(TransportVersions.QUERY_RULE_TEST_API)).collect(Collectors.toList()); + return getAllBWCVersions().stream().filter(v -> v.onOrAfter(TransportVersions.V_8_16_0)).collect(Collectors.toList()); } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/async/AsyncTaskManagementService.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/async/AsyncTaskManagementService.java index 94bac95b91501..91fdb9c39b6e3 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/async/AsyncTaskManagementService.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/async/AsyncTaskManagementService.java @@ -208,7 +208,7 @@ private ActionListener wrapStoringListener( ActionListener listener ) { AtomicReference> exclusiveListener = new AtomicReference<>(listener); - // This is will performed in case of timeout + // This will be performed in case of timeout Scheduler.ScheduledCancellable timeoutHandler = threadPool.schedule(() -> { ActionListener acquiredListener = exclusiveListener.getAndSet(null); if (acquiredListener != null) { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java index 8baffbf887e47..4e4338aad3704 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java @@ -132,8 +132,16 @@ public static String name(Expression e) { return e instanceof NamedExpression ne ? ne.name() : e.sourceText(); } - public static boolean isNull(Expression e) { - return e.dataType() == DataType.NULL || (e.foldable() && e.fold() == null); + /** + * Is this {@linkplain Expression} guaranteed to have + * only the {@code null} value. {@linkplain Expression}s that + * {@link Expression#fold()} to {@code null} may + * return {@code false} here, but should eventually be folded + * into a {@link Literal} containing {@code null} which will return + * {@code true} from here. + */ + public static boolean isGuaranteedNull(Expression e) { + return e.dataType() == DataType.NULL || (e instanceof Literal lit && lit.value() == null); } public static List names(Collection e) { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/planner/ExpressionTranslators.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/planner/ExpressionTranslators.java index 7836522c77130..468d076c1b7ef 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/planner/ExpressionTranslators.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/planner/ExpressionTranslators.java @@ -107,9 +107,7 @@ protected Query asQuery(Not not, TranslatorHandler handler) { } public static Query doTranslate(Not not, TranslatorHandler handler) { - Query wrappedQuery = handler.asQuery(not.field()); - Query q = wrappedQuery.negate(not.source()); - return q; + return handler.asQuery(not.field()).negate(not.source()); } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java index a63571093ba58..d86cdb0de038c 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/DataType.java @@ -32,6 +32,113 @@ import static org.elasticsearch.xpack.esql.core.util.PlanStreamInput.readCachedStringWithVersionCheck; import static org.elasticsearch.xpack.esql.core.util.PlanStreamOutput.writeCachedStringWithVersionCheck; +/** + * This enum represents data types the ES|QL query processing layer is able to + * interact with in some way. This includes fully representable types (e.g. + * {@link DataType#LONG}, numeric types which we promote (e.g. {@link DataType#SHORT}) + * or fold into other types (e.g. {@link DataType#DATE_PERIOD}) early in the + * processing pipeline, types for internal use + * cases (e.g. {@link DataType#PARTIAL_AGG}), and types which the language + * doesn't support, but require special handling anyway (e.g. + * {@link DataType#OBJECT}) + * + *

Process for adding a new data type

+ * Note: it is not expected that all the following steps be done in a single PR. + * Use capabilities to gate tests as you go, and use as many PRs as you think + * appropriate. New data types are complex, and smaller PRs will make reviews + * easier. + *
    + *
  • + * Create a new feature flag for the type in {@link EsqlCorePlugin}. We + * recommend developing the data type over a series of smaller PRs behind + * a feature flag; even for relatively simple data types.
  • + *
  • + * Add a capability to EsqlCapabilities related to the new type, and + * gated by the feature flag you just created. Again, using the feature + * flag is preferred over snapshot-only. As development progresses, you may + * need to add more capabilities related to the new type, e.g. for + * supporting specific functions. This is fine, and expected.
  • + *
  • + * Create a new CSV test file for the new type. You'll either need to + * create a new data file as well, or add values of the new type to + * and existing data file. See CsvTestDataLoader for creating a new data + * set.
  • + *
  • + * In the new CSV test file, start adding basic functionality tests. + * These should include reading and returning values, both from indexed data + * and from the ROW command. It should also include functions that support + * "every" type, such as Case or MvFirst.
  • + *
  • + * Add the new type to the CsvTestUtils#Type enum, if it isn't already + * there. You also need to modify CsvAssert to support reading values + * of the new type.
  • + *
  • + * At this point, the CSV tests should fail with a sensible ES|QL error + * message. Make sure they're failing in ES|QL, not in the test + * framework.
  • + *
  • + * Add the new data type to this enum. This will cause a bunch of + * compile errors for switch statements throughout the code. Resolve those + * as appropriate. That is the main way in which the new type will be tied + * into the framework.
  • + *
  • + * Add the new type to the {@link DataType#UNDER_CONSTRUCTION} + * collection. This is used by the test framework to disable some checks + * around how functions report their supported types, which would otherwise + * generate a lot of noise while the type is still in development.
  • + *
  • + * Add typed data generators to TestCaseSupplier, and make sure all + * functions that support the new type have tests for it.
  • + *
  • + * Work to support things all types should do. Equality and the + * "typeless" MV functions (MvFirst, MvLast, and MvCount) should work for + * most types. Case and Coalesce should also support all types. + * If the type has a natural ordering, make sure to test + * sorting and the other binary comparisons. Make sure these functions all + * have CSV tests that run against indexed data.
  • + *
  • + * Add conversion functions as appropriate. Almost all types should + * support ToString, and should have a "ToType" function that accepts a + * string. There may be other logical conversions depending on the nature + * of the type. Make sure to add the conversion function to the + * TYPE_TO_CONVERSION_FUNCTION map in EsqlDataTypeConverter. Make sure the + * conversion functions have CSV tests that run against indexed data.
  • + *
  • + * Support the new type in aggregations that are type independent. + * This includes Values, Count, and Count Distinct. Make sure there are + * CSV tests against indexed data for these.
  • + *
  • + * Support other functions and aggregations as appropriate, making sure + * to included CSV tests.
  • + *
  • + * Consider how the type will interact with other types. For example, + * if the new type is numeric, it may be good for it to be comparable with + * other numbers. Supporting this may require new logic in + * EsqlDataTypeConverter#commonType, individual function type checking, the + * verifier rules, or other places. We suggest starting with CSV tests and + * seeing where they fail.
  • + *
+ * There are some additional steps that should be taken when removing the + * feature flag and getting ready for a release: + *
    + *
  • + * Ensure the capabilities for this type are always enabled + *
  • + *
  • + * Remove the type from the {@link DataType#UNDER_CONSTRUCTION} + * collection
  • + *
  • + * Fix new test failures related to declared function types + *
  • + *
  • + * Make sure to run the full test suite locally via gradle to generate + * the function type tables and helper files with the new type. Ensure all + * the functions that support the type have appropriate docs for it.
  • + *
  • + * If appropriate, remove the type from the ESQL limitations list of + * unsupported types.
  • + *
+ */ public enum DataType { /** * Fields of this type are unsupported by any functions and are always diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java index 47dadcbb11de2..73e2d5ec626ac 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/type/EsField.java @@ -72,7 +72,7 @@ public EsField(StreamInput in) throws IOException { private DataType readDataType(StreamInput in) throws IOException { String name = readCachedStringWithVersionCheck(in); - if (in.getTransportVersion().before(TransportVersions.ESQL_NESTED_UNSUPPORTED) && name.equalsIgnoreCase("NESTED")) { + if (in.getTransportVersion().before(TransportVersions.V_8_16_0) && name.equalsIgnoreCase("NESTED")) { /* * The "nested" data type existed in older versions of ESQL but was * entirely used to filter mappings away. Those versions will still diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java index 8bfcf4ca5c405..ce0540687121f 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/CollectionUtils.java @@ -30,12 +30,8 @@ public static List combine(List left, List righ } List list = new ArrayList<>(left.size() + right.size()); - if (left.isEmpty() == false) { - list.addAll(left); - } - if (right.isEmpty() == false) { - list.addAll(right); - } + list.addAll(left); + list.addAll(right); return list; } @@ -73,13 +69,6 @@ public static List combine(Collection left, T... entries) { return list; } - public static int mapSize(int size) { - if (size < 2) { - return size + 1; - } - return (int) (size / 0.75f + 1f); - } - @SafeVarargs @SuppressWarnings("varargs") public static List nullSafeList(T... entries) { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java index e8ccae3429001..b570a50535a59 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamInput.java @@ -52,7 +52,7 @@ public interface PlanStreamInput { String readCachedString() throws IOException; static String readCachedStringWithVersionCheck(StreamInput planStreamInput) throws IOException { - if (planStreamInput.getTransportVersion().before(TransportVersions.ESQL_CACHED_STRING_SERIALIZATION)) { + if (planStreamInput.getTransportVersion().before(TransportVersions.V_8_16_0)) { return planStreamInput.readString(); } return ((PlanStreamInput) planStreamInput).readCachedString(); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java index fb4af33d2fd60..a5afcb5fa29a6 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/PlanStreamOutput.java @@ -37,7 +37,7 @@ public interface PlanStreamOutput { void writeCachedString(String field) throws IOException; static void writeCachedStringWithVersionCheck(StreamOutput planStreamOutput, String string) throws IOException { - if (planStreamOutput.getTransportVersion().before(TransportVersions.ESQL_CACHED_STRING_SERIALIZATION)) { + if (planStreamOutput.getTransportVersion().before(TransportVersions.V_8_16_0)) { planStreamOutput.writeString(string); } else { ((PlanStreamOutput) planStreamOutput).writeCachedString(string); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/EsqlRefCountingListener.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/EsqlRefCountingListener.java new file mode 100644 index 0000000000000..69df0fb8ceff1 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/EsqlRefCountingListener.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.compute.operator.FailureCollector; +import org.elasticsearch.core.Releasable; + +/** + * Similar to {@link org.elasticsearch.action.support.RefCountingListener}, + * but prefers non-task-cancelled exceptions over task-cancelled ones as they are more useful for diagnosing issues. + * @see FailureCollector + */ +public final class EsqlRefCountingListener implements Releasable { + private final FailureCollector failureCollector; + private final RefCountingRunnable refs; + + public EsqlRefCountingListener(ActionListener delegate) { + this.failureCollector = new FailureCollector(); + this.refs = new RefCountingRunnable(() -> { + Exception error = failureCollector.getFailure(); + if (error != null) { + delegate.onFailure(error); + } else { + delegate.onResponse(null); + } + }); + } + + public ActionListener acquire() { + return refs.acquireListener().delegateResponse((l, e) -> { + failureCollector.unwrapAndCollect(e); + l.onFailure(e); + }); + } + + @Override + public void close() { + refs.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java index 9338077a55570..f57f450c7ee39 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AggregationOperator.java @@ -219,7 +219,7 @@ public Status(long aggregationNanos, long aggregationFinishNanos, int pagesProce protected Status(StreamInput in) throws IOException { aggregationNanos = in.readVLong(); - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_AGGREGATION_OPERATOR_STATUS_FINISH_NANOS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { aggregationFinishNanos = in.readOptionalVLong(); } else { aggregationFinishNanos = null; @@ -230,7 +230,7 @@ protected Status(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { out.writeVLong(aggregationNanos); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_AGGREGATION_OPERATOR_STATUS_FINISH_NANOS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalVLong(aggregationFinishNanos); } out.writeVInt(pagesProcessed); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverProfile.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverProfile.java index d98613f1817ab..c071b5055df76 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverProfile.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverProfile.java @@ -79,7 +79,7 @@ public DriverProfile( } public DriverProfile(StreamInput in) throws IOException { - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PROFILE_SLEEPS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.startMillis = in.readVLong(); this.stopMillis = in.readVLong(); } else { @@ -101,7 +101,7 @@ public DriverProfile(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PROFILE_SLEEPS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVLong(startMillis); out.writeVLong(stopMillis); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverSleeps.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverSleeps.java index 01e9a73c4fb5f..d8856ebedb80b 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverSleeps.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverSleeps.java @@ -76,7 +76,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws static final int RECORDS = 10; public static DriverSleeps read(StreamInput in) throws IOException { - if (in.getTransportVersion().before(TransportVersions.ESQL_PROFILE_SLEEPS)) { + if (in.getTransportVersion().before(TransportVersions.V_8_16_0)) { return empty(); } return new DriverSleeps( @@ -88,7 +88,7 @@ public static DriverSleeps read(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { - if (out.getTransportVersion().before(TransportVersions.ESQL_PROFILE_SLEEPS)) { + if (out.getTransportVersion().before(TransportVersions.V_8_16_0)) { return; } out.writeMap(counts, StreamOutput::writeVLong); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java index 943ba4dc1f4fa..337075edbdcf6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java @@ -13,9 +13,8 @@ import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.transport.TransportException; -import java.util.List; import java.util.Queue; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.Semaphore; /** * {@code FailureCollector} is responsible for collecting exceptions that occur in the compute engine. @@ -26,12 +25,11 @@ */ public final class FailureCollector { private final Queue cancelledExceptions = ConcurrentCollections.newQueue(); - private final AtomicInteger cancelledExceptionsCount = new AtomicInteger(); + private final Semaphore cancelledExceptionsPermits; private final Queue nonCancelledExceptions = ConcurrentCollections.newQueue(); - private final AtomicInteger nonCancelledExceptionsCount = new AtomicInteger(); + private final Semaphore nonCancelledExceptionsPermits; - private final int maxExceptions; private volatile boolean hasFailure = false; private Exception finalFailure = null; @@ -43,7 +41,8 @@ public FailureCollector(int maxExceptions) { if (maxExceptions <= 0) { throw new IllegalArgumentException("maxExceptions must be at least one"); } - this.maxExceptions = maxExceptions; + this.cancelledExceptionsPermits = new Semaphore(maxExceptions); + this.nonCancelledExceptionsPermits = new Semaphore(maxExceptions); } private static Exception unwrapTransportException(TransportException te) { @@ -60,13 +59,12 @@ private static Exception unwrapTransportException(TransportException te) { public void unwrapAndCollect(Exception e) { e = e instanceof TransportException te ? unwrapTransportException(te) : e; if (ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null) { - if (cancelledExceptionsCount.incrementAndGet() <= maxExceptions) { + if (nonCancelledExceptions.isEmpty() && cancelledExceptionsPermits.tryAcquire()) { cancelledExceptions.add(e); } - } else { - if (nonCancelledExceptionsCount.incrementAndGet() <= maxExceptions) { - nonCancelledExceptions.add(e); - } + } else if (nonCancelledExceptionsPermits.tryAcquire()) { + nonCancelledExceptions.add(e); + cancelledExceptions.clear(); } hasFailure = true; } @@ -99,20 +97,22 @@ public Exception getFailure() { private Exception buildFailure() { assert hasFailure; assert Thread.holdsLock(this); - int total = 0; Exception first = null; - for (var exceptions : List.of(nonCancelledExceptions, cancelledExceptions)) { - for (Exception e : exceptions) { - if (first == null) { - first = e; - total++; - } else if (first != e) { - first.addSuppressed(e); - total++; - } - if (total >= maxExceptions) { - return first; - } + for (Exception e : nonCancelledExceptions) { + if (first == null) { + first = e; + } else if (first != e) { + first.addSuppressed(e); + } + } + if (first != null) { + return first; + } + for (Exception e : cancelledExceptions) { + if (first == null) { + first = e; + } else if (first != e) { + first.addSuppressed(e); } } assert first != null; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java index 00c68c4f48e86..62cc4daf5fde5 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.support.ChannelActionListener; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -23,6 +24,7 @@ import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockStreamInput; +import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.CancellableTask; @@ -40,10 +42,11 @@ import java.io.IOException; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; /** * {@link ExchangeService} is responsible for exchanging pages between exchange sinks and sources on the same or different nodes. @@ -293,7 +296,7 @@ static final class TransportRemoteSink implements RemoteSink { final Executor responseExecutor; final AtomicLong estimatedPageSizeInBytes = new AtomicLong(0L); - final AtomicBoolean finished = new AtomicBoolean(false); + final AtomicReference> completionListenerRef = new AtomicReference<>(null); TransportRemoteSink( TransportService transportService, @@ -318,13 +321,14 @@ public void fetchPageAsync(boolean allSourcesFinished, ActionListener completionListener = completionListenerRef.get(); + if (completionListener != null) { + completionListener.addListener(listener.map(unused -> new ExchangeResponse(blockFactory, null, true))); return; } doFetchPageAsync(false, ActionListener.wrap(r -> { if (r.finished()) { - finished.set(true); + completionListenerRef.compareAndSet(null, SubscribableListener.newSucceeded(null)); } listener.onResponse(r); }, e -> close(ActionListener.running(() -> listener.onFailure(e))))); @@ -356,10 +360,19 @@ private void doFetchPageAsync(boolean allSourcesFinished, ActionListener listener) { - if (finished.compareAndSet(false, true)) { - doFetchPageAsync(true, listener.delegateFailure((l, unused) -> l.onResponse(null))); - } else { - listener.onResponse(null); + final SubscribableListener candidate = new SubscribableListener<>(); + final SubscribableListener actual = completionListenerRef.updateAndGet( + curr -> Objects.requireNonNullElse(curr, candidate) + ); + actual.addListener(listener); + if (candidate == actual) { + doFetchPageAsync(true, ActionListener.wrap(r -> { + final Page page = r.takePage(); + if (page != null) { + page.releaseBlocks(); + } + candidate.onResponse(null); + }, e -> candidate.onResponse(null))); } } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java index 375016a5d51d5..aa722695b841e 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java @@ -9,15 +9,18 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.FailureCollector; import org.elasticsearch.compute.operator.IsBlockedResult; import org.elasticsearch.core.Releasable; import java.util.List; +import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicInteger; @@ -40,6 +43,9 @@ public final class ExchangeSourceHandler { // The final failure collected will be notified to callers via the {@code completionListener}. private final FailureCollector failure = new FailureCollector(); + private final AtomicInteger nextSinkId = new AtomicInteger(); + private final Map remoteSinks = ConcurrentCollections.newConcurrentMap(); + /** * Creates a new ExchangeSourceHandler. * @@ -52,22 +58,25 @@ public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor, ActionLi this.buffer = new ExchangeBuffer(maxBufferSize); this.fetchExecutor = fetchExecutor; this.outstandingSinks = new PendingInstances(() -> buffer.finish(false)); - this.outstandingSources = new PendingInstances(() -> buffer.finish(true)); + final PendingInstances closingSinks = new PendingInstances(() -> {}); + closingSinks.trackNewInstance(); + this.outstandingSources = new PendingInstances(() -> finishEarly(true, ActionListener.running(closingSinks::finishInstance))); buffer.addCompletionListener(ActionListener.running(() -> { - final ActionListener listener = ActionListener.assertAtLeastOnce(completionListener).delegateFailure((l, unused) -> { + final ActionListener listener = ActionListener.assertAtLeastOnce(completionListener); + try (RefCountingRunnable refs = new RefCountingRunnable(() -> { final Exception e = failure.getFailure(); if (e != null) { - l.onFailure(e); + listener.onFailure(e); } else { - l.onResponse(null); + listener.onResponse(null); } - }); - try (RefCountingListener refs = new RefCountingListener(listener)) { + })) { + closingSinks.completion.addListener(refs.acquireListener()); for (PendingInstances pending : List.of(outstandingSinks, outstandingSources)) { // Create an outstanding instance and then finish to complete the completionListener // if we haven't registered any instances of exchange sinks or exchange sources before. pending.trackNewInstance(); - pending.completion.addListener(refs.acquire()); + pending.completion.addListener(refs.acquireListener()); pending.finishInstance(); } } @@ -256,7 +265,11 @@ void onSinkComplete() { * @see ExchangeSinkHandler#fetchPageAsync(boolean, ActionListener) */ public void addRemoteSink(RemoteSink remoteSink, boolean failFast, int instances, ActionListener listener) { - final ActionListener sinkListener = ActionListener.assertAtLeastOnce(ActionListener.notifyOnce(listener)); + final int sinkId = nextSinkId.incrementAndGet(); + remoteSinks.put(sinkId, remoteSink); + final ActionListener sinkListener = ActionListener.assertAtLeastOnce( + ActionListener.notifyOnce(ActionListener.runBefore(listener, () -> remoteSinks.remove(sinkId))) + ); fetchExecutor.execute(new AbstractRunnable() { @Override public void onFailure(Exception e) { @@ -269,7 +282,7 @@ public void onFailure(Exception e) { @Override protected void doRun() { - try (RefCountingListener refs = new RefCountingListener(sinkListener)) { + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(sinkListener)) { for (int i = 0; i < instances; i++) { var fetcher = new RemoteSinkFetcher(remoteSink, failFast, refs.acquire()); fetcher.fetchPage(); @@ -290,6 +303,22 @@ public Releasable addEmptySink() { return outstandingSinks::finishInstance; } + /** + * Gracefully terminates the exchange source early by instructing all remote exchange sinks to stop their computations. + * This can happen when the exchange source has accumulated enough data (e.g., reaching the LIMIT) or when users want to + * see the current result immediately. + * + * @param drainingPages whether to discard pages already fetched in the exchange + */ + public void finishEarly(boolean drainingPages, ActionListener listener) { + buffer.finish(drainingPages); + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(listener)) { + for (RemoteSink remoteSink : remoteSinks.values()) { + remoteSink.close(refs.acquire()); + } + } + } + private static class PendingInstances { private final AtomicInteger instances = new AtomicInteger(); private final SubscribableListener completion = new SubscribableListener<>(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java index aaa937ef17c0e..63b5d324ce851 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java @@ -8,6 +8,7 @@ package org.elasticsearch.compute.operator.exchange; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.compute.data.Page; public interface RemoteSink { @@ -15,11 +16,11 @@ public interface RemoteSink { default void close(ActionListener listener) { fetchPageAsync(true, listener.delegateFailure((l, r) -> { - try { - r.close(); - } finally { - l.onResponse(null); + final Page page = r.takePage(); + if (page != null) { + page.releaseBlocks(); } + l.onResponse(null); })); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java index 637cbe8892b3e..5fec82b32ddac 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.operator; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.breaker.CircuitBreaker; @@ -86,6 +87,14 @@ public void testCollect() throws Exception { assertNotNull(failure); assertThat(failure, Matchers.in(nonCancelledExceptions)); assertThat(failure.getSuppressed().length, lessThan(maxExceptions)); + assertTrue( + "cancellation exceptions must be ignored", + ExceptionsHelper.unwrapCausesAndSuppressed(failure, t -> t instanceof TaskCancelledException).isEmpty() + ); + assertTrue( + "remote transport exception must be unwrapped", + ExceptionsHelper.unwrapCausesAndSuppressed(failure, t -> t instanceof TransportException).isEmpty() + ); } public void testEmpty() { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java index fc6c850ba187b..8f7532b582bc2 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java @@ -55,7 +55,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Queue; import java.util.Set; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -421,7 +423,7 @@ public void testExchangeSourceContinueOnFailure() { } } - public void testEarlyTerminate() { + public void testClosingSinks() { BlockFactory blockFactory = blockFactory(); IntBlock block1 = blockFactory.newConstantIntBlockWith(1, 2); IntBlock block2 = blockFactory.newConstantIntBlockWith(1, 2); @@ -441,6 +443,57 @@ public void testEarlyTerminate() { assertTrue(sink.isFinished()); } + public void testFinishEarly() throws Exception { + ExchangeSourceHandler sourceHandler = new ExchangeSourceHandler(20, threadPool.generic(), ActionListener.noop()); + Semaphore permits = new Semaphore(between(1, 5)); + BlockFactory blockFactory = blockFactory(); + Queue pages = ConcurrentCollections.newQueue(); + ExchangeSource exchangeSource = sourceHandler.createExchangeSource(); + AtomicBoolean sinkClosed = new AtomicBoolean(); + PlainActionFuture sinkCompleted = new PlainActionFuture<>(); + sourceHandler.addRemoteSink((allSourcesFinished, listener) -> { + if (allSourcesFinished) { + sinkClosed.set(true); + permits.release(10); + listener.onResponse(new ExchangeResponse(blockFactory, null, sinkClosed.get())); + } else { + try { + if (permits.tryAcquire(between(0, 100), TimeUnit.MICROSECONDS)) { + boolean closed = sinkClosed.get(); + final Page page; + if (closed) { + page = new Page(blockFactory.newConstantIntBlockWith(1, 1)); + pages.add(page); + } else { + page = null; + } + listener.onResponse(new ExchangeResponse(blockFactory, page, closed)); + } else { + listener.onResponse(new ExchangeResponse(blockFactory, null, sinkClosed.get())); + } + } catch (Exception e) { + throw new AssertionError(e); + } + } + }, false, between(1, 3), sinkCompleted); + threadPool.schedule( + () -> sourceHandler.finishEarly(randomBoolean(), ActionListener.noop()), + TimeValue.timeValueMillis(between(0, 10)), + threadPool.generic() + ); + sinkCompleted.actionGet(); + Page p; + while ((p = exchangeSource.pollPage()) != null) { + assertSame(p, pages.poll()); + p.releaseBlocks(); + } + while ((p = pages.poll()) != null) { + p.releaseBlocks(); + } + assertTrue(exchangeSource.isFinished()); + exchangeSource.finish(); + } + public void testConcurrentWithTransportActions() { MockTransportService node0 = newTransportService(); ExchangeService exchange0 = new ExchangeService(Settings.EMPTY, threadPool, ESQL_TEST_EXECUTOR, blockFactory()); diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RequestIndexFilteringTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RequestIndexFilteringTestCase.java index 3314430d63eaa..406997b66dbf0 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RequestIndexFilteringTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RequestIndexFilteringTestCase.java @@ -101,7 +101,7 @@ public void testFieldExistsFilter_KeepWildcard() throws IOException { indexTimestampData(docsTest1, "test1", "2024-11-26", "id1"); indexTimestampData(docsTest2, "test2", "2023-11-26", "id2"); - // filter includes only test1. Columns are rows of test2 are filtered out + // filter includes only test1. Columns and rows of test2 are filtered out RestEsqlTestCase.RequestObjectBuilder builder = existsFilter("id1").query("FROM test*"); Map result = runEsql(builder); assertMap( @@ -253,6 +253,9 @@ protected void indexTimestampData(int docs, String indexName, String date, Strin "@timestamp": { "type": "date" }, + "value": { + "type": "long" + }, "%differentiator_field_name%": { "type": "integer" } diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java index 505ab3adc553b..6a8779eef4efc 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java @@ -350,21 +350,21 @@ public void testTextMode() throws IOException { int count = randomIntBetween(0, 100); bulkLoadTestData(count); var builder = requestObjectBuilder().query(fromIndex() + " | keep keyword, integer | sort integer asc | limit 100"); - assertEquals(expectedTextBody("txt", count, null), runEsqlAsTextWithFormat(builder, "txt", null)); + assertEquals(expectedTextBody("txt", count, null), runEsqlAsTextWithFormat(builder, "txt", null, mode)); } public void testCSVMode() throws IOException { int count = randomIntBetween(0, 100); bulkLoadTestData(count); var builder = requestObjectBuilder().query(fromIndex() + " | keep keyword, integer | sort integer asc | limit 100"); - assertEquals(expectedTextBody("csv", count, '|'), runEsqlAsTextWithFormat(builder, "csv", '|')); + assertEquals(expectedTextBody("csv", count, '|'), runEsqlAsTextWithFormat(builder, "csv", '|', mode)); } public void testTSVMode() throws IOException { int count = randomIntBetween(0, 100); bulkLoadTestData(count); var builder = requestObjectBuilder().query(fromIndex() + " | keep keyword, integer | sort integer asc | limit 100"); - assertEquals(expectedTextBody("tsv", count, null), runEsqlAsTextWithFormat(builder, "tsv", null)); + assertEquals(expectedTextBody("tsv", count, null), runEsqlAsTextWithFormat(builder, "tsv", null, mode)); } public void testCSVNoHeaderMode() throws IOException { @@ -1003,53 +1003,35 @@ public static Map runEsqlSync(RequestObjectBuilder requestObject } public static Map runEsqlAsync(RequestObjectBuilder requestObject) throws IOException { - return runEsqlAsync(requestObject, new AssertWarnings.NoWarnings()); + return runEsqlAsync(requestObject, randomBoolean(), new AssertWarnings.NoWarnings()); } static Map runEsql(RequestObjectBuilder requestObject, AssertWarnings assertWarnings, Mode mode) throws IOException { if (mode == ASYNC) { - return runEsqlAsync(requestObject, assertWarnings); + return runEsqlAsync(requestObject, randomBoolean(), assertWarnings); } else { return runEsqlSync(requestObject, assertWarnings); } } public static Map runEsqlSync(RequestObjectBuilder requestObject, AssertWarnings assertWarnings) throws IOException { - requestObject.build(); - Request request = prepareRequest(SYNC); - String mediaType = attachBody(requestObject, request); - - RequestOptions.Builder options = request.getOptions().toBuilder(); - options.setWarningsHandler(WarningsHandler.PERMISSIVE); // We assert the warnings ourselves - options.addHeader("Content-Type", mediaType); - - if (randomBoolean()) { - options.addHeader("Accept", mediaType); - } else { - request.addParameter("format", requestObject.contentType().queryParameter()); - } - request.setOptions(options); + Request request = prepareRequestWithOptions(requestObject, SYNC); HttpEntity entity = performRequest(request, assertWarnings); return entityToMap(entity, requestObject.contentType()); } public static Map runEsqlAsync(RequestObjectBuilder requestObject, AssertWarnings assertWarnings) throws IOException { - addAsyncParameters(requestObject); - requestObject.build(); - Request request = prepareRequest(ASYNC); - String mediaType = attachBody(requestObject, request); - - RequestOptions.Builder options = request.getOptions().toBuilder(); - options.setWarningsHandler(WarningsHandler.PERMISSIVE); // We assert the warnings ourselves - options.addHeader("Content-Type", mediaType); + return runEsqlAsync(requestObject, randomBoolean(), assertWarnings); + } - if (randomBoolean()) { - options.addHeader("Accept", mediaType); - } else { - request.addParameter("format", requestObject.contentType().queryParameter()); - } - request.setOptions(options); + public static Map runEsqlAsync( + RequestObjectBuilder requestObject, + boolean keepOnCompletion, + AssertWarnings assertWarnings + ) throws IOException { + addAsyncParameters(requestObject, keepOnCompletion); + Request request = prepareRequestWithOptions(requestObject, ASYNC); if (shouldLog()) { LOGGER.info("REQUEST={}", request); @@ -1061,7 +1043,7 @@ public static Map runEsqlAsync(RequestObjectBuilder requestObjec Object initialColumns = null; Object initialValues = null; var json = entityToMap(entity, requestObject.contentType()); - checkKeepOnCompletion(requestObject, json); + checkKeepOnCompletion(requestObject, json, keepOnCompletion); String id = (String) json.get("id"); var supportsAsyncHeaders = clusterHasCapability("POST", "/_query", List.of(), List.of("async_query_status_headers")).orElse(false); @@ -1101,7 +1083,7 @@ public static Map runEsqlAsync(RequestObjectBuilder requestObjec // issue a second request to "async get" the results Request getRequest = prepareAsyncGetRequest(id); - getRequest.setOptions(options); + getRequest.setOptions(request.getOptions()); response = performRequest(getRequest); entity = response.getEntity(); } @@ -1119,6 +1101,66 @@ public static Map runEsqlAsync(RequestObjectBuilder requestObjec return removeAsyncProperties(result); } + public void testAsyncGetWithoutContentType() throws IOException { + int count = randomIntBetween(0, 100); + bulkLoadTestData(count); + var requestObject = requestObjectBuilder().query(fromIndex() + " | keep keyword, integer | sort integer asc | limit 100"); + + addAsyncParameters(requestObject, true); + Request request = prepareRequestWithOptions(requestObject, ASYNC); + + if (shouldLog()) { + LOGGER.info("REQUEST={}", request); + } + + Response response = performRequest(request); + HttpEntity entity = response.getEntity(); + + var json = entityToMap(entity, requestObject.contentType()); + checkKeepOnCompletion(requestObject, json, true); + String id = (String) json.get("id"); + // results won't be returned since keepOnCompletion is true + assertThat(id, is(not(emptyOrNullString()))); + + // issue an "async get" request with no Content-Type + Request getRequest = prepareAsyncGetRequest(id); + response = performRequest(getRequest); + entity = response.getEntity(); + var result = entityToMap(entity, XContentType.JSON); + + ListMatcher values = matchesList(); + for (int i = 0; i < count; i++) { + values = values.item(matchesList().item("keyword" + i).item(i)); + } + assertMap( + result, + matchesMap().entry( + "columns", + matchesList().item(matchesMap().entry("name", "keyword").entry("type", "keyword")) + .item(matchesMap().entry("name", "integer").entry("type", "integer")) + ).entry("values", values).entry("took", greaterThanOrEqualTo(0)).entry("id", id).entry("is_running", false) + ); + + } + + static Request prepareRequestWithOptions(RequestObjectBuilder requestObject, Mode mode) throws IOException { + requestObject.build(); + Request request = prepareRequest(mode); + String mediaType = attachBody(requestObject, request); + + RequestOptions.Builder options = request.getOptions().toBuilder(); + options.setWarningsHandler(WarningsHandler.PERMISSIVE); // We assert the warnings ourselves + options.addHeader("Content-Type", mediaType); + + if (randomBoolean()) { + options.addHeader("Accept", mediaType); + } else { + request.addParameter("format", requestObject.contentType().queryParameter()); + } + request.setOptions(options); + return request; + } + // Removes async properties, otherwise consuming assertions would need to handle sync and async differences static Map removeAsyncProperties(Map map) { Map copy = new HashMap<>(map); @@ -1139,17 +1181,20 @@ protected static Map entityToMap(HttpEntity entity, XContentType } } - static void addAsyncParameters(RequestObjectBuilder requestObject) throws IOException { + static void addAsyncParameters(RequestObjectBuilder requestObject, boolean keepOnCompletion) throws IOException { // deliberately short in order to frequently trigger return without results requestObject.waitForCompletion(TimeValue.timeValueNanos(randomIntBetween(1, 100))); - requestObject.keepOnCompletion(randomBoolean()); + requestObject.keepOnCompletion(keepOnCompletion); requestObject.keepAlive(TimeValue.timeValueDays(randomIntBetween(1, 10))); } // If keep_on_completion is set then an id must always be present, regardless of the value of any other property. - static void checkKeepOnCompletion(RequestObjectBuilder requestObject, Map json) { + static void checkKeepOnCompletion(RequestObjectBuilder requestObject, Map json, boolean keepOnCompletion) { if (requestObject.keepOnCompletion()) { + assertTrue(keepOnCompletion); assertThat((String) json.get("id"), not(emptyOrNullString())); + } else { + assertFalse(keepOnCompletion); } } @@ -1167,14 +1212,19 @@ static void deleteNonExistent(Request request) throws IOException { assertEquals(404, response.getStatusLine().getStatusCode()); } - static String runEsqlAsTextWithFormat(RequestObjectBuilder builder, String format, @Nullable Character delimiter) throws IOException { - Request request = prepareRequest(SYNC); + static String runEsqlAsTextWithFormat(RequestObjectBuilder builder, String format, @Nullable Character delimiter, Mode mode) + throws IOException { + Request request = prepareRequest(mode); + if (mode == ASYNC) { + addAsyncParameters(builder, randomBoolean()); + } String mediaType = attachBody(builder.build(), request); RequestOptions.Builder options = request.getOptions().toBuilder(); options.addHeader("Content-Type", mediaType); - if (randomBoolean()) { + boolean addParam = randomBoolean(); + if (addParam) { request.addParameter("format", format); } else { switch (format) { @@ -1188,8 +1238,75 @@ static String runEsqlAsTextWithFormat(RequestObjectBuilder builder, String forma } request.setOptions(options); - HttpEntity entity = performRequest(request, new AssertWarnings.NoWarnings()); - return Streams.copyToString(new InputStreamReader(entity.getContent(), StandardCharsets.UTF_8)); + if (shouldLog()) { + LOGGER.info("REQUEST={}", request); + } + + Response response = performRequest(request); + HttpEntity entity = assertWarnings(response, new AssertWarnings.NoWarnings()); + + // get the content, it could be empty because the request might have not completed + String initialValue = Streams.copyToString(new InputStreamReader(entity.getContent(), StandardCharsets.UTF_8)); + String id = response.getHeader("X-Elasticsearch-Async-Id"); + + if (mode == SYNC) { + assertThat(id, is(emptyOrNullString())); + return initialValue; + } + + if (id == null) { + // no id returned from an async call, must have completed immediately and without keep_on_completion + assertThat(builder.keepOnCompletion(), either(nullValue()).or(is(false))); + assertNull(response.getHeader("is_running")); + // the content cant be empty + assertThat(initialValue, not(emptyOrNullString())); + return initialValue; + } else { + // async may not return results immediately, so may need an async get + assertThat(id, is(not(emptyOrNullString()))); + String isRunning = response.getHeader("X-Elasticsearch-Async-Is-Running"); + if ("?0".equals(isRunning)) { + // must have completed immediately so keep_on_completion must be true + assertThat(builder.keepOnCompletion(), is(true)); + } else { + // did not return results immediately, so we will need an async get + // Also, different format modes return different results. + switch (format) { + case "txt" -> assertThat(initialValue, emptyOrNullString()); + case "csv" -> { + assertEquals(initialValue, "\r\n"); + initialValue = ""; + } + case "tsv" -> { + assertEquals(initialValue, "\n"); + initialValue = ""; + } + } + } + // issue a second request to "async get" the results + Request getRequest = prepareAsyncGetRequest(id); + if (delimiter != null) { + getRequest.addParameter("delimiter", String.valueOf(delimiter)); + } + // If the `format` parameter is not added, the GET request will return a response + // with the `Content-Type` type due to the lack of an `Accept` header. + if (addParam) { + getRequest.addParameter("format", format); + } + // if `addParam` is false, `options` will already have an `Accept` header + getRequest.setOptions(options); + response = performRequest(getRequest); + entity = assertWarnings(response, new AssertWarnings.NoWarnings()); + } + String newValue = Streams.copyToString(new InputStreamReader(entity.getContent(), StandardCharsets.UTF_8)); + + // assert initial contents, if any, are the same as async get contents + if (initialValue != null && initialValue.isEmpty() == false) { + assertEquals(initialValue, newValue); + } + + assertDeletable(id); + return newValue; } private static Request prepareRequest(Mode mode) { diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvSpecReader.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvSpecReader.java index 84e06e0c1b674..ba0d11059a69b 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvSpecReader.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvSpecReader.java @@ -80,7 +80,12 @@ public Object parse(String line) { testCase.expectedWarningsRegexString.add(regex); testCase.expectedWarningsRegex.add(warningRegexToPattern(regex)); } else if (lower.startsWith("ignoreorder:")) { - testCase.ignoreOrder = Boolean.parseBoolean(line.substring("ignoreOrder:".length()).trim()); + String value = lower.substring("ignoreOrder:".length()).trim(); + if ("true".equals(value)) { + testCase.ignoreOrder = true; + } else if ("false".equals(value) == false) { + throw new IllegalArgumentException("Invalid value for ignoreOrder: [" + value + "], it can only be true or false"); + } } else if (line.startsWith(";")) { testCase.expectedResults = data.toString(); // clean-up and emit diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index d6715a932c075..5535e801b1b0c 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -10,6 +10,7 @@ import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.sandbox.document.HalfFloatPoint; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.NoopCircuitBreaker; @@ -30,7 +31,9 @@ import org.elasticsearch.geo.ShapeTestUtils; import org.elasticsearch.index.IndexMode; import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.RemoteTransportException; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.esql.action.EsqlQueryResponse; import org.elasticsearch.xpack.esql.analysis.EnrichResolution; @@ -129,6 +132,8 @@ import static org.elasticsearch.xpack.esql.parser.ParserUtils.ParamClassification.PATTERN; import static org.elasticsearch.xpack.esql.parser.ParserUtils.ParamClassification.VALUE; import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; public final class EsqlTestUtils { @@ -726,7 +731,7 @@ public static Literal randomLiteral(DataType type) { case UNSIGNED_LONG, LONG, COUNTER_LONG -> randomLong(); case DATE_PERIOD -> Period.of(randomIntBetween(-1000, 1000), randomIntBetween(-13, 13), randomIntBetween(-32, 32)); case DATETIME -> randomMillisUpToYear9999(); - case DATE_NANOS -> randomLong(); + case DATE_NANOS -> randomLongBetween(0, Long.MAX_VALUE); case DOUBLE, SCALED_FLOAT, COUNTER_DOUBLE -> randomDouble(); case FLOAT -> randomFloat(); case HALF_FLOAT -> HalfFloatPoint.sortableShortToHalfFloat(HalfFloatPoint.halfFloatToSortableShort(randomFloat())); @@ -784,4 +789,17 @@ public static QueryParam paramAsIdentifier(String name, Object value) { public static QueryParam paramAsPattern(String name, Object value) { return new QueryParam(name, value, NULL, PATTERN); } + + /** + * Asserts that: + * 1. Cancellation exceptions are ignored when more relevant exceptions exist. + * 2. Transport exceptions are unwrapped, and the actual causes are reported to users. + */ + public static void assertEsqlFailure(Exception e) { + assertNotNull(e); + var cancellationFailure = ExceptionsHelper.unwrapCausesAndSuppressed(e, t -> t instanceof TaskCancelledException).orElse(null); + assertNull("cancellation exceptions must be ignored", cancellationFailure); + ExceptionsHelper.unwrapCausesAndSuppressed(e, t -> t instanceof RemoteTransportException) + .ifPresent(transportFailure -> assertNull("remote transport exception must be unwrapped", transportFailure.getCause())); + } } diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date_nanos.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date_nanos.csv-spec index daa45825b93fc..0d113c0422562 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date_nanos.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/date_nanos.csv-spec @@ -216,11 +216,40 @@ millis:date | nanos:date_nanos | num:long 2023-10-23T13:33:34.937Z | 2023-10-23T13:33:34.937193000Z | 1698068014937193000 ; +date nanos greater than millis +required_capability: date_nanos_type +required_capability: date_nanos_compare_to_millis + +FROM date_nanos | WHERE MV_MIN(nanos) > TO_DATETIME("2023-10-23T12:27:28.948Z") | SORT nanos DESC; + +millis:date | nanos:date_nanos | num:long +2023-10-23T13:55:01.543Z | 2023-10-23T13:55:01.543123456Z | 1698069301543123456 +2023-10-23T13:53:55.832Z | 2023-10-23T13:53:55.832987654Z | 1698069235832987654 +2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015787878Z | 1698069175015787878 +2023-10-23T13:51:54.732Z | 2023-10-23T13:51:54.732102837Z | 1698069114732102837 +2023-10-23T13:33:34.937Z | 2023-10-23T13:33:34.937193000Z | 1698068014937193000 +; + date nanos greater than or equal required_capability: to_date_nanos required_capability: date_nanos_binary_comparison -FROM date_nanos | WHERE MV_MIN(nanos) >= TO_DATE_NANOS("2023-10-23T12:27:28.948000000Z") | SORT nanos DESC; +FROM date_nanos | WHERE MV_MIN(nanos) >= TO_DATE_NANOS("2023-10-23T12:27:28.948Z") | SORT nanos DESC; + +millis:date | nanos:date_nanos | num:long +2023-10-23T13:55:01.543Z | 2023-10-23T13:55:01.543123456Z | 1698069301543123456 +2023-10-23T13:53:55.832Z | 2023-10-23T13:53:55.832987654Z | 1698069235832987654 +2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015787878Z | 1698069175015787878 +2023-10-23T13:51:54.732Z | 2023-10-23T13:51:54.732102837Z | 1698069114732102837 +2023-10-23T13:33:34.937Z | 2023-10-23T13:33:34.937193000Z | 1698068014937193000 +2023-10-23T12:27:28.948Z | 2023-10-23T12:27:28.948000000Z | 1698064048948000000 +; + +date nanos greater than or equal millis +required_capability: date_nanos_type +required_capability: date_nanos_compare_to_millis + +FROM date_nanos | WHERE MV_MIN(nanos) >= TO_DATETIME("2023-10-23T12:27:28.948Z") | SORT nanos DESC; millis:date | nanos:date_nanos | num:long 2023-10-23T13:55:01.543Z | 2023-10-23T13:55:01.543123456Z | 1698069301543123456 @@ -231,11 +260,23 @@ millis:date | nanos:date_nanos | num:long 2023-10-23T12:27:28.948Z | 2023-10-23T12:27:28.948000000Z | 1698064048948000000 ; + date nanos less than required_capability: to_date_nanos required_capability: date_nanos_binary_comparison -FROM date_nanos | WHERE MV_MIN(nanos) < TO_DATE_NANOS("2023-10-23T12:27:28.948000000Z") AND millis > "2000-01-01" | SORT nanos DESC; +FROM date_nanos | WHERE MV_MIN(nanos) < TO_DATE_NANOS("2023-10-23T12:27:28.948Z") AND millis > "2000-01-01" | SORT nanos DESC; + +millis:date | nanos:date_nanos | num:long +2023-10-23T12:15:03.360Z | 2023-10-23T12:15:03.360103847Z | 1698063303360103847 +2023-10-23T12:15:03.360Z | 2023-10-23T12:15:03.360103847Z | 1698063303360103847 +; + +date nanos less than millis +required_capability: date_nanos_type +required_capability: date_nanos_compare_to_millis + +FROM date_nanos | WHERE MV_MIN(nanos) < TO_DATETIME("2023-10-23T12:27:28.948Z") AND millis > "2000-01-01" | SORT nanos DESC; millis:date | nanos:date_nanos | num:long 2023-10-23T12:15:03.360Z | 2023-10-23T12:15:03.360103847Z | 1698063303360103847 @@ -246,7 +287,19 @@ date nanos less than equal required_capability: to_date_nanos required_capability: date_nanos_binary_comparison -FROM date_nanos | WHERE MV_MIN(nanos) <= TO_DATE_NANOS("2023-10-23T12:27:28.948000000Z") AND millis > "2000-01-01" | SORT nanos DESC; +FROM date_nanos | WHERE MV_MIN(nanos) <= TO_DATE_NANOS("2023-10-23T12:27:28.948Z") AND millis > "2000-01-01" | SORT nanos DESC; + +millis:date | nanos:date_nanos | num:long +2023-10-23T12:27:28.948Z | 2023-10-23T12:27:28.948000000Z | 1698064048948000000 +2023-10-23T12:15:03.360Z | 2023-10-23T12:15:03.360103847Z | 1698063303360103847 +2023-10-23T12:15:03.360Z | 2023-10-23T12:15:03.360103847Z | 1698063303360103847 +; + +date nanos less than equal millis +required_capability: date_nanos_type +required_capability: date_nanos_compare_to_millis + +FROM date_nanos | WHERE MV_MIN(nanos) <= TO_DATETIME("2023-10-23T12:27:28.948Z") AND millis > "2000-01-01" | SORT nanos DESC; millis:date | nanos:date_nanos | num:long 2023-10-23T12:27:28.948Z | 2023-10-23T12:27:28.948000000Z | 1698064048948000000 @@ -254,6 +307,7 @@ millis:date | nanos:date_nanos | num:long 2023-10-23T12:15:03.360Z | 2023-10-23T12:15:03.360103847Z | 1698063303360103847 ; + date nanos equals required_capability: to_date_nanos required_capability: date_nanos_binary_comparison @@ -264,6 +318,25 @@ millis:date | nanos:date_nanos | num:long 2023-10-23T12:27:28.948Z | 2023-10-23T12:27:28.948000000Z | 1698064048948000000 ; +date nanos equals millis exact match +required_capability: date_nanos_type +required_capability: date_nanos_compare_to_millis + +FROM date_nanos | WHERE MV_MIN(nanos) == TO_DATETIME("2023-10-23T12:27:28.948Z"); + +millis:date | nanos:date_nanos | num:long +2023-10-23T12:27:28.948Z | 2023-10-23T12:27:28.948000000Z | 1698064048948000000 +; + +date nanos equals millis without exact match +required_capability: date_nanos_type +required_capability: date_nanos_compare_to_millis + +FROM date_nanos | WHERE MV_MIN(nanos) == TO_DATETIME("2023-10-23T13:33:34.937"); + +millis:date | nanos:date_nanos | num:long +; + date nanos not equals required_capability: to_date_nanos required_capability: date_nanos_binary_comparison @@ -280,6 +353,22 @@ millis:date | nanos:date_nanos | num:long 2023-10-23T12:15:03.360Z | 2023-10-23T12:15:03.360103847Z | 1698063303360103847 ; +date nanos not equals millis +required_capability: date_nanos_type +required_capability: date_nanos_compare_to_millis + +FROM date_nanos | WHERE MV_MIN(nanos) != TO_DATETIME("2023-10-23T12:27:28.948Z") AND millis > "2000-01-01" | SORT nanos DESC; + +millis:date | nanos:date_nanos | num:long +2023-10-23T13:55:01.543Z | 2023-10-23T13:55:01.543123456Z | 1698069301543123456 +2023-10-23T13:53:55.832Z | 2023-10-23T13:53:55.832987654Z | 1698069235832987654 +2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015787878Z | 1698069175015787878 +2023-10-23T13:51:54.732Z | 2023-10-23T13:51:54.732102837Z | 1698069114732102837 +2023-10-23T13:33:34.937Z | 2023-10-23T13:33:34.937193000Z | 1698068014937193000 +2023-10-23T12:15:03.360Z | 2023-10-23T12:15:03.360103847Z | 1698063303360103847 +2023-10-23T12:15:03.360Z | 2023-10-23T12:15:03.360103847Z | 1698063303360103847 +; + date nanos to long, index version required_capability: to_date_nanos diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec index f2800456ceb33..2d4c105cfff20 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/lookup-join.csv-spec @@ -120,6 +120,19 @@ left:keyword | client_ip:keyword | right:keyword | env:keyword left | 172.21.0.5 | right | Development ; +lookupIPFromRowWithShadowingKeepReordered +required_capability: join_lookup_v4 + +ROW left = "left", client_ip = "172.21.0.5", env = "env", right = "right" +| EVAL client_ip = client_ip::keyword +| LOOKUP JOIN clientips_lookup ON client_ip +| KEEP right, env, client_ip +; + +right:keyword | env:keyword | client_ip:keyword +right | Development | 172.21.0.5 +; + lookupIPFromIndex required_capability: join_lookup_v4 @@ -127,6 +140,7 @@ FROM sample_data | EVAL client_ip = client_ip::keyword | LOOKUP JOIN clientips_lookup ON client_ip ; +ignoreOrder:true @timestamp:date | event_duration:long | message:keyword | client_ip:keyword | env:keyword 2023-10-23T13:55:01.543Z | 1756467 | Connected to 10.1.0.1 | 172.21.3.15 | Production @@ -146,6 +160,7 @@ FROM sample_data | LOOKUP JOIN clientips_lookup ON client_ip | KEEP @timestamp, client_ip, event_duration, message, env ; +ignoreOrder:true @timestamp:date | client_ip:keyword | event_duration:long | message:keyword | env:keyword 2023-10-23T13:55:01.543Z | 172.21.3.15 | 1756467 | Connected to 10.1.0.1 | Production @@ -230,6 +245,7 @@ required_capability: join_lookup_v4 FROM sample_data | LOOKUP JOIN message_types_lookup ON message ; +ignoreOrder:true @timestamp:date | client_ip:ip | event_duration:long | message:keyword | type:keyword 2023-10-23T13:55:01.543Z | 172.21.3.15 | 1756467 | Connected to 10.1.0.1 | Success @@ -248,6 +264,7 @@ FROM sample_data | LOOKUP JOIN message_types_lookup ON message | KEEP @timestamp, client_ip, event_duration, message, type ; +ignoreOrder:true @timestamp:date | client_ip:ip | event_duration:long | message:keyword | type:keyword 2023-10-23T13:55:01.543Z | 172.21.3.15 | 1756467 | Connected to 10.1.0.1 | Success @@ -259,6 +276,24 @@ FROM sample_data 2023-10-23T12:15:03.360Z | 172.21.2.162 | 3450233 | Connected to 10.1.0.3 | Success ; +lookupMessageFromIndexKeepReordered +required_capability: join_lookup_v4 + +FROM sample_data +| LOOKUP JOIN message_types_lookup ON message +| KEEP type, client_ip, event_duration, message +; + +type:keyword | client_ip:ip | event_duration:long | message:keyword +Success | 172.21.3.15 | 1756467 | Connected to 10.1.0.1 +Error | 172.21.3.15 | 5033755 | Connection error +Error | 172.21.3.15 | 8268153 | Connection error +Error | 172.21.3.15 | 725448 | Connection error +Disconnected | 172.21.0.5 | 1232382 | Disconnected +Success | 172.21.2.113 | 2764889 | Connected to 10.1.0.2 +Success | 172.21.2.162 | 3450233 | Connected to 10.1.0.3 +; + lookupMessageFromIndexStats required_capability: join_lookup_v4 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/term-function.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/term-function.csv-spec new file mode 100644 index 0000000000000..0c72cad02eed1 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/term-function.csv-spec @@ -0,0 +1,206 @@ +############################################### +# Tests for Term function +# + +termWithTextField +required_capability: term_function + +// tag::term-with-field[] +FROM books +| WHERE TERM(author, "gabriel") +| KEEP book_no, title +| LIMIT 3; +// end::term-with-field[] +ignoreOrder:true + +book_no:keyword | title:text +4814 | El Coronel No Tiene Quien Le Escriba / No One Writes to the Colonel (Spanish Edition) +4917 | Autumn of the Patriarch +6380 | La hojarasca (Spanish Edition) +; + +termWithKeywordField +required_capability: term_function + +from employees +| where term(first_name, "Guoxiang") +| keep emp_no, first_name; + +// tag::term-with-keyword-field-result[] +emp_no:integer | first_name:keyword +10015 | Guoxiang +; +// end::term-with-keyword-field-result[] + +termWithQueryExpressions +required_capability: term_function + +from books +| where term(author, CONCAT("gab", "riel")) +| keep book_no, title; +ignoreOrder:true + +book_no:keyword | title:text +4814 | El Coronel No Tiene Quien Le Escriba / No One Writes to the Colonel (Spanish Edition) +4917 | Autumn of the Patriarch +6380 | La hojarasca (Spanish Edition) +; + +termAfterKeep +required_capability: term_function + +from books +| keep book_no, author +| where term(author, "faulkner") +| sort book_no +| limit 5; + +book_no:keyword | author:text +2378 | [Carol Faulkner, Holly Byers Ochoa, Lucretia Mott] +2713 | William Faulkner +2847 | Colleen Faulkner +2883 | William Faulkner +3293 | Danny Faulkner +; + +termAfterDrop +required_capability: term_function + +from books +| drop ratings, description, year, publisher, title, author.keyword +| where term(author, "william") +| keep book_no, author +| sort book_no +| limit 2; + +book_no:keyword | author:text +2713 | William Faulkner +2883 | William Faulkner +; + +termAfterEval +required_capability: term_function + +from books +| eval stars = to_long(ratings / 2.0) +| where term(author, "colleen") +| sort book_no +| keep book_no, author, stars +| limit 2; + +book_no:keyword | author:text | stars:long +2847 | Colleen Faulkner | 3 +4502 | Colleen Faulkner | 3 +; + +termWithConjunction +required_capability: term_function + +from books +| where term(author, "tolkien") and ratings > 4.95 +| eval author = mv_sort(author) +| keep book_no, ratings, author; +ignoreOrder:true + +book_no:keyword | ratings:double | author:keyword +2301 | 5.0 | John Ronald Reuel Tolkien +3254 | 5.0 | [Christopher Tolkien, John Ronald Reuel Tolkien] +7350 | 5.0 | [Christopher Tolkien, John Ronald Reuel Tolkien] +; + +termWithConjunctionAndSort +required_capability: term_function + +from books +| where term(author, "tolkien") and ratings > 4.95 +| eval author = mv_sort(author) +| keep book_no, ratings, author +| sort book_no; + +book_no:keyword | ratings:double | author:keyword +2301 | 5.0 | John Ronald Reuel Tolkien +3254 | 5.0 | [Christopher Tolkien, John Ronald Reuel Tolkien] +7350 | 5.0 | [Christopher Tolkien, John Ronald Reuel Tolkien] +; + +termWithFunctionPushedToLucene +required_capability: term_function + +from hosts +| where term(host, "beta") and cidr_match(ip1, "127.0.0.2/32", "127.0.0.3/32") +| keep card, host, ip0, ip1; +ignoreOrder:true + +card:keyword |host:keyword |ip0:ip |ip1:ip +eth1 |beta |127.0.0.1 |127.0.0.2 +; + +termWithNonPushableConjunction +required_capability: term_function + +from books +| where term(title, "rings") and length(title) > 75 +| keep book_no, title; +ignoreOrder:true + +book_no:keyword | title:text +4023 | A Tolkien Compass: Including J. R. R. Tolkien's Guide to the Names in The Lord of the Rings +; + +termWithMultipleWhereClauses +required_capability: term_function + +from books +| where term(title, "rings") +| where term(title, "lord") +| keep book_no, title; +ignoreOrder:true + +book_no:keyword | title:text +2675 | The Lord of the Rings - Boxed Set +2714 | Return of the King Being the Third Part of The Lord of the Rings +4023 | A Tolkien Compass: Including J. R. R. Tolkien's Guide to the Names in The Lord of the Rings +7140 | The Lord of the Rings Poster Collection: Six Paintings by Alan Lee (No. 1) +; + +termWithMultivaluedField +required_capability: term_function + +from employees +| where term(job_positions, "Data Scientist") +| keep emp_no, first_name, last_name +| sort emp_no asc +| limit 2; +ignoreOrder:true + +emp_no:integer | first_name:keyword | last_name:keyword +10014 | Berni | Genin +10017 | Cristinel | Bouloucos +; + +testWithMultiValuedFieldWithConjunction +required_capability: term_function + +from employees +| where term(job_positions, "Data Scientist") and term(first_name, "Cristinel") +| keep emp_no, first_name, last_name +| limit 1; + +emp_no:integer | first_name:keyword | last_name:keyword +10017 | Cristinel | Bouloucos +; + +termWithConjQueryStringFunctions +required_capability: term_function +required_capability: qstr_function + +from employees +| where term(job_positions, "Data Scientist") and qstr("first_name: Cristinel and gender: F") +| keep emp_no, first_name, last_name +| sort emp_no ASC +| limit 1; +ignoreOrder:true + +emp_no:integer | first_name:keyword | last_name:keyword +10017 | Cristinel | Bouloucos +; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterAsyncQueryIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterAsyncQueryIT.java index 440582dcfbb45..c8206621de419 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterAsyncQueryIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterAsyncQueryIT.java @@ -66,7 +66,7 @@ public class CrossClusterAsyncQueryIT extends AbstractMultiClustersTestCase { private static final String INDEX_WITH_RUNTIME_MAPPING = "blocking"; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER_1, REMOTE_CLUSTER_2); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterEnrichUnavailableClustersIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterEnrichUnavailableClustersIT.java index d142752d0c408..5c3e1974e924f 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterEnrichUnavailableClustersIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterEnrichUnavailableClustersIT.java @@ -53,7 +53,7 @@ public class CrossClusterEnrichUnavailableClustersIT extends AbstractMultiCluste public static String REMOTE_CLUSTER_2 = "c2"; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER_1, REMOTE_CLUSTER_2); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterQueryUnavailableRemotesIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterQueryUnavailableRemotesIT.java index 0f1aa8541fdd9..d1c9b5cfb2ac7 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterQueryUnavailableRemotesIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClusterQueryUnavailableRemotesIT.java @@ -42,7 +42,7 @@ public class CrossClusterQueryUnavailableRemotesIT extends AbstractMultiClusters private static final String REMOTE_CLUSTER_2 = "cluster-b"; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER_1, REMOTE_CLUSTER_2); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java index f29f79976dc0d..5291ad3b0d039 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java @@ -55,7 +55,7 @@ public class CrossClustersCancellationIT extends AbstractMultiClustersTestCase { private static final String REMOTE_CLUSTER = "cluster-a"; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersEnrichIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersEnrichIT.java index e8e9f45694e9c..57f85751999a5 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersEnrichIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersEnrichIT.java @@ -64,7 +64,7 @@ public class CrossClustersEnrichIT extends AbstractMultiClustersTestCase { @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of("c1", "c2"); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersQueryIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersQueryIT.java index 596c70e57ccd6..46bbad5551e6b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersQueryIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersQueryIT.java @@ -67,7 +67,7 @@ public class CrossClustersQueryIT extends AbstractMultiClustersTestCase { private static String REMOTE_INDEX = "logs-2"; @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER_1, REMOTE_CLUSTER_2); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java index dab99a0f719dd..c4da0bf32ef96 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java @@ -143,6 +143,7 @@ protected EsqlQueryResponse run(EsqlQueryRequest request) { return client.execute(EsqlQueryAction.INSTANCE, request).actionGet(2, TimeUnit.MINUTES); } catch (Exception e) { logger.info("request failed", e); + EsqlTestUtils.assertEsqlFailure(e); ensureBlocksReleased(); } finally { setRequestCircuitBreakerLimit(null); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java index 37833d8aed2d3..ec7ee8b61c2d5 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java @@ -23,6 +23,7 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.junit.annotations.TestLogging; +import org.elasticsearch.xpack.esql.EsqlTestUtils; import java.util.ArrayList; import java.util.Collection; @@ -85,6 +86,7 @@ private EsqlQueryResponse runWithBreaking(EsqlQueryRequest request) throws Circu } catch (Exception e) { logger.info("request failed", e); ensureBlocksReleased(); + EsqlTestUtils.assertEsqlFailure(e); throw e; } finally { setRequestCircuitBreakerLimit(null); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java index 147b13b36c44b..00f53d31165b1 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionIT.java @@ -18,6 +18,7 @@ import org.elasticsearch.client.internal.ClusterAdminClient; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.Index; @@ -1648,6 +1649,44 @@ public void testMaxTruncationSizeSetting() { } } + public void testScriptField() throws Exception { + XContentBuilder mapping = JsonXContent.contentBuilder(); + mapping.startObject(); + { + mapping.startObject("runtime"); + { + mapping.startObject("k1"); + mapping.field("type", "long"); + mapping.endObject(); + mapping.startObject("k2"); + mapping.field("type", "long"); + mapping.endObject(); + } + mapping.endObject(); + { + mapping.startObject("properties"); + mapping.startObject("meter").field("type", "double").endObject(); + mapping.endObject(); + } + } + mapping.endObject(); + String sourceMode = randomBoolean() ? "stored" : "synthetic"; + Settings.Builder settings = indexSettings(1, 0).put(indexSettings()).put("index.mapping.source.mode", sourceMode); + client().admin().indices().prepareCreate("test-script").setMapping(mapping).setSettings(settings).get(); + for (int i = 0; i < 10; i++) { + index("test-script", Integer.toString(i), Map.of("k1", i, "k2", "b-" + i, "meter", 10000 * i)); + } + refresh("test-script"); + try (EsqlQueryResponse resp = run("FROM test-script | SORT k1 | LIMIT 10")) { + List k1Column = Iterators.toList(resp.column(0)); + assertThat(k1Column, contains(0L, 1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L)); + List k2Column = Iterators.toList(resp.column(1)); + assertThat(k2Column, contains(null, null, null, null, null, null, null, null, null, null)); + List meterColumn = Iterators.toList(resp.column(2)); + assertThat(meterColumn, contains(0.0, 10000.0, 20000.0, 30000.0, 40000.0, 50000.0, 60000.0, 70000.0, 80000.0, 90000.0)); + } + } + private void clearPersistentSettings(Setting... settings) { Settings.Builder clearedSettings = Settings.builder(); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java index 1939f81353c0e..abd4f6b49d7b4 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java @@ -36,6 +36,7 @@ import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.junit.Before; @@ -338,7 +339,15 @@ private void assertCancelled(ActionFuture response) throws Ex */ assertThat( cancelException.getMessage(), - in(List.of("test cancel", "task cancelled", "request cancelled test cancel", "parent task was cancelled [test cancel]")) + in( + List.of( + "test cancel", + "task cancelled", + "request cancelled test cancel", + "parent task was cancelled [test cancel]", + "cancelled on failure" + ) + ) ); assertBusy( () -> assertThat( @@ -434,6 +443,7 @@ protected void doRun() throws Exception { allowedFetching.countDown(); } Exception failure = expectThrows(Exception.class, () -> future.actionGet().close()); + EsqlTestUtils.assertEsqlFailure(failure); assertThat(failure.getMessage(), containsString("failed to fetch pages")); // If we proceed without waiting for pages, we might cancel the main request before starting the data-node request. // As a result, the exchange sinks on data-nodes won't be removed until the inactive_timeout elapses, which is diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java index e9eada5def0dc..72a60a6b6b928 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java @@ -23,6 +23,7 @@ import org.elasticsearch.test.disruption.ServiceDisruptionScheme; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.transport.TransportSettings; +import org.elasticsearch.xpack.esql.EsqlTestUtils; import java.util.ArrayList; import java.util.Collection; @@ -111,6 +112,7 @@ private EsqlQueryResponse runQueryWithDisruption(EsqlQueryRequest request) { assertTrue("request must be failed or completed after clearing disruption", future.isDone()); ensureBlocksReleased(); logger.info("--> failed to execute esql query with disruption; retrying...", e); + EsqlTestUtils.assertEsqlFailure(e); return client().execute(EsqlQueryAction.INSTANCE, request).actionGet(2, TimeUnit.MINUTES); } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/TermIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/TermIT.java new file mode 100644 index 0000000000000..4bb4897c9db5f --- /dev/null +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/plugin/TermIT.java @@ -0,0 +1,139 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.xpack.esql.VerificationException; +import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase; +import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.action.EsqlQueryRequest; +import org.elasticsearch.xpack.esql.action.EsqlQueryResponse; +import org.junit.Before; + +import java.util.List; + +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.CoreMatchers.containsString; + +public class TermIT extends AbstractEsqlIntegTestCase { + + @Before + public void setupIndex() { + createAndPopulateIndex(); + } + + @Override + protected EsqlQueryResponse run(EsqlQueryRequest request) { + assumeTrue("term function capability not available", EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()); + return super.run(request); + } + + public void testSimpleTermQuery() throws Exception { + var query = """ + FROM test + | WHERE term(content,"dog") + | KEEP id + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id")); + assertColumnTypes(resp.columns(), List.of("integer")); + assertValues(resp.values(), List.of(List.of(1), List.of(3), List.of(4), List.of(5))); + } + } + + public void testTermWithinEval() { + var query = """ + FROM test + | EVAL term_query = term(title,"fox") + """; + + var error = expectThrows(VerificationException.class, () -> run(query)); + assertThat(error.getMessage(), containsString("[Term] function is only supported in WHERE commands")); + } + + public void testMultipleTerm() { + var query = """ + FROM test + | WHERE term(content,"fox") AND term(content,"brown") + | KEEP id + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id")); + assertColumnTypes(resp.columns(), List.of("integer")); + assertValues(resp.values(), List.of(List.of(2), List.of(4), List.of(5))); + } + } + + public void testNotWhereTerm() { + var query = """ + FROM test + | WHERE NOT term(content,"brown") + | KEEP id + | SORT id + """; + + try (var resp = run(query)) { + assertColumnNames(resp.columns(), List.of("id")); + assertColumnTypes(resp.columns(), List.of("integer")); + assertValues(resp.values(), List.of(List.of(3))); + } + } + + private void createAndPopulateIndex() { + var indexName = "test"; + var client = client().admin().indices(); + var CreateRequest = client.prepareCreate(indexName) + .setSettings(Settings.builder().put("index.number_of_shards", 1)) + .setMapping("id", "type=integer", "content", "type=text"); + assertAcked(CreateRequest); + client().prepareBulk() + .add( + new IndexRequest(indexName).id("1") + .source("id", 1, "content", "The quick brown animal swiftly jumps over a lazy dog", "title", "A Swift Fox's Journey") + ) + .add( + new IndexRequest(indexName).id("2") + .source("id", 2, "content", "A speedy brown fox hops effortlessly over a sluggish canine", "title", "The Fox's Leap") + ) + .add( + new IndexRequest(indexName).id("3") + .source("id", 3, "content", "Quick and nimble, the fox vaults over the lazy dog", "title", "Brown Fox in Action") + ) + .add( + new IndexRequest(indexName).id("4") + .source( + "id", + 4, + "content", + "A fox that is quick and brown jumps over a dog that is quite lazy", + "title", + "Speedy Animals" + ) + ) + .add( + new IndexRequest(indexName).id("5") + .source( + "id", + 5, + "content", + "With agility, a quick brown fox bounds over a slow-moving dog", + "title", + "Foxes and Canines" + ) + ) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .get(); + ensureYellow(indexName); + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsMillisNanosEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsMillisNanosEvaluator.java new file mode 100644 index 0000000000000..b5013c4080507 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsMillisNanosEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Equals}. + * This class is generated. Do not edit it. + */ +public final class EqualsMillisNanosEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public EqualsMillisNanosEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(Equals.processMillisNanos(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, Equals.processMillisNanos(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "EqualsMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public EqualsMillisNanosEvaluator get(DriverContext context) { + return new EqualsMillisNanosEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "EqualsMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsNanosMillisEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsNanosMillisEvaluator.java new file mode 100644 index 0000000000000..3ed1e922608e6 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsNanosMillisEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link Equals}. + * This class is generated. Do not edit it. + */ +public final class EqualsNanosMillisEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public EqualsNanosMillisEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(Equals.processNanosMillis(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, Equals.processNanosMillis(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "EqualsNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public EqualsNanosMillisEvaluator get(DriverContext context) { + return new EqualsNanosMillisEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "EqualsNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanMillisNanosEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanMillisNanosEvaluator.java new file mode 100644 index 0000000000000..bdd877c7f866e --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanMillisNanosEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link GreaterThan}. + * This class is generated. Do not edit it. + */ +public final class GreaterThanMillisNanosEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public GreaterThanMillisNanosEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(GreaterThan.processMillisNanos(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, GreaterThan.processMillisNanos(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "GreaterThanMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public GreaterThanMillisNanosEvaluator get(DriverContext context) { + return new GreaterThanMillisNanosEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "GreaterThanMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanNanosMillisEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanNanosMillisEvaluator.java new file mode 100644 index 0000000000000..d509547eb17ce --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanNanosMillisEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link GreaterThan}. + * This class is generated. Do not edit it. + */ +public final class GreaterThanNanosMillisEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public GreaterThanNanosMillisEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(GreaterThan.processNanosMillis(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, GreaterThan.processNanosMillis(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "GreaterThanNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public GreaterThanNanosMillisEvaluator get(DriverContext context) { + return new GreaterThanNanosMillisEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "GreaterThanNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqualMillisNanosEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqualMillisNanosEvaluator.java new file mode 100644 index 0000000000000..7a0da0a55d0dc --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqualMillisNanosEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link GreaterThanOrEqual}. + * This class is generated. Do not edit it. + */ +public final class GreaterThanOrEqualMillisNanosEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public GreaterThanOrEqualMillisNanosEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(GreaterThanOrEqual.processMillisNanos(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, GreaterThanOrEqual.processMillisNanos(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "GreaterThanOrEqualMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public GreaterThanOrEqualMillisNanosEvaluator get(DriverContext context) { + return new GreaterThanOrEqualMillisNanosEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "GreaterThanOrEqualMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqualNanosMillisEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqualNanosMillisEvaluator.java new file mode 100644 index 0000000000000..d4386a64aaf8a --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqualNanosMillisEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link GreaterThanOrEqual}. + * This class is generated. Do not edit it. + */ +public final class GreaterThanOrEqualNanosMillisEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public GreaterThanOrEqualNanosMillisEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(GreaterThanOrEqual.processNanosMillis(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, GreaterThanOrEqual.processNanosMillis(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "GreaterThanOrEqualNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public GreaterThanOrEqualNanosMillisEvaluator get(DriverContext context) { + return new GreaterThanOrEqualNanosMillisEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "GreaterThanOrEqualNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanMillisNanosEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanMillisNanosEvaluator.java new file mode 100644 index 0000000000000..21d7d50af5b1e --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanMillisNanosEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link LessThan}. + * This class is generated. Do not edit it. + */ +public final class LessThanMillisNanosEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public LessThanMillisNanosEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(LessThan.processMillisNanos(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, LessThan.processMillisNanos(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "LessThanMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public LessThanMillisNanosEvaluator get(DriverContext context) { + return new LessThanMillisNanosEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "LessThanMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanNanosMillisEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanNanosMillisEvaluator.java new file mode 100644 index 0000000000000..48593f9d537f3 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanNanosMillisEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link LessThan}. + * This class is generated. Do not edit it. + */ +public final class LessThanNanosMillisEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public LessThanNanosMillisEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(LessThan.processNanosMillis(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, LessThan.processNanosMillis(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "LessThanNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public LessThanNanosMillisEvaluator get(DriverContext context) { + return new LessThanNanosMillisEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "LessThanNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqualMillisNanosEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqualMillisNanosEvaluator.java new file mode 100644 index 0000000000000..06973e71e834a --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqualMillisNanosEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link LessThanOrEqual}. + * This class is generated. Do not edit it. + */ +public final class LessThanOrEqualMillisNanosEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public LessThanOrEqualMillisNanosEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(LessThanOrEqual.processMillisNanos(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, LessThanOrEqual.processMillisNanos(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "LessThanOrEqualMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public LessThanOrEqualMillisNanosEvaluator get(DriverContext context) { + return new LessThanOrEqualMillisNanosEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "LessThanOrEqualMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqualNanosMillisEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqualNanosMillisEvaluator.java new file mode 100644 index 0000000000000..4763629873d02 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqualNanosMillisEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link LessThanOrEqual}. + * This class is generated. Do not edit it. + */ +public final class LessThanOrEqualNanosMillisEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public LessThanOrEqualNanosMillisEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(LessThanOrEqual.processNanosMillis(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, LessThanOrEqual.processNanosMillis(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "LessThanOrEqualNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public LessThanOrEqualNanosMillisEvaluator get(DriverContext context) { + return new LessThanOrEqualNanosMillisEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "LessThanOrEqualNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEqualsMillisNanosEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEqualsMillisNanosEvaluator.java new file mode 100644 index 0000000000000..9bede03737a5f --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEqualsMillisNanosEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link NotEquals}. + * This class is generated. Do not edit it. + */ +public final class NotEqualsMillisNanosEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public NotEqualsMillisNanosEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(NotEquals.processMillisNanos(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, NotEquals.processMillisNanos(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "NotEqualsMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public NotEqualsMillisNanosEvaluator get(DriverContext context) { + return new NotEqualsMillisNanosEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "NotEqualsMillisNanosEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEqualsNanosMillisEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEqualsNanosMillisEvaluator.java new file mode 100644 index 0000000000000..e8e28eec7ee27 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEqualsNanosMillisEvaluator.java @@ -0,0 +1,148 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.predicate.operator.comparison; + +import java.lang.IllegalArgumentException; +import java.lang.Override; +import java.lang.String; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link NotEquals}. + * This class is generated. Do not edit it. + */ +public final class NotEqualsNanosMillisEvaluator implements EvalOperator.ExpressionEvaluator { + private final Source source; + + private final EvalOperator.ExpressionEvaluator lhs; + + private final EvalOperator.ExpressionEvaluator rhs; + + private final DriverContext driverContext; + + private Warnings warnings; + + public NotEqualsNanosMillisEvaluator(Source source, EvalOperator.ExpressionEvaluator lhs, + EvalOperator.ExpressionEvaluator rhs, DriverContext driverContext) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock lhsBlock = (LongBlock) lhs.eval(page)) { + try (LongBlock rhsBlock = (LongBlock) rhs.eval(page)) { + LongVector lhsVector = lhsBlock.asVector(); + if (lhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + LongVector rhsVector = rhsBlock.asVector(); + if (rhsVector == null) { + return eval(page.getPositionCount(), lhsBlock, rhsBlock); + } + return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + } + } + } + + public BooleanBlock eval(int positionCount, LongBlock lhsBlock, LongBlock rhsBlock) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + if (lhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (lhsBlock.getValueCount(p) != 1) { + if (lhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + if (rhsBlock.isNull(p)) { + result.appendNull(); + continue position; + } + if (rhsBlock.getValueCount(p) != 1) { + if (rhsBlock.getValueCount(p) > 1) { + warnings().registerException(new IllegalArgumentException("single-value function encountered multi-value")); + } + result.appendNull(); + continue position; + } + result.appendBoolean(NotEquals.processNanosMillis(lhsBlock.getLong(lhsBlock.getFirstValueIndex(p)), rhsBlock.getLong(rhsBlock.getFirstValueIndex(p)))); + } + return result.build(); + } + } + + public BooleanVector eval(int positionCount, LongVector lhsVector, LongVector rhsVector) { + try(BooleanVector.FixedBuilder result = driverContext.blockFactory().newBooleanVectorFixedBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + result.appendBoolean(p, NotEquals.processNanosMillis(lhsVector.getLong(p), rhsVector.getLong(p))); + } + return result.build(); + } + } + + @Override + public String toString() { + return "NotEqualsNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(lhs, rhs); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory lhs; + + private final EvalOperator.ExpressionEvaluator.Factory rhs; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory lhs, + EvalOperator.ExpressionEvaluator.Factory rhs) { + this.source = source; + this.lhs = lhs; + this.rhs = rhs; + } + + @Override + public NotEqualsNanosMillisEvaluator get(DriverContext context) { + return new NotEqualsNanosMillisEvaluator(source, lhs.get(context), rhs.get(context), context); + } + + @Override + public String toString() { + return "NotEqualsNanosMillisEvaluator[" + "lhs=" + lhs + ", rhs=" + rhs + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 19ba6a5151eaf..7c3f2a45df6a0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -345,6 +345,11 @@ public enum Cap { */ DATE_NANOS_BINARY_COMPARISON(), + /** + * Support for mixed comparisons between nanosecond and millisecond dates + */ + DATE_NANOS_COMPARE_TO_MILLIS(), + /** * Support Least and Greatest functions on Date Nanos type */ @@ -550,7 +555,12 @@ public enum Cap { /** * Support the "METADATA _score" directive to enable _score column. */ - METADATA_SCORE(Build.current().isSnapshot()); + METADATA_SCORE(Build.current().isSnapshot()), + + /** + * Term function + */ + TERM_FUNCTION(Build.current().isSnapshot()); private final boolean enabled; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfo.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfo.java index ba7a7e8266845..52170dfb05256 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfo.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlExecutionInfo.java @@ -107,7 +107,7 @@ public EsqlExecutionInfo(StreamInput in) throws IOException { clusterList.forEach(c -> m.put(c.getClusterAlias(), c)); this.clusterInfo = m; } - if (in.getTransportVersion().onOrAfter(TransportVersions.OPT_IN_ESQL_CCS_EXECUTION_INFO)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.includeCCSMetadata = in.readBoolean(); } else { this.includeCCSMetadata = false; @@ -124,7 +124,7 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeCollection(Collections.emptyList()); } - if (out.getTransportVersion().onOrAfter(TransportVersions.OPT_IN_ESQL_CCS_EXECUTION_INFO)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeBoolean(includeCCSMetadata); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java index 77aed298baea5..dc0e9fd1fb06d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlQueryResponse.java @@ -113,7 +113,7 @@ static EsqlQueryResponse deserialize(BlockStreamInput in) throws IOException { } boolean columnar = in.readBoolean(); EsqlExecutionInfo executionInfo = null; - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CCS_EXECUTION_INFO)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { executionInfo = in.readOptionalWriteable(EsqlExecutionInfo::new); } return new EsqlQueryResponse(columns, pages, profile, columnar, asyncExecutionId, isRunning, isAsync, executionInfo); @@ -132,7 +132,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalWriteable(profile); } out.writeBoolean(columnar); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CCS_EXECUTION_INFO)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalWriteable(executionInfo); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResolveFieldsAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResolveFieldsAction.java index f7e6793fc4fb3..f7fd991a9ef16 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResolveFieldsAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResolveFieldsAction.java @@ -58,7 +58,7 @@ void executeRemoteRequest( ActionListener remoteListener ) { remoteClient.getConnection(remoteRequest, remoteListener.delegateFailure((l, conn) -> { - var remoteAction = conn.getTransportVersion().onOrAfter(TransportVersions.ESQL_ORIGINAL_INDICES) + var remoteAction = conn.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? RESOLVE_REMOTE_TYPE : TransportFieldCapabilitiesAction.REMOTE_TYPE; remoteClient.execute(conn, remoteAction, remoteRequest, l); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java index 1c88fe6f45d81..fb7e0f651458c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlResponseListener.java @@ -22,6 +22,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.action.RestRefCountedChunkedToXContentListener; import org.elasticsearch.xcontent.MediaType; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.esql.arrow.ArrowFormat; import org.elasticsearch.xpack.esql.arrow.ArrowResponse; import org.elasticsearch.xpack.esql.formatter.TextFormat; @@ -87,7 +88,7 @@ public TimeValue stop() { /** * Keep the initial query for logging purposes. */ - private final String esqlQuery; + private final String esqlQueryOrId; /** * Stop the time it took to build a response to later log it. Use something thread-safe here because stopping time requires state and * {@link EsqlResponseListener} might be used from different threads. @@ -98,29 +99,23 @@ public TimeValue stop() { * To correctly time the execution of a request, a {@link EsqlResponseListener} must be constructed immediately before execution begins. */ public EsqlResponseListener(RestChannel channel, RestRequest restRequest, EsqlQueryRequest esqlRequest) { - super(channel); + this(channel, restRequest, esqlRequest.query(), EsqlMediaTypeParser.getResponseMediaType(restRequest, esqlRequest)); + } + /** + * Async query GET API does not have an EsqlQueryRequest. + */ + public EsqlResponseListener(RestChannel channel, RestRequest getRequest) { + this(channel, getRequest, getRequest.param("id"), EsqlMediaTypeParser.getResponseMediaType(getRequest, XContentType.JSON)); + } + + private EsqlResponseListener(RestChannel channel, RestRequest restRequest, String esqlQueryOrId, MediaType mediaType) { + super(channel); this.channel = channel; this.restRequest = restRequest; - this.esqlQuery = esqlRequest.query(); - mediaType = EsqlMediaTypeParser.getResponseMediaType(restRequest, esqlRequest); - - /* - * Special handling for the "delimiter" parameter which should only be - * checked for being present or not in the case of CSV format. We cannot - * override {@link BaseRestHandler#responseParams()} because this - * parameter should only be checked for CSV, not other formats. - */ - if (mediaType != CSV && restRequest.hasParam(URL_PARAM_DELIMITER)) { - String message = String.format( - Locale.ROOT, - "parameter: [%s] can only be used with the format [%s] for request [%s]", - URL_PARAM_DELIMITER, - CSV.queryParameter(), - restRequest.path() - ); - throw new IllegalArgumentException(message); - } + this.esqlQueryOrId = esqlQueryOrId; + this.mediaType = mediaType; + checkDelimiter(); } @Override @@ -197,14 +192,18 @@ public ActionListener wrapWithLogging() { listener.onResponse(r); // At this point, the StopWatch should already have been stopped, so we log a consistent time. LOGGER.debug( - "Finished execution of ESQL query.\nQuery string: [{}]\nExecution time: [{}]ms", - esqlQuery, + "Finished execution of ESQL query.\nQuery string or async ID: [{}]\nExecution time: [{}]ms", + esqlQueryOrId, getTook(r, TimeUnit.MILLISECONDS) ); }, ex -> { // In case of failure, stop the time manually before sending out the response. long timeMillis = getTook(null, TimeUnit.MILLISECONDS); - LOGGER.debug("Failed execution of ESQL query.\nQuery string: [{}]\nExecution time: [{}]ms", esqlQuery, timeMillis); + LOGGER.debug( + "Failed execution of ESQL query.\nQuery string or async ID: [{}]\nExecution time: [{}]ms", + esqlQueryOrId, + timeMillis + ); listener.onFailure(ex); }); } @@ -213,4 +212,23 @@ static void logOnFailure(Throwable throwable) { RestStatus status = ExceptionsHelper.status(throwable); LOGGER.log(status.getStatus() >= 500 ? Level.WARN : Level.DEBUG, () -> "Request failed with status [" + status + "]: ", throwable); } + + /* + * Special handling for the "delimiter" parameter which should only be + * checked for being present or not in the case of CSV format. We cannot + * override {@link BaseRestHandler#responseParams()} because this + * parameter should only be checked for CSV, not other formats. + */ + private void checkDelimiter() { + if (mediaType != CSV && restRequest.hasParam(URL_PARAM_DELIMITER)) { + String message = String.format( + Locale.ROOT, + "parameter: [%s] can only be used with the format [%s] for request [%s]", + URL_PARAM_DELIMITER, + CSV.queryParameter(), + restRequest.path() + ); + throw new IllegalArgumentException(message); + } + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/RestEsqlGetAsyncResultAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/RestEsqlGetAsyncResultAction.java index b5a1821350e5e..848a75d7fb19f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/RestEsqlGetAsyncResultAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/RestEsqlGetAsyncResultAction.java @@ -12,7 +12,6 @@ import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.Scope; import org.elasticsearch.rest.ServerlessScope; -import org.elasticsearch.rest.action.RestRefCountedChunkedToXContentListener; import org.elasticsearch.xpack.core.async.GetAsyncResultRequest; import java.util.List; @@ -43,7 +42,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli if (request.hasParam("keep_alive")) { get.setKeepAlive(request.paramAsTime("keep_alive", get.getKeepAlive())); } - return channel -> client.execute(EsqlAsyncGetResultAction.INSTANCE, get, new RestRefCountedChunkedToXContentListener<>(channel)); + return channel -> client.execute(EsqlAsyncGetResultAction.INSTANCE, get, new EsqlResponseListener(channel, request)); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index b847508d2b161..cf91c7df9a034 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -633,9 +633,10 @@ private Join resolveLookupJoin(LookupJoin join) { config = new JoinConfig(coreJoin, leftKeys, leftKeys, rightKeys); join = new LookupJoin(join.source(), join.left(), join.right(), config); - } - // everything else is unsupported for now - else { + } else if (type != JoinTypes.LEFT) { + // everything else is unsupported for now + // LEFT can only happen by being mapped from a USING above. So we need to exclude this as well because this rule can be run + // more than once. UnresolvedAttribute errorAttribute = new UnresolvedAttribute(join.source(), "unsupported", "Unsupported join type"); // add error message return join.withConfig(new JoinConfig(type, singletonList(errorAttribute), emptyList(), emptyList())); @@ -651,7 +652,7 @@ private List resolveUsingColumns(List cols, List"), enrichResolution); + this( + configuration, + functionRegistry, + indexResolution, + IndexResolution.invalid("AnalyzerContext constructed without any lookup join resolution"), + enrichResolution + ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index 49d8a5ee8caad..ecfe1aa7f9169 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.analysis; -import org.elasticsearch.index.IndexMode; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; import org.elasticsearch.xpack.esql.common.Failure; @@ -37,10 +36,12 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Kql; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Term; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.EsqlBinaryComparison; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; @@ -55,7 +56,8 @@ import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; import org.elasticsearch.xpack.esql.plan.logical.Row; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; -import org.elasticsearch.xpack.esql.plan.logical.join.LookupJoin; +import org.elasticsearch.xpack.esql.plan.logical.join.Join; +import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig; import org.elasticsearch.xpack.esql.stats.FeatureMetric; import org.elasticsearch.xpack.esql.stats.Metrics; @@ -172,20 +174,6 @@ else if (p instanceof Lookup lookup) { else { lookup.matchFields().forEach(unresolvedExpressions); } - } else if (p instanceof LookupJoin lj) { - // expect right side to always be a lookup index - lj.right().forEachUp(EsRelation.class, r -> { - if (r.indexMode() != IndexMode.LOOKUP) { - failures.add( - fail( - r, - "LOOKUP JOIN right side [{}] must be a lookup index (index_mode=lookup, not [{}]", - r.index().name(), - r.indexMode().getName() - ) - ); - } - }); } else { @@ -217,6 +205,7 @@ else if (p instanceof Lookup lookup) { checkSort(p, failures); checkFullTextQueryFunctions(p, failures); + checkJoin(p, failures); }); checkRemoteEnrich(plan, failures); checkMetadataScoreNameReserved(plan, failures); @@ -608,7 +597,11 @@ private void gatherMetrics(LogicalPlan plan, BitSet b) { } /** - * Limit QL's comparisons to types we support. + * Limit QL's comparisons to types we support. This should agree with + * {@link EsqlBinaryComparison}'s checkCompatibility method + * + * @return null if the given binary comparison has valid input types, + * otherwise a failure message suitable to return to the user. */ public static Failure validateBinaryComparison(BinaryComparison bc) { if (bc.left().dataType().isNumeric()) { @@ -653,6 +646,12 @@ public static Failure validateBinaryComparison(BinaryComparison bc) { if (DataType.isString(bc.left().dataType()) && DataType.isString(bc.right().dataType())) { return null; } + + // Allow mixed millisecond and nanosecond binary comparisons + if (bc.left().dataType().isDate() && bc.right().dataType().isDate()) { + return null; + } + if (bc.left().dataType() != bc.right().dataType()) { return fail( bc, @@ -791,6 +790,35 @@ private static void checkNotPresentInDisjunctions( }); } + /** + * Checks Joins for invalid usage. + * + * @param plan root plan to check + * @param failures failures found + */ + private static void checkJoin(LogicalPlan plan, Set failures) { + if (plan instanceof Join join) { + JoinConfig config = join.config(); + for (int i = 0; i < config.leftFields().size(); i++) { + Attribute leftField = config.leftFields().get(i); + Attribute rightField = config.rightFields().get(i); + if (leftField.dataType() != rightField.dataType()) { + failures.add( + fail( + leftField, + "JOIN left field [{}] of type [{}] is incompatible with right field [{}] of type [{}]", + leftField.name(), + leftField.dataType(), + rightField.name(), + rightField.dataType() + ) + ); + } + } + + } + } + /** * Checks full text query functions for invalid usage. * @@ -821,6 +849,14 @@ private static void checkFullTextQueryFunctions(LogicalPlan plan, Set f m -> "[" + m.functionName() + "] " + m.functionType(), failures ); + checkCommandsBeforeExpression( + plan, + condition, + Term.class, + lp -> (lp instanceof Limit == false) && (lp instanceof Aggregate == false), + m -> "[" + m.functionName() + "] " + m.functionType(), + failures + ); checkNotPresentInDisjunctions(condition, ftf -> "[" + ftf.functionName() + "] " + ftf.functionType(), failures); checkFullTextFunctionsParents(condition, failures); } else { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/ResolvedEnrichPolicy.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/ResolvedEnrichPolicy.java index e891089aa55b5..64595e776a96e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/ResolvedEnrichPolicy.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/ResolvedEnrichPolicy.java @@ -35,8 +35,7 @@ public ResolvedEnrichPolicy(StreamInput in) throws IOException { } private static Reader getEsFieldReader(StreamInput in) { - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_ES_FIELD_CACHED_SERIALIZATION) - || in.getTransportVersion().isPatchFrom(TransportVersions.V_8_15_2)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_2)) { return EsField::readFrom; } return EsField::new; @@ -56,8 +55,7 @@ public void writeTo(StreamOutput out) throws IOException { */ (o, v) -> { var field = new EsField(v.getName(), v.getDataType(), v.getProperties(), v.isAggregatable(), v.isAlias()); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_ES_FIELD_CACHED_SERIALIZATION) - || out.getTransportVersion().isPatchFrom(TransportVersions.V_8_15_2)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_2)) { field.writeTo(o); } else { field.writeContent(o); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index c66a5293eb14a..3749b46879354 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Kql; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Term; import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; @@ -424,7 +425,8 @@ private static FunctionDefinition[][] snapshotFunctions() { // This is an experimental function and can be removed without notice. def(Delay.class, Delay::new, "delay"), def(Kql.class, Kql::new, "kql"), - def(Rate.class, Rate::withUnresolvedTimestamp, "rate") } }; + def(Rate.class, Rate::withUnresolvedTimestamp, "rate"), + def(Term.class, Term::new, "term") } }; } public EsqlFunctionRegistry snapshotRegistry() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java index d372eddb961ae..089f6db373c54 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java @@ -81,8 +81,7 @@ private UnsupportedAttribute(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), readCachedStringWithVersionCheck(in), - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_ES_FIELD_CACHED_SERIALIZATION) - || in.getTransportVersion().isPatchFrom(TransportVersions.V_8_15_2) ? EsField.readFrom(in) : new UnsupportedEsField(in), + in.getTransportVersion().onOrAfter(TransportVersions.V_8_15_2) ? EsField.readFrom(in) : new UnsupportedEsField(in), in.readOptionalString(), NameId.readFrom((PlanStreamInput) in) ); @@ -93,8 +92,7 @@ public void writeTo(StreamOutput out) throws IOException { if (((PlanStreamOutput) out).writeAttributeCacheHeader(this)) { Source.EMPTY.writeTo(out); writeCachedStringWithVersionCheck(out, name()); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_ES_FIELD_CACHED_SERIALIZATION) - || out.getTransportVersion().isPatchFrom(TransportVersions.V_8_15_2)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_2)) { field().writeTo(out); } else { field().writeContent(out); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index 87efccfc90ab3..265b08de5556d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -53,10 +53,8 @@ protected AggregateFunction(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) - ? in.readNamedWriteable(Expression.class) - : Literal.TRUE, - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteable(Expression.class) : Literal.TRUE, + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteableCollectionAsList(Expression.class) : emptyList() ); @@ -66,7 +64,7 @@ protected AggregateFunction(StreamInput in) throws IOException { public final void writeTo(StreamOutput out) throws IOException { Source.EMPTY.writeTo(out); out.writeNamedWriteable(field); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeNamedWriteable(filter); out.writeNamedWriteableCollection(parameters); } else { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java index 2e45b1c1fe082..7436db9e00dd2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java @@ -147,10 +147,8 @@ private CountDistinct(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) - ? in.readNamedWriteable(Expression.class) - : Literal.TRUE, - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteable(Expression.class) : Literal.TRUE, + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteableCollectionAsList(Expression.class) : nullSafeList(in.readOptionalNamedWriteable(Expression.class)) ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FromPartial.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FromPartial.java index 0f9037a28d7d7..a67b87c7617c4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FromPartial.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/FromPartial.java @@ -58,10 +58,8 @@ private FromPartial(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) - ? in.readNamedWriteable(Expression.class) - : Literal.TRUE, - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteable(Expression.class) : Literal.TRUE, + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteableCollectionAsList(Expression.class).get(0) : in.readNamedWriteable(Expression.class) ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java index febd9f28b2291..0d57267da1e29 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java @@ -92,10 +92,8 @@ private Percentile(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) - ? in.readNamedWriteable(Expression.class) - : Literal.TRUE, - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteable(Expression.class) : Literal.TRUE, + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteableCollectionAsList(Expression.class).get(0) : in.readNamedWriteable(Expression.class) ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java index b7b04658f8d58..87ac9b77a6826 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java @@ -74,10 +74,8 @@ public Rate(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) - ? in.readNamedWriteable(Expression.class) - : Literal.TRUE, - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteable(Expression.class) : Literal.TRUE, + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteableCollectionAsList(Expression.class) : nullSafeList(in.readNamedWriteable(Expression.class), in.readOptionalNamedWriteable(Expression.class)) ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ToPartial.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ToPartial.java index cffac616b3c8c..a2856f60e4c51 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ToPartial.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ToPartial.java @@ -80,10 +80,8 @@ private ToPartial(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) - ? in.readNamedWriteable(Expression.class) - : Literal.TRUE, - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteable(Expression.class) : Literal.TRUE, + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteableCollectionAsList(Expression.class).get(0) : in.readNamedWriteable(Expression.class) ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java index e0a7da806b3ac..40777b4d78dc2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java @@ -81,10 +81,8 @@ private Top(StreamInput in) throws IOException { super( Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) - ? in.readNamedWriteable(Expression.class) - : Literal.TRUE, - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteable(Expression.class) : Literal.TRUE, + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteableCollectionAsList(Expression.class) : asList(in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class)) ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java index dbcc50cea3b9b..49c68d002440f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java @@ -68,10 +68,8 @@ private WeightedAvg(StreamInput in) throws IOException { this( Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) - ? in.readNamedWriteable(Expression.class) - : Literal.TRUE, - in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteable(Expression.class) : Literal.TRUE, + in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readNamedWriteableCollectionAsList(Expression.class).get(0) : in.readNamedWriteable(Expression.class) ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java index 8804a031de78c..d6b79d16b74f6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextWritables.java @@ -29,6 +29,9 @@ public static List getNamedWriteables() { if (EsqlCapabilities.Cap.KQL_FUNCTION.isEnabled()) { entries.add(Kql.ENTRY); } + if (EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()) { + entries.add(Term.ENTRY); + } return Collections.unmodifiableList(entries); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java new file mode 100644 index 0000000000000..125a5b02b6e1c --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Term.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.fulltext; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.capabilities.Validatable; +import org.elasticsearch.xpack.esql.common.Failure; +import org.elasticsearch.xpack.esql.common.Failures; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.querydsl.query.TermQuery; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isString; + +/** + * Full text function that performs a {@link TermQuery} . + */ +public class Term extends FullTextFunction implements Validatable { + + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Term", Term::readFrom); + + private final Expression field; + + @FunctionInfo( + returnType = "boolean", + preview = true, + description = "Performs a Term query on the specified field. Returns true if the provided term matches the row.", + examples = { @Example(file = "term-function", tag = "term-with-field") } + ) + public Term( + Source source, + @Param(name = "field", type = { "keyword", "text" }, description = "Field that the query will target.") Expression field, + @Param( + name = "query", + type = { "keyword", "text" }, + description = "Term you wish to find in the provided field." + ) Expression termQuery + ) { + super(source, termQuery, List.of(field, termQuery)); + this.field = field; + } + + private static Term readFrom(StreamInput in) throws IOException { + Source source = Source.readFrom((PlanStreamInput) in); + Expression field = in.readNamedWriteable(Expression.class); + Expression query = in.readNamedWriteable(Expression.class); + return new Term(source, field, query); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(field()); + out.writeNamedWriteable(query()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected TypeResolution resolveNonQueryParamTypes() { + return isNotNull(field, sourceText(), FIRST).and(isString(field, sourceText(), FIRST)).and(super.resolveNonQueryParamTypes()); + } + + @Override + public void validate(Failures failures) { + if (field instanceof FieldAttribute == false) { + failures.add( + Failure.fail( + field, + "[{}] {} cannot operate on [{}], which is not a field from an index mapping", + functionName(), + functionType(), + field.sourceText() + ) + ); + } + } + + @Override + public Expression replaceChildren(List newChildren) { + return new Term(source(), newChildren.get(0), newChildren.get(1)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, Term::new, field, query()); + } + + protected TypeResolutions.ParamOrdinal queryParamOrdinal() { + return SECOND; + } + + public Expression field() { + return field; + } + + @Override + public String functionName() { + return ENTRY.name; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/Equals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/Equals.java index 6bb249385affe..464553977d3cc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/Equals.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/Equals.java @@ -8,6 +8,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.time.DateUtils; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; @@ -95,11 +96,28 @@ public Equals( description = "An expression." ) Expression right ) { - super(source, left, right, BinaryComparisonOperation.EQ, evaluatorMap); + super( + source, + left, + right, + BinaryComparisonOperation.EQ, + evaluatorMap, + EqualsNanosMillisEvaluator.Factory::new, + EqualsMillisNanosEvaluator.Factory::new + ); } public Equals(Source source, Expression left, Expression right, ZoneId zoneId) { - super(source, left, right, BinaryComparisonOperation.EQ, zoneId, evaluatorMap); + super( + source, + left, + right, + BinaryComparisonOperation.EQ, + zoneId, + evaluatorMap, + EqualsNanosMillisEvaluator.Factory::new, + EqualsMillisNanosEvaluator.Factory::new + ); } @Override @@ -142,6 +160,16 @@ static boolean processLongs(long lhs, long rhs) { return lhs == rhs; } + @Evaluator(extraName = "MillisNanos") + static boolean processMillisNanos(long lhs, long rhs) { + return DateUtils.compareNanosToMillis(rhs, lhs) == 0; + } + + @Evaluator(extraName = "NanosMillis") + static boolean processNanosMillis(long lhs, long rhs) { + return DateUtils.compareNanosToMillis(lhs, rhs) == 0; + } + @Evaluator(extraName = "Doubles") static boolean processDoubles(double lhs, double rhs) { return lhs == rhs; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java index cbbf87fb6c4cb..217c6528c9fd6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java @@ -35,6 +35,8 @@ public abstract class EsqlBinaryComparison extends BinaryComparison implements E private final Map evaluatorMap; private final BinaryComparisonOperation functionType; + private final EsqlArithmeticOperation.BinaryEvaluator nanosToMillisEvaluator; + private final EsqlArithmeticOperation.BinaryEvaluator millisToNanosEvaluator; @FunctionalInterface public interface BinaryOperatorConstructor { @@ -118,9 +120,11 @@ protected EsqlBinaryComparison( Expression left, Expression right, BinaryComparisonOperation operation, - Map evaluatorMap + Map evaluatorMap, + EsqlArithmeticOperation.BinaryEvaluator nanosToMillisEvaluator, + EsqlArithmeticOperation.BinaryEvaluator millisToNanosEvaluator ) { - this(source, left, right, operation, null, evaluatorMap); + this(source, left, right, operation, null, evaluatorMap, nanosToMillisEvaluator, millisToNanosEvaluator); } protected EsqlBinaryComparison( @@ -130,11 +134,15 @@ protected EsqlBinaryComparison( BinaryComparisonOperation operation, // TODO: We are definitely not doing the right thing with this zoneId ZoneId zoneId, - Map evaluatorMap + Map evaluatorMap, + EsqlArithmeticOperation.BinaryEvaluator nanosToMillisEvaluator, + EsqlArithmeticOperation.BinaryEvaluator millisToNanosEvaluator ) { super(source, left, right, operation.shim, zoneId); this.evaluatorMap = evaluatorMap; this.functionType = operation; + this.nanosToMillisEvaluator = nanosToMillisEvaluator; + this.millisToNanosEvaluator = millisToNanosEvaluator; } public static EsqlBinaryComparison readFrom(StreamInput in) throws IOException { @@ -163,11 +171,24 @@ public BinaryComparisonOperation getFunctionType() { @Override public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { - // Our type is always boolean, so figure out the evaluator type from the inputs - DataType commonType = commonType(left().dataType(), right().dataType()); EvalOperator.ExpressionEvaluator.Factory lhs; EvalOperator.ExpressionEvaluator.Factory rhs; + // Special cases for mixed nanosecond and millisecond comparisions + if (left().dataType() == DataType.DATE_NANOS && right().dataType() == DataType.DATETIME) { + lhs = toEvaluator.apply(left()); + rhs = toEvaluator.apply(right()); + return nanosToMillisEvaluator.apply(source(), lhs, rhs); + } + + if (left().dataType() == DataType.DATETIME && right().dataType() == DataType.DATE_NANOS) { + lhs = toEvaluator.apply(left()); + rhs = toEvaluator.apply(right()); + return millisToNanosEvaluator.apply(source(), lhs, rhs); + } + + // Our type is always boolean, so figure out the evaluator type from the inputs + DataType commonType = commonType(left().dataType(), right().dataType()); if (commonType.isNumeric()) { lhs = Cast.cast(source(), left().dataType(), commonType, toEvaluator.apply(left())); rhs = Cast.cast(source(), right().dataType(), commonType, toEvaluator.apply(right())); @@ -209,7 +230,9 @@ protected TypeResolution resolveInputType(Expression e, TypeResolutions.ParamOrd } /** - * Check if the two input types are compatible for this operation + * Check if the two input types are compatible for this operation. + * NOTE: this method should be consistent with + * {@link org.elasticsearch.xpack.esql.analysis.Verifier#validateBinaryComparison(BinaryComparison)} * * @return TypeResolution.TYPE_RESOLVED iff the types are compatible. Otherwise, an appropriate type resolution error. */ @@ -225,6 +248,7 @@ protected TypeResolution checkCompatibility() { if ((leftType.isNumeric() && rightType.isNumeric()) || (DataType.isString(leftType) && DataType.isString(rightType)) + || (leftType.isDate() && rightType.isDate()) // Millis and Nanos || leftType.equals(rightType) || DataType.isNull(leftType) || DataType.isNull(rightType)) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThan.java index 3a46070389368..6087240387f01 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThan.java @@ -8,6 +8,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.time.DateUtils; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; @@ -62,11 +63,28 @@ public GreaterThan( description = "An expression." ) Expression right ) { - super(source, left, right, BinaryComparisonOperation.GT, evaluatorMap); + super( + source, + left, + right, + BinaryComparisonOperation.GT, + evaluatorMap, + GreaterThanNanosMillisEvaluator.Factory::new, + GreaterThanMillisNanosEvaluator.Factory::new + ); } public GreaterThan(Source source, Expression left, Expression right, ZoneId zoneId) { - super(source, left, right, BinaryComparisonOperation.GT, zoneId, evaluatorMap); + super( + source, + left, + right, + BinaryComparisonOperation.GT, + zoneId, + evaluatorMap, + GreaterThanNanosMillisEvaluator.Factory::new, + GreaterThanMillisNanosEvaluator.Factory::new + ); } @Override @@ -109,6 +127,17 @@ static boolean processLongs(long lhs, long rhs) { return lhs > rhs; } + @Evaluator(extraName = "MillisNanos") + static boolean processMillisNanos(long lhs, long rhs) { + // Note, parameters are reversed, so we need to invert the check. + return DateUtils.compareNanosToMillis(rhs, lhs) < 0; + } + + @Evaluator(extraName = "NanosMillis") + static boolean processNanosMillis(long lhs, long rhs) { + return DateUtils.compareNanosToMillis(lhs, rhs) > 0; + } + @Evaluator(extraName = "Doubles") static boolean processDoubles(double lhs, double rhs) { return lhs > rhs; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqual.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqual.java index 841fe5294c660..7ec1e5590bef6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqual.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqual.java @@ -8,6 +8,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.time.DateUtils; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; @@ -62,11 +63,28 @@ public GreaterThanOrEqual( description = "An expression." ) Expression right ) { - super(source, left, right, BinaryComparisonOperation.GTE, evaluatorMap); + super( + source, + left, + right, + BinaryComparisonOperation.GTE, + evaluatorMap, + GreaterThanOrEqualNanosMillisEvaluator.Factory::new, + GreaterThanOrEqualMillisNanosEvaluator.Factory::new + ); } public GreaterThanOrEqual(Source source, Expression left, Expression right, ZoneId zoneId) { - super(source, left, right, BinaryComparisonOperation.GTE, zoneId, evaluatorMap); + super( + source, + left, + right, + BinaryComparisonOperation.GTE, + zoneId, + evaluatorMap, + GreaterThanOrEqualNanosMillisEvaluator.Factory::new, + GreaterThanOrEqualMillisNanosEvaluator.Factory::new + ); } @Override @@ -109,6 +127,17 @@ static boolean processLongs(long lhs, long rhs) { return lhs >= rhs; } + @Evaluator(extraName = "MillisNanos") + static boolean processMillisNanos(long lhs, long rhs) { + // Note, parameters are reversed, so we need to invert the check. + return DateUtils.compareNanosToMillis(rhs, lhs) <= 0; + } + + @Evaluator(extraName = "NanosMillis") + static boolean processNanosMillis(long lhs, long rhs) { + return DateUtils.compareNanosToMillis(lhs, rhs) >= 0; + } + @Evaluator(extraName = "Doubles") static boolean processDoubles(double lhs, double rhs) { return lhs >= rhs; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java index eda6aadccc86a..f6c23304c189b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java @@ -151,14 +151,14 @@ public Expression replaceChildren(List newChildren) { public boolean foldable() { // QL's In fold()s to null, if value() is null, but isn't foldable() unless all children are // TODO: update this null check in QL too? - return Expressions.isNull(value) + return Expressions.isGuaranteedNull(value) || Expressions.foldable(children()) - || (Expressions.foldable(list) && list.stream().allMatch(Expressions::isNull)); + || (Expressions.foldable(list) && list.stream().allMatch(Expressions::isGuaranteedNull)); } @Override public Object fold() { - if (Expressions.isNull(value) || list.stream().allMatch(Expressions::isNull)) { + if (Expressions.isGuaranteedNull(value) || list.stream().allMatch(Expressions::isGuaranteedNull)) { return null; } return super.fold(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThan.java index 3ae7bd93092ef..5f130c054cd6f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThan.java @@ -8,6 +8,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.time.DateUtils; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; @@ -66,7 +67,16 @@ public LessThan( } public LessThan(Source source, Expression left, Expression right, ZoneId zoneId) { - super(source, left, right, BinaryComparisonOperation.LT, zoneId, evaluatorMap); + super( + source, + left, + right, + BinaryComparisonOperation.LT, + zoneId, + evaluatorMap, + LessThanNanosMillisEvaluator.Factory::new, + LessThanMillisNanosEvaluator.Factory::new + ); } @Override @@ -109,6 +119,17 @@ static boolean processLongs(long lhs, long rhs) { return lhs < rhs; } + @Evaluator(extraName = "MillisNanos") + static boolean processMillisNanos(long lhs, long rhs) { + // Note, parameters are reversed, so we need to invert the check. + return DateUtils.compareNanosToMillis(rhs, lhs) > 0; + } + + @Evaluator(extraName = "NanosMillis") + static boolean processNanosMillis(long lhs, long rhs) { + return DateUtils.compareNanosToMillis(lhs, rhs) < 0; + } + @Evaluator(extraName = "Doubles") static boolean processDoubles(double lhs, double rhs) { return lhs < rhs; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqual.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqual.java index e084eee1e8c20..0904c408bfab5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqual.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqual.java @@ -8,6 +8,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.time.DateUtils; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; @@ -66,7 +67,16 @@ public LessThanOrEqual( } public LessThanOrEqual(Source source, Expression left, Expression right, ZoneId zoneId) { - super(source, left, right, BinaryComparisonOperation.LTE, zoneId, evaluatorMap); + super( + source, + left, + right, + BinaryComparisonOperation.LTE, + zoneId, + evaluatorMap, + LessThanOrEqualNanosMillisEvaluator.Factory::new, + LessThanOrEqualMillisNanosEvaluator.Factory::new + ); } @Override @@ -109,6 +119,17 @@ static boolean processLongs(long lhs, long rhs) { return lhs <= rhs; } + @Evaluator(extraName = "MillisNanos") + static boolean processMillisNanos(long lhs, long rhs) { + // Note, parameters are reversed, so we need to invert the check. + return DateUtils.compareNanosToMillis(rhs, lhs) >= 0; + } + + @Evaluator(extraName = "NanosMillis") + static boolean processNanosMillis(long lhs, long rhs) { + return DateUtils.compareNanosToMillis(lhs, rhs) <= 0; + } + @Evaluator(extraName = "Doubles") static boolean processDoubles(double lhs, double rhs) { return lhs <= rhs; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEquals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEquals.java index 9e961c04153d6..d4f86e9a878a9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEquals.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEquals.java @@ -8,6 +8,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.time.DateUtils; import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; @@ -95,11 +96,28 @@ public NotEquals( description = "An expression." ) Expression right ) { - super(source, left, right, BinaryComparisonOperation.NEQ, evaluatorMap); + super( + source, + left, + right, + BinaryComparisonOperation.NEQ, + evaluatorMap, + NotEqualsNanosMillisEvaluator.Factory::new, + NotEqualsMillisNanosEvaluator.Factory::new + ); } public NotEquals(Source source, Expression left, Expression right, ZoneId zoneId) { - super(source, left, right, BinaryComparisonOperation.NEQ, zoneId, evaluatorMap); + super( + source, + left, + right, + BinaryComparisonOperation.NEQ, + zoneId, + evaluatorMap, + NotEqualsNanosMillisEvaluator.Factory::new, + NotEqualsMillisNanosEvaluator.Factory::new + ); } @Override @@ -117,6 +135,16 @@ static boolean processLongs(long lhs, long rhs) { return lhs != rhs; } + @Evaluator(extraName = "MillisNanos") + static boolean processMillisNanos(long lhs, long rhs) { + return DateUtils.compareNanosToMillis(rhs, lhs) != 0; + } + + @Evaluator(extraName = "NanosMillis") + static boolean processNanosMillis(long lhs, long rhs) { + return DateUtils.compareNanosToMillis(lhs, rhs) != 0; + } + @Evaluator(extraName = "Doubles") static boolean processDoubles(double lhs, double rhs) { return lhs != rhs; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/index/EsIndex.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/index/EsIndex.java index ce52b3a7611b3..ee51a6f391a65 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/index/EsIndex.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/index/EsIndex.java @@ -50,7 +50,7 @@ public void writeTo(StreamOutput out) throws IOException { @SuppressWarnings("unchecked") private static Map readIndexNameWithModes(StreamInput in) throws IOException { - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_ADD_INDEX_MODE_CONCRETE_INDICES)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { return in.readMap(IndexMode::readFrom); } else { Set indices = (Set) in.readGenericValue(); @@ -60,7 +60,7 @@ private static Map readIndexNameWithModes(StreamInput in) thr } private static void writeIndexNameWithModes(Map concreteIndices, StreamOutput out) throws IOException { - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_ADD_INDEX_MODE_CONCRETE_INDICES)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeMap(concreteIndices, (o, v) -> IndexMode.writeTo(v, out)); } else { out.writeGenericValue(concreteIndices.keySet()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java index 47e5b9acfbf9d..948fd1c683544 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java @@ -182,8 +182,7 @@ public NameId mapNameId(long l) { @Override @SuppressWarnings("unchecked") public A readAttributeWithCache(CheckedFunction constructor) throws IOException { - if (getTransportVersion().onOrAfter(TransportVersions.ESQL_ATTRIBUTE_CACHED_SERIALIZATION) - || getTransportVersion().isPatchFrom(TransportVersions.V_8_15_2)) { + if (getTransportVersion().onOrAfter(TransportVersions.V_8_15_2)) { // it's safe to cast to int, since the max value for this is {@link PlanStreamOutput#MAX_SERIALIZED_ATTRIBUTES} int cacheId = Math.toIntExact(readZLong()); if (cacheId < 0) { @@ -222,8 +221,7 @@ private void cacheAttribute(int id, Attribute attr) { @SuppressWarnings("unchecked") public A readEsFieldWithCache() throws IOException { - if (getTransportVersion().onOrAfter(TransportVersions.ESQL_ES_FIELD_CACHED_SERIALIZATION) - || getTransportVersion().isPatchFrom(TransportVersions.V_8_15_2)) { + if (getTransportVersion().onOrAfter(TransportVersions.V_8_15_2)) { // it's safe to cast to int, since the max value for this is {@link PlanStreamOutput#MAX_SERIALIZED_ATTRIBUTES} int cacheId = Math.toIntExact(readZLong()); if (cacheId < 0) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java index 615c4266620c7..63d95c21d7d9d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java @@ -154,8 +154,7 @@ public void writeCachedBlock(Block block) throws IOException { @Override public boolean writeAttributeCacheHeader(Attribute attribute) throws IOException { - if (getTransportVersion().onOrAfter(TransportVersions.ESQL_ATTRIBUTE_CACHED_SERIALIZATION) - || getTransportVersion().isPatchFrom(TransportVersions.V_8_15_2)) { + if (getTransportVersion().onOrAfter(TransportVersions.V_8_15_2)) { Integer cacheId = attributeIdFromCache(attribute); if (cacheId != null) { writeZLong(cacheId); @@ -186,8 +185,7 @@ private int cacheAttribute(Attribute attr) { @Override public boolean writeEsFieldCacheHeader(EsField field) throws IOException { - if (getTransportVersion().onOrAfter(TransportVersions.ESQL_ES_FIELD_CACHED_SERIALIZATION) - || getTransportVersion().isPatchFrom(TransportVersions.V_8_15_2)) { + if (getTransportVersion().onOrAfter(TransportVersions.V_8_15_2)) { Integer cacheId = esFieldIdFromCache(field); if (cacheId != null) { writeZLong(cacheId); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java index 638fa1b8db456..4f97bf60bd863 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java @@ -30,7 +30,7 @@ public Expression rule(Expression e) { // perform this early to prevent the rule from converting the null filter into nullifying the whole expression // P.S. this could be done inside the Aggregate but this place better centralizes the logic if (e instanceof AggregateFunction agg) { - if (Expressions.isNull(agg.filter())) { + if (Expressions.isGuaranteedNull(agg.filter())) { return agg.withFilter(Literal.of(agg.filter(), false)); } } @@ -38,13 +38,13 @@ public Expression rule(Expression e) { if (result != e) { return result; } else if (e instanceof In in) { - if (Expressions.isNull(in.value())) { + if (Expressions.isGuaranteedNull(in.value())) { return Literal.of(in, null); } } else if (e instanceof Alias == false && e.nullable() == Nullability.TRUE && e instanceof Categorize == false - && Expressions.anyMatch(e.children(), Expressions::isNull)) { + && Expressions.anyMatch(e.children(), Expressions::isGuaranteedNull)) { return Literal.of(e, null); } return e; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PruneFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PruneFilters.java index b6f7ac9e464f4..00698d009ea23 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PruneFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PruneFilters.java @@ -29,7 +29,7 @@ protected LogicalPlan rule(Filter filter) { if (TRUE.equals(condition)) { return filter.child(); } - if (FALSE.equals(condition) || Expressions.isNull(condition)) { + if (FALSE.equals(condition) || Expressions.isGuaranteedNull(condition)) { return PruneEmptyPlans.skipPlan(filter); } } @@ -42,8 +42,8 @@ protected LogicalPlan rule(Filter filter) { private static Expression foldBinaryLogic(BinaryLogic binaryLogic) { if (binaryLogic instanceof Or or) { - boolean nullLeft = Expressions.isNull(or.left()); - boolean nullRight = Expressions.isNull(or.right()); + boolean nullLeft = Expressions.isGuaranteedNull(or.left()); + boolean nullRight = Expressions.isGuaranteedNull(or.right()); if (nullLeft && nullRight) { return new Literal(binaryLogic.source(), null, DataType.NULL); } @@ -55,7 +55,7 @@ private static Expression foldBinaryLogic(BinaryLogic binaryLogic) { } } if (binaryLogic instanceof And and) { - if (Expressions.isNull(and.left()) || Expressions.isNull(and.right())) { + if (Expressions.isGuaranteedNull(and.left()) || Expressions.isGuaranteedNull(and.right())) { return new Literal(binaryLogic.source(), null, DataType.NULL); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java index 930b485dbd374..9e9ae6a9a559d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java @@ -30,7 +30,7 @@ public Expression rule(In in) { List foldables = new ArrayList<>(in.list().size()); List nonFoldables = new ArrayList<>(in.list().size()); in.list().forEach(e -> { - if (e.foldable() && Expressions.isNull(e) == false) { // keep `null`s, needed for the 3VL + if (e.foldable() && Expressions.isGuaranteedNull(e) == false) { // keep `null`s, needed for the 3VL foldables.add(e); } else { nonFoldables.add(e); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/InsertFieldExtraction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/InsertFieldExtraction.java index dc32a4ad3c282..ed8851b64c27e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/InsertFieldExtraction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/InsertFieldExtraction.java @@ -11,14 +11,12 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; -import org.elasticsearch.xpack.esql.core.expression.TypedAttribute; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.optimizer.rules.physical.ProjectAwayColumns; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; import org.elasticsearch.xpack.esql.plan.physical.FieldExtractExec; import org.elasticsearch.xpack.esql.plan.physical.LeafExec; -import org.elasticsearch.xpack.esql.plan.physical.LookupJoinExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.rule.Rule; @@ -102,25 +100,17 @@ private static Set missingAttributes(PhysicalPlan p) { var missing = new LinkedHashSet(); var input = p.inputSet(); - // For LOOKUP JOIN we only need field-extraction on left fields used to match, since the right side is always materialized - if (p instanceof LookupJoinExec join) { - join.leftFields().forEach(f -> { - if (input.contains(f) == false) { - missing.add(f); - } - }); - return missing; - } - - // collect field attributes used inside expressions - // TODO: Rather than going over all expressions manually, this should just call .references() - p.forEachExpression(TypedAttribute.class, f -> { + // Collect field attributes referenced by this plan but not yet present in the child's output. + // This is also correct for LookupJoinExec, where we only need field extraction on the left fields used to match, since the right + // side is always materialized. + p.references().forEach(f -> { if (f instanceof FieldAttribute || f instanceof MetadataAttribute) { if (input.contains(f) == false) { missing.add(f); } } }); + return missing; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushFiltersToSource.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushFiltersToSource.java index 3d6c35e914294..9d02af0efbab0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushFiltersToSource.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushFiltersToSource.java @@ -32,6 +32,7 @@ import org.elasticsearch.xpack.esql.core.util.Queries; import org.elasticsearch.xpack.esql.expression.function.fulltext.FullTextFunction; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Term; import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.BinarySpatialFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction; @@ -254,6 +255,8 @@ static boolean canPushToSource(Expression exp, LucenePushdownPredicates lucenePu return canPushSpatialFunctionToSource(spatial, lucenePushdownPredicates); } else if (exp instanceof Match mf) { return mf.field() instanceof FieldAttribute && DataType.isString(mf.field().dataType()); + } else if (exp instanceof Term term) { + return term.field() instanceof FieldAttribute && DataType.isString(term.field().dataType()); } else if (exp instanceof FullTextFunction) { return true; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java index 620a25e0170ea..2e55b4df1e223 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java @@ -33,6 +33,15 @@ public class EsqlParser { private static final Logger log = LogManager.getLogger(EsqlParser.class); + /** + * Maximum number of characters in an ESQL query. Antlr may parse the entire + * query into tokens to make the choices, buffering the world. There's a lot we + * can do in the grammar to prevent that, but let's be paranoid and assume we'll + * fail at preventing antlr from slurping in the world. Instead, let's make sure + * that the world just isn't that big. + */ + public static final int MAX_LENGTH = 1_000_000; + private EsqlConfig config = new EsqlConfig(); public EsqlConfig config() { @@ -60,8 +69,14 @@ private T invokeParser( Function parseFunction, BiFunction result ) { + if (query.length() > MAX_LENGTH) { + throw new org.elasticsearch.xpack.esql.core.ParsingException( + "ESQL statement is too large [{} characters > {}]", + query.length(), + MAX_LENGTH + ); + } try { - // new CaseChangingCharStream() EsqlBaseLexer lexer = new EsqlBaseLexer(CharStreams.fromString(query)); lexer.removeErrorListeners(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java index dff55f0738975..891d03c571b27 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/AggregateExec.java @@ -85,7 +85,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(child()); out.writeNamedWriteableCollection(groupings()); out.writeNamedWriteableCollection(aggregates()); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_AGGREGATE_EXEC_TRACKS_INTERMEDIATE_ATTRS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeEnum(getMode()); out.writeNamedWriteableCollection(intermediateAttributes()); } else { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 7bf7d0e2d08eb..17468f7afec1b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -39,6 +39,7 @@ import org.elasticsearch.index.mapper.FieldNamesFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.NestedLookup; +import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -297,15 +298,11 @@ public SourceLoader newSourceLoader() { @Override public Query toQuery(QueryBuilder queryBuilder) { Query query = ctx.toQuery(queryBuilder).query(); - NestedLookup nestedLookup = ctx.nestedLookup(); - if (nestedLookup != NestedLookup.EMPTY) { - NestedHelper nestedHelper = new NestedHelper(nestedLookup, ctx::isFieldMapped); - if (nestedHelper.mightMatchNestedDocs(query)) { - // filter out nested documents - query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.MUST) - .add(newNonNestedFilter(ctx.indexVersionCreated()), BooleanClause.Occur.FILTER) - .build(); - } + if (ctx.nestedLookup() != NestedLookup.EMPTY && NestedHelper.mightMatchNestedDocs(query, ctx)) { + // filter out nested documents + query = new BooleanQuery.Builder().add(query, BooleanClause.Occur.MUST) + .add(newNonNestedFilter(ctx.indexVersionCreated()), BooleanClause.Occur.FILTER) + .build(); } if (aliasFilter != AliasFilter.EMPTY) { Query filterQuery = ctx.toQuery(aliasFilter.getQueryBuilder()).query(); @@ -348,7 +345,16 @@ public MappedFieldType.FieldExtractPreference fieldExtractPreference() { @Override public SearchLookup lookup() { - return ctx.lookup(); + boolean syntheticSource = SourceFieldMapper.isSynthetic(indexSettings()); + var searchLookup = ctx.lookup(); + if (syntheticSource) { + // in the context of scripts and when synthetic source is used the search lookup can't always be reused between + // users of SearchLookup. This is only an issue when scripts fallback to _source, but since we can't always + // accurately determine whether a script uses _source, we should do this for all script usages. + // This lookup() method is only invoked for scripts / runtime fields, so it is ok to do here. + searchLookup = searchLookup.swapSourceProvider(ctx.createSourceProvider()); + } + return searchLookup; } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java index 1580b77931240..1aee8f029e474 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.Kql; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.fulltext.QueryString; +import org.elasticsearch.xpack.esql.expression.function.fulltext.Term; import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesUtils; @@ -92,6 +93,7 @@ public final class EsqlExpressionTranslators { new MatchFunctionTranslator(), new QueryStringFunctionTranslator(), new KqlFunctionTranslator(), + new TermFunctionTranslator(), new Scalars() ); @@ -548,4 +550,12 @@ protected Query asQuery(Kql kqlFunction, TranslatorHandler handler) { return new KqlQuery(kqlFunction.source(), kqlFunction.queryAsText()); } } + + public static class TermFunctionTranslator extends ExpressionTranslator { + @Override + protected Query asQuery(Term term, TranslatorHandler handler) { + return new TermQuery(term.source(), ((FieldAttribute) term.field()).name(), term.queryAsText()); + } + } + } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 8c0488afdd42a..b85340936497e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -565,21 +565,12 @@ private PhysicalOperation planHashJoin(HashJoinExec join, LocalExecutionPlannerC private PhysicalOperation planLookupJoin(LookupJoinExec join, LocalExecutionPlannerContext context) { PhysicalOperation source = plan(join.left(), context); - // TODO: The source builder includes incoming fields including the ones we're going to drop Layout.Builder layoutBuilder = source.layout.builder(); for (Attribute f : join.addedFields()) { layoutBuilder.append(f); } Layout layout = layoutBuilder.build(); - // TODO: this works when the join happens on the coordinator - /* - * But when it happens on the data node we get a - * \_FieldExtractExec[language_code{f}#15, language_name{f}#16]<[]> - * \_EsQueryExec[languages_lookup], indexMode[lookup], query[][_doc{f}#18], limit[], sort[] estimatedRowSize[62] - * Which we'd prefer not to do - at least for now. We already know the fields we're loading - * and don't want any local planning. - */ EsQueryExec localSourceExec = (EsQueryExec) join.lookup(); if (localSourceExec.indexMode() != IndexMode.LOOKUP) { throw new IllegalArgumentException("can't plan [" + join + "]"); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java index 8d041ffbdf0e4..8bd23230fcde7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java @@ -9,8 +9,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.operator.DriverProfile; -import org.elasticsearch.compute.operator.FailureCollector; import org.elasticsearch.compute.operator.ResponseHeadersCollector; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; @@ -39,8 +39,7 @@ final class ComputeListener implements Releasable { private static final Logger LOGGER = LogManager.getLogger(ComputeService.class); - private final RefCountingListener refs; - private final FailureCollector failureCollector = new FailureCollector(); + private final EsqlRefCountingListener refs; private final AtomicBoolean cancelled = new AtomicBoolean(); private final CancellableTask task; private final TransportService transportService; @@ -105,7 +104,7 @@ private ComputeListener( : "clusterAlias and executionInfo must both be null or both non-null"; // listener that executes after all the sub-listeners refs (created via acquireCompute) have completed - this.refs = new RefCountingListener(1, ActionListener.wrap(ignored -> { + this.refs = new EsqlRefCountingListener(delegate.delegateFailure((l, ignored) -> { responseHeaders.finish(); ComputeResponse result; @@ -131,7 +130,7 @@ private ComputeListener( } } delegate.onResponse(result); - }, e -> delegate.onFailure(failureCollector.getFailure()))); + })); } private static void setFinalStatusAndShardCounts(String clusterAlias, EsqlExecutionInfo executionInfo) { @@ -191,7 +190,6 @@ private boolean isCCSListener(String computeClusterAlias) { */ ActionListener acquireAvoid() { return refs.acquire().delegateResponse((l, e) -> { - failureCollector.unwrapAndCollect(e); try { if (cancelled.compareAndSet(false, true)) { LOGGER.debug("cancelling ESQL task {} on failure", task); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeResponse.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeResponse.java index 308192704fe0e..8d2e092cd4149 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeResponse.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeResponse.java @@ -61,7 +61,7 @@ final class ComputeResponse extends TransportResponse { } else { profiles = null; } - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CCS_EXECUTION_INFO)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.took = in.readOptionalTimeValue(); this.totalShards = in.readVInt(); this.successfulShards = in.readVInt(); @@ -86,7 +86,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(profiles); } } - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CCS_EXECUTION_INFO)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalTimeValue(took); out.writeVInt(totalShards); out.writeVInt(successfulShards); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index c9c8635a60f57..9b59b98a7cdc2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -16,11 +16,11 @@ import org.elasticsearch.action.search.SearchShardsRequest; import org.elasticsearch.action.search.SearchShardsResponse; import org.elasticsearch.action.support.ChannelActionListener; -import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; @@ -45,6 +45,7 @@ import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.lookup.SourceProvider; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; @@ -87,6 +88,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME; @@ -373,7 +375,7 @@ private void startComputeOnDataNodes( var lookupListener = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); // SearchShards API can_match is done in lookupDataNodes lookupDataNodes(parentTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(dataNodeResult -> { - try (RefCountingListener refs = new RefCountingListener(lookupListener)) { + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(lookupListener)) { // update ExecutionInfo with shard counts (total and skipped) executionInfo.swapCluster( clusterAlias, @@ -434,7 +436,7 @@ private void startComputeOnRemoteClusters( ) { var queryPragmas = configuration.pragmas(); var linkExchangeListeners = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); - try (RefCountingListener refs = new RefCountingListener(linkExchangeListeners)) { + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(linkExchangeListeners)) { for (RemoteCluster cluster : clusters) { final var childSessionId = newChildSession(sessionId); ExchangeService.openExchange( @@ -471,12 +473,17 @@ void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, List contexts = new ArrayList<>(context.searchContexts.size()); for (int i = 0; i < context.searchContexts.size(); i++) { SearchContext searchContext = context.searchContexts.get(i); + var searchExecutionContext = new SearchExecutionContext(searchContext.getSearchExecutionContext()) { + + @Override + public SourceProvider createSourceProvider() { + final Supplier supplier = () -> super.createSourceProvider(); + return new ReinitializingSourceProvider(supplier); + + } + }; contexts.add( - new EsPhysicalOperationProviders.DefaultShardContext( - i, - searchContext.getSearchExecutionContext(), - searchContext.request().getAliasFilter() - ) + new EsPhysicalOperationProviders.DefaultShardContext(i, searchExecutionContext, searchContext.request().getAliasFilter()) ); } final List drivers; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java index 8f890e63bf54e..4c01d326ed7bc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequest.java @@ -81,7 +81,7 @@ final class DataNodeRequest extends TransportRequest implements IndicesRequest.R this.shardIds = in.readCollectionAsList(ShardId::new); this.aliasFilters = in.readMap(Index::new, AliasFilter::readFrom); this.plan = new PlanStreamInput(in, in.namedWriteableRegistry(), configuration).readNamedWriteable(PhysicalPlan.class); - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_ORIGINAL_INDICES)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.indices = in.readStringArray(); this.indicesOptions = IndicesOptions.readIndicesOptions(in); } else { @@ -101,7 +101,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(shardIds); out.writeMap(aliasFilters); new PlanStreamOutput(out, configuration).writeNamedWriteable(plan); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_ORIGINAL_INDICES)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeStringArray(indices); indicesOptions.writeIndicesOptions(out); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlMediaTypeParser.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlMediaTypeParser.java index 17329ca2e0054..1931692cea8bc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlMediaTypeParser.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlMediaTypeParser.java @@ -42,16 +42,23 @@ public class EsqlMediaTypeParser { * combinations are detected. */ public static MediaType getResponseMediaType(RestRequest request, EsqlQueryRequest esqlRequest) { - var mediaType = request.hasParam(URL_PARAM_FORMAT) ? mediaTypeFromParams(request) : mediaTypeFromHeaders(request); + var mediaType = getResponseMediaType(request, (MediaType) null); validateColumnarRequest(esqlRequest.columnar(), mediaType); validateIncludeCCSMetadata(esqlRequest.includeCCSMetadata(), mediaType); return checkNonNullMediaType(mediaType, request); } + /* + * Retrieve the mediaType of a REST request. If no mediaType can be established from the request, return the provided default. + */ + public static MediaType getResponseMediaType(RestRequest request, MediaType defaultMediaType) { + var mediaType = request.hasParam(URL_PARAM_FORMAT) ? mediaTypeFromParams(request) : mediaTypeFromHeaders(request); + return mediaType == null ? defaultMediaType : mediaType; + } + private static MediaType mediaTypeFromHeaders(RestRequest request) { ParsedMediaType acceptType = request.getParsedAccept(); - MediaType mediaType = acceptType != null ? acceptType.toMediaType(MEDIA_TYPE_REGISTRY) : request.getXContentType(); - return checkNonNullMediaType(mediaType, request); + return acceptType != null ? acceptType.toMediaType(MEDIA_TYPE_REGISTRY) : request.getXContentType(); } private static MediaType mediaTypeFromParams(RestRequest request) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ReinitializingSourceProvider.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ReinitializingSourceProvider.java new file mode 100644 index 0000000000000..b6b2c6dfec755 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ReinitializingSourceProvider.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.search.lookup.Source; +import org.elasticsearch.search.lookup.SourceProvider; + +import java.io.IOException; +import java.util.function.Supplier; + +/** + * This is a workaround for when compute engine executes concurrently with data partitioning by docid. + */ +final class ReinitializingSourceProvider implements SourceProvider { + + private PerThreadSourceProvider perThreadProvider; + private final Supplier sourceProviderFactory; + + ReinitializingSourceProvider(Supplier sourceProviderFactory) { + this.sourceProviderFactory = sourceProviderFactory; + } + + @Override + public Source getSource(LeafReaderContext ctx, int doc) throws IOException { + var currentThread = Thread.currentThread(); + PerThreadSourceProvider provider = perThreadProvider; + if (provider == null || provider.creatingThread != currentThread) { + provider = new PerThreadSourceProvider(sourceProviderFactory.get(), currentThread); + this.perThreadProvider = provider; + } + return perThreadProvider.source.getSource(ctx, doc); + } + + private record PerThreadSourceProvider(SourceProvider source, Thread creatingThread) { + + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java index 031bfd7139a84..aed196f963e9b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/RemoteClusterPlan.java @@ -23,7 +23,7 @@ static RemoteClusterPlan from(PlanStreamInput planIn) throws IOException { var plan = planIn.readNamedWriteable(PhysicalPlan.class); var targetIndices = planIn.readStringArray(); final OriginalIndices originalIndices; - if (planIn.getTransportVersion().onOrAfter(TransportVersions.ESQL_ORIGINAL_INDICES)) { + if (planIn.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { originalIndices = OriginalIndices.readOriginalIndices(planIn); } else { // fallback to the previous behavior @@ -35,7 +35,7 @@ static RemoteClusterPlan from(PlanStreamInput planIn) throws IOException { public void writeTo(PlanStreamOutput out) throws IOException { out.writeNamedWriteable(plan); out.writeStringArray(targetIndices); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_ORIGINAL_INDICES)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { OriginalIndices.writeOriginalIndices(originalIndices, out); } else { out.writeStringArray(originalIndices.indices()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuery.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuery.java index 8d33e9b480594..bc11d246904d5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuery.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/querydsl/query/SingleValueQuery.java @@ -107,7 +107,7 @@ public static class Builder extends AbstractQueryBuilder { super(in); this.next = in.readNamedWriteable(QueryBuilder.class); this.field = in.readString(); - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_SINGLE_VALUE_QUERY_SOURCE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { if (in instanceof PlanStreamInput psi) { this.source = Source.readFrom(psi); } else { @@ -128,7 +128,7 @@ public static class Builder extends AbstractQueryBuilder { protected void doWriteTo(StreamOutput out) throws IOException { out.writeNamedWriteable(next); out.writeString(field); - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_SINGLE_VALUE_QUERY_SOURCE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { source.writeTo(out); } else if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) { writeOldSource(out, source); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java index 4ec2746b24ee4..997f3265803f7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Configuration.java @@ -101,7 +101,7 @@ public Configuration(BlockStreamInput in) throws IOException { } else { this.tables = Map.of(); } - if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CCS_EXECUTION_INFO)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.queryStartTimeNanos = in.readLong(); } else { this.queryStartTimeNanos = -1; @@ -127,7 +127,7 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { out.writeMap(tables, (o1, columns) -> o1.writeMap(columns, StreamOutput::writeWriteable)); } - if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CCS_EXECUTION_INFO)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeLong(queryStartTimeNanos); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index 71fba5683644d..4f7c620bc8d12 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -374,10 +374,11 @@ private void preAnalyzeLookupIndices(List indices, ListenerResult lis // call the EsqlResolveFieldsAction (field-caps) to resolve indices and get field types indexResolver.resolveAsMergedMapping( table.index(), - Set.of("*"), // Current LOOKUP JOIN syntax does not allow for field selection + Set.of("*"), // TODO: for LOOKUP JOIN, this currently declares all lookup index fields relevant and might fetch too many. null, listener.map(indexResolution -> listenerResult.withLookupIndexResolution(indexResolution)) ); + // TODO: Verify that the resolved index actually has indexMode: "lookup" } else { try { // No lookup indices specified diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index 2e8b856cf82a6..2834e5f3f8358 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -265,6 +265,10 @@ public final void test() throws Throwable { "lookup join disabled for csv tests", testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.JOIN_LOOKUP_V4.capabilityName()) ); + assumeFalse( + "can't use TERM function in csv tests", + testCase.requiredCapabilities.contains(EsqlCapabilities.Cap.TERM_FUNCTION.capabilityName()) + ); if (Build.current().isSnapshot()) { assertThat( "Capability is not included in the enabled list capabilities on a snapshot build. Spelling mistake?", diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java index a63ee53cdd498..4e89a09db9ed4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java @@ -38,7 +38,7 @@ public static Analyzer defaultAnalyzer() { } public static Analyzer expandedDefaultAnalyzer() { - return analyzer(analyzerExpandedDefaultMapping()); + return analyzer(expandedDefaultIndexResolution()); } public static Analyzer analyzer(IndexResolution indexResolution) { @@ -47,18 +47,33 @@ public static Analyzer analyzer(IndexResolution indexResolution) { public static Analyzer analyzer(IndexResolution indexResolution, Verifier verifier) { return new Analyzer( - new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), indexResolution, defaultEnrichResolution()), + new AnalyzerContext( + EsqlTestUtils.TEST_CFG, + new EsqlFunctionRegistry(), + indexResolution, + defaultLookupResolution(), + defaultEnrichResolution() + ), verifier ); } public static Analyzer analyzer(IndexResolution indexResolution, Verifier verifier, Configuration config) { - return new Analyzer(new AnalyzerContext(config, new EsqlFunctionRegistry(), indexResolution, defaultEnrichResolution()), verifier); + return new Analyzer( + new AnalyzerContext(config, new EsqlFunctionRegistry(), indexResolution, defaultLookupResolution(), defaultEnrichResolution()), + verifier + ); } public static Analyzer analyzer(Verifier verifier) { return new Analyzer( - new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), analyzerDefaultMapping(), defaultEnrichResolution()), + new AnalyzerContext( + EsqlTestUtils.TEST_CFG, + new EsqlFunctionRegistry(), + analyzerDefaultMapping(), + defaultLookupResolution(), + defaultEnrichResolution() + ), verifier ); } @@ -98,10 +113,14 @@ public static IndexResolution analyzerDefaultMapping() { return loadMapping("mapping-basic.json", "test"); } - public static IndexResolution analyzerExpandedDefaultMapping() { + public static IndexResolution expandedDefaultIndexResolution() { return loadMapping("mapping-default.json", "test"); } + public static IndexResolution defaultLookupResolution() { + return loadMapping("mapping-languages.json", "languages_lookup"); + } + public static EnrichResolution defaultEnrichResolution() { EnrichResolution enrichResolution = new EnrichResolution(); loadEnrichPolicyResolution(enrichResolution, MATCH_TYPE, "languages", "language_code", "languages_idx", "mapping-languages.json"); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 5a1e109041a16..6edbb55af463d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.analysis.IndexAnalyzers; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.LoadMapping; import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; @@ -73,6 +74,8 @@ import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzer; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyzerDefaultMapping; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultEnrichResolution; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.loadMapping; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.tsdbIndexResolution; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; @@ -83,6 +86,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.matchesRegex; +import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.startsWith; //@TestLogging(value = "org.elasticsearch.xpack.esql.analysis:TRACE", reason = "debug") @@ -2002,6 +2006,58 @@ public void testLookupMatchTypeWrong() { assertThat(e.getMessage(), containsString("column type mismatch, table column was [integer] and original column was [keyword]")); } + public void testLookupJoinUnknownIndex() { + assumeTrue("requires LOOKUP JOIN capability", EsqlCapabilities.Cap.JOIN_LOOKUP_V4.isEnabled()); + + String errorMessage = "Unknown index [foobar]"; + IndexResolution missingLookupIndex = IndexResolution.invalid(errorMessage); + + Analyzer analyzerMissingLookupIndex = new Analyzer( + new AnalyzerContext( + EsqlTestUtils.TEST_CFG, + new EsqlFunctionRegistry(), + analyzerDefaultMapping(), + missingLookupIndex, + defaultEnrichResolution() + ), + TEST_VERIFIER + ); + + String query = "FROM test | LOOKUP JOIN foobar ON last_name"; + + VerificationException e = expectThrows(VerificationException.class, () -> analyze(query, analyzerMissingLookupIndex)); + assertThat(e.getMessage(), containsString("1:25: " + errorMessage)); + + String query2 = "FROM test | LOOKUP JOIN foobar ON missing_field"; + + e = expectThrows(VerificationException.class, () -> analyze(query2, analyzerMissingLookupIndex)); + assertThat(e.getMessage(), containsString("1:25: " + errorMessage)); + assertThat(e.getMessage(), not(containsString("[missing_field]"))); + } + + public void testLookupJoinUnknownField() { + assumeTrue("requires LOOKUP JOIN capability", EsqlCapabilities.Cap.JOIN_LOOKUP_V4.isEnabled()); + + String query = "FROM test | LOOKUP JOIN languages_lookup ON last_name"; + String errorMessage = "1:45: Unknown column [last_name] in right side of join"; + + VerificationException e = expectThrows(VerificationException.class, () -> analyze(query)); + assertThat(e.getMessage(), containsString(errorMessage)); + + String query2 = "FROM test | LOOKUP JOIN languages_lookup ON language_code"; + String errorMessage2 = "1:45: Unknown column [language_code] in left side of join"; + + e = expectThrows(VerificationException.class, () -> analyze(query2)); + assertThat(e.getMessage(), containsString(errorMessage2)); + + String query3 = "FROM test | LOOKUP JOIN languages_lookup ON missing_altogether"; + String errorMessage3 = "1:45: Unknown column [missing_altogether] in "; + + e = expectThrows(VerificationException.class, () -> analyze(query3)); + assertThat(e.getMessage(), containsString(errorMessage3 + "left side of join")); + assertThat(e.getMessage(), containsString(errorMessage3 + "right side of join")); + } + public void testImplicitCasting() { var e = expectThrows(VerificationException.class, () -> analyze(""" from test | eval x = concat("2024", "-04", "-01") + 1 day diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java index 3cafd42b731f6..68529e99c6b1b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java @@ -103,6 +103,14 @@ public void testInlineCast() throws IOException { logger.info("Wrote to file: {}", file); } + public void testTooBigQuery() { + StringBuilder query = new StringBuilder("FROM foo | EVAL a = a"); + while (query.length() < EsqlParser.MAX_LENGTH) { + query.append(", a = CONCAT(a, a)"); + } + assertEquals("-1:0: ESQL statement is too large [1000011 characters > 1000000]", error(query.toString())); + } + private String functionName(EsqlFunctionRegistry registry, Expression functionCall) { for (FunctionDefinition def : registry.listFunctions()) { if (functionCall.getClass().equals(def.clazz())) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 74e2de1141728..7e3ef4f1f5f87 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -1337,6 +1337,11 @@ public void testMatchFunctionOnlyAllowedInWhere() throws Exception { checkFullTextFunctionsOnlyAllowedInWhere("MATCH", "match(first_name, \"Anna\")", "function"); } + public void testTermFunctionOnlyAllowedInWhere() throws Exception { + assumeTrue("term function capability not available", EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()); + checkFullTextFunctionsOnlyAllowedInWhere("Term", "term(first_name, \"Anna\")", "function"); + } + public void testMatchOperatornOnlyAllowedInWhere() throws Exception { checkFullTextFunctionsOnlyAllowedInWhere(":", "first_name:\"Anna\"", "operator"); } @@ -1401,6 +1406,11 @@ public void testMatchFunctionWithDisjunctions() { checkWithDisjunctions("MATCH", "match(first_name, \"Anna\")", "function"); } + public void testTermFunctionWithDisjunctions() { + assumeTrue("term function capability not available", EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()); + checkWithDisjunctions("Term", "term(first_name, \"Anna\")", "function"); + } + public void testMatchOperatorWithDisjunctions() { checkWithDisjunctions(":", "first_name : \"Anna\"", "operator"); } @@ -1463,6 +1473,11 @@ public void testMatchFunctionWithNonBooleanFunctions() { checkFullTextFunctionsWithNonBooleanFunctions("MATCH", "match(first_name, \"Anna\")", "function"); } + public void testTermFunctionWithNonBooleanFunctions() { + assumeTrue("term function capability not available", EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()); + checkFullTextFunctionsWithNonBooleanFunctions("Term", "term(first_name, \"Anna\")", "function"); + } + public void testMatchOperatorWithNonBooleanFunctions() { checkFullTextFunctionsWithNonBooleanFunctions(":", "first_name:\"Anna\"", "operator"); } @@ -1563,6 +1578,45 @@ public void testMatchTargetsExistingField() throws Exception { assertEquals("1:33: Unknown column [first_name]", error("from test | keep emp_no | where first_name : \"Anna\"")); } + public void testTermFunctionArgNotConstant() throws Exception { + assumeTrue("term function capability not available", EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()); + assertEquals( + "1:19: second argument of [term(first_name, first_name)] must be a constant, received [first_name]", + error("from test | where term(first_name, first_name)") + ); + assertEquals( + "1:59: second argument of [term(first_name, query)] must be a constant, received [query]", + error("from test | eval query = concat(\"first\", \" name\") | where term(first_name, query)") + ); + // Other value types are tested in QueryStringFunctionTests + } + + // These should pass eventually once we lift some restrictions on match function + public void testTermFunctionCurrentlyUnsupportedBehaviour() throws Exception { + assumeTrue("term function capability not available", EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()); + assertEquals( + "1:67: Unknown column [first_name]", + error("from test | stats max_salary = max(salary) by emp_no | where term(first_name, \"Anna\")") + ); + } + + public void testTermFunctionNullArgs() throws Exception { + assumeTrue("term function capability not available", EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()); + assertEquals( + "1:19: first argument of [term(null, \"query\")] cannot be null, received [null]", + error("from test | where term(null, \"query\")") + ); + assertEquals( + "1:19: second argument of [term(first_name, null)] cannot be null, received [null]", + error("from test | where term(first_name, null)") + ); + } + + public void testTermTargetsExistingField() throws Exception { + assumeTrue("term function capability not available", EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()); + assertEquals("1:38: Unknown column [first_name]", error("from test | keep emp_no | where term(first_name, \"Anna\")")); + } + public void testCoalesceWithMixedNumericTypes() { assertEquals( "1:22: second argument of [coalesce(languages, height)] must be [integer], found value [height] type [double]", @@ -1926,6 +1980,17 @@ public void testSortByAggregate() { assertEquals("1:18: Aggregate functions are not allowed in SORT [COUNT]", error("FROM test | SORT count(*)")); } + public void testLookupJoinDataTypeMismatch() { + assumeTrue("requires LOOKUP JOIN capability", EsqlCapabilities.Cap.JOIN_LOOKUP_V4.isEnabled()); + + query("FROM test | EVAL language_code = languages | LOOKUP JOIN languages_lookup ON language_code"); + + assertEquals( + "1:87: JOIN left field [language_code] of type [KEYWORD] is incompatible with right field [language_code] of type [INTEGER]", + error("FROM test | EVAL language_code = languages::keyword | LOOKUP JOIN languages_lookup ON language_code") + ); + } + private void query(String query) { defaultAnalyzer.analyze(parser.createStatement(query)); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index 377027b70fb54..2004fa3a1cdb0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -1791,9 +1791,9 @@ public TypedData withData(Object data) { @Override public String toString() { if (type == DataType.UNSIGNED_LONG && data instanceof Long longData) { - return type.toString() + "(" + NumericUtils.unsignedLongAsBigInteger(longData).toString() + ")"; + return type + "(" + NumericUtils.unsignedLongAsBigInteger(longData).toString() + ")"; } - return type.toString() + "(" + (data == null ? "null" : data.toString()) + ")"; + return type.toString() + "(" + (data == null ? "null" : getValue().toString()) + ")"; } /** diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/TermTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/TermTests.java new file mode 100644 index 0000000000000..c1c0dc26880ab --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/fulltext/TermTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.fulltext; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.xpack.core.security.authc.support.mapper.expressiondsl.FieldExpression; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.FunctionName; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.equalTo; + +@FunctionName("term") +public class TermTests extends AbstractFunctionTestCase { + + public TermTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List> supportedPerPosition = supportedParams(); + List suppliers = new LinkedList<>(); + for (DataType fieldType : DataType.stringTypes()) { + for (DataType queryType : DataType.stringTypes()) { + addPositiveTestCase(List.of(fieldType, queryType), suppliers); + addNonFieldTestCase(List.of(fieldType, queryType), supportedPerPosition, suppliers); + } + } + + List suppliersWithErrors = errorsForCasesWithoutExamples(suppliers, (v, p) -> "string"); + + // Don't test null, as it is not allowed but the expected message is not a type error - so we check it separately in VerifierTests + return parameterSuppliersFromTypedData( + suppliersWithErrors.stream().filter(s -> s.types().contains(DataType.NULL) == false).toList() + ); + } + + protected static List> supportedParams() { + Set supportedTextParams = Set.of(DataType.KEYWORD, DataType.TEXT); + Set supportedNumericParams = Set.of(DataType.DOUBLE, DataType.INTEGER); + Set supportedFuzzinessParams = Set.of(DataType.INTEGER, DataType.KEYWORD, DataType.TEXT); + List> supportedPerPosition = List.of( + supportedTextParams, + supportedTextParams, + supportedNumericParams, + supportedFuzzinessParams + ); + return supportedPerPosition; + } + + protected static void addPositiveTestCase(List paramDataTypes, List suppliers) { + + // Positive case - creates an ES field from the field parameter type + suppliers.add( + new TestCaseSupplier( + getTestCaseName(paramDataTypes, "-ES field"), + paramDataTypes, + () -> new TestCaseSupplier.TestCase( + getTestParams(paramDataTypes), + "EndsWithEvaluator[str=Attribute[channel=0], suffix=Attribute[channel=1]]", + DataType.BOOLEAN, + equalTo(true) + ) + ) + ); + } + + private static void addNonFieldTestCase( + List paramDataTypes, + List> supportedPerPosition, + List suppliers + ) { + // Negative case - use directly the field parameter type + suppliers.add( + new TestCaseSupplier( + getTestCaseName(paramDataTypes, "-non ES field"), + paramDataTypes, + typeErrorSupplier(true, supportedPerPosition, paramDataTypes, TermTests::matchTypeErrorSupplier) + ) + ); + } + + private static List getTestParams(List paramDataTypes) { + String fieldName = randomIdentifier(); + List params = new ArrayList<>(); + params.add( + new TestCaseSupplier.TypedData( + new FieldExpression(fieldName, List.of(new FieldExpression.FieldValue(fieldName))), + paramDataTypes.get(0), + "field" + ) + ); + params.add(new TestCaseSupplier.TypedData(new BytesRef(randomAlphaOfLength(10)), paramDataTypes.get(1), "query")); + return params; + } + + private static String getTestCaseName(List paramDataTypes, String fieldType) { + StringBuilder sb = new StringBuilder(); + sb.append("<"); + sb.append(paramDataTypes.get(0)).append(fieldType).append(", "); + sb.append(paramDataTypes.get(1)); + sb.append(">"); + return sb.toString(); + } + + private static String matchTypeErrorSupplier(boolean includeOrdinal, List> validPerPosition, List types) { + return "[] cannot operate on [" + types.getFirst().typeName() + "], which is not a field from an index mapping"; + } + + @Override + protected Expression build(Source source, List args) { + return new Match(source, args.get(0), args.get(1)); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsTests.java index 0fb416584b472..6666eb8adab61 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EqualsTests.java @@ -144,6 +144,34 @@ public static Iterable parameters() { ) ); + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "EqualsNanosMillisEvaluator", + "lhs", + "rhs", + Object::equals, + DataType.BOOLEAN, + TestCaseSupplier.dateNanosCases(), + TestCaseSupplier.dateCases(), + List.of(), + false + ) + ); + + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "EqualsMillisNanosEvaluator", + "lhs", + "rhs", + Object::equals, + DataType.BOOLEAN, + TestCaseSupplier.dateCases(), + TestCaseSupplier.dateNanosCases(), + List.of(), + false + ) + ); + suppliers.addAll( TestCaseSupplier.stringCases( Object::equals, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqualTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqualTests.java index 395a574028f6a..0fbd49abd885b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqualTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanOrEqualTests.java @@ -121,6 +121,34 @@ public static Iterable parameters() { throw new UnsupportedOperationException("Got some weird types"); }, DataType.BOOLEAN, TestCaseSupplier.dateNanosCases(), TestCaseSupplier.dateNanosCases(), List.of(), false)); + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "GreaterThanOrEqualNanosMillisEvaluator", + "lhs", + "rhs", + (lhs, rhs) -> (((Instant) lhs).isAfter((Instant) rhs) || lhs.equals(rhs)), + DataType.BOOLEAN, + TestCaseSupplier.dateNanosCases(), + TestCaseSupplier.dateCases(), + List.of(), + false + ) + ); + + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "GreaterThanOrEqualMillisNanosEvaluator", + "lhs", + "rhs", + (lhs, rhs) -> (((Instant) lhs).isAfter((Instant) rhs) || lhs.equals(rhs)), + DataType.BOOLEAN, + TestCaseSupplier.dateCases(), + TestCaseSupplier.dateNanosCases(), + List.of(), + false + ) + ); + suppliers.addAll( TestCaseSupplier.stringCases( (l, r) -> ((BytesRef) l).compareTo((BytesRef) r) >= 0, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanTests.java index b56ecd7392ba6..ccc66df60fb3f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/GreaterThanTests.java @@ -135,6 +135,34 @@ public static Iterable parameters() { ) ); + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "GreaterThanNanosMillisEvaluator", + "lhs", + "rhs", + (l, r) -> ((Instant) l).isAfter((Instant) r), + DataType.BOOLEAN, + TestCaseSupplier.dateNanosCases(), + TestCaseSupplier.dateCases(), + List.of(), + false + ) + ); + + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "GreaterThanMillisNanosEvaluator", + "lhs", + "rhs", + (l, r) -> ((Instant) l).isAfter((Instant) r), + DataType.BOOLEAN, + TestCaseSupplier.dateCases(), + TestCaseSupplier.dateNanosCases(), + List.of(), + false + ) + ); + suppliers.addAll( TestCaseSupplier.stringCases( (l, r) -> ((BytesRef) l).compareTo((BytesRef) r) > 0, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqualTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqualTests.java index 60062f071c183..1e91a65e04c0e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqualTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanOrEqualTests.java @@ -121,6 +121,34 @@ public static Iterable parameters() { throw new UnsupportedOperationException("Got some weird types"); }, DataType.BOOLEAN, TestCaseSupplier.dateNanosCases(), TestCaseSupplier.dateNanosCases(), List.of(), false)); + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "LessThanOrEqualNanosMillisEvaluator", + "lhs", + "rhs", + (l, r) -> (((Instant) l).isBefore((Instant) r) || l.equals(r)), + DataType.BOOLEAN, + TestCaseSupplier.dateNanosCases(), + TestCaseSupplier.dateCases(), + List.of(), + false + ) + ); + + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "LessThanOrEqualMillisNanosEvaluator", + "lhs", + "rhs", + (l, r) -> (((Instant) l).isBefore((Instant) r) || l.equals(r)), + DataType.BOOLEAN, + TestCaseSupplier.dateCases(), + TestCaseSupplier.dateNanosCases(), + List.of(), + false + ) + ); + suppliers.addAll( TestCaseSupplier.stringCases( (l, r) -> ((BytesRef) l).compareTo((BytesRef) r) <= 0, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanTests.java index 30812cf8e538d..69dc59bac6456 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/LessThanTests.java @@ -135,6 +135,34 @@ public static Iterable parameters() { ) ); + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "LessThanNanosMillisEvaluator", + "lhs", + "rhs", + (l, r) -> ((Instant) l).isBefore((Instant) r), + DataType.BOOLEAN, + TestCaseSupplier.dateNanosCases(), + TestCaseSupplier.dateCases(), + List.of(), + false + ) + ); + + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "LessThanMillisNanosEvaluator", + "lhs", + "rhs", + (l, r) -> ((Instant) l).isBefore((Instant) r), + DataType.BOOLEAN, + TestCaseSupplier.dateCases(), + TestCaseSupplier.dateNanosCases(), + List.of(), + false + ) + ); + suppliers.addAll( TestCaseSupplier.stringCases( (l, r) -> ((BytesRef) l).compareTo((BytesRef) r) < 0, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEqualsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEqualsTests.java index 53676a43b16a0..7b57b97dfe28e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEqualsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/NotEqualsTests.java @@ -128,7 +128,7 @@ public static Iterable parameters() { false ) ); - // Datetime + // Datenanos suppliers.addAll( TestCaseSupplier.forBinaryNotCasting( "NotEqualsLongsEvaluator", @@ -142,6 +142,36 @@ public static Iterable parameters() { false ) ); + + // nanoseconds to milliseconds. NB: these have different evaluator names depending on the direction + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "NotEqualsNanosMillisEvaluator", + "lhs", + "rhs", + (l, r) -> false == l.equals(r), + DataType.BOOLEAN, + TestCaseSupplier.dateNanosCases(), + TestCaseSupplier.dateCases(), + List.of(), + false + ) + ); + + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "NotEqualsMillisNanosEvaluator", + "lhs", + "rhs", + (l, r) -> false == l.equals(r), + DataType.BOOLEAN, + TestCaseSupplier.dateCases(), + TestCaseSupplier.dateNanosCases(), + List.of(), + false + ) + ); + suppliers.addAll( TestCaseSupplier.stringCases( (l, r) -> false == l.equals(r), diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 86f5c812737b1..d32124c1aaf32 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -1391,6 +1391,35 @@ public void testMultipleMatchFilterPushdown() { assertThat(actualLuceneQuery.toString(), is(expectedLuceneQuery.toString())); } + /** + * Expecting + * LimitExec[1000[INTEGER]] + * \_ExchangeExec[[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gender{f}#4, job{f}#9, job.raw{f}#10, languages{f}#5, last_na + * me{f}#6, long_noidx{f}#11, salary{f}#7],false] + * \_ProjectExec[[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gender{f}#4, job{f}#9, job.raw{f}#10, languages{f}#5, last_na + * me{f}#6, long_noidx{f}#11, salary{f}#7]] + * \_FieldExtractExec[_meta_field{f}#8, emp_no{f}#2, first_name{f}#3, gen] + * \_EsQueryExec[test], indexMode[standard], query[{"term":{"last_name":{"query":"Smith"}}}] + */ + public void testTermFunction() { + // Skip test if the term function is not enabled. + assumeTrue("term function capability not available", EsqlCapabilities.Cap.TERM_FUNCTION.isEnabled()); + + var plan = plannerOptimizer.plan(""" + from test + | where term(last_name, "Smith") + """, IS_SV_STATS); + + var limit = as(plan, LimitExec.class); + var exchange = as(limit.child(), ExchangeExec.class); + var project = as(exchange.child(), ProjectExec.class); + var field = as(project.child(), FieldExtractExec.class); + var query = as(field.child(), EsQueryExec.class); + assertThat(query.limit().fold(), is(1000)); + var expected = QueryBuilders.termQuery("last_name", "Smith"); + assertThat(query.query().toString(), is(expected.toString())); + } + private QueryBuilder wrapWithSingleQuery(String query, QueryBuilder inner, String fieldName, Source source) { return FilterTests.singleValueQuery(query, inner, fieldName, source); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index b76781f76f4af..c2a26845d4e88 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -4820,7 +4820,7 @@ private static boolean oneLeaveIsNull(Expression e) { e.forEachUp(node -> { if (node.children().size() == 0) { - result.set(result.get() || Expressions.isNull(node)); + result.set(result.get() || Expressions.isGuaranteedNull(node)); } }); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java index 07ca112e8c527..3dfc0f611eb2b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java @@ -156,11 +156,7 @@ protected ClusterComputeRequest mutateInstance(ClusterComputeRequest in) throws public void testFallbackIndicesOptions() throws Exception { ClusterComputeRequest request = createTestInstance(); - var version = TransportVersionUtils.randomVersionBetween( - random(), - TransportVersions.V_8_14_0, - TransportVersions.ESQL_ORIGINAL_INDICES - ); + var version = TransportVersionUtils.randomVersionBetween(random(), TransportVersions.V_8_14_0, TransportVersions.V_8_16_0); ClusterComputeRequest cloned = copyInstance(request, version); assertThat(cloned.clusterAlias(), equalTo(request.clusterAlias())); assertThat(cloned.sessionId(), equalTo(request.sessionId())); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/EsqlMediaTypeParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/EsqlMediaTypeParserTests.java index 4b9166c621940..4758f83c42bb7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/EsqlMediaTypeParserTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/EsqlMediaTypeParserTests.java @@ -17,6 +17,7 @@ import java.util.Collections; import java.util.Map; +import static org.elasticsearch.xcontent.XContentType.JSON; import static org.elasticsearch.xpack.esql.formatter.TextFormat.CSV; import static org.elasticsearch.xpack.esql.formatter.TextFormat.PLAIN_TEXT; import static org.elasticsearch.xpack.esql.formatter.TextFormat.TSV; @@ -123,11 +124,17 @@ public void testIncludeCCSMetadataWithNonJSONMediaTypesInParams() { public void testNoFormat() { IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> getResponseMediaType(new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(), createTestInstance(false)) + () -> getResponseMediaType(emptyRequest(), createTestInstance(false)) ); assertEquals(e.getMessage(), "Invalid request content type: Accept=[null], Content-Type=[null], format=[null]"); } + public void testNoContentType() { + RestRequest fakeRestRequest = emptyRequest(); + assertThat(getResponseMediaType(fakeRestRequest, CSV), is(CSV)); + assertThat(getResponseMediaType(fakeRestRequest, JSON), is(JSON)); + } + private static RestRequest reqWithAccept(String acceptHeader) { return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withHeaders( Map.of("Content-Type", Collections.singletonList("application/json"), "Accept", Collections.singletonList(acceptHeader)) @@ -140,6 +147,10 @@ private static RestRequest reqWithParams(Map params) { ).withParams(params).build(); } + private static RestRequest emptyRequest() { + return new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(); + } + protected EsqlQueryRequest createTestInstance(boolean columnar) { var request = new EsqlQueryRequest(); request.columnar(columnar); diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 3c19e11a450b4..1d0236a5834e5 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -26,10 +26,6 @@ base { archivesName = 'x-pack-inference' } -versions << [ - 'aws2': '2.28.13' -] - dependencies { implementation project(path: ':libs:logging') compileOnly project(":server") @@ -62,36 +58,36 @@ dependencies { implementation 'io.opencensus:opencensus-contrib-http-util:0.31.1' /* AWS SDK v2 */ - implementation ("software.amazon.awssdk:bedrockruntime:${versions.aws2}") - api "software.amazon.awssdk:protocol-core:${versions.aws2}" - api "software.amazon.awssdk:aws-json-protocol:${versions.aws2}" - api "software.amazon.awssdk:third-party-jackson-core:${versions.aws2}" - api "software.amazon.awssdk:http-auth-aws:${versions.aws2}" - api "software.amazon.awssdk:checksums-spi:${versions.aws2}" - api "software.amazon.awssdk:checksums:${versions.aws2}" - api "software.amazon.awssdk:sdk-core:${versions.aws2}" + implementation ("software.amazon.awssdk:bedrockruntime:${versions.awsv2sdk}") + api "software.amazon.awssdk:protocol-core:${versions.awsv2sdk}" + api "software.amazon.awssdk:aws-json-protocol:${versions.awsv2sdk}" + api "software.amazon.awssdk:third-party-jackson-core:${versions.awsv2sdk}" + api "software.amazon.awssdk:http-auth-aws:${versions.awsv2sdk}" + api "software.amazon.awssdk:checksums-spi:${versions.awsv2sdk}" + api "software.amazon.awssdk:checksums:${versions.awsv2sdk}" + api "software.amazon.awssdk:sdk-core:${versions.awsv2sdk}" api "org.reactivestreams:reactive-streams:1.0.4" api "org.reactivestreams:reactive-streams-tck:1.0.4" - api "software.amazon.awssdk:profiles:${versions.aws2}" - api "software.amazon.awssdk:retries:${versions.aws2}" - api "software.amazon.awssdk:auth:${versions.aws2}" - api "software.amazon.awssdk:http-auth-aws-eventstream:${versions.aws2}" + api "software.amazon.awssdk:profiles:${versions.awsv2sdk}" + api "software.amazon.awssdk:retries:${versions.awsv2sdk}" + api "software.amazon.awssdk:auth:${versions.awsv2sdk}" + api "software.amazon.awssdk:http-auth-aws-eventstream:${versions.awsv2sdk}" api "software.amazon.eventstream:eventstream:1.0.1" - api "software.amazon.awssdk:http-auth-spi:${versions.aws2}" - api "software.amazon.awssdk:http-auth:${versions.aws2}" - api "software.amazon.awssdk:identity-spi:${versions.aws2}" - api "software.amazon.awssdk:http-client-spi:${versions.aws2}" - api "software.amazon.awssdk:regions:${versions.aws2}" - api "software.amazon.awssdk:annotations:${versions.aws2}" - api "software.amazon.awssdk:utils:${versions.aws2}" - api "software.amazon.awssdk:aws-core:${versions.aws2}" - api "software.amazon.awssdk:metrics-spi:${versions.aws2}" - api "software.amazon.awssdk:json-utils:${versions.aws2}" - api "software.amazon.awssdk:endpoints-spi:${versions.aws2}" - api "software.amazon.awssdk:retries-spi:${versions.aws2}" + api "software.amazon.awssdk:http-auth-spi:${versions.awsv2sdk}" + api "software.amazon.awssdk:http-auth:${versions.awsv2sdk}" + api "software.amazon.awssdk:identity-spi:${versions.awsv2sdk}" + api "software.amazon.awssdk:http-client-spi:${versions.awsv2sdk}" + api "software.amazon.awssdk:regions:${versions.awsv2sdk}" + api "software.amazon.awssdk:annotations:${versions.awsv2sdk}" + api "software.amazon.awssdk:utils:${versions.awsv2sdk}" + api "software.amazon.awssdk:aws-core:${versions.awsv2sdk}" + api "software.amazon.awssdk:metrics-spi:${versions.awsv2sdk}" + api "software.amazon.awssdk:json-utils:${versions.awsv2sdk}" + api "software.amazon.awssdk:endpoints-spi:${versions.awsv2sdk}" + api "software.amazon.awssdk:retries-spi:${versions.awsv2sdk}" /* Netty (via AWS SDKv2) */ - implementation "software.amazon.awssdk:netty-nio-client:${versions.aws2}" + implementation "software.amazon.awssdk:netty-nio-client:${versions.awsv2sdk}" runtimeOnly "io.netty:netty-buffer:${versions.netty}" runtimeOnly "io.netty:netty-codec-dns:${versions.netty}" runtimeOnly "io.netty:netty-codec-http2:${versions.netty}" diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 4e32ef99d06dd..07ce2fe00642b 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -21,6 +21,9 @@ import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; import org.junit.ClassRule; @@ -341,10 +344,21 @@ protected Deque streamInferOnMockService(String modelId, TaskTy return callAsync(endpoint, input); } + protected Deque unifiedCompletionInferOnMockService(String modelId, TaskType taskType, List input) + throws Exception { + var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId); + return callAsyncUnified(endpoint, input, "user"); + } + private Deque callAsync(String endpoint, List input) throws Exception { - var responseConsumer = new AsyncInferenceResponseConsumer(); var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input)); + + return execAsyncCall(request); + } + + private Deque execAsyncCall(Request request) throws Exception { + var responseConsumer = new AsyncInferenceResponseConsumer(); request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build()); var latch = new CountDownLatch(1); client().performRequestAsync(request, new ResponseListener() { @@ -362,6 +376,22 @@ public void onFailure(Exception exception) { return responseConsumer.events(); } + private Deque callAsyncUnified(String endpoint, List input, String role) throws Exception { + var request = new Request("POST", endpoint); + + request.setJsonEntity(createUnifiedJsonBody(input, role)); + return execAsyncCall(request); + } + + private String createUnifiedJsonBody(List input, String role) throws IOException { + var messages = input.stream().map(i -> Map.of("content", i, "role", role)).toList(); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + builder.field("messages", messages); + builder.endObject(); + return org.elasticsearch.common.Strings.toString(builder); + } + protected Map infer(String modelId, TaskType taskType, List input) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return inferInternal(endpoint, input, Map.of()); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index f5773e73f2b22..1e19491aeaa60 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -11,13 +11,18 @@ import org.apache.http.util.EntityUtils; import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -481,6 +486,56 @@ public void testSupportedStream() throws Exception { } } + public void testUnifiedCompletionInference() throws Exception { + String modelId = "streaming"; + putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION)); + var singleModel = getModel(modelId); + assertEquals(modelId, singleModel.get("inference_id")); + assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type")); + + var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomUUID()).toList(); + try { + var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input); + var expectedResponses = expectedResultsIterator(input); + assertThat(events.size(), equalTo((input.size() + 1) * 2)); + events.forEach(event -> { + switch (event.name()) { + case EVENT -> assertThat(event.value(), equalToIgnoringCase("message")); + case DATA -> assertThat(event.value(), equalTo(expectedResponses.next())); + } + }); + } finally { + deleteModel(modelId); + } + } + + private static Iterator expectedResultsIterator(List input) { + return Stream.concat(input.stream().map(String::toUpperCase).map(InferenceCrudIT::expectedResult), Stream.of("[DONE]")).iterator(); + } + + private static String expectedResult(String input) { + try { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + builder.field("id", "id"); + builder.startArray("choices"); + builder.startObject(); + builder.startObject("delta"); + builder.field("content", input); + builder.endObject(); + builder.field("index", 0); + builder.endObject(); + builder.endArray(); + builder.field("model", "gpt-4o-2024-08-06"); + builder.field("object", "chat.completion.chunk"); + builder.endObject(); + + return Strings.toString(builder); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + public void testGetZeroModels() throws IOException { var models = getModels("_all", TaskType.RERANK); assertThat(models, empty()); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index ae11a02d312e2..f5f682b143a72 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -132,6 +133,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 9320571572f0a..fa1e27005c287 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -29,6 +29,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -120,6 +121,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index fe0223cce0323..64569fd8c5c6a 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -29,6 +29,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -123,6 +124,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("unifiedCompletionInfer not supported"); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 6d7983bc8cb53..f7a05a27354ef 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -30,12 +30,14 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import java.io.IOException; import java.util.EnumSet; @@ -121,6 +123,24 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + switch (model.getConfigurations().getTaskType()) { + case COMPLETION -> listener.onResponse(makeUnifiedResults(request)); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + } + private StreamingChatCompletionResults makeResults(List input) { var responseIter = input.stream().map(String::toUpperCase).iterator(); return new StreamingChatCompletionResults(subscriber -> { @@ -152,6 +172,59 @@ private ChunkedToXContent completionChunk(String delta) { ); } + private StreamingUnifiedChatCompletionResults makeUnifiedResults(UnifiedCompletionRequest request) { + var responseIter = request.messages().stream().map(message -> message.content().toString().toUpperCase()).iterator(); + return new StreamingUnifiedChatCompletionResults(subscriber -> { + subscriber.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + if (responseIter.hasNext()) { + subscriber.onNext(unifiedCompletionChunk(responseIter.next())); + } else { + subscriber.onComplete(); + } + } + + @Override + public void cancel() {} + }); + }); + } + + /* + The response format looks like this + { + "id": "chatcmpl-AarrzyuRflye7yzDF4lmVnenGmQCF", + "choices": [ + { + "delta": { + "content": " information" + }, + "index": 0 + } + ], + "model": "gpt-4o-2024-08-06", + "object": "chat.completion.chunk" + } + */ + private ChunkedToXContent unifiedCompletionChunk(String delta) { + return params -> Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field("id", "id"), + ChunkedToXContentHelper.startArray("choices"), + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.startObject("delta"), + ChunkedToXContentHelper.field("content", delta), + ChunkedToXContentHelper.endObject(), + ChunkedToXContentHelper.field("index", 0), + ChunkedToXContentHelper.endObject(), + ChunkedToXContentHelper.endArray(), + ChunkedToXContentHelper.field("model", "gpt-4o-2024-08-06"), + ChunkedToXContentHelper.field("object", "chat.completion.chunk"), + ChunkedToXContentHelper.endObject() + ); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index c82f287792a7c..67892dfe78624 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -33,6 +33,8 @@ public Set getFeatures() { ); } + private static final NodeFeature SEMANTIC_TEXT_HIGHLIGHTER = new NodeFeature("semantic_text.highlighter"); + @Override public Set getTestFeatures() { return Set.of( @@ -40,7 +42,8 @@ public Set getTestFeatures() { SemanticTextFieldMapper.SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX, SemanticTextFieldMapper.SEMANTIC_TEXT_DELETE_FIX, SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX, - SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX + SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX, + SEMANTIC_TEXT_HIGHLIGHTER ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 2320cca8295d1..b83c098ca808c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; @@ -137,11 +138,18 @@ public static List getNamedWriteables() { addEisNamedWriteables(namedWriteables); addAlibabaCloudSearchNamedWriteables(namedWriteables); + addUnifiedNamedWriteables(namedWriteables); + namedWriteables.addAll(StreamingTaskManager.namedWriteables()); return namedWriteables; } + private static void addUnifiedNamedWriteables(List namedWriteables) { + var writeables = UnifiedCompletionRequest.getNamedWriteables(); + namedWriteables.addAll(writeables); + } + private static void addAmazonBedrockNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 48458bf4f5086..148a784456361 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -37,6 +37,7 @@ import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.threadpool.ExecutorBuilder; @@ -50,6 +51,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceEndpointAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction; @@ -58,6 +60,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; @@ -67,7 +70,9 @@ import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; +import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder; import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder; @@ -83,6 +88,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction; +import org.elasticsearch.xpack.inference.rest.RestUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService; @@ -156,8 +162,9 @@ public InferencePlugin(Settings settings) { @Override public List> getActions() { - return List.of( + var availableActions = List.of( new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class), + new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class), new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class), new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class), @@ -166,6 +173,13 @@ public InferencePlugin(Settings settings) { new ActionHandler<>(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class), new ActionHandler<>(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class) ); + + List> conditionalActions = + UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() + ? List.of(new ActionHandler<>(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class)) + : List.of(); + + return Stream.concat(availableActions.stream(), conditionalActions.stream()).toList(); } @Override @@ -180,7 +194,7 @@ public List getRestHandlers( Supplier nodesInCluster, Predicate clusterSupportsFeature ) { - return List.of( + var availableRestActions = List.of( new RestInferenceAction(), new RestStreamInferenceAction(), new RestGetInferenceModelAction(), @@ -190,6 +204,11 @@ public List getRestHandlers( new RestGetInferenceDiagnosticsAction(), new RestGetInferenceServicesAction() ); + List conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() + ? List.of(new RestUnifiedCompletionInferenceAction()) + : List.of(); + + return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList(); } @Override @@ -392,7 +411,12 @@ public void close() { @Override public Map getMappers() { - return Map.of(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); + return Map.of( + SemanticTextFieldMapper.CONTENT_TYPE, + SemanticTextFieldMapper.PARSER, + OffsetSourceFieldMapper.CONTENT_TYPE, + OffsetSourceFieldMapper.PARSER + ); } @Override @@ -411,4 +435,9 @@ public List> getRetrievers() { new RetrieverSpec<>(new ParseField(RandomRankBuilder.NAME), RandomRankRetrieverBuilder::fromXContent) ); } + + @Override + public Map getHighlighters() { + return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java new file mode 100644 index 0000000000000..3e13d0c1e39de --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * Unified Completion feature flag. When the feature is complete, this flag will be removed. + * Enable feature via JVM option: `-Des.inference_unified_feature_flag_enabled=true`. + */ +public class UnifiedCompletionFeature { + public static final FeatureFlag UNIFIED_COMPLETION_FEATURE_FLAG = new FeatureFlag("inference_unified"); + + private UnifiedCompletionFeature() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java new file mode 100644 index 0000000000000..2a0e8e1775279 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -0,0 +1,250 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; + +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; + +public abstract class BaseTransportInferenceAction extends HandledTransportAction< + Request, + InferenceAction.Response> { + + private static final Logger log = LogManager.getLogger(BaseTransportInferenceAction.class); + private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; + private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; + private final ModelRegistry modelRegistry; + private final InferenceServiceRegistry serviceRegistry; + private final InferenceStats inferenceStats; + private final StreamingTaskManager streamingTaskManager; + + public BaseTransportInferenceAction( + String inferenceActionName, + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager, + Writeable.Reader requestReader + ) { + super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); + this.modelRegistry = modelRegistry; + this.serviceRegistry = serviceRegistry; + this.inferenceStats = inferenceStats; + this.streamingTaskManager = streamingTaskManager; + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + var timer = InferenceTimer.start(); + + var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> { + var service = serviceRegistry.getService(unparsedModel.service()); + try { + validationHelper(service::isEmpty, () -> unknownServiceException(unparsedModel.service(), request.getInferenceEntityId())); + validationHelper( + () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false, + () -> requestModelTaskTypeMismatchException(request.getTaskType(), unparsedModel.taskType()) + ); + validationHelper( + () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel), + () -> createInvalidTaskTypeException(request, unparsedModel) + ); + } catch (Exception e) { + recordMetrics(unparsedModel, timer, e); + listener.onFailure(e); + return; + } + + var model = service.get() + .parsePersistedConfigWithSecrets( + unparsedModel.inferenceEntityId(), + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ); + inferOnServiceWithMetrics(model, request, service.get(), timer, listener); + }, e -> { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); + } catch (Exception metricsException) { + log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics"); + } + listener.onFailure(e); + }); + + modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); + } + + private static void validationHelper(Supplier validationFailure, Supplier exceptionCreator) { + if (validationFailure.get()) { + throw exceptionCreator.get(); + } + } + + protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel); + + protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel); + + private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); + } catch (Exception e) { + log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); + } + } + + private void inferOnServiceWithMetrics( + Model model, + Request request, + InferenceService service, + InferenceTimer timer, + ActionListener listener + ) { + inferenceStats.requestCount().incrementBy(1, modelAttributes(model)); + inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> { + if (request.isStreaming()) { + var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); + inferenceResults.publisher().subscribe(taskProcessor); + + var instrumentedStream = new PublisherWithMetrics(timer, model); + taskProcessor.subscribe(instrumentedStream); + + listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream)); + } else { + recordMetrics(model, timer, null); + listener.onResponse(new InferenceAction.Response(inferenceResults)); + } + }, e -> { + recordMetrics(model, timer, e); + listener.onFailure(e); + })); + } + + private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); + } catch (Exception e) { + log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); + } + } + + private void inferOnService(Model model, Request request, InferenceService service, ActionListener listener) { + if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + doInference(model, request, service, listener); + } else { + listener.onFailure(unsupportedStreamingTaskException(request, service)); + } + } + + protected abstract void doInference( + Model model, + Request request, + InferenceService service, + ActionListener listener + ); + + private ElasticsearchStatusException unsupportedStreamingTaskException(Request request, InferenceService service) { + var supportedTasks = service.supportedStreamingTasks(); + if (supportedTasks.isEmpty()) { + return new ElasticsearchStatusException( + format("Streaming is not allowed for service [%s].", service.name()), + RestStatus.METHOD_NOT_ALLOWED + ); + } else { + var validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(",")); + return new ElasticsearchStatusException( + format( + "Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", + service.name(), + request.getTaskType(), + validTasks + ), + RestStatus.METHOD_NOT_ALLOWED + ); + } + } + + private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { + return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); + } + + private static ElasticsearchStatusException requestModelTaskTypeMismatchException(TaskType requested, TaskType expected) { + return new ElasticsearchStatusException( + "Incompatible task_type, the requested type [{}] does not match the model type [{}]", + RestStatus.BAD_REQUEST, + requested, + expected + ); + } + + private class PublisherWithMetrics extends DelegatingProcessor { + + private final InferenceTimer timer; + private final Model model; + + private PublisherWithMetrics(InferenceTimer timer, Model model) { + this.timer = timer; + this.model = model; + } + + @Override + protected void next(ChunkedToXContent item) { + downstream().onNext(item); + } + + @Override + public void onError(Throwable throwable) { + recordMetrics(model, timer, throwable); + super.onError(throwable); + } + + @Override + protected void onCancel() { + recordMetrics(model, timer, null); + super.onCancel(); + } + + @Override + public void onComplete() { + recordMetrics(model, timer, null); + super.onComplete(); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index ba9ab3c133731..08e6d869a553d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -7,47 +7,22 @@ package org.elasticsearch.xpack.inference.action; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; -import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; -import java.util.stream.Collectors; - -import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; - -public class TransportInferenceAction extends HandledTransportAction { - private static final Logger log = LogManager.getLogger(TransportInferenceAction.class); - private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; - private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; - - private final ModelRegistry modelRegistry; - private final InferenceServiceRegistry serviceRegistry; - private final InferenceStats inferenceStats; - private final StreamingTaskManager streamingTaskManager; +public class TransportInferenceAction extends BaseTransportInferenceAction { @Inject public TransportInferenceAction( @@ -58,184 +33,44 @@ public TransportInferenceAction( InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager ) { - super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); - this.modelRegistry = modelRegistry; - this.serviceRegistry = serviceRegistry; - this.inferenceStats = inferenceStats; - this.streamingTaskManager = streamingTaskManager; + super( + InferenceAction.NAME, + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager, + InferenceAction.Request::new + ); } @Override - protected void doExecute(Task task, InferenceAction.Request request, ActionListener listener) { - var timer = InferenceTimer.start(); - - var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> { - var service = serviceRegistry.getService(unparsedModel.service()); - if (service.isEmpty()) { - var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } - - if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { - // not the wildcard task type and not the model task type - var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } - - var model = service.get() - .parsePersistedConfigWithSecrets( - unparsedModel.inferenceEntityId(), - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ); - inferOnServiceWithMetrics(model, request, service.get(), timer, listener); - }, e -> { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); - } catch (Exception metricsException) { - log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics"); - } - listener.onFailure(e); - }); - - modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); - } - - private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); - } - } - - private void inferOnServiceWithMetrics( - Model model, - InferenceAction.Request request, - InferenceService service, - InferenceTimer timer, - ActionListener listener - ) { - inferenceStats.requestCount().incrementBy(1, modelAttributes(model)); - inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> { - if (request.isStreaming()) { - var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); - inferenceResults.publisher().subscribe(taskProcessor); - - var instrumentedStream = new PublisherWithMetrics(timer, model); - taskProcessor.subscribe(instrumentedStream); - - listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream)); - } else { - recordMetrics(model, timer, null); - listener.onResponse(new InferenceAction.Response(inferenceResults)); - } - }, e -> { - recordMetrics(model, timer, e); - listener.onFailure(e); - })); + protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request request, UnparsedModel unparsedModel) { + return false; } - private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); - } + @Override + protected ElasticsearchStatusException createInvalidTaskTypeException(InferenceAction.Request request, UnparsedModel unparsedModel) { + return null; } - private void inferOnService( + @Override + protected void doInference( Model model, InferenceAction.Request request, InferenceService service, ActionListener listener ) { - if (request.isStreaming() == false || service.canStream(request.getTaskType())) { - service.infer( - model, - request.getQuery(), - request.getInput(), - request.isStreaming(), - request.getTaskSettings(), - request.getInputType(), - request.getInferenceTimeout(), - listener - ); - } else { - listener.onFailure(unsupportedStreamingTaskException(request, service)); - } - } - - private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) { - var supportedTasks = service.supportedStreamingTasks(); - if (supportedTasks.isEmpty()) { - return new ElasticsearchStatusException( - format("Streaming is not allowed for service [%s].", service.name()), - RestStatus.METHOD_NOT_ALLOWED - ); - } else { - var validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(",")); - return new ElasticsearchStatusException( - format( - "Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", - service.name(), - request.getTaskType(), - validTasks - ), - RestStatus.METHOD_NOT_ALLOWED - ); - } - } - - private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { - return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); - } - - private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) { - return new ElasticsearchStatusException( - "Incompatible task_type, the requested type [{}] does not match the model type [{}]", - RestStatus.BAD_REQUEST, - requested, - expected + service.infer( + model, + request.getQuery(), + request.getInput(), + request.isStreaming(), + request.getTaskSettings(), + request.getInputType(), + request.getInferenceTimeout(), + listener ); } - - private class PublisherWithMetrics extends DelegatingProcessor { - private final InferenceTimer timer; - private final Model model; - - private PublisherWithMetrics(InferenceTimer timer, Model model) { - this.timer = timer; - this.model = model; - } - - @Override - protected void next(ChunkedToXContent item) { - downstream().onNext(item); - } - - @Override - public void onError(Throwable throwable) { - recordMetrics(model, timer, throwable); - super.onError(throwable); - } - - @Override - protected void onCancel() { - recordMetrics(model, timer, null); - super.onCancel(); - } - - @Override - public void onComplete() { - recordMetrics(model, timer, null); - super.onComplete(); - } - } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java new file mode 100644 index 0000000000000..f0906231d8f42 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; + +public class TransportUnifiedCompletionInferenceAction extends BaseTransportInferenceAction { + + @Inject + public TransportUnifiedCompletionInferenceAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + super( + UnifiedCompletionAction.NAME, + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager, + UnifiedCompletionAction.Request::new + ); + } + + @Override + protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, UnparsedModel unparsedModel) { + return request.getTaskType().isAnyOrSame(TaskType.COMPLETION) == false || unparsedModel.taskType() != TaskType.COMPLETION; + } + + @Override + protected ElasticsearchStatusException createInvalidTaskTypeException( + UnifiedCompletionAction.Request request, + UnparsedModel unparsedModel + ) { + return new ElasticsearchStatusException( + "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", + RestStatus.BAD_REQUEST, + request.getTaskType(), + TaskType.COMPLETION.toString() + ); + } + + @Override + protected void doInference( + Model model, + UnifiedCompletionAction.Request request, + InferenceService service, + ActionListener listener + ) { + service.unifiedCompletionInfer(model, request.getUnifiedCompletionRequest(), null, listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index def52e97666f9..9d6f5bb89218f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -49,7 +49,7 @@ public SentenceBoundaryChunkingSettings(Integer maxChunkSize, @Nullable Integer public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { maxChunkSize = in.readInt(); - if (in.getTransportVersion().onOrAfter(TransportVersions.CHUNK_SENTENCE_OVERLAP_SETTING_ADDED)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { sentenceOverlap = in.readVInt(); } } @@ -113,13 +113,13 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS; + return TransportVersions.V_8_16_0; } @Override public void writeTo(StreamOutput out) throws IOException { out.writeInt(maxChunkSize); - if (out.getTransportVersion().onOrAfter(TransportVersions.CHUNK_SENTENCE_OVERLAP_SETTING_ADDED)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVInt(sentenceOverlap); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 7fb0fdc91bf72..7e0378d5b0cd1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -104,7 +104,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index 03e794e42c3a2..eda3fc0f3bfdb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -9,7 +9,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; - +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -25,6 +32,33 @@ public abstract class DelegatingProcessor implements Flow.Processor private Flow.Subscriber downstream; private Flow.Subscription upstream; + public static Deque parseEvent( + Deque item, + ParseChunkFunction parseFunction, + XContentParserConfiguration parserConfig, + Logger logger + ) throws Exception { + var results = new ArrayDeque(item.size()); + for (ServerSentEvent event : item) { + if (ServerSentEventField.DATA == event.name() && event.hasValue()) { + try { + var delta = parseFunction.apply(parserConfig, event); + delta.forEachRemaining(results::offer); + } catch (Exception e) { + logger.warn("Failed to parse event from inference provider: {}", event); + throw e; + } + } + } + + return results; + } + + @FunctionalInterface + public interface ParseChunkFunction { + Iterator apply(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException; + } + @Override public void subscribe(Flow.Subscriber subscriber) { if (downstream != null) { @@ -51,7 +85,7 @@ public void request(long n) { if (isClosed.get()) { downstream.onComplete(); } else if (upstream != null) { - upstream.request(n); + upstreamRequest(n); } else { pendingRequests.accumulateAndGet(n, Long::sum); } @@ -67,6 +101,13 @@ public void cancel() { }; } + /** + * Guaranteed to be called when the upstream is set and this processor had not been closed. + */ + protected void upstreamRequest(long n) { + upstream.request(n); + } + protected void onCancel() {} @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java index 4e97554b56445..b43e5ab70e2f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java @@ -12,7 +12,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -34,13 +33,7 @@ public SingleInputSenderExecutableAction( @Override public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { - if (inferenceInputs instanceof DocumentsOnlyInput == false) { - listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR)); - return; - } - - var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs; - if (docsOnlyInput.getInputs().size() > 1) { + if (inferenceInputs.inputSize() > 1) { listener.onFailure( new ElasticsearchStatusException(requestTypeForInputValidationError + " only accepts 1 input", RestStatus.BAD_REQUEST) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java index 9c83264b5581f..bd5c53d589df0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java @@ -26,7 +26,7 @@ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type. */ public class OpenAiActionCreator implements OpenAiActionVisitor { - private static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; + public static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; private final Sender sender; private final ServiceComponents serviceComponents; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java index a0a44e62f9f73..e7a960f1316f2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java @@ -69,7 +69,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List input = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + List input = inferenceInputs.castTo(ChatCompletionInput.class).getInputs(); AlibabaCloudSearchCompletionRequest request = new AlibabaCloudSearchCompletionRequest(account, input, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java index 69a5c665feb86..3929585a0745d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java @@ -44,10 +44,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, docsInput); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, inputs); var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream); var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java index 5418b3dd9840b..6d4aeb9e31bac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java @@ -46,10 +46,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index 21cec68b14a49..affd2e3a7760e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -41,10 +41,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, docsInput, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, inputs, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java index d036559ec3dcb..c2f5f3e9db5ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java @@ -46,10 +46,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java new file mode 100644 index 0000000000000..928da95d9c2f0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import java.util.List; +import java.util.Objects; + +/** + * This class encapsulates the input text passed by the request and indicates whether the response should be streamed. + * The main difference between this class and {@link UnifiedChatInput} is this should only be used for + * {@link org.elasticsearch.inference.TaskType#COMPLETION} originating through the + * {@link org.elasticsearch.inference.InferenceService#infer} code path. These are requests sent to the + * API without using the _unified route. + */ +public class ChatCompletionInput extends InferenceInputs { + private final List input; + + public ChatCompletionInput(List input) { + this(input, false); + } + + public ChatCompletionInput(List input, boolean stream) { + super(stream); + this.input = Objects.requireNonNull(input); + } + + public List getInputs() { + return this.input; + } + + public int inputSize() { + return input.size(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java index ae46fbe0fef87..40cd03c87664e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java @@ -50,10 +50,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + CohereCompletionRequest request = new CohereCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java index 8cf411d84c932..3feb79d3de6cc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java @@ -14,30 +14,28 @@ public class DocumentsOnlyInput extends InferenceInputs { public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) { if (inferenceInputs instanceof DocumentsOnlyInput == false) { - throw createUnsupportedTypeException(inferenceInputs); + throw createUnsupportedTypeException(inferenceInputs, DocumentsOnlyInput.class); } return (DocumentsOnlyInput) inferenceInputs; } private final List input; - private final boolean stream; public DocumentsOnlyInput(List input) { this(input, false); } public DocumentsOnlyInput(List input, boolean stream) { - super(); + super(stream); this.input = Objects.requireNonNull(input); - this.stream = stream; } public List getInputs() { return this.input; } - public boolean stream() { - return stream; + public int inputSize() { + return input.size(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java index abe50c6fae3f9..0097f9c08ea21 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java @@ -51,7 +51,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(DocumentsOnlyInput.of(inferenceInputs), model); + GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest( + inferenceInputs.castTo(ChatCompletionInput.class), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index dd241857ef0c4..e85ea6f1d9b35 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -10,7 +10,29 @@ import org.elasticsearch.common.Strings; public abstract class InferenceInputs { - public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs) { - return new IllegalArgumentException(Strings.format("Unsupported inference inputs type: [%s]", inferenceInputs.getClass())); + private final boolean stream; + + public InferenceInputs(boolean stream) { + this.stream = stream; + } + + public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs, Class clazz) { + return new IllegalArgumentException( + Strings.format("Unable to convert inference inputs type: [%s] to [%s]", inferenceInputs.getClass(), clazz) + ); } + + public T castTo(Class clazz) { + if (clazz.isInstance(this) == false) { + throw createUnsupportedTypeException(this, clazz); + } + + return clazz.cast(this); + } + + public boolean stream() { + return stream; + } + + public abstract int inputSize(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index cea89332e5bf0..4d730be6aa6bd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -15,7 +15,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; @@ -25,8 +25,8 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager { private static final Logger logger = LogManager.getLogger(OpenAiCompletionRequestManager.class); - private static final ResponseHandler HANDLER = createCompletionHandler(); + static final String USER_ROLE = "user"; public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { return new OpenAiCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); @@ -35,7 +35,7 @@ public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, private final OpenAiChatCompletionModel model; private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { - super(threadPool, model, OpenAiChatCompletionRequest::buildDefaultUri); + super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); this.model = Objects.requireNonNull(model); } @@ -46,10 +46,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(docsInput, model, stream); + var chatCompletionInputs = inferenceInputs.castTo(ChatCompletionInput.class); + var request = new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(chatCompletionInputs, USER_ROLE), model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java new file mode 100644 index 0000000000000..3b0f770e3e061 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public class OpenAiUnifiedCompletionRequestManager extends OpenAiRequestManager { + + private static final Logger logger = LogManager.getLogger(OpenAiUnifiedCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + public static OpenAiUnifiedCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { + return new OpenAiUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final OpenAiChatCompletionModel model; + + private OpenAiUnifiedCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { + super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest( + inferenceInputs.castTo(UnifiedChatInput.class), + model + ); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } + + private static ResponseHandler createCompletionHandler() { + return new OpenAiUnifiedChatCompletionResponseHandler("openai completion", OpenAiChatCompletionResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 50bb77b307db3..5af5245ac5b40 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -14,7 +14,7 @@ public class QueryAndDocsInputs extends InferenceInputs { public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { if (inferenceInputs instanceof QueryAndDocsInputs == false) { - throw createUnsupportedTypeException(inferenceInputs); + throw createUnsupportedTypeException(inferenceInputs, QueryAndDocsInputs.class); } return (QueryAndDocsInputs) inferenceInputs; @@ -22,17 +22,15 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { private final String query; private final List chunks; - private final boolean stream; public QueryAndDocsInputs(String query, List chunks) { this(query, chunks, false); } public QueryAndDocsInputs(String query, List chunks, boolean stream) { - super(); + super(stream); this.query = Objects.requireNonNull(query); this.chunks = Objects.requireNonNull(chunks); - this.stream = stream; } public String getQuery() { @@ -43,8 +41,7 @@ public List getChunks() { return chunks; } - public boolean stream() { - return stream; + public int inputSize() { + return chunks.size(); } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java new file mode 100644 index 0000000000000..f89fa1ee37a6f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.UnifiedCompletionRequest; + +import java.util.List; +import java.util.Objects; + +/** + * This class encapsulates the unified request. + * The main difference between this class and {@link ChatCompletionInput} is this should only be used for + * {@link org.elasticsearch.inference.TaskType#COMPLETION} originating through the + * {@link org.elasticsearch.inference.InferenceService#unifiedCompletionInfer(Model, UnifiedCompletionRequest, TimeValue, ActionListener)} + * code path. These are requests sent to the API with the _unified route. + */ +public class UnifiedChatInput extends InferenceInputs { + private final UnifiedCompletionRequest request; + + public UnifiedChatInput(UnifiedCompletionRequest request, boolean stream) { + super(stream); + this.request = Objects.requireNonNull(request); + } + + public UnifiedChatInput(ChatCompletionInput completionInput, String roleValue) { + this(completionInput.getInputs(), roleValue, completionInput.stream()); + } + + public UnifiedChatInput(List inputs, String roleValue, boolean stream) { + this(UnifiedCompletionRequest.of(convertToMessages(inputs, roleValue)), stream); + } + + private static List convertToMessages(List inputs, String roleValue) { + return inputs.stream() + .map( + value -> new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(value), + roleValue, + null, + null, + null + ) + ) + .toList(); + } + + public UnifiedCompletionRequest getRequest() { + return request; + } + + public int inputSize() { + return request.messages().size(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index 6e006fe255956..48c8132035b50 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -18,10 +18,8 @@ import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Collections; import java.util.Deque; import java.util.Iterator; @@ -115,19 +113,7 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - - var results = new ArrayDeque(item.size()); - for (ServerSentEvent event : item) { - if (ServerSentEventField.DATA == event.name() && event.hasValue()) { - try { - var delta = parse(parserConfig, event); - delta.forEachRemaining(results::offer); - } catch (Exception e) { - log.warn("Failed to parse event from inference provider: {}", event); - throw e; - } - } - } + var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig, log); if (results.isEmpty()) { upstream().request(1); @@ -136,7 +122,7 @@ protected void next(Deque item) throws Exception { } } - private Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) + private static Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException { if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { return Collections.emptyIterator(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..fce2556efc5e0 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; + +import java.util.concurrent.Flow; + +public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction); + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(openAiProcessor); + return new StreamingUnifiedChatCompletionResults(openAiProcessor); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java new file mode 100644 index 0000000000000..599d71df3dcfa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -0,0 +1,287 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.LinkedBlockingDeque; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { + public static final String FUNCTION_FIELD = "function"; + private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class); + + private static final String CHOICES_FIELD = "choices"; + private static final String DELTA_FIELD = "delta"; + private static final String CONTENT_FIELD = "content"; + private static final String DONE_MESSAGE = "[done]"; + private static final String REFUSAL_FIELD = "refusal"; + private static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String ROLE_FIELD = "role"; + public static final String FINISH_REASON_FIELD = "finish_reason"; + public static final String INDEX_FIELD = "index"; + public static final String OBJECT_FIELD = "object"; + public static final String MODEL_FIELD = "model"; + public static final String ID_FIELD = "id"; + public static final String CHOICE_FIELD = "choice"; + public static final String USAGE_FIELD = "usage"; + public static final String TYPE_FIELD = "type"; + public static final String NAME_FIELD = "name"; + public static final String ARGUMENTS_FIELD = "arguments"; + public static final String COMPLETION_TOKENS_FIELD = "completion_tokens"; + public static final String PROMPT_TOKENS_FIELD = "prompt_tokens"; + public static final String TOTAL_TOKENS_FIELD = "total_tokens"; + + private final Deque buffer = new LinkedBlockingDeque<>(); + + @Override + protected void upstreamRequest(long n) { + if (buffer.isEmpty()) { + super.upstreamRequest(n); + } else { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); + } + } + + @Override + protected void next(Deque item) throws Exception { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + var results = parseEvent(item, OpenAiUnifiedStreamingProcessor::parse, parserConfig, logger); + + if (results.isEmpty()) { + upstream().request(1); + } else if (results.size() == 1) { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } else { + // results > 1, but openai spec only wants 1 chunk per SSE event + var firstItem = singleItem(results.poll()); + while (results.isEmpty() == false) { + buffer.offer(results.poll()); + } + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem)); + } + } + + private static Iterator parse( + XContentParserConfiguration parserConfig, + ServerSentEvent event + ) throws IOException { + if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { + return Collections.emptyIterator(); + } + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser); + + return Collections.singleton(chunk).iterator(); + } + } + + public static class ChatCompletionChunkParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "chat_completion_chunk", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + (String) args[0], + (List) args[1], + (String) args[2], + (String) args[3], + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage) args[4] + ) + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(ID_FIELD)); + PARSER.declareObjectArray( + ConstructingObjectParser.constructorArg(), + (p, c) -> ChatCompletionChunkParser.ChoiceParser.parse(p), + new ParseField(CHOICES_FIELD) + ); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_FIELD)); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(OBJECT_FIELD)); + PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.UsageParser.parse(p), + null, + new ParseField(USAGE_FIELD) + ); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private static class ChoiceParser { + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + CHOICE_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta) args[0], + (String) args[1], + (int) args[2] + ) + ); + + static { + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (p, c) -> ChatCompletionChunkParser.DeltaParser.parse(p), + new ParseField(DELTA_FIELD) + ); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(FINISH_REASON_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice parse(XContentParser parser) { + return PARSER.apply(parser, null); + } + } + + private static class DeltaParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta, + Void> PARSER = new ConstructingObjectParser<>( + DELTA_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + (String) args[0], + (String) args[1], + (String) args[2], + (List) args[3] + ) + ); + + static { + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD)); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(REFUSAL_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD)); + PARSER.declareObjectArray( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.ToolCallParser.parse(p), + new ParseField(TOOL_CALLS_FIELD) + ); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta parse(XContentParser parser) + throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class ToolCallParser { + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall, + Void> PARSER = new ConstructingObjectParser<>( + "tool_call", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + (int) args[0], + (String) args[1], + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function) args[2], + (String) args[3] + ) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ID_FIELD)); + PARSER.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.FunctionParser.parse(p), + new ParseField(FUNCTION_FIELD) + ); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(TYPE_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall parse(XContentParser parser) + throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class FunctionParser { + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function, + Void> PARSER = new ConstructingObjectParser<>( + FUNCTION_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + (String) args[0], + (String) args[1] + ) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse( + XContentParser parser + ) throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class UsageParser { + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + USAGE_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage((int) args[0], (int) args[1], (int) args[2]) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(COMPLETION_TOKENS_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(PROMPT_TOKENS_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(TOTAL_TOKENS_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + } + + private Deque singleItem( + StreamingUnifiedChatCompletionResults.ChatCompletionChunk result + ) { + var deque = new ArrayDeque(1); + deque.offer(result); + return deque; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java index 80770d63ef139..b1af18d03dda4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; @@ -27,13 +27,13 @@ public class GoogleAiStudioCompletionRequest implements GoogleAiStudioRequest { private static final String ALT_PARAM = "alt"; private static final String SSE_VALUE = "sse"; - private final DocumentsOnlyInput input; + private final ChatCompletionInput input; private final LazyInitializable uri; private final GoogleAiStudioCompletionModel model; - public GoogleAiStudioCompletionRequest(DocumentsOnlyInput input, GoogleAiStudioCompletionModel model) { + public GoogleAiStudioCompletionRequest(ChatCompletionInput input, GoogleAiStudioCompletionModel model) { this.input = Objects.requireNonNull(input); this.model = Objects.requireNonNull(model); this.uri = new LazyInitializable<>(() -> model.uri(input.stream())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java deleted file mode 100644 index 867a7ca80cbcb..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.openai; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -public class OpenAiChatCompletionRequestEntity implements ToXContentObject { - - private static final String MESSAGES_FIELD = "messages"; - private static final String MODEL_FIELD = "model"; - - private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; - - private static final String ROLE_FIELD = "role"; - private static final String USER_FIELD = "user"; - private static final String CONTENT_FIELD = "content"; - private static final String STREAM_FIELD = "stream"; - - private final List messages; - private final String model; - - private final String user; - private final boolean stream; - - public OpenAiChatCompletionRequestEntity(List messages, String model, String user, boolean stream) { - Objects.requireNonNull(messages); - Objects.requireNonNull(model); - - this.messages = messages; - this.model = model; - this.user = user; - this.stream = stream; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.startArray(MESSAGES_FIELD); - { - for (String message : messages) { - builder.startObject(); - - { - builder.field(ROLE_FIELD, USER_FIELD); - builder.field(CONTENT_FIELD, message); - } - - builder.endObject(); - } - } - builder.endArray(); - - builder.field(MODEL_FIELD, model); - builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); - - if (Strings.isNullOrEmpty(user) == false) { - builder.field(USER_FIELD, user); - } - - if (stream) { - builder.field(STREAM_FIELD, true); - } - - builder.endObject(); - - return builder; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java similarity index 80% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java index 99a025e70d003..2e6bdb748fd33 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java @@ -13,6 +13,7 @@ import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; @@ -21,24 +22,21 @@ import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; -import java.util.List; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader; -public class OpenAiChatCompletionRequest implements OpenAiRequest { +public class OpenAiUnifiedChatCompletionRequest implements OpenAiRequest { private final OpenAiAccount account; - private final List input; private final OpenAiChatCompletionModel model; - private final boolean stream; + private final UnifiedChatInput unifiedChatInput; - public OpenAiChatCompletionRequest(List input, OpenAiChatCompletionModel model, boolean stream) { - this.account = OpenAiAccount.of(model, OpenAiChatCompletionRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); + public OpenAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { + this.account = OpenAiAccount.of(model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); this.model = Objects.requireNonNull(model); - this.stream = stream; } @Override @@ -46,9 +44,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString( - new OpenAiChatCompletionRequestEntity(input, model.getServiceSettings().modelId(), model.getTaskSettings().user(), stream) - ).getBytes(StandardCharsets.UTF_8) + Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); @@ -87,7 +83,7 @@ public String getInferenceEntityId() { @Override public boolean isStreaming() { - return stream; + return unifiedChatInput.stream(); } public static URI buildDefaultUri() throws URISyntaxException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java new file mode 100644 index 0000000000000..50339bf851f7d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -0,0 +1,185 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.openai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject { + + public static final String NAME_FIELD = "name"; + public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; + public static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String ID_FIELD = "id"; + public static final String FUNCTION_FIELD = "function"; + public static final String ARGUMENTS_FIELD = "arguments"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String STRICT_FIELD = "strict"; + public static final String TOP_P_FIELD = "top_p"; + public static final String USER_FIELD = "user"; + public static final String STREAM_FIELD = "stream"; + private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; + private static final String MODEL_FIELD = "model"; + public static final String MESSAGES_FIELD = "messages"; + private static final String ROLE_FIELD = "role"; + private static final String CONTENT_FIELD = "content"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private static final String STOP_FIELD = "stop"; + private static final String TEMPERATURE_FIELD = "temperature"; + private static final String TOOL_CHOICE_FIELD = "tool_choice"; + private static final String TOOL_FIELD = "tools"; + private static final String TEXT_FIELD = "text"; + private static final String TYPE_FIELD = "type"; + private static final String STREAM_OPTIONS_FIELD = "stream_options"; + private static final String INCLUDE_USAGE_FIELD = "include_usage"; + + private final UnifiedCompletionRequest unifiedRequest; + private final boolean stream; + private final OpenAiChatCompletionModel model; + + public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { + Objects.requireNonNull(unifiedChatInput); + + this.unifiedRequest = unifiedChatInput.getRequest(); + this.stream = unifiedChatInput.stream(); + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(MESSAGES_FIELD); + { + for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { + builder.startObject(); + { + switch (message.content()) { + case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); + case UnifiedCompletionRequest.ContentObjects contentObjects -> { + builder.startArray(CONTENT_FIELD); + for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { + builder.startObject(); + builder.field(TEXT_FIELD, contentObject.text()); + builder.field(TYPE_FIELD, contentObject.type()); + builder.endObject(); + } + builder.endArray(); + } + } + + builder.field(ROLE_FIELD, message.role()); + if (message.name() != null) { + builder.field(NAME_FIELD, message.name()); + } + if (message.toolCallId() != null) { + builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); + } + if (message.toolCalls() != null) { + builder.startArray(TOOL_CALLS_FIELD); + for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { + builder.startObject(); + { + builder.field(ID_FIELD, toolCall.id()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); + builder.field(NAME_FIELD, toolCall.function().name()); + } + builder.endObject(); + builder.field(TYPE_FIELD, toolCall.type()); + } + builder.endObject(); + } + builder.endArray(); + } + } + builder.endObject(); + } + } + builder.endArray(); + + builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); + if (unifiedRequest.maxCompletionTokens() != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); + } + + builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); + + if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { + builder.field(STOP_FIELD, unifiedRequest.stop()); + } + if (unifiedRequest.temperature() != null) { + builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); + } + if (unifiedRequest.toolChoice() != null) { + if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { + builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); + } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { + builder.startObject(TOOL_CHOICE_FIELD); + { + builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field( + NAME_FIELD, + ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() + ); + } + builder.endObject(); + } + builder.endObject(); + } + } + if (unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false) { + builder.startArray(TOOL_FIELD); + for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) { + builder.startObject(); + { + builder.field(TYPE_FIELD, t.type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(DESCRIPTION_FIELD, t.function().description()); + builder.field(NAME_FIELD, t.function().name()); + builder.field(PARAMETERS_FIELD, t.function().parameters()); + if (t.function().strict() != null) { + builder.field(STRICT_FIELD, t.function().strict()); + } + } + builder.endObject(); + } + builder.endObject(); + } + builder.endArray(); + } + if (unifiedRequest.topP() != null) { + builder.field(TOP_P_FIELD, unifiedRequest.topP()); + } + + if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) { + builder.field(USER_FIELD, model.getTaskSettings().user()); + } + + builder.field(STREAM_FIELD, stream); + if (stream) { + builder.startObject(STREAM_OPTIONS_FIELD); + builder.field(INCLUDE_USAGE_FIELD, true); + builder.endObject(); + } + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java new file mode 100644 index 0000000000000..f2bfa72ec617a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java @@ -0,0 +1,226 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.highlight; + +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.elasticsearch.common.text.Text; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; +import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; +import org.elasticsearch.search.vectors.VectorData; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * A {@link Highlighter} designed for the {@link SemanticTextFieldMapper}. + * This highlighter extracts semantic queries and evaluates them against each chunk produced by the semantic text field. + * It returns the top-scoring chunks as snippets, optionally sorted by their scores. + */ +public class SemanticTextHighlighter implements Highlighter { + public static final String NAME = "semantic"; + + private record OffsetAndScore(int offset, float score) {} + + @Override + public boolean canHighlight(MappedFieldType fieldType) { + if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) { + return true; + } + return false; + } + + @Override + public HighlightField highlight(FieldHighlightContext fieldContext) throws IOException { + SemanticTextFieldMapper.SemanticTextFieldType fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldContext.fieldType; + if (fieldType.getEmbeddingsField() == null) { + // nothing indexed yet + return null; + } + + final List queries = switch (fieldType.getModelSettings().taskType()) { + case SPARSE_EMBEDDING -> extractSparseVectorQueries( + (SparseVectorFieldType) fieldType.getEmbeddingsField().fieldType(), + fieldContext.query + ); + case TEXT_EMBEDDING -> extractDenseVectorQueries( + (DenseVectorFieldType) fieldType.getEmbeddingsField().fieldType(), + fieldContext.query + ); + default -> throw new IllegalStateException( + "Wrong task type for a semantic text field, got [" + fieldType.getModelSettings().taskType().name() + "]" + ); + }; + if (queries.isEmpty()) { + // nothing to highlight + return null; + } + + int numberOfFragments = fieldContext.field.fieldOptions().numberOfFragments() <= 0 + ? 1 // we return the best fragment by default + : fieldContext.field.fieldOptions().numberOfFragments(); + + List chunks = extractOffsetAndScores( + fieldContext.context.getSearchExecutionContext(), + fieldContext.hitContext.reader(), + fieldType, + fieldContext.hitContext.docId(), + queries + ); + if (chunks.size() == 0) { + return null; + } + + chunks.sort(Comparator.comparingDouble(OffsetAndScore::score).reversed()); + int size = Math.min(chunks.size(), numberOfFragments); + if (fieldContext.field.fieldOptions().scoreOrdered() == false) { + chunks = chunks.subList(0, size); + chunks.sort(Comparator.comparingInt(c -> c.offset)); + } + Text[] snippets = new Text[size]; + List> nestedSources = XContentMapValues.extractNestedSources( + fieldType.getChunksField().fullPath(), + fieldContext.hitContext.source().source() + ); + for (int i = 0; i < size; i++) { + var chunk = chunks.get(i); + if (nestedSources.size() <= chunk.offset) { + throw new IllegalStateException( + String.format( + Locale.ROOT, + "Invalid content detected for field [%s]: the chunks size is [%d], " + + "but a reference to offset [%d] was found in the result.", + fieldType.name(), + nestedSources.size(), + chunk.offset + ) + ); + } + String content = (String) nestedSources.get(chunk.offset).get(SemanticTextField.CHUNKED_TEXT_FIELD); + if (content == null) { + throw new IllegalStateException( + String.format( + Locale.ROOT, + + "Invalid content detected for field [%s]: missing text for the chunk at offset [%d].", + fieldType.name(), + chunk.offset + ) + ); + } + snippets[i] = new Text(content); + } + return new HighlightField(fieldContext.fieldName, snippets); + } + + private List extractOffsetAndScores( + SearchExecutionContext context, + LeafReader reader, + SemanticTextFieldMapper.SemanticTextFieldType fieldType, + int docId, + List leafQueries + ) throws IOException { + var bitSet = context.bitsetFilter(fieldType.getChunksField().parentTypeFilter()).getBitSet(reader.getContext()); + int previousParent = docId > 0 ? bitSet.prevSetBit(docId - 1) : -1; + + BooleanQuery.Builder bq = new BooleanQuery.Builder().add(fieldType.getChunksField().nestedTypeFilter(), BooleanClause.Occur.FILTER); + leafQueries.stream().forEach(q -> bq.add(q, BooleanClause.Occur.SHOULD)); + Weight weight = new IndexSearcher(reader).createWeight(bq.build(), ScoreMode.COMPLETE, 1); + Scorer scorer = weight.scorer(reader.getContext()); + if (previousParent != -1) { + if (scorer.iterator().advance(previousParent) == DocIdSetIterator.NO_MORE_DOCS) { + return List.of(); + } + } else if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) { + return List.of(); + } + List results = new ArrayList<>(); + int offset = 0; + while (scorer.docID() < docId) { + results.add(new OffsetAndScore(offset++, scorer.score())); + if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) { + break; + } + } + return results; + } + + private List extractDenseVectorQueries(DenseVectorFieldType fieldType, Query querySection) { + // TODO: Handle knn section when semantic text field can be used. + List queries = new ArrayList<>(); + querySection.visit(new QueryVisitor() { + @Override + public boolean acceptField(String field) { + return fieldType.name().equals(field); + } + + @Override + public void consumeTerms(Query query, Term... terms) { + super.consumeTerms(query, terms); + } + + @Override + public void visitLeaf(Query query) { + if (query instanceof KnnFloatVectorQuery knnQuery) { + queries.add(fieldType.createExactKnnQuery(VectorData.fromFloats(knnQuery.getTargetCopy()), null)); + } else if (query instanceof KnnByteVectorQuery knnQuery) { + queries.add(fieldType.createExactKnnQuery(VectorData.fromBytes(knnQuery.getTargetCopy()), null)); + } + } + }); + return queries; + } + + private List extractSparseVectorQueries(SparseVectorFieldType fieldType, Query querySection) { + List queries = new ArrayList<>(); + querySection.visit(new QueryVisitor() { + @Override + public boolean acceptField(String field) { + return fieldType.name().equals(field); + } + + @Override + public void consumeTerms(Query query, Term... terms) { + super.consumeTerms(query, terms); + } + + @Override + public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) { + if (parent instanceof SparseVectorQueryWrapper sparseVectorQuery) { + queries.add(sparseVectorQuery.getTermsQuery()); + } + return this; + } + }); + return queries; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceField.java new file mode 100644 index 0000000000000..d8339f1004da2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceField.java @@ -0,0 +1,145 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.FieldType; +import org.apache.lucene.index.IndexOptions; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.search.DocIdSetIterator; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Represents a {@link Field} that stores a {@link Term} along with its start and end offsets. + * Note: The {@link Charset} used to calculate these offsets is not associated with this field. + * It is the responsibility of the consumer to handle the appropriate {@link Charset}. + */ +public final class OffsetSourceField extends Field { + private static final FieldType FIELD_TYPE = new FieldType(); + + static { + FIELD_TYPE.setTokenized(false); + FIELD_TYPE.setOmitNorms(true); + FIELD_TYPE.setIndexOptions(IndexOptions.DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS); + } + + private int startOffset; + private int endOffset; + + public OffsetSourceField(String fieldName, String sourceFieldName, int startOffset, int endOffset) { + super(fieldName, sourceFieldName, FIELD_TYPE); + this.startOffset = startOffset; + this.endOffset = endOffset; + } + + public void setValues(String fieldName, int startOffset, int endOffset) { + this.fieldsData = fieldName; + this.startOffset = startOffset; + this.endOffset = endOffset; + } + + @Override + public TokenStream tokenStream(Analyzer analyzer, TokenStream reuse) { + OffsetTokenStream stream; + if (reuse instanceof OffsetTokenStream) { + stream = (OffsetTokenStream) reuse; + } else { + stream = new OffsetTokenStream(); + } + + stream.setValues((String) fieldsData, startOffset, endOffset); + return stream; + } + + public static OffsetSourceLoader loader(Terms terms) throws IOException { + return new OffsetSourceLoader(terms); + } + + private static final class OffsetTokenStream extends TokenStream { + private final CharTermAttribute termAttribute = addAttribute(CharTermAttribute.class); + private final OffsetAttribute offsetAttribute = addAttribute(OffsetAttribute.class); + private boolean used = true; + private String term = null; + private int startOffset = 0; + private int endOffset = 0; + + private OffsetTokenStream() {} + + /** Sets the values */ + void setValues(String term, int startOffset, int endOffset) { + this.term = term; + this.startOffset = startOffset; + this.endOffset = endOffset; + } + + @Override + public boolean incrementToken() { + if (used) { + return false; + } + clearAttributes(); + termAttribute.append(term); + offsetAttribute.setOffset(startOffset, endOffset); + used = true; + return true; + } + + @Override + public void reset() { + used = false; + } + + @Override + public void close() { + term = null; + } + } + + public static class OffsetSourceLoader { + private final Map postingsEnums = new LinkedHashMap<>(); + + private OffsetSourceLoader(Terms terms) throws IOException { + var termsEnum = terms.iterator(); + while (termsEnum.next() != null) { + var postings = termsEnum.postings(null, PostingsEnum.OFFSETS); + if (postings.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + postingsEnums.put(termsEnum.term().utf8ToString(), postings); + } + } + } + + public OffsetSourceFieldMapper.OffsetSource advanceTo(int doc) throws IOException { + for (var it = postingsEnums.entrySet().iterator(); it.hasNext();) { + var entry = it.next(); + var postings = entry.getValue(); + if (postings.docID() < doc) { + if (postings.advance(doc) == DocIdSetIterator.NO_MORE_DOCS) { + it.remove(); + continue; + } + } + if (postings.docID() == doc) { + assert postings.freq() == 1; + postings.nextPosition(); + return new OffsetSourceFieldMapper.OffsetSource(entry.getKey(), postings.startOffset(), postings.endOffset()); + } + } + return null; + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldMapper.java new file mode 100644 index 0000000000000..e612076f1aaf2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldMapper.java @@ -0,0 +1,253 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Query; +import org.elasticsearch.index.fielddata.FieldDataContext; +import org.elasticsearch.index.fielddata.IndexFieldData; +import org.elasticsearch.index.mapper.DocumentParserContext; +import org.elasticsearch.index.mapper.FieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.MapperBuilderContext; +import org.elasticsearch.index.mapper.TextSearchInfo; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.fetch.StoredFieldsSpec; +import org.elasticsearch.search.lookup.Source; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; + +/** + * A {@link FieldMapper} that maps a field name to its start and end offsets. + * The {@link CharsetFormat} used to compute the offsets is specified via the charset parameter. + * Currently, only {@link CharsetFormat#UTF_16} is supported, aligning with Java's {@code String} charset + * for simpler internal usage and integration. + * + * Each document can store at most one value in this field. + * + * Note: This mapper is not yet documented and is intended exclusively for internal use by + * {@link SemanticTextFieldMapper}. If exposing this mapper directly to users becomes necessary, + * extending charset compatibility should be considered, as the current default (and sole supported charset) + * was chosen for ease of Java integration. + */ +public class OffsetSourceFieldMapper extends FieldMapper { + public static final String CONTENT_TYPE = "offset_source"; + + private static final String SOURCE_NAME_FIELD = "field"; + private static final String START_OFFSET_FIELD = "start"; + private static final String END_OFFSET_FIELD = "end"; + + public record OffsetSource(String field, int start, int end) implements ToXContentObject { + public OffsetSource { + if (start < 0 || end < 0) { + throw new IllegalArgumentException("Illegal offsets, expected positive numbers, got: " + start + ":" + end); + } + if (start > end) { + throw new IllegalArgumentException("Illegal offsets, expected start < end, got: " + start + " > " + end); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(SOURCE_NAME_FIELD, field); + builder.field(START_OFFSET_FIELD, start); + builder.field(END_OFFSET_FIELD, end); + return builder.endObject(); + } + } + + private static final ConstructingObjectParser OFFSET_SOURCE_PARSER = new ConstructingObjectParser<>( + CONTENT_TYPE, + true, + args -> new OffsetSource((String) args[0], (int) args[1], (int) args[2]) + ); + + static { + OFFSET_SOURCE_PARSER.declareString(constructorArg(), new ParseField(SOURCE_NAME_FIELD)); + OFFSET_SOURCE_PARSER.declareInt(constructorArg(), new ParseField(START_OFFSET_FIELD)); + OFFSET_SOURCE_PARSER.declareInt(constructorArg(), new ParseField(END_OFFSET_FIELD)); + } + + public enum CharsetFormat { + UTF_16(StandardCharsets.UTF_16); + + private Charset charSet; + + CharsetFormat(Charset charSet) { + this.charSet = charSet; + } + } + + public static class Builder extends FieldMapper.Builder { + private final Parameter charset = Parameter.enumParam( + "charset", + false, + i -> CharsetFormat.UTF_16, + CharsetFormat.UTF_16, + CharsetFormat.class + ); + private final Parameter> meta = Parameter.metaParam(); + + public Builder(String name) { + super(name); + } + + @Override + protected Parameter[] getParameters() { + return new Parameter[] { meta, charset }; + } + + @Override + public OffsetSourceFieldMapper build(MapperBuilderContext context) { + return new OffsetSourceFieldMapper( + leafName(), + new OffsetSourceFieldType(context.buildFullName(leafName()), charset.get(), meta.getValue()), + builderParams(this, context) + ); + } + } + + public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n)); + + public static final class OffsetSourceFieldType extends MappedFieldType { + private final CharsetFormat charset; + + public OffsetSourceFieldType(String name, CharsetFormat charset, Map meta) { + super(name, true, false, false, TextSearchInfo.NONE, meta); + this.charset = charset; + } + + public Charset getCharset() { + return charset.charSet; + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + @Override + public boolean fieldHasValue(FieldInfos fieldInfos) { + return fieldInfos.fieldInfo(name()) != null; + } + + @Override + public IndexFieldData.Builder fielddataBuilder(FieldDataContext fieldDataContext) { + throw new IllegalArgumentException("[offset_source] fields do not support sorting, scripting or aggregating"); + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return new ValueFetcher() { + OffsetSourceField.OffsetSourceLoader offsetLoader; + + @Override + public void setNextReader(LeafReaderContext context) { + try { + var terms = context.reader().terms(name()); + offsetLoader = terms != null ? OffsetSourceField.loader(terms) : null; + } catch (IOException exc) { + throw new UncheckedIOException(exc); + } + } + + @Override + public List fetchValues(Source source, int doc, List ignoredValues) throws IOException { + var offsetSource = offsetLoader != null ? offsetLoader.advanceTo(doc) : null; + return offsetSource != null ? List.of(offsetSource) : null; + } + + @Override + public StoredFieldsSpec storedFieldsSpec() { + return StoredFieldsSpec.NO_REQUIREMENTS; + } + }; + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + throw new IllegalArgumentException("Queries on [offset_source] fields are not supported"); + } + + @Override + public boolean isSearchable() { + return false; + } + } + + /** + * @param simpleName the leaf name of the mapper + * @param mappedFieldType + * @param params initialization params for this field mapper + */ + protected OffsetSourceFieldMapper(String simpleName, MappedFieldType mappedFieldType, BuilderParams params) { + super(simpleName, mappedFieldType, params); + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + protected boolean supportsParsingObject() { + return true; + } + + @Override + protected void parseCreateField(DocumentParserContext context) throws IOException { + var parser = context.parser(); + if (parser.currentToken() == XContentParser.Token.VALUE_NULL) { + // skip + return; + } + + if (context.doc().getByKey(fullPath()) != null) { + throw new IllegalArgumentException( + "[offset_source] fields do not support indexing multiple values for the same field [" + + fullPath() + + "] in the same document" + ); + } + + // make sure that we don't expand dots in field names while parsing + boolean isWithinLeafObject = context.path().isWithinLeafObject(); + context.path().setWithinLeafObject(true); + try { + var offsetSource = OFFSET_SOURCE_PARSER.parse(parser, null); + context.doc() + .addWithKey( + fieldType().name(), + new OffsetSourceField(fullPath(), offsetSource.field, offsetSource.start, offsetSource.end) + ); + context.addToFieldNames(fieldType().name()); + } finally { + context.path().setWithinLeafObject(isWithinLeafObject); + } + } + + @Override + public FieldMapper.Builder getMergeBuilder() { + return new Builder(leafName()).init(this); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index e60e95b58770f..0f26f6577860f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -61,7 +61,7 @@ public record SemanticTextField(String fieldName, List originalValues, I static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id"; static final String CHUNKS_FIELD = "chunks"; static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; - static final String CHUNKED_TEXT_FIELD = "text"; + public static final String CHUNKED_TEXT_FIELD = "text"; static final String MODEL_SETTINGS_FIELD = "model_settings"; static final String TASK_TYPE_FIELD = "task_type"; static final String DIMENSIONS_FIELD = "dimensions"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 3744bf2a6dbed..683bb5a53028b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -46,7 +46,6 @@ import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.SimilarityMeasure; @@ -57,6 +56,7 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import java.io.IOException; import java.util.ArrayList; @@ -529,17 +529,15 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, Integer req ); } - // TODO: Use WeightedTokensQueryBuilder TextExpansionResults textExpansionResults = (TextExpansionResults) inferenceResults; - var boolQuery = QueryBuilders.boolQuery(); - for (var weightedToken : textExpansionResults.getWeightedTokens()) { - boolQuery.should( - QueryBuilders.termQuery(inferenceResultsFieldName, weightedToken.token()).boost(weightedToken.weight()) - ); - } - boolQuery.minimumShouldMatch(1); - - yield boolQuery; + yield new SparseVectorQueryBuilder( + inferenceResultsFieldName, + textExpansionResults.getWeightedTokens(), + null, + null, + null, + null + ); } case TEXT_EMBEDDING -> { if (inferenceResults instanceof MlTextEmbeddingResults == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java index fdb5503e491eb..15d41301d0a3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/random/RandomRankBuilder.java @@ -85,7 +85,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.RANDOM_RERANKER_RETRIEVER; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java index d208623e53324..7ad3e8eea0538 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankDoc.java @@ -98,6 +98,6 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.TEXT_SIMILARITY_RERANKER_QUERY_REWRITE; + return TransportVersions.V_8_16_0; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java index c239319b6283a..fd2427dc8ac6a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankRetrieverBuilder.java @@ -129,7 +129,10 @@ public TextSimilarityRankRetrieverBuilder( } @Override - protected TextSimilarityRankRetrieverBuilder clone(List newChildRetrievers) { + protected TextSimilarityRankRetrieverBuilder clone( + List newChildRetrievers, + List newPreFilterQueryBuilders + ) { return new TextSimilarityRankRetrieverBuilder( newChildRetrievers, inferenceId, @@ -138,7 +141,7 @@ protected TextSimilarityRankRetrieverBuilder clone(List newChil rankWindowSize, minScore, retrieverName, - preFilterQueryBuilders + newPreFilterQueryBuilders ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java index e72e68052f648..d911158e82296 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestChannel; @@ -21,27 +22,32 @@ import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; abstract class BaseInferenceAction extends BaseRestHandler { - @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String inferenceEntityId; - TaskType taskType; + static Params parseParams(RestRequest restRequest) { if (restRequest.hasParam(INFERENCE_ID)) { - inferenceEntityId = restRequest.param(INFERENCE_ID); - taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + var inferenceEntityId = restRequest.param(INFERENCE_ID); + var taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + return new Params(inferenceEntityId, taskType); } else { - inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID); - taskType = TaskType.ANY; + return new Params(restRequest.param(TASK_TYPE_OR_INFERENCE_ID), TaskType.ANY); } + } + + record Params(String inferenceEntityId, TaskType taskType) {} + + static TimeValue parseTimeout(RestRequest restRequest) { + return restRequest.paramAsTime(InferenceAction.Request.TIMEOUT.getPreferredName(), InferenceAction.Request.DEFAULT_TIMEOUT); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + var params = parseParams(restRequest); InferenceAction.Request.Builder requestBuilder; try (var parser = restRequest.contentParser()) { - requestBuilder = InferenceAction.Request.parseRequest(inferenceEntityId, taskType, parser); + requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser); } - var inferTimeout = restRequest.paramAsTime( - InferenceAction.Request.TIMEOUT.getPreferredName(), - InferenceAction.Request.DEFAULT_TIMEOUT - ); + var inferTimeout = parseTimeout(restRequest); requestBuilder.setInferenceTimeout(inferTimeout); var request = prepareInferenceRequest(requestBuilder); return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index 55d6443b43c03..c46f211bb26af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -30,6 +30,12 @@ public final class Paths { + "}/{" + INFERENCE_ID + "}/_stream"; + static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_unified"; + static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + + TASK_TYPE_OR_INFERENCE_ID + + "}/{" + + INFERENCE_ID + + "}/_unified"; private Paths() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java new file mode 100644 index 0000000000000..5c71b560a6b9d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.Scope; +import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH; +import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_TASK_TYPE_INFERENCE_ID_PATH; + +@ServerlessScope(Scope.PUBLIC) +public class RestUnifiedCompletionInferenceAction extends BaseRestHandler { + @Override + public String getName() { + return "unified_inference_action"; + } + + @Override + public List routes() { + return List.of(new Route(POST, UNIFIED_INFERENCE_ID_PATH), new Route(POST, UNIFIED_TASK_TYPE_INFERENCE_ID_PATH)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + var params = BaseInferenceAction.parseParams(restRequest); + + var inferTimeout = BaseInferenceAction.parseTimeout(restRequest); + + UnifiedCompletionAction.Request request; + try (var parser = restRequest.contentParser()) { + request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); + } + + return channel -> client.execute(UnifiedCompletionAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 8e2dac1ef9db2..e9b75e9ec7796 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.inference.services; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.InferenceService; @@ -17,11 +19,15 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import java.io.IOException; import java.util.EnumSet; @@ -61,11 +67,31 @@ public void infer( ActionListener listener ) { init(); - if (query != null) { - doInfer(model, new QueryAndDocsInputs(query, input, stream), taskSettings, inputType, timeout, listener); - } else { - doInfer(model, new DocumentsOnlyInput(input, stream), taskSettings, inputType, timeout, listener); - } + var inferenceInput = createInput(model, input, query, stream); + doInfer(model, inferenceInput, taskSettings, inputType, timeout, listener); + } + + private static InferenceInputs createInput(Model model, List input, @Nullable String query, boolean stream) { + return switch (model.getTaskType()) { + case COMPLETION -> new ChatCompletionInput(input, stream); + case RERANK -> new QueryAndDocsInputs(query, input, stream); + case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream); + default -> throw new ElasticsearchStatusException( + Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()), + RestStatus.BAD_REQUEST + ); + }; + } + + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + init(); + doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, listener); } @Override @@ -92,6 +118,13 @@ protected abstract void doInfer( ActionListener listener ); + protected abstract void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ); + protected abstract void doChunkedInfer( Model model, DocumentsOnlyInput inputs, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index ec4b8d9bb4d3d..7d05bac363fb1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -776,5 +776,9 @@ public static T nonNullOrDefault(@Nullable T requestValue, @Nullable T origi return requestValue == null ? originalSettingsValue : requestValue; } + public static void throwUnsupportedUnifiedCompletionOperation(String serviceName) { + throw new UnsupportedOperationException(Strings.format("The %s service does not support unified completion", serviceName)); + } + private ServiceUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index d7ac7caed7efc..ffd26b9ac534d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; @@ -57,14 +58,13 @@ import java.util.Map; import java.util.stream.Stream; -import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; -import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.HOST; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME; @@ -261,6 +261,16 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta ); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, @@ -359,7 +369,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + return TransportVersions.V_8_16_0; } public static class Configuration { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettings.java index 3500bdf814e16..f6ddac34a2b27 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceSettings.java @@ -163,7 +163,7 @@ public ToXContentObject getFilteredXContentObject() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettings.java index 631ec8a8648e8..a299cf5b655c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionServiceSettings.java @@ -74,7 +74,7 @@ public ToXContentObject getFilteredXContentObject() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java index 05b5873a81d8d..7883e7b1d90df 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/completion/AlibabaCloudSearchCompletionTaskSettings.java @@ -115,7 +115,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java index 8896e983d3e7f..8f40ce2a8b8b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsServiceSettings.java @@ -135,7 +135,7 @@ public ToXContentObject getFilteredXContentObject() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java index 9a431717d9fb9..a08ca6cce66d6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/embeddings/AlibabaCloudSearchEmbeddingsTaskSettings.java @@ -151,7 +151,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java index 42c7238aefa7f..40e645074f61c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankServiceSettings.java @@ -74,7 +74,7 @@ public ToXContentObject getFilteredXContentObject() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java index 40c3dee00d6c7..2a7806f4beab3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/rerank/AlibabaCloudSearchRerankTaskSettings.java @@ -85,7 +85,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettings.java index fe44c936c4e61..0a55d2aba6cea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseServiceSettings.java @@ -74,7 +74,7 @@ public ToXContentObject getFilteredXContentObject() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java index 0f4ebce920167..17c5b178c2a13 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/sparse/AlibabaCloudSearchSparseTaskSettings.java @@ -164,7 +164,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_ALIBABACLOUD_SEARCH_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 48b3c3df03e11..d224e50bb650d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -64,6 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; @@ -89,6 +91,16 @@ public AmazonBedrockService( this.amazonBedrockSender = amazonBedrockFactory.createSender(); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index b3d503de8e3eb..f1840af18779f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -32,6 +32,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -52,6 +53,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class AnthropicService extends SenderService { public static final String NAME = "anthropic"; @@ -192,6 +194,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index bba331fc0b5df..f8ea11e4b15a5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -63,6 +64,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.ENDPOINT_TYPE_FIELD; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TARGET_FIELD; @@ -81,6 +83,16 @@ public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents super(factory, serviceComponents); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 16c94dfa9ad94..a38c265d2613c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME; @@ -233,6 +235,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index b3d8b3b6efce3..ccb8d79dacd6c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class CohereService extends SenderService { @@ -232,6 +234,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java index a3d2483a068e2..78178466f9f3a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java @@ -92,7 +92,7 @@ public CohereRerankServiceSettings(@Nullable String url, @Nullable String modelI public CohereRerankServiceSettings(StreamInput in) throws IOException { this.uri = createOptionalUri(in.readOptionalString()); - if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED)) { + if (in.getTransportVersion().before(TransportVersions.V_8_16_0)) { // An older node sends these fields, so we need to skip them to progress through the serialized data in.readOptionalEnum(SimilarityMeasure.class); in.readOptionalVInt(); @@ -162,7 +162,7 @@ public void writeTo(StreamOutput out) throws IOException { var uriToWrite = uri != null ? uri.toString() : null; out.writeOptionalString(uriToWrite); - if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_COHERE_UNUSED_RERANK_SETTINGS_REMOVED)) { + if (out.getTransportVersion().before(TransportVersions.V_8_16_0)) { // An old node expects this data to be present, so we need to send at least the booleans // indicating that the fields are not set out.writeOptionalEnum(null); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 1f08c06edaa91..fe8ee52eb8816 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -57,6 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class ElasticInferenceService extends SenderService { @@ -76,6 +78,16 @@ public ElasticInferenceService( this.elasticInferenceServiceComponents = elasticInferenceServiceComponents; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, @@ -229,7 +241,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_EIS_INTEGRATION_ADDED; + return TransportVersions.V_8_16_0; } private ElasticInferenceServiceModel createModelFromPersistent( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java index bbda1bb716794..3af404aeef36b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java @@ -113,7 +113,7 @@ public RateLimitSettings rateLimitSettings() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_EIS_INTEGRATION_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 2ec3a9d629434..8cb91782e238e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption; @@ -77,6 +78,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_THREADS; @@ -569,6 +571,16 @@ private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomE ); } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void infer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 962c939146ef2..244108edc3dd4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -157,19 +157,17 @@ public ElasticsearchInternalServiceSettings(ElasticsearchInternalServiceSettings } public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException { - if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.numAllocations = in.readOptionalVInt(); } else { this.numAllocations = in.readVInt(); } this.numThreads = in.readVInt(); this.modelId = in.readString(); - this.adaptiveAllocationsSettings = in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) + this.adaptiveAllocationsSettings = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalWriteable(AdaptiveAllocationsSettings::new) : null; - this.deploymentId = in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ATTACH_TO_EXISTSING_DEPLOYMENT) - ? in.readOptionalString() - : null; + this.deploymentId = in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0) ? in.readOptionalString() : null; } public void setNumAllocations(Integer numAllocations) { @@ -178,17 +176,15 @@ public void setNumAllocations(Integer numAllocations) { @Override public void writeTo(StreamOutput out) throws IOException { - if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalVInt(getNumAllocations()); } else { out.writeVInt(getNumAllocations()); } out.writeVInt(getNumThreads()); out.writeString(modelId()); - if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeOptionalWriteable(getAdaptiveAllocationsSettings()); - } - if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ATTACH_TO_EXISTSING_DEPLOYMENT)) { out.writeOptionalString(deploymentId); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 57a8a66a3f3a6..b681722a82136 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -64,6 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class GoogleAiStudioService extends SenderService { @@ -282,9 +284,8 @@ protected void doInfer( ) { if (model instanceof GoogleAiStudioCompletionModel completionModel) { var requestManager = new GoogleAiStudioCompletionRequestManager(completionModel, getServiceComponents().threadPool()); - var docsOnly = DocumentsOnlyInput.of(inputs); var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( - completionModel.uri(docsOnly.stream()), + completionModel.uri(inputs.stream()), "Google AI Studio completion" ); var action = new SingleInputSenderExecutableAction( @@ -308,6 +309,16 @@ protected void doInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 857d475499aae..87a2d98dca92c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -57,6 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; @@ -206,6 +208,16 @@ protected void doInfer( action.execute(inputs, timeout, listener); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 51cca72f26054..b74ec01cd76e7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SettingsConfiguration; @@ -31,6 +32,7 @@ import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -47,6 +49,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; @@ -139,6 +142,16 @@ protected void doChunkedInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 75920efa251f2..5b038781b96af 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; @@ -49,6 +50,7 @@ import java.util.Map; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; public class HuggingFaceElserService extends HuggingFaceBaseService { @@ -81,6 +83,16 @@ protected HuggingFaceModel createModel( }; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index ea263fb77a2da..cc66d5fd7ee74 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -57,6 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE; @@ -223,7 +225,7 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_IBM_WATSONX_EMBEDDINGS_ADDED; + return TransportVersions.V_8_16_0; } @Override @@ -276,6 +278,16 @@ protected void doInfer( action.execute(input, timeout, listener); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsServiceSettings.java index 53d5c6c8bb5e8..3a9625aef31c7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsServiceSettings.java @@ -207,7 +207,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_IBM_WATSONX_EMBEDDINGS_ADDED; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index fe0edb851902b..881e7d36f2a21 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD; public class MistralService extends SenderService { @@ -88,6 +90,16 @@ protected void doInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 20ff1c617d21f..7b51b068708ca 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -32,10 +32,13 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.OpenAiUnifiedCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -53,6 +56,8 @@ import java.util.Map; import java.util.Set; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator.COMPLETION_ERROR_PREFIX; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -257,6 +262,28 @@ public void doInfer( action.execute(inputs, timeout, listener); } + @Override + public void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof OpenAiChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + OpenAiChatCompletionModel openAiModel = (OpenAiChatCompletionModel) model; + + var overriddenModel = OpenAiChatCompletionModel.of(openAiModel, inputs.getRequest()); + var requestCreator = OpenAiUnifiedCompletionRequestManager.of(overriddenModel, getServiceComponents().threadPool()); + var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getServiceSettings().uri(), COMPLETION_ERROR_PREFIX); + var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage); + + action.execute(inputs, timeout, listener); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index e721cd2955cf3..7d79d64b3a771 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -24,6 +25,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER; @@ -38,6 +40,26 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map< return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); } + public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) { + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new OpenAiChatCompletionServiceSettings( + Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), + originalModelServiceSettings.uri(), + originalModelServiceSettings.organizationId(), + originalModelServiceSettings.maxInputTokens(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new OpenAiChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getTaskSettings(), + model.getSecretSettings() + ); + } + public OpenAiChatCompletionModel( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java index 8029d8579baba..7ef7f85d71a6a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java @@ -48,5 +48,4 @@ public static OpenAiChatCompletionRequestTaskSettings fromMap(Map TaskType.fromStringOrStatusException(null)); + assertThat(exception.getMessage(), Matchers.is("Task type must not be null")); + + exception = expectThrows(ElasticsearchStatusException.class, () -> TaskType.fromStringOrStatusException("blah")); + assertThat(exception.getMessage(), Matchers.is("Unknown task_type [blah]")); + + assertThat(TaskType.fromStringOrStatusException("any"), Matchers.is(TaskType.ANY)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index 5abb9000f4d04..9395ae222e9ba 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.common.Truncator; @@ -160,9 +161,11 @@ public static Model getInvalidModel(String inferenceEntityId, String serviceName var mockConfigs = mock(ModelConfigurations.class); when(mockConfigs.getInferenceEntityId()).thenReturn(inferenceEntityId); when(mockConfigs.getService()).thenReturn(serviceName); + when(mockConfigs.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); var mockModel = mock(Model.class); when(mockModel.getConfigurations()).thenReturn(mockConfigs); + when(mockModel.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); return mockModel; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java new file mode 100644 index 0000000000000..47f3a0e0b57aa --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -0,0 +1,364 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Flow; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public abstract class BaseTransportInferenceActionTestCase extends ESTestCase { + private ModelRegistry modelRegistry; + private StreamingTaskManager streamingTaskManager; + private BaseTransportInferenceAction action; + + protected static final String serviceId = "serviceId"; + protected static final TaskType taskType = TaskType.COMPLETION; + protected static final String inferenceId = "inferenceEntityId"; + protected InferenceServiceRegistry serviceRegistry; + protected InferenceStats inferenceStats; + + @Before + public void setUp() throws Exception { + super.setUp(); + TransportService transportService = mock(); + ActionFilters actionFilters = mock(); + modelRegistry = mock(); + serviceRegistry = mock(); + inferenceStats = new InferenceStats(mock(), mock()); + streamingTaskManager = mock(); + action = createAction(transportService, actionFilters, modelRegistry, serviceRegistry, inferenceStats, streamingTaskManager); + } + + protected abstract BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ); + + protected abstract Request createRequest(); + + public void testMetricsAfterModelRegistryError() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onFailure(expectedException); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + var listener = doExecute(taskType); + verify(listener).onFailure(same(expectedException)); + + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), nullValue()); + assertThat(attributes.get("task_type"), nullValue()); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + protected ActionListener doExecute(TaskType taskType) { + return doExecute(taskType, false); + } + + protected ActionListener doExecute(TaskType taskType, boolean stream) { + Request request = createRequest(); + when(request.getInferenceEntityId()).thenReturn(inferenceId); + when(request.getTaskType()).thenReturn(taskType); + when(request.isStreaming()).thenReturn(stream); + ActionListener listener = mock(); + action.doExecute(mock(), request, listener); + return listener; + } + + public void testMetricsAfterMissingService() { + mockModelRegistry(taskType); + + when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); + + var listener = doExecute(taskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. ")); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + protected void mockModelRegistry(TaskType expectedTaskType) { + var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + } + + public void testMetricsAfterUnknownTaskType() { + var modelTaskType = TaskType.RERANK; + var requestTaskType = TaskType.SPARSE_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is( + "Incompatible task_type, the requested type [" + + requestTaskType + + "] does not match the model type [" + + modelTaskType + + "]" + ) + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testMetricsAfterInferError() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + mockService(listener -> listener.onFailure(expectedException)); + + var listener = doExecute(taskType); + + verify(listener).onFailure(same(expectedException)); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterStreamUnsupported() { + var expectedStatus = RestStatus.METHOD_NOT_ALLOWED; + var expectedError = String.valueOf(expectedStatus.getStatus()); + mockService(l -> {}); + + var listener = doExecute(taskType, true); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + var ese = (ElasticsearchStatusException) e; + assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "].")); + assertThat(ese.status(), is(expectedStatus)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(expectedStatus.getStatus())); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterInferSuccess() { + mockService(listener -> listener.onResponse(mock())); + + var listener = doExecute(taskType); + + verify(listener).onResponse(any()); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + public void testMetricsAfterStreamInferSuccess() { + mockStreamResponse(Flow.Subscriber::onComplete); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + public void testMetricsAfterStreamInferFailure() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + mockStreamResponse(subscriber -> { + subscriber.subscribe(mock()); + subscriber.onError(expectedException); + }); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterStreamCancel() { + var response = mockStreamResponse(s -> s.onSubscribe(mock())); + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscription.cancel(); + } + + @Override + public void onNext(ChunkedToXContent item) { + + } + + @Override + public void onError(Throwable throwable) { + + } + + @Override + public void onComplete() { + + } + }); + + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + protected Flow.Publisher mockStreamResponse(Consumer> action) { + mockService(true, Set.of(), listener -> { + Flow.Processor taskProcessor = mock(); + doAnswer(innerAns -> { + action.accept(innerAns.getArgument(0)); + return null; + }).when(taskProcessor).subscribe(any()); + when(streamingTaskManager.create(any(), any())).thenReturn(taskProcessor); + var inferenceServiceResults = mock(InferenceServiceResults.class); + when(inferenceServiceResults.publisher()).thenReturn(mock()); + listener.onResponse(inferenceServiceResults); + }); + + var listener = doExecute(taskType, true); + var captor = ArgumentCaptor.forClass(InferenceAction.Response.class); + verify(listener).onResponse(captor.capture()); + assertTrue(captor.getValue().isStreaming()); + assertNotNull(captor.getValue().publisher()); + return captor.getValue().publisher(); + } + + protected void mockService(Consumer> listenerAction) { + mockService(false, Set.of(), listenerAction); + } + + protected void mockService( + boolean stream, + Set supportedStreamingTasks, + Consumer> listenerAction + ) { + InferenceService service = mock(); + Model model = mockModel(); + when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model); + when(service.name()).thenReturn(serviceId); + + when(service.canStream(any())).thenReturn(stream); + when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks); + doAnswer(ans -> { + listenerAction.accept(ans.getArgument(7)); + return null; + }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + doAnswer(ans -> { + listenerAction.accept(ans.getArgument(3)); + return null; + }).when(service).unifiedCompletionInfer(any(), any(), any(), any()); + mockModelAndServiceRegistry(service); + } + + protected Model mockModel() { + Model model = mock(); + ModelConfigurations modelConfigurations = mock(); + when(modelConfigurations.getService()).thenReturn(serviceId); + when(model.getConfigurations()).thenReturn(modelConfigurations); + when(model.getTaskType()).thenReturn(taskType); + when(model.getServiceSettings()).thenReturn(mock()); + return model; + } + + protected void mockModelAndServiceRegistry(InferenceService service) { + var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index 0ed9cbf56b3fa..e54175cb27009 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -7,66 +7,28 @@ package org.elasticsearch.xpack.inference.action; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; -import org.junit.Before; -import org.mockito.ArgumentCaptor; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.Flow; -import java.util.function.Consumer; - -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.isA; -import static org.hamcrest.Matchers.nullValue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.assertArg; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -public class TransportInferenceActionTests extends ESTestCase { - private static final String serviceId = "serviceId"; - private static final TaskType taskType = TaskType.COMPLETION; - private static final String inferenceId = "inferenceEntityId"; - private ModelRegistry modelRegistry; - private InferenceServiceRegistry serviceRegistry; - private InferenceStats inferenceStats; - private StreamingTaskManager streamingTaskManager; - private TransportInferenceAction action; +public class TransportInferenceActionTests extends BaseTransportInferenceActionTestCase { - @Before - public void setUp() throws Exception { - super.setUp(); - TransportService transportService = mock(); - ActionFilters actionFilters = mock(); - modelRegistry = mock(); - serviceRegistry = mock(); - inferenceStats = new InferenceStats(mock(), mock()); - streamingTaskManager = mock(); - action = new TransportInferenceAction( + @Override + protected BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + return new TransportInferenceAction( transportService, actionFilters, modelRegistry, @@ -76,279 +38,8 @@ public void setUp() throws Exception { ); } - public void testMetricsAfterModelRegistryError() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onFailure(expectedException); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - - var listener = doExecute(taskType); - verify(listener).onFailure(same(expectedException)); - - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), nullValue()); - assertThat(attributes.get("task_type"), nullValue()); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - private ActionListener doExecute(TaskType taskType) { - return doExecute(taskType, false); - } - - private ActionListener doExecute(TaskType taskType, boolean stream) { - InferenceAction.Request request = mock(); - when(request.getInferenceEntityId()).thenReturn(inferenceId); - when(request.getTaskType()).thenReturn(taskType); - when(request.isStreaming()).thenReturn(stream); - ActionListener listener = mock(); - action.doExecute(mock(), request, listener); - return listener; - } - - public void testMetricsAfterMissingService() { - mockModelRegistry(taskType); - - when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); - - var listener = doExecute(taskType); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. ")); - assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); - assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); - })); - } - - private void mockModelRegistry(TaskType expectedTaskType) { - var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - } - - public void testMetricsAfterUnknownTaskType() { - var modelTaskType = TaskType.RERANK; - var requestTaskType = TaskType.SPARSE_EMBEDDING; - mockModelRegistry(modelTaskType); - when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); - - var listener = doExecute(requestTaskType); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - assertThat( - e.getMessage(), - is( - "Incompatible task_type, the requested type [" - + requestTaskType - + "] does not match the model type [" - + modelTaskType - + "]" - ) - ); - assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(modelTaskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); - assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); - })); - } - - public void testMetricsAfterInferError() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - mockService(listener -> listener.onFailure(expectedException)); - - var listener = doExecute(taskType); - - verify(listener).onFailure(same(expectedException)); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterStreamUnsupported() { - var expectedStatus = RestStatus.METHOD_NOT_ALLOWED; - var expectedError = String.valueOf(expectedStatus.getStatus()); - mockService(l -> {}); - - var listener = doExecute(taskType, true); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - var ese = (ElasticsearchStatusException) e; - assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "].")); - assertThat(ese.status(), is(expectedStatus)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(expectedStatus.getStatus())); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterInferSuccess() { - mockService(listener -> listener.onResponse(mock())); - - var listener = doExecute(taskType); - - verify(listener).onResponse(any()); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - public void testMetricsAfterStreamInferSuccess() { - mockStreamResponse(Flow.Subscriber::onComplete); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - public void testMetricsAfterStreamInferFailure() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - mockStreamResponse(subscriber -> { - subscriber.subscribe(mock()); - subscriber.onError(expectedException); - }); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterStreamCancel() { - var response = mockStreamResponse(s -> s.onSubscribe(mock())); - response.subscribe(new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscription.cancel(); - } - - @Override - public void onNext(ChunkedToXContent item) { - - } - - @Override - public void onError(Throwable throwable) { - - } - - @Override - public void onComplete() { - - } - }); - - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - private Flow.Publisher mockStreamResponse(Consumer> action) { - mockService(true, Set.of(), listener -> { - Flow.Processor taskProcessor = mock(); - doAnswer(innerAns -> { - action.accept(innerAns.getArgument(0)); - return null; - }).when(taskProcessor).subscribe(any()); - when(streamingTaskManager.create(any(), any())).thenReturn(taskProcessor); - var inferenceServiceResults = mock(InferenceServiceResults.class); - when(inferenceServiceResults.publisher()).thenReturn(mock()); - listener.onResponse(inferenceServiceResults); - }); - - var listener = doExecute(taskType, true); - var captor = ArgumentCaptor.forClass(InferenceAction.Response.class); - verify(listener).onResponse(captor.capture()); - assertTrue(captor.getValue().isStreaming()); - assertNotNull(captor.getValue().publisher()); - return captor.getValue().publisher(); - } - - private void mockService(Consumer> listenerAction) { - mockService(false, Set.of(), listenerAction); - } - - private void mockService( - boolean stream, - Set supportedStreamingTasks, - Consumer> listenerAction - ) { - InferenceService service = mock(); - Model model = mockModel(); - when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model); - when(service.name()).thenReturn(serviceId); - - when(service.canStream(any())).thenReturn(stream); - when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks); - doAnswer(ans -> { - listenerAction.accept(ans.getArgument(7)); - return null; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - mockModelAndServiceRegistry(service); - } - - private Model mockModel() { - Model model = mock(); - ModelConfigurations modelConfigurations = mock(); - when(modelConfigurations.getService()).thenReturn(serviceId); - when(model.getConfigurations()).thenReturn(modelConfigurations); - when(model.getTaskType()).thenReturn(taskType); - when(model.getServiceSettings()).thenReturn(mock()); - return model; - } - - private void mockModelAndServiceRegistry(InferenceService service) { - var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - - when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); + @Override + protected InferenceAction.Request createRequest() { + return mock(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java new file mode 100644 index 0000000000000..4c943599ce523 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; + +import java.util.Optional; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TransportUnifiedCompletionActionTests extends BaseTransportInferenceActionTestCase { + + @Override + protected BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + return new TransportUnifiedCompletionInferenceAction( + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager + ); + } + + @Override + protected UnifiedCompletionAction.Request createRequest() { + return mock(); + } + + public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInferenceEndpoint() { + var modelTaskType = TaskType.TEXT_EMBEDDING; + var requestTaskType = TaskType.TEXT_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testThrows_IncompatibleTaskTypeException_WhenUsingRequestIsAny_ModelIsTextEmbedding() { + var modelTaskType = TaskType.ANY; + var requestTaskType = TaskType.TEXT_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testMetricsAfterUnifiedInferSuccess_WithRequestTaskTypeAny() { + mockModelRegistry(TaskType.COMPLETION); + mockService(listener -> listener.onResponse(mock())); + + var listener = doExecute(TaskType.ANY); + + verify(listener).onResponse(any()); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index d4ab9b1f1e19a..9e7c58b0ca79e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -61,25 +61,11 @@ public void testOneInputIsValid() { assertTrue("Test failed to call listener.", testRan.get()); } - public void testInvalidInputType() { - var badInput = mock(InferenceInputs.class); - var actualException = new AtomicReference(); - - executableAction.execute( - badInput, - mock(TimeValue.class), - ActionListener.wrap(shouldNotSucceed -> fail("Test failed."), actualException::set) - ); - - assertThat(actualException.get(), notNullValue()); - assertThat(actualException.get().getMessage(), is("Invalid inference input type")); - assertThat(actualException.get(), instanceOf(ElasticsearchStatusException.class)); - assertThat(((ElasticsearchStatusException) actualException.get()).status(), is(RestStatus.INTERNAL_SERVER_ERROR)); - } - public void testMoreThanOneInput() { var badInput = mock(DocumentsOnlyInput.class); - when(badInput.getInputs()).thenReturn(List.of("one", "two")); + var input = List.of("one", "two"); + when(badInput.getInputs()).thenReturn(input); + when(badInput.inputSize()).thenReturn(input.size()); var actualException = new AtomicReference(); executableAction.execute( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java index 87d3a82b4aae6..e7543aa6ba9e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; @@ -130,7 +131,7 @@ public void testCompletionRequestAction() throws IOException { ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string")))); @@ -163,7 +164,7 @@ public void testChatCompletionRequestAction_HandlesException() throws IOExceptio ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(sender.sendCount(), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java index a3114300c5ddc..f0de37ceaaf98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -49,6 +49,7 @@ import static org.mockito.Mockito.mock; public class AnthropicActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -103,7 +104,7 @@ public void testCreate_ChatCompletionModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -168,7 +169,7 @@ public void testCreate_ChatCompletionModel_FailsFromInvalidResponseFormat() thro var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java index fca2e316af17f..2065a726b7589 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.AnthropicCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -113,7 +113,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -149,7 +149,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -170,7 +170,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -187,7 +187,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -229,7 +229,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java index 8792234102a94..210fab457de10 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -160,7 +161,7 @@ public void testChatCompletionRequestAction() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java index 45a2fb0954c79..7e1e3e55caed8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; @@ -475,7 +476,7 @@ public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOExcept var action = actionCreator.create(model, taskSettingsWithUserOverride); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -531,7 +532,7 @@ public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOExceptio var action = actionCreator.create(model, requestTaskSettingsWithoutUser); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -589,7 +590,7 @@ public void testInfer_AzureOpenAiCompletionModel_FailsFromInvalidResponseFormat( var action = actionCreator.create(model, requestTaskSettingsWithoutUser); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java index 4c7683c882816..dca12dfda9c98 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java @@ -26,7 +26,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; @@ -111,7 +111,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction("resource", "deployment", "apiversion", user, apiKey, sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -142,7 +142,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -163,7 +163,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -177,7 +177,7 @@ public void testExecute_ThrowsException() { var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 9ec34e7d8e5c5..3a512de25a39c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -197,7 +198,7 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep var action = actionCreator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -257,7 +258,7 @@ public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOEx var action = actionCreator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java index ba839e0d7c5e9..c5871adb34864 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java @@ -26,8 +26,8 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.CohereCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils; @@ -120,7 +120,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -181,7 +181,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws var action = createAction(getUrl(webServer), "secret", null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -214,7 +214,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -235,7 +235,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -256,7 +256,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -270,7 +270,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -284,7 +284,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -334,7 +334,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java index 72b5ffa45a0dd..ff17bbf66e02a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java @@ -25,7 +25,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -128,7 +128,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -159,7 +159,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -180,7 +180,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -197,7 +197,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -260,7 +260,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index b6d7eb673b7f0..fe076eb721ea2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -330,7 +331,7 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -345,11 +346,12 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -393,7 +395,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -408,10 +410,11 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(3)); + assertThat(requestMap.size(), is(4)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -455,7 +458,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -470,11 +473,12 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO assertNull(request.getHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -523,7 +527,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( @@ -542,11 +546,12 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( assertNull(webServer.requests().get(0).getHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index d84b2b5bb324a..ba74d2ab42c21 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager; @@ -119,7 +119,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -134,11 +134,12 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(request.getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -159,7 +160,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -180,7 +181,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -201,7 +202,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -215,7 +216,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -229,7 +230,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -273,7 +274,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java index e68beaf4c1eb5..929aefeeef6b9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; @@ -67,8 +68,15 @@ public void send( ActionListener listener ) { sendCounter++; - var docsInput = (DocumentsOnlyInput) inferenceInputs; - inputs.add(docsInput.getInputs()); + if (inferenceInputs instanceof DocumentsOnlyInput docsInput) { + inputs.add(docsInput.getInputs()); + } else if (inferenceInputs instanceof ChatCompletionInput chatCompletionInput) { + inputs.add(chatCompletionInput.getInputs()); + } else { + throw new IllegalArgumentException( + "Invalid inference inputs received in mock sender: " + inferenceInputs.getClass().getSimpleName() + ); + } if (results.isEmpty()) { listener.onFailure(new ElasticsearchException("No results found")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java index 7fa8a09d5bf12..a8f37aedcece3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockChatCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -107,7 +108,7 @@ public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws PlainActionFuture listener = new PlainActionFuture<>(); var requestManager = new AmazonBedrockChatCompletionRequestManager(model, threadPool, new TimeValue(30, TimeUnit.SECONDS)); - sender.send(requestManager, new DocumentsOnlyInput(List.of("abc")), null, listener); + sender.send(requestManager, new ChatCompletionInput(List.of("abc")), null, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test response text")))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java new file mode 100644 index 0000000000000..f0da67a982374 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import java.util.List; + +public class InferenceInputsTests extends ESTestCase { + public void testCastToSucceeds() { + InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false); + assertThat(inputs.castTo(DocumentsOnlyInput.class), Matchers.instanceOf(DocumentsOnlyInput.class)); + + var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null); + assertThat(new UnifiedChatInput(emptyRequest, false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class)); + assertThat( + new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class), + Matchers.instanceOf(QueryAndDocsInputs.class) + ); + } + + public void testCastToFails() { + InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false); + var exception = expectThrows(IllegalArgumentException.class, () -> inputs.castTo(QueryAndDocsInputs.class)); + assertThat( + exception.getMessage(), + Matchers.containsString( + Strings.format("Unable to convert inference inputs type: [%s] to [%s]", DocumentsOnlyInput.class, QueryAndDocsInputs.class) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java new file mode 100644 index 0000000000000..42e1b18168aec --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import java.util.List; + +public class UnifiedChatInputTests extends ESTestCase { + + public void testConvertsStringInputToMessages() { + var a = new UnifiedChatInput(List.of("hello", "awesome"), "a role", true); + + assertThat(a.inputSize(), Matchers.is(2)); + assertThat( + a.getRequest(), + Matchers.is( + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("hello"), + "a role", + null, + null, + null + ), + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("awesome"), + "a role", + null, + null, + null + ) + ) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java new file mode 100644 index 0000000000000..0f127998f9c54 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java @@ -0,0 +1,383 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; + +import java.io.IOException; +import java.util.List; + +public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase { + + public void testJsonLiteral() { + String json = """ + { + "id": "example_id", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": null, + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool_call_id", + "function": { + "arguments": "example_arguments", + "name": "example_function_name" + }, + "type": "function" + } + ] + }, + "finish_reason": "stop", + "index": 0 + } + ], + "model": "example_model", + "object": "chat.completion.chunk", + "usage": { + "completion_tokens": 50, + "prompt_tokens": 20, + "total_tokens": 70 + } + } + """; + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals("example_id", chunk.getId()); + assertEquals("example_model", chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNotNull(chunk.getUsage()); + assertEquals(50, chunk.getUsage().completionTokens()); + assertEquals(20, chunk.getUsage().promptTokens()); + assertEquals(70, chunk.getUsage().totalTokens()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertEquals("example_content", choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertEquals("assistant", choice.delta().getRole()); + assertEquals("stop", choice.finishReason()); + assertEquals(0, choice.index()); + + List toolCalls = choice.delta().getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(1, toolCall.getIndex()); + assertEquals("tool_call_id", toolCall.getId()); + assertEquals("example_function_name", toolCall.getFunction().getName()); + assertEquals("example_arguments", toolCall.getFunction().getArguments()); + assertEquals("function", toolCall.getType()); + } catch (IOException e) { + fail(); + } + } + + public void testJsonLiteralCornerCases() { + String json = """ + { + "id": "example_id", + "choices": [ + { + "delta": { + "content": null, + "refusal": null, + "role": "assistant", + "tool_calls": [] + }, + "finish_reason": null, + "index": 0 + }, + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "user", + "tool_calls": [ + { + "index": 1, + "function": { + "name": "example_function_name" + }, + "type": "function" + } + ] + }, + "finish_reason": "stop", + "index": 1 + } + ], + "model": "example_model", + "object": "chat.completion.chunk", + "usage": null + } + """; + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals("example_id", chunk.getId()); + assertEquals("example_model", chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNull(chunk.getUsage()); + + List choices = chunk.getChoices(); + assertEquals(2, choices.size()); + + // First choice assertions + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0); + assertNull(firstChoice.delta().getContent()); + assertNull(firstChoice.delta().getRefusal()); + assertEquals("assistant", firstChoice.delta().getRole()); + assertTrue(firstChoice.delta().getToolCalls().isEmpty()); + assertNull(firstChoice.finishReason()); + assertEquals(0, firstChoice.index()); + + // Second choice assertions + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice secondChoice = choices.get(1); + assertEquals("example_content", secondChoice.delta().getContent()); + assertEquals("example_refusal", secondChoice.delta().getRefusal()); + assertEquals("user", secondChoice.delta().getRole()); + assertEquals("stop", secondChoice.finishReason()); + assertEquals(1, secondChoice.index()); + + List toolCalls = secondChoice.delta() + .getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(1, toolCall.getIndex()); + assertNull(toolCall.getId()); + assertEquals("example_function_name", toolCall.getFunction().getName()); + assertNull(toolCall.getFunction().getArguments()); + assertEquals("function", toolCall.getType()); + } catch (IOException e) { + fail(); + } + } + + public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException { + // Generate random values for the JSON fields + int toolCallIndex = randomIntBetween(0, 10); + String toolCallId = randomAlphaOfLength(5); + String toolCallFunctionName = randomAlphaOfLength(8); + String toolCallFunctionArguments = randomAlphaOfLength(10); + String toolCallType = "function"; + String toolCallJson = createToolCallJson(toolCallIndex, toolCallId, toolCallFunctionName, toolCallFunctionArguments, toolCallType); + + String choiceContent = randomAlphaOfLength(10); + String choiceRole = randomFrom("system", "user", "assistant", "tool"); + String choiceFinishReason = randomFrom("stop", "length", "tool_calls", "content_filter", "function_call", null); + int choiceIndex = randomIntBetween(0, 10); + String choiceJson = createChoiceJson(choiceContent, null, choiceRole, toolCallJson, choiceFinishReason, choiceIndex); + + int usageCompletionTokens = randomIntBetween(1, 100); + int usagePromptTokens = randomIntBetween(1, 100); + int usageTotalTokens = randomIntBetween(1, 200); + String usageJson = createUsageJson(usageCompletionTokens, usagePromptTokens, usageTotalTokens); + + String chatCompletionChunkId = randomAlphaOfLength(10); + String chatCompletionChunkModel = randomAlphaOfLength(5); + String chatCompletionChunkJson = createChatCompletionChunkJson( + chatCompletionChunkId, + choiceJson, + chatCompletionChunkModel, + "chat.completion.chunk", + usageJson + ); + + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, chatCompletionChunkJson)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals(chatCompletionChunkId, chunk.getId()); + assertEquals(chatCompletionChunkModel, chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNotNull(chunk.getUsage()); + assertEquals(usageCompletionTokens, chunk.getUsage().completionTokens()); + assertEquals(usagePromptTokens, chunk.getUsage().promptTokens()); + assertEquals(usageTotalTokens, chunk.getUsage().totalTokens()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertEquals(choiceContent, choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertEquals(choiceRole, choice.delta().getRole()); + assertEquals(choiceFinishReason, choice.finishReason()); + assertEquals(choiceIndex, choice.index()); + + List toolCalls = choice.delta().getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(toolCallIndex, toolCall.getIndex()); + assertEquals(toolCallId, toolCall.getId()); + assertEquals(toolCallFunctionName, toolCall.getFunction().getName()); + assertEquals(toolCallFunctionArguments, toolCall.getFunction().getArguments()); + assertEquals(toolCallType, toolCall.getType()); + } + } + + public void testOpenAiUnifiedStreamingProcessorParsingWithNullFields() throws IOException { + // JSON with null fields + int choiceIndex = randomIntBetween(0, 10); + String choiceJson = createChoiceJson(null, null, null, "", null, choiceIndex); + + String chatCompletionChunkId = randomAlphaOfLength(10); + String chatCompletionChunkModel = randomAlphaOfLength(5); + String chatCompletionChunkJson = createChatCompletionChunkJson( + chatCompletionChunkId, + choiceJson, + chatCompletionChunkModel, + "chat.completion.chunk", + null + ); + + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, chatCompletionChunkJson)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals(chatCompletionChunkId, chunk.getId()); + assertEquals(chatCompletionChunkModel, chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNull(chunk.getUsage()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertNull(choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertNull(choice.delta().getRole()); + assertNull(choice.finishReason()); + assertEquals(choiceIndex, choice.index()); + assertTrue(choice.delta().getToolCalls().isEmpty()); + } + } + + private String createToolCallJson(int index, String id, String functionName, String functionArguments, String type) { + return Strings.format(""" + { + "index": %d, + "id": "%s", + "function": { + "name": "%s", + "arguments": "%s" + }, + "type": "%s" + } + """, index, id, functionName, functionArguments, type); + } + + private String createChoiceJson(String content, String refusal, String role, String toolCallsJson, String finishReason, int index) { + if (role == null) { + return Strings.format( + """ + { + "delta": { + "content": %s, + "refusal": %s, + "tool_calls": [%s] + }, + "finish_reason": %s, + "index": %d + } + """, + content != null ? "\"" + content + "\"" : "null", + refusal != null ? "\"" + refusal + "\"" : "null", + toolCallsJson, + finishReason != null ? "\"" + finishReason + "\"" : "null", + index + ); + } else { + return Strings.format( + """ + { + "delta": { + "content": %s, + "refusal": %s, + "role": %s, + "tool_calls": [%s] + }, + "finish_reason": %s, + "index": %d + } + """, + content != null ? "\"" + content + "\"" : "null", + refusal != null ? "\"" + refusal + "\"" : "null", + role != null ? "\"" + role + "\"" : "null", + toolCallsJson, + finishReason != null ? "\"" + finishReason + "\"" : "null", + index + ); + } + } + + private String createChatCompletionChunkJson(String id, String choicesJson, String model, String object, String usageJson) { + if (usageJson != null) { + return Strings.format(""" + { + "id": "%s", + "choices": [%s], + "model": "%s", + "object": "%s", + "usage": %s + } + """, id, choicesJson, model, object, usageJson); + } else { + return Strings.format(""" + { + "id": "%s", + "choices": [%s], + "model": "%s", + "object": "%s" + } + """, id, choicesJson, model, object); + } + } + + private String createUsageJson(int completionTokens, int promptTokens, int totalTokens) { + return Strings.format(""" + { + "completion_tokens": %d, + "prompt_tokens": %d, + "total_tokens": %d + } + """, completionTokens, promptTokens, totalTokens); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java index 7ffa8940ad6be..065dfee577a82 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java @@ -10,7 +10,7 @@ import org.apache.http.client.methods.HttpPost; import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequest; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests; @@ -72,7 +72,7 @@ public void testTruncationInfo_ReturnsNull() { assertNull(request.getTruncationInfo()); } - private static DocumentsOnlyInput listOf(String... input) { - return new DocumentsOnlyInput(List.of(input)); + private static ChatCompletionInput listOf(String... input) { + return new ChatCompletionInput(List.of(input)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntityTests.java deleted file mode 100644 index 9d5492f9e9516..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntityTests.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.openai; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; - -import java.io.IOException; -import java.util.List; - -import static org.hamcrest.CoreMatchers.is; - -public class OpenAiChatCompletionRequestEntityTests extends ESTestCase { - - public void testXContent_WritesUserWhenDefined() throws IOException { - var entity = new OpenAiChatCompletionRequestEntity(List.of("abc"), "model", "user", false); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"messages":[{"role":"user","content":"abc"}],"model":"model","n":1,"user":"user"}""")); - - } - - public void testXContent_DoesNotWriteUserWhenItIsNull() throws IOException { - var entity = new OpenAiChatCompletionRequestEntity(List.of("abc"), "model", null, false); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"messages":[{"role":"user","content":"abc"}],"model":"model","n":1}""")); - } - - public void testXContent_ThrowsIfModelIsNull() { - assertThrows(NullPointerException.class, () -> new OpenAiChatCompletionRequestEntity(List.of("abc"), null, "user", false)); - } - - public void testXContent_ThrowsIfMessagesAreNull() { - assertThrows(NullPointerException.class, () -> new OpenAiChatCompletionRequestEntity(null, "model", "user", false)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..f945c154ea234 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,856 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.openai; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Random; + +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.hamcrest.Matchers.equalTo; + +public class OpenAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { + + // 1. Basic Serialization + // Test with minimal required fields to ensure basic serialization works. + public void testBasicSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 2. Serialization with All Fields + // Test with all possible fields populated to ensure complete serialization. + public void testSerializationWithAllFields() throws IOException { + // Create a message with all fields populated + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "name", + "tool_call_id", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments", "function_name"), + "type" + ) + ) + ); + + // Create a tool with all fields populated + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with all fields populated + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList(tool), + 0.8f // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "name": "name", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "function_name" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "max_completion_tokens": 100, + "n": 1, + "stop": ["stop"], + "temperature": 0.9, + "tool_choice": "tool_choice", + "tools": [ + { + "type": "type", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "description": "The location to get the weather for", + "type": "string" + }, + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": 0.8, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + + } + + // 3. Serialization with Null Optional Fields + // Test with optional fields set to null to ensure they are correctly omitted from the output. + public void testSerializationWithNullOptionalFields() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + // Create the unified request with optional fields set to null + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 4. Serialization with Empty Lists + // Test with fields that are lists set to empty lists to ensure they are correctly serialized. + public void testSerializationWithEmptyLists() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + Collections.emptyList() // empty toolCalls list + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with empty lists + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + Collections.emptyList(), // empty stop list + null, // temperature + null, // toolChoice + Collections.emptyList(), // empty tools list + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "tool_calls": [] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 5. Serialization with Nested Objects + // Test with nested objects (e.g., toolCalls, toolChoice, tool) to ensure they are correctly serialized. + public void testSerializationWithNestedObjects() throws IOException { + Random random = Randomness.get(); + + // Generate random values + String randomContent = "Hello, world! " + random.nextInt(1000); + String randomName = "name" + random.nextInt(1000); + String randomToolCallId = "tool_call_id" + random.nextInt(1000); + String randomArguments = "arguments" + random.nextInt(1000); + String randomFunctionName = "function_name" + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + String randomModel = "model" + random.nextInt(1000); + String randomStop = "stop" + random.nextInt(1000); + float randomTemperature = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + float randomTopP = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + + // Create a message with nested toolCalls + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContent), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + randomName, + randomToolCallId, + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField(randomArguments, randomFunctionName), + randomType + ) + ) + ); + + // Create a tool with nested function fields + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + randomType, + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with nested objects + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + randomModel, + 100L, // maxCompletionTokens + Collections.singletonList(randomStop), + randomTemperature, // temperature + new UnifiedCompletionRequest.ToolChoiceObject( + randomType, + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomFunctionName) + ), + Collections.singletonList(tool), + randomTopP // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", randomModel, null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + // Expected JSON should be dynamically generated based on random values + String expectedJson = String.format( + Locale.US, + """ + { + "messages": [ + { + "content": "%s", + "role": "user", + "name": "%s", + "tool_call_id": "%s", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "%s", + "name": "%s" + }, + "type": "%s" + } + ] + } + ], + "model": "%s", + "max_completion_tokens": 100, + "n": 1, + "stop": ["%s"], + "temperature": %.5f, + "tool_choice": { + "type": "%s", + "function": { + "name": "%s" + } + }, + "tools": [ + { + "type": "%s", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + }, + "location": { + "description": "The location to get the weather for", + "type": "string" + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": %.5f, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, + randomContent, + randomName, + randomToolCallId, + randomArguments, + randomFunctionName, + randomType, + randomModel, + randomStop, + randomTemperature, + randomType, + randomFunctionName, + randomType, + randomTopP + ); + assertJsonEquals(jsonString, expectedJson); + } + + // 6. Serialization with Different Content Types + // Test with different content types in messages (e.g., ContentString, ContentObjects) to ensure they are correctly serialized. + public void testSerializationWithDifferentContentTypes() throws IOException { + Random random = Randomness.get(); + + // Generate random values for ContentString + String randomContentString = "Hello, world! " + random.nextInt(1000); + + // Generate random values for ContentObjects + String randomText = "Random text " + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + UnifiedCompletionRequest.ContentObject contentObject = new UnifiedCompletionRequest.ContentObject(randomText, randomType); + + var contentObjectsList = new ArrayList(); + contentObjectsList.add(contentObject); + UnifiedCompletionRequest.ContentObjects contentObjects = new UnifiedCompletionRequest.ContentObjects(contentObjectsList); + + // Create messages with different content types + UnifiedCompletionRequest.Message messageWithString = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContentString), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + + UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message( + contentObjects, + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(messageWithString); + messageList.add(messageWithObjects); + + // Create the unified request with both types of messages + UnifiedCompletionRequest unifiedRequest = UnifiedCompletionRequest.of(messageList); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = String.format(Locale.US, """ + { + "messages": [ + { + "content": "%s", + "role": "user" + }, + { + "content": [ + { + "text": "%s", + "type": "%s" + } + ], + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, randomContentString, randomText, randomType); + assertJsonEquals(jsonString, expectedJson); + } + + // 7. Serialization with Special Characters + // Test with special characters in string fields to ensure they are correctly escaped and serialized. + public void testSerializationWithSpecialCharacters() throws IOException { + // Create a message with special characters + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "name\nwith\nnewlines", + "tool_call_id\twith\ttabs", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", + "role": "user", + "name": "name\\nwith\\nnewlines", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 8. Serialization with Boolean Fields + // Test with boolean fields (stream) set to both true and false to ensure they are correctly serialized. + public void testSerializationWithBooleanFields() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Test with stream set to true + UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); + OpenAiUnifiedChatCompletionRequestEntity entityTrue = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputTrue, model); + + XContentBuilder builderTrue = JsonXContent.contentBuilder(); + entityTrue.toXContent(builderTrue, ToXContent.EMPTY_PARAMS); + + String jsonStringTrue = Strings.toString(builderTrue); + String expectedJsonTrue = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(expectedJsonTrue, jsonStringTrue); + + // Test with stream set to false + UnifiedChatInput unifiedChatInputFalse = new UnifiedChatInput(unifiedRequest, false); + OpenAiUnifiedChatCompletionRequestEntity entityFalse = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputFalse, model); + + XContentBuilder builderFalse = JsonXContent.contentBuilder(); + entityFalse.toXContent(builderFalse, ToXContent.EMPTY_PARAMS); + + String jsonStringFalse = Strings.toString(builderFalse); + String expectedJsonFalse = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": false + } + """; + assertJsonEquals(expectedJsonFalse, jsonStringFalse); + } + + // 9. Serialization with Missing Required Fields + // Test with missing required fields to ensure appropriate exceptions are thrown. + public void testSerializationWithMissingRequiredFields() { + // Create a message with missing content (required field) + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + null, // missing content + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Attempt to serialize to XContent and expect an exception + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + fail("Expected an exception due to missing required fields"); + } catch (NullPointerException | IOException e) { + // Expected exception + } + } + + // 10. Serialization with Mixed Valid and Invalid Data + // Test with a mix of valid and invalid data to ensure the serializer handles it gracefully. + public void testSerializationWithMixedValidAndInvalidData() throws IOException { + // Create a valid message + UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Valid content"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "validName", + "validToolCallId", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "validId", + new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"), + "validType" + ) + ) + ); + + // Create an invalid message with null content + UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message( + null, // invalid content + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "invalidName", + "invalidToolCallId", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "invalidId", + new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"), + "invalidType" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(validMessage); + messageList.add(invalidMessage); + // Create the unified request with both valid and invalid messages + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model-name", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList( + new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ) + ), + 0.8f // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent and verify + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + fail("Expected an exception due to invalid data"); + } catch (NullPointerException | IOException e) { + // Expected exception + } + } + + public static Map createParameters() { + Map parameters = new LinkedHashMap<>(); + parameters.put("type", "object"); + + Map properties = new HashMap<>(); + + Map location = new HashMap<>(); + location.put("type", "string"); + location.put("description", "The location to get the weather for"); + properties.put("location", location); + + Map unit = new HashMap<>(); + unit.put("type", "string"); + unit.put("description", "The unit to return the temperature in"); + unit.put("enum", new String[] { "F", "C" }); + properties.put("unit", unit); + + parameters.put("properties", properties); + parameters.put("additionalProperties", false); + parameters.put("required", new String[] { "location", "unit" }); + + return parameters; + } + + private void assertJsonEquals(String actual, String expected) throws IOException { + try ( + var actualParser = createParser(JsonXContent.jsonXContent, actual); + var expectedParser = createParser(JsonXContent.jsonXContent, expected) + ) { + assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered())); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java similarity index 75% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java index b6ebfd02941f3..2be12c9b12e0b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; import java.io.IOException; @@ -20,16 +21,16 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest.buildDefaultUri; +import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest.buildDefaultUri; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class OpenAiChatCompletionRequestTests extends ESTestCase { +public class OpenAiUnifiedChatCompletionRequestTests extends ESTestCase { public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOException { - var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user"); + var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user", true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -41,15 +42,27 @@ public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOExceptio assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); + assertRequestMapWithUser(requestMap, "user"); + } + + private void assertRequestMapWithoutUser(Map requestMap) { + assertRequestMapWithUser(requestMap, null); + } + + private void assertRequestMapWithUser(Map requestMap, @Nullable String user) { + assertThat(requestMap, aMapWithSize(user != null ? 6 : 5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("user"), is("user")); + if (user != null) { + assertThat(requestMap.get("user"), is(user)); + } assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); } public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOException { - var request = createRequest(null, "org", "secret", "abc", "model", "user"); + var request = createRequest(null, "org", "secret", "abc", "model", "user", true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -61,33 +74,27 @@ public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOExce assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("user"), is("user")); - assertThat(requestMap.get("n"), is(1)); + assertRequestMapWithUser(requestMap, "user"); + } public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws URISyntaxException, IOException { - var request = createRequest(null, null, "secret", "abc", "model", null); + var request = createRequest(null, null, "secret", "abc", "model", null, true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); + assertThat(httpPost.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(3)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("n"), is(1)); + assertRequestMapWithoutUser(requestMap); } - public void testCreateRequest_WithStreaming() throws URISyntaxException, IOException { + public void testCreateRequest_WithStreaming() throws IOException { var request = createRequest(null, null, "secret", "abc", "model", null, true); var httpRequest = request.createHttpRequest(); @@ -99,29 +106,31 @@ public void testCreateRequest_WithStreaming() throws URISyntaxException, IOExcep } public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException { - var request = createRequest(null, null, "secret", "abcd", "model", null); + var request = createRequest(null, null, "secret", "abcd", "model", null, true); var truncatedRequest = request.truncate(); - assertThat(request.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); + assertThat(request.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); var httpRequest = truncatedRequest.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap, aMapWithSize(5)); // We do not truncate for OpenAi chat completions assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); } public void testTruncationInfo_ReturnsNull() { - var request = createRequest(null, null, "secret", "abcd", "model", null); + var request = createRequest(null, null, "secret", "abcd", "model", null, true); assertNull(request.getTruncationInfo()); } - public static OpenAiChatCompletionRequest createRequest( + public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -132,7 +141,7 @@ public static OpenAiChatCompletionRequest createRequest( return createRequest(url, org, apiKey, input, model, user, false); } - public static OpenAiChatCompletionRequest createRequest( + public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -142,7 +151,7 @@ public static OpenAiChatCompletionRequest createRequest( boolean stream ) { var chatCompletionModel = OpenAiChatCompletionModelTests.createChatCompletionModel(url, org, apiKey, model, user); - return new OpenAiChatCompletionRequest(List.of(input), chatCompletionModel, stream); + return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java new file mode 100644 index 0000000000000..7dc4d99e06acc --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java @@ -0,0 +1,288 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.highlight; + +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.Streams; +import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MapperServiceTestCase; +import org.elasticsearch.index.mapper.SourceToParse; +import org.elasticsearch.index.query.NestedQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.FetchContext; +import org.elasticsearch.search.fetch.FetchSubPhase; +import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext; +import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.lookup.Source; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.junit.Before; +import org.mockito.Mockito; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.zip.GZIPInputStream; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.mockito.Mockito.mock; + +public class SemanticTextHighlighterTests extends MapperServiceTestCase { + private static final String SEMANTIC_FIELD_E5 = "body-e5"; + private static final String SEMANTIC_FIELD_ELSER = "body-elser"; + + private Map queries; + + @Override + protected Collection getPlugins() { + return List.of(new InferencePlugin(Settings.EMPTY)); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + var input = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("queries.json")); + this.queries = XContentHelper.convertToMap(input, false, XContentType.JSON).v2(); + } + + @SuppressWarnings("unchecked") + public void testDenseVector() throws Exception { + var mapperService = createDefaultMapperService(); + Map queryMap = (Map) queries.get("dense_vector_1"); + float[] vector = readDenseVector(queryMap.get("embeddings")); + var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_E5); + KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), knnQuery, ScoreMode.Max); + var shardRequest = createShardSearchRequest(nestedQueryBuilder); + var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON); + + String[] expectedScorePassages = ((List) queryMap.get("expected_by_score")).toArray(String[]::new); + for (int i = 0; i < expectedScorePassages.length; i++) { + assertHighlightOneDoc( + mapperService, + shardRequest, + sourceToParse, + SEMANTIC_FIELD_E5, + i + 1, + HighlightBuilder.Order.SCORE, + Arrays.copyOfRange(expectedScorePassages, 0, i + 1) + ); + } + + String[] expectedOffsetPassages = ((List) queryMap.get("expected_by_offset")).toArray(String[]::new); + assertHighlightOneDoc( + mapperService, + shardRequest, + sourceToParse, + SEMANTIC_FIELD_E5, + expectedOffsetPassages.length, + HighlightBuilder.Order.NONE, + expectedOffsetPassages + ); + } + + @SuppressWarnings("unchecked") + public void testSparseVector() throws Exception { + var mapperService = createDefaultMapperService(); + Map queryMap = (Map) queries.get("sparse_vector_1"); + List tokens = readSparseVector(queryMap.get("embeddings")); + var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_ELSER); + SparseVectorQueryBuilder sparseQuery = new SparseVectorQueryBuilder( + fieldType.getEmbeddingsField().fullPath(), + tokens, + null, + null, + null, + null + ); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), sparseQuery, ScoreMode.Max); + var shardRequest = createShardSearchRequest(nestedQueryBuilder); + var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON); + + String[] expectedScorePassages = ((List) queryMap.get("expected_by_score")).toArray(String[]::new); + for (int i = 0; i < expectedScorePassages.length; i++) { + assertHighlightOneDoc( + mapperService, + shardRequest, + sourceToParse, + SEMANTIC_FIELD_ELSER, + i + 1, + HighlightBuilder.Order.SCORE, + Arrays.copyOfRange(expectedScorePassages, 0, i + 1) + ); + } + + String[] expectedOffsetPassages = ((List) queryMap.get("expected_by_offset")).toArray(String[]::new); + assertHighlightOneDoc( + mapperService, + shardRequest, + sourceToParse, + SEMANTIC_FIELD_ELSER, + expectedOffsetPassages.length, + HighlightBuilder.Order.NONE, + expectedOffsetPassages + ); + } + + private MapperService createDefaultMapperService() throws IOException { + var mappings = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("mappings.json")); + return createMapperService(mappings.utf8ToString()); + } + + private float[] readDenseVector(Object value) { + if (value instanceof List lst) { + float[] res = new float[lst.size()]; + int pos = 0; + for (var obj : lst) { + if (obj instanceof Number number) { + res[pos++] = number.floatValue(); + } else { + throw new IllegalArgumentException("Expected number, got " + obj.getClass().getSimpleName()); + } + } + return res; + } + throw new IllegalArgumentException("Expected list, got " + value.getClass().getSimpleName()); + } + + private List readSparseVector(Object value) { + if (value instanceof Map map) { + List res = new ArrayList<>(); + for (var entry : map.entrySet()) { + if (entry.getValue() instanceof Number number) { + res.add(new WeightedToken((String) entry.getKey(), number.floatValue())); + } else { + throw new IllegalArgumentException("Expected number, got " + entry.getValue().getClass().getSimpleName()); + } + } + return res; + } + throw new IllegalArgumentException("Expected map, got " + value.getClass().getSimpleName()); + } + + private void assertHighlightOneDoc( + MapperService mapperService, + ShardSearchRequest request, + SourceToParse source, + String fieldName, + int numFragments, + HighlightBuilder.Order order, + String[] expectedPassages + ) throws Exception { + SemanticTextFieldMapper fieldMapper = (SemanticTextFieldMapper) mapperService.mappingLookup().getMapper(fieldName); + var doc = mapperService.documentMapper().parse(source); + assertNull(doc.dynamicMappingsUpdate()); + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(new StandardAnalyzer()); + RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwc); + iw.addDocuments(doc.docs()); + try (DirectoryReader reader = wrapInMockESDirectoryReader(iw.getReader())) { + IndexSearcher searcher = newSearcher(reader); + iw.close(); + TopDocs topDocs = searcher.search(Queries.newNonNestedFilter(IndexVersion.current()), 1, Sort.INDEXORDER); + assertThat(topDocs.totalHits.value(), equalTo(1L)); + int docID = topDocs.scoreDocs[0].doc; + SemanticTextHighlighter highlighter = new SemanticTextHighlighter(); + var execContext = createSearchExecutionContext(mapperService); + var luceneQuery = execContext.toQuery(request.source().query()).query(); + FetchContext fetchContext = mock(FetchContext.class); + Mockito.when(fetchContext.highlight()).thenReturn(new SearchHighlightContext(Collections.emptyList())); + Mockito.when(fetchContext.query()).thenReturn(luceneQuery); + Mockito.when(fetchContext.getSearchExecutionContext()).thenReturn(execContext); + + FetchSubPhase.HitContext hitContext = new FetchSubPhase.HitContext( + new SearchHit(docID), + getOnlyLeafReader(reader).getContext(), + docID, + Map.of(), + Source.fromBytes(source.source()), + new RankDoc(docID, Float.NaN, 0) + ); + try { + var highlightContext = new HighlightBuilder().field(fieldName, 0, numFragments) + .order(order) + .highlighterType(SemanticTextHighlighter.NAME) + .build(execContext); + + for (var fieldContext : highlightContext.fields()) { + FieldHighlightContext context = new FieldHighlightContext( + fieldName, + fieldContext, + fieldMapper.fieldType(), + fetchContext, + hitContext, + luceneQuery, + new HashMap<>() + ); + var result = highlighter.highlight(context); + assertThat(result.fragments().length, equalTo(expectedPassages.length)); + for (int i = 0; i < result.fragments().length; i++) { + assertThat(result.fragments()[i].string(), equalTo(expectedPassages[i])); + } + } + } finally { + hitContext.hit().decRef(); + } + } + } + } + + private SearchRequest createSearchRequest(QueryBuilder queryBuilder) { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.allowPartialSearchResults(false); + request.source().query(queryBuilder); + return request; + } + + private ShardSearchRequest createShardSearchRequest(QueryBuilder queryBuilder) { + SearchRequest request = createSearchRequest(queryBuilder); + return new ShardSearchRequest(OriginalIndices.NONE, request, new ShardId("index", "index", 0), 0, 1, AliasFilter.EMPTY, 1, 0, null); + } + + private BytesReference readSampleDoc(String fileName) throws IOException { + try (var in = new GZIPInputStream(SemanticTextHighlighterTests.class.getResourceAsStream(fileName))) { + return new BytesArray(new BytesRef(in.readAllBytes())); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldMapperTests.java new file mode 100644 index 0000000000000..40140d6da5eb5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldMapperTests.java @@ -0,0 +1,216 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.OffsetAttribute; +import org.apache.lucene.index.IndexableField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.mapper.DocumentMapper; +import org.elasticsearch.index.mapper.DocumentParsingException; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MapperTestCase; +import org.elasticsearch.index.mapper.ParsedDocument; +import org.elasticsearch.index.mapper.SourceToParse; +import org.elasticsearch.index.mapper.ValueFetcher; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.lookup.Source; +import org.elasticsearch.search.lookup.SourceProvider; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.junit.AssumptionViolatedException; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class OffsetSourceFieldMapperTests extends MapperTestCase { + @Override + protected Collection getPlugins() { + return List.of(new InferencePlugin(Settings.EMPTY)); + } + + @Override + protected void minimalMapping(XContentBuilder b) throws IOException { + b.field("type", "offset_source"); + } + + @Override + protected Object getSampleValueForDocument() { + return getSampleObjectForDocument(); + } + + @Override + protected Object getSampleObjectForDocument() { + return Map.of("field", "foo", "start", 100, "end", 300); + } + + @Override + protected Object generateRandomInputValue(MappedFieldType ft) { + return new OffsetSourceFieldMapper.OffsetSource("field", randomIntBetween(0, 100), randomIntBetween(101, 1000)); + } + + @Override + protected IngestScriptSupport ingestScriptSupport() { + throw new AssumptionViolatedException("not supported"); + } + + @Override + protected void registerParameters(ParameterChecker checker) throws IOException {} + + @Override + protected void assertSearchable(MappedFieldType fieldType) { + assertFalse(fieldType.isSearchable()); + } + + @Override + protected boolean supportsStoredFields() { + return false; + } + + @Override + protected boolean supportsEmptyInputArray() { + return false; + } + + @Override + protected boolean supportsCopyTo() { + return false; + } + + @Override + protected boolean supportsIgnoreMalformed() { + return false; + } + + @Override + protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) { + return new SyntheticSourceSupport() { + @Override + public SyntheticSourceExample example(int maxValues) { + return new SyntheticSourceExample(getSampleValueForDocument(), getSampleValueForDocument(), null, b -> minimalMapping(b)); + } + + @Override + public List invalidExample() { + return List.of(); + } + }; + } + + @Override + public void testSyntheticSourceKeepArrays() { + // This mapper doesn't support multiple values (array of objects). + } + + public void testDefaults() throws Exception { + DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); + assertEquals(Strings.toString(fieldMapping(this::minimalMapping)), mapper.mappingSource().toString()); + + ParsedDocument doc1 = mapper.parse( + source(b -> b.startObject("field").field("field", "foo").field("start", 0).field("end", 128).endObject()) + ); + List fields = doc1.rootDoc().getFields("field"); + assertEquals(1, fields.size()); + assertThat(fields.get(0), instanceOf(OffsetSourceField.class)); + OffsetSourceField offsetField1 = (OffsetSourceField) fields.get(0); + + ParsedDocument doc2 = mapper.parse( + source(b -> b.startObject("field").field("field", "bar").field("start", 128).field("end", 512).endObject()) + ); + OffsetSourceField offsetField2 = (OffsetSourceField) doc2.rootDoc().getFields("field").get(0); + + assertTokenStream(offsetField1.tokenStream(null, null), "foo", 0, 128); + assertTokenStream(offsetField2.tokenStream(null, null), "bar", 128, 512); + } + + private void assertTokenStream(TokenStream tk, String expectedTerm, int expectedStartOffset, int expectedEndOffset) throws IOException { + CharTermAttribute termAttribute = tk.addAttribute(CharTermAttribute.class); + OffsetAttribute offsetAttribute = tk.addAttribute(OffsetAttribute.class); + tk.reset(); + assertTrue(tk.incrementToken()); + assertThat(new String(termAttribute.buffer(), 0, termAttribute.length()), equalTo(expectedTerm)); + assertThat(offsetAttribute.startOffset(), equalTo(expectedStartOffset)); + assertThat(offsetAttribute.endOffset(), equalTo(expectedEndOffset)); + assertFalse(tk.incrementToken()); + } + + @Override + protected void assertFetch(MapperService mapperService, String field, Object value, String format) throws IOException { + MappedFieldType ft = mapperService.fieldType(field); + MappedFieldType.FielddataOperation fdt = MappedFieldType.FielddataOperation.SEARCH; + SourceToParse source = source(b -> b.field(ft.name(), value)); + SearchExecutionContext searchExecutionContext = mock(SearchExecutionContext.class); + when(searchExecutionContext.isSourceEnabled()).thenReturn(true); + when(searchExecutionContext.sourcePath(field)).thenReturn(Set.of(field)); + when(searchExecutionContext.getForField(ft, fdt)).thenAnswer(inv -> fieldDataLookup(mapperService).apply(ft, () -> { + throw new UnsupportedOperationException(); + }, fdt)); + ValueFetcher nativeFetcher = ft.valueFetcher(searchExecutionContext, format); + ParsedDocument doc = mapperService.documentMapper().parse(source); + withLuceneIndex(mapperService, iw -> iw.addDocuments(doc.docs()), ir -> { + Source s = SourceProvider.fromStoredFields().getSource(ir.leaves().get(0), 0); + nativeFetcher.setNextReader(ir.leaves().get(0)); + List fromNative = nativeFetcher.fetchValues(s, 0, new ArrayList<>()); + assertThat(fromNative.size(), equalTo(1)); + assertThat("fetching " + value, fromNative.get(0), equalTo(value)); + }); + } + + @Override + protected void assertFetchMany(MapperService mapperService, String field, Object value, String format, int count) throws IOException { + assumeFalse("[offset_source] currently don't support multiple values in the same field", false); + } + + public void testInvalidCharset() { + var exc = expectThrows(Exception.class, () -> createDocumentMapper(mapping(b -> { + b.startObject("field").field("type", "offset_source").field("charset", "utf_8").endObject(); + }))); + assertThat(exc.getCause().getMessage(), containsString("Unknown value [utf_8] for field [charset]")); + } + + public void testRejectMultiValuedFields() throws IOException { + DocumentMapper mapper = createDocumentMapper(mapping(b -> { b.startObject("field").field("type", "offset_source").endObject(); })); + + DocumentParsingException exc = expectThrows(DocumentParsingException.class, () -> mapper.parse(source(b -> { + b.startArray("field"); + { + b.startObject().field("field", "bar1").field("start", 128).field("end", 512).endObject(); + b.startObject().field("field", "bar2").field("start", 128).field("end", 512).endObject(); + } + b.endArray(); + }))); + assertThat(exc.getCause().getMessage(), containsString("[offset_source] fields do not support indexing multiple values")); + } + + public void testInvalidOffsets() throws IOException { + DocumentMapper mapper = createDocumentMapper(mapping(b -> { b.startObject("field").field("type", "offset_source").endObject(); })); + + DocumentParsingException exc = expectThrows(DocumentParsingException.class, () -> mapper.parse(source(b -> { + b.startArray("field"); + { + b.startObject().field("field", "bar1").field("start", -1).field("end", 512).endObject(); + } + b.endArray(); + }))); + assertThat(exc.getCause().getCause().getCause().getMessage(), containsString("Illegal offsets")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldTests.java new file mode 100644 index 0000000000000..4d86263e446f8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldTests.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.document.Document; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.test.ESTestCase; + +public class OffsetSourceFieldTests extends ESTestCase { + public void testBasics() throws Exception { + Directory dir = newDirectory(); + RandomIndexWriter writer = new RandomIndexWriter( + random(), + dir, + newIndexWriterConfig().setMergePolicy(newLogMergePolicy(random().nextBoolean())) + ); + Document doc = new Document(); + OffsetSourceField field1 = new OffsetSourceField("field1", "foo", 1, 10); + doc.add(field1); + writer.addDocument(doc); + + field1.setValues("bar", 10, 128); + writer.addDocument(doc); + + writer.addDocument(new Document()); // gap + + field1.setValues("foo", 50, 256); + writer.addDocument(doc); + + writer.addDocument(new Document()); // double gap + writer.addDocument(new Document()); + + field1.setValues("baz", 32, 512); + writer.addDocument(doc); + + writer.forceMerge(1); + var reader = writer.getReader(); + writer.close(); + + var searcher = newSearcher(reader); + var context = searcher.getIndexReader().leaves().get(0); + + var terms = context.reader().terms("field1"); + assertNotNull(terms); + OffsetSourceField.OffsetSourceLoader loader = OffsetSourceField.loader(terms); + + var offset = loader.advanceTo(0); + assertEquals(new OffsetSourceFieldMapper.OffsetSource("foo", 1, 10), offset); + + offset = loader.advanceTo(1); + assertEquals(new OffsetSourceFieldMapper.OffsetSource("bar", 10, 128), offset); + + assertNull(loader.advanceTo(2)); + + offset = loader.advanceTo(3); + assertEquals(new OffsetSourceFieldMapper.OffsetSource("foo", 50, 256), offset); + + offset = loader.advanceTo(6); + assertEquals(new OffsetSourceFieldMapper.OffsetSource("baz", 32, 512), offset); + + assertNull(loader.advanceTo(189)); + + IOUtils.close(reader, dir); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldTypeTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldTypeTests.java new file mode 100644 index 0000000000000..ccb696515a060 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/OffsetSourceFieldTypeTests.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.mapper; + +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.elasticsearch.index.mapper.FieldTypeTestCase; +import org.elasticsearch.index.mapper.MappedFieldType; + +import java.util.Collections; + +public class OffsetSourceFieldTypeTests extends FieldTypeTestCase { + public void testIsNotAggregatable() { + MappedFieldType fieldType = getMappedFieldType(); + assertFalse(fieldType.isAggregatable()); + } + + @Override + public void testFieldHasValue() { + MappedFieldType fieldType = getMappedFieldType(); + FieldInfos fieldInfos = new FieldInfos(new FieldInfo[] { getFieldInfoWithName(fieldType.name()) }); + assertTrue(fieldType.fieldHasValue(fieldInfos)); + } + + @Override + public void testFieldHasValueWithEmptyFieldInfos() { + MappedFieldType fieldType = getMappedFieldType(); + assertFalse(fieldType.fieldHasValue(FieldInfos.EMPTY)); + } + + @Override + public MappedFieldType getMappedFieldType() { + return new OffsetSourceFieldMapper.OffsetSourceFieldType( + "field", + OffsetSourceFieldMapper.CharsetFormat.UTF_16, + Collections.emptyMap() + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index 71ff9fc7d84cf..c6a492dfcf4e9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.mapper; -import org.apache.lucene.document.FeatureField; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.IndexableField; @@ -47,6 +46,7 @@ import org.elasticsearch.index.mapper.SourceToParse; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.XFeatureField; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; import org.elasticsearch.inference.Model; @@ -61,6 +61,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; import org.junit.AssumptionViolatedException; @@ -1110,7 +1111,12 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook } queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); - return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); + return new ESToParentBlockJoinQuery( + new SparseVectorQueryWrapper(fieldName, queryBuilder.build()), + parentFilter, + ScoreMode.Total, + null + ); } private static void assertChildLeafNestedDocument( @@ -1130,7 +1136,7 @@ private static void assertChildLeafNestedDocument( private static void assertSparseFeatures(LuceneDocument doc, String fieldName, int expectedCount) { int count = 0; for (IndexableField field : doc.getFields()) { - if (field instanceof FeatureField featureField) { + if (field instanceof XFeatureField featureField) { assertThat(featureField.name(), equalTo(fieldName)); ++count; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index b8bcb766b53e1..36aa2200eceae 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -45,12 +45,14 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.XPackClientPlugin; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; @@ -114,7 +116,7 @@ public void setUp() throws Exception { @Override protected Collection> getPlugins() { - return List.of(InferencePlugin.class, FakeMlPlugin.class); + return List.of(XPackClientPlugin.class, InferencePlugin.class, FakeMlPlugin.class); } @Override @@ -194,9 +196,11 @@ protected void doAssertLuceneQuery(SemanticQueryBuilder queryBuilder, Query quer private void assertSparseEmbeddingLuceneQuery(Query query) { Query innerQuery = assertOuterBooleanQuery(query); - assertThat(innerQuery, instanceOf(BooleanQuery.class)); + assertThat(innerQuery, instanceOf(SparseVectorQueryWrapper.class)); + var sparseQuery = (SparseVectorQueryWrapper) innerQuery; + assertThat(((SparseVectorQueryWrapper) innerQuery).getTermsQuery(), instanceOf(BooleanQuery.class)); - BooleanQuery innerBooleanQuery = (BooleanQuery) innerQuery; + BooleanQuery innerBooleanQuery = (BooleanQuery) sparseQuery.getTermsQuery(); assertThat(innerBooleanQuery.clauses().size(), equalTo(queryTokenCount)); innerBooleanQuery.forEach(c -> { assertThat(c.occur(), equalTo(SHOULD)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 05a8d52be5df4..5528c80066b0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -8,11 +8,14 @@ package org.elasticsearch.xpack.inference.rest; import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestRequestTests; import org.elasticsearch.rest.action.RestChunkedToXContentListener; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; @@ -26,6 +29,10 @@ import java.util.Map; import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseParams; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout; +import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; +import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -56,6 +63,42 @@ private static String route(String param) { return "_route/" + param; } + public void testParseParams_ExtractsInferenceIdAndTaskType() { + var params = parseParams( + RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id", TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString())) + ); + assertThat(params, is(new BaseInferenceAction.Params("id", TaskType.COMPLETION))); + } + + public void testParseParams_DefaultsToTaskTypeAny_WhenInferenceId_IsMissing() { + var params = parseParams( + RestRequestTests.contentRestRequest("{}", Map.of(TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString())) + ); + assertThat(params, is(new BaseInferenceAction.Params("completion", TaskType.ANY))); + } + + public void testParseParams_ThrowsStatusException_WhenTaskTypeIsMissing() { + var e = expectThrows( + ElasticsearchStatusException.class, + () -> parseParams(RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id"))) + ); + assertThat(e.getMessage(), is("Task type must not be null")); + } + + public void testParseTimeout_ReturnsTimeout() { + var timeout = parseTimeout( + RestRequestTests.contentRestRequest("{}", Map.of(InferenceAction.Request.TIMEOUT.getPreferredName(), "4s")) + ); + + assertThat(timeout, is(TimeValue.timeValueSeconds(4))); + } + + public void testParseTimeout_ReturnsDefaultTimeout() { + var timeout = parseTimeout(RestRequestTests.contentRestRequest("{}", Map.of())); + + assertThat(timeout, is(TimeValue.timeValueSeconds(30))); + } + public void testUsesDefaultTimeout() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java new file mode 100644 index 0000000000000..5acfe67b175df --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.AbstractRestChannel; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.junit.Before; + +import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase { + + @Before + public void setUpAction() { + controller().registerHandler(new RestUnifiedCompletionInferenceAction()); + } + + public void testStreamIsTrue() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(UnifiedCompletionAction.Request.class)); + + var request = (UnifiedCompletionAction.Request) actionRequest; + assertThat(request.isStreaming(), is(true)); + + executeCalled.set(true); + return createResponse(); + })); + + var requestBody = """ + { + "messages": [ + { + "content": "abc", + "role": "user" + } + ] + } + """; + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath("_inference/completion/test/_unified") + .withContent(new BytesArray(requestBody), XContentType.JSON) + .build(); + + final SetOnce responseSetOnce = new SetOnce<>(); + dispatchRequest(inferenceRequest, new AbstractRestChannel(inferenceRequest, true) { + @Override + public void sendResponse(RestResponse response) { + responseSetOnce.set(response); + } + }); + + // the response content will be null when there is no error + assertNull(responseSetOnce.get().content()); + assertThat(executeCalled.get(), equalTo(true)); + } + + private void dispatchRequest(final RestRequest request, final RestChannel channel) { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + controller().dispatchRequest(request, channel, threadContext); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 47a96bf78dda1..6768583598b2d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.junit.After; import org.junit.Before; @@ -119,6 +120,14 @@ protected void doInfer( } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) {} + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 76b5d6fee2c59..159b77789482d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; @@ -920,6 +921,68 @@ public void testInfer_SendsRequest() throws IOException { } } + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":"stop"\ + }\ + ],\ + "usage":{\ + "prompt_tokens": 16,\ + "completion_tokens": 28,\ + "total_tokens": 44,\ + "prompt_tokens_details": {\ + "cached_tokens": 0,\ + "audio_tokens": 0\ + },\ + "completion_tokens_details": {\ + "reasoning_tokens": 0,\ + "audio_tokens": 0,\ + "accepted_prediction_tokens": 0,\ + "rejected_prediction_tokens": 0\ + }\ + }\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null, null) + ) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"id":"12345","choices":[{"delta":{"content":"hello, world"},"finish_reason":"stop","index":0}],""" + """ + "model":"gpt-4o-mini","object":"chat.completion.chunk",""" + """ + "usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}"""); + } + } + public void testInfer_StreamRequest() throws Exception { String responseJson = """ data: {\ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java index ab1786f0a5843..e7ac4cf879e92 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java @@ -10,9 +10,11 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap; @@ -42,10 +44,48 @@ public void testOverrideWith_EmptyMap() { public void testOverrideWith_NullMap() { var model = createChatCompletionModel("url", "org", "api_key", "model_name", null); - var overriddenModel = OpenAiChatCompletionModel.of(model, null); + var overriddenModel = OpenAiChatCompletionModel.of(model, (Map) null); assertThat(overriddenModel, sameInstance(model)); } + public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { + var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + assertThat( + OpenAiChatCompletionModel.of(model, request), + is(createChatCompletionModel("url", "org", "api_key", "different_model", "user")) + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + assertThat( + OpenAiChatCompletionModel.of(model, request), + is(createChatCompletionModel("url", "org", "api_key", "model_name", "user")) + ); + } + public static OpenAiChatCompletionModel createChatCompletionModel( String url, @Nullable String org, diff --git a/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/mappings.json b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/mappings.json new file mode 100644 index 0000000000000..9841ee0aed6e2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/mappings.json @@ -0,0 +1,27 @@ +{ + "_doc": { + "properties": { + "body": { + "type": "text", + "copy_to": ["body-elser", "body-e5"] + }, + "body-e5": { + "type": "semantic_text", + "inference_id": ".multilingual-e5-small-elasticsearch", + "model_settings": { + "task_type": "text_embedding", + "dimensions": 384, + "similarity": "cosine", + "element_type": "float" + } + }, + "body-elser": { + "type": "semantic_text", + "inference_id": ".elser-2-elasticsearch", + "model_settings": { + "task_type": "sparse_embedding" + } + } + } + } +} \ No newline at end of file diff --git a/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/queries.json b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/queries.json new file mode 100644 index 0000000000000..6227f3f498854 --- /dev/null +++ b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/queries.json @@ -0,0 +1,467 @@ +{ + "dense_vector_1": { + "embeddings": [ + 0.09475211, + 0.044564713, + -0.04378501, + -0.07908551, + 0.04332011, + -0.03891992, + -0.0062305215, + 0.024245035, + -0.008976331, + 0.032832284, + 0.052760173, + 0.008123907, + 0.09049037, + -0.01637332, + -0.054353267, + 0.00771307, + 0.08545496, + -0.079716265, + -0.045666866, + -0.04369993, + 0.009189822, + -0.013782891, + -0.07701858, + 0.037278354, + 0.049807206, + 0.078036495, + -0.059533164, + 0.051413406, + 0.040234447, + -0.038139492, + -0.085189626, + -0.045546446, + 0.0544375, + -0.05604156, + 0.057408098, + 0.041913517, + -0.037348013, + -0.025998272, + 0.08486864, + -0.046678443, + 0.0041820924, + 0.007514462, + 0.06424746, + 0.044233218, + 0.103267275, + 0.014130771, + -0.049954403, + 0.04226959, + -0.08346965, + -0.01639249, + -0.060537644, + 0.04546336, + 0.012866155, + 0.05375096, + 0.036775924, + -0.0762226, + -0.037304543, + -0.05692274, + -0.055807598, + 0.0040082196, + 0.059259634, + 0.012022011, + -8.0863154E-4, + 0.0070405705, + 0.050255686, + 0.06810016, + 0.017190414, + 0.051975194, + -0.051436286, + 0.023408439, + -0.029802637, + 0.034137156, + -0.004660689, + -0.0442122, + 0.019065322, + 0.030806554, + 0.0064652697, + -0.066789865, + 0.057111286, + 0.009412479, + -0.041444767, + -0.06807582, + -0.085881524, + 0.04901128, + -0.047871742, + 0.06328623, + 0.040418074, + -0.081432894, + 0.058384005, + 0.006206527, + 0.045801315, + 0.037274595, + -0.054337103, + -0.06755516, + -0.07396888, + -0.043732334, + -0.052053086, + 0.03210978, + 0.048101492, + -0.083828256, + 0.05205026, + -0.048474856, + 0.029116616, + -0.10924888, + 0.003796487, + 0.030567763, + 0.026949523, + -0.052353345, + 0.043198872, + -0.09456988, + -0.05711594, + -2.2292069E-4, + 0.032972734, + 0.054394923, + -0.0767535, + -0.02710579, + -0.032135617, + -0.01732382, + 0.059442326, + -0.07686165, + 0.07104082, + -0.03090021, + -0.05450075, + -0.038997203, + -0.07045443, + 0.00483161, + 0.010933604, + 0.020874644, + 0.037941266, + 0.019729063, + 0.06178368, + 0.013503478, + -0.008584046, + 0.045592044, + 0.05528768, + 0.11568184, + 0.0041300594, + 0.015404516, + -3.8067883E-4, + -0.06365399, + -0.07826643, + 0.061575573, + -0.060548335, + 0.05706082, + 0.042301804, + 0.052173313, + 0.07193179, + -0.03839231, + 0.0734415, + -0.045380164, + 0.02832276, + 0.003745178, + 0.058844633, + 0.04307504, + 0.037800383, + -0.031050054, + -0.06856359, + -0.059114788, + -0.02148857, + 0.07854358, + -0.03253363, + -0.04566468, + -0.019933948, + -0.057993464, + -0.08677458, + -0.06626883, + 0.031657256, + 0.101128764, + -0.08050056, + -0.050226066, + -0.014335166, + 0.050344367, + -0.06851419, + 0.008698909, + -0.011893435, + 0.07741272, + -0.059579294, + 0.03250109, + 0.058700256, + 0.046834726, + -0.035081457, + -0.0043140925, + -0.09764087, + -0.0034994273, + -0.034056358, + -0.019066337, + -0.034376107, + 0.012964423, + 0.029291175, + -0.012090671, + 0.021585712, + 0.028859599, + -0.04391145, + -0.071166754, + -0.031040335, + 0.02808108, + -0.05621317, + 0.06543945, + 0.10094665, + 0.041057374, + -0.03222324, + -0.063366964, + 0.064944476, + 0.023641933, + 0.06806713, + 0.06806097, + -0.08220105, + 0.04148528, + -0.09254079, + 0.044620737, + 0.05526614, + -0.03849534, + -0.04722273, + 0.0670776, + -0.024274077, + -0.016903497, + 0.07584147, + 0.04760533, + -0.038843267, + -0.028365409, + 0.08022705, + -0.039916333, + 0.049067073, + -0.030701574, + -0.057169467, + 0.043025102, + 0.07109674, + -0.047296863, + -0.047463104, + 0.040868305, + -0.04409507, + -0.034977127, + -0.057109762, + -0.08616165, + -0.03486079, + -0.046201482, + 0.025963873, + 0.023392359, + 0.09594902, + -0.007847159, + -0.021231368, + 0.009007263, + 0.0032713825, + -0.06876065, + 0.03169641, + -7.2582875E-4, + -0.07049708, + 0.03900843, + -0.0075472407, + 0.05184822, + 0.06452079, + -0.09832754, + -0.012775799, + -0.03925948, + -0.029761659, + 0.0065437574, + 0.0815465, + 0.0411695, + -0.0702844, + -0.009533786, + 0.07024532, + 0.0098710675, + 0.09915362, + 0.0415453, + 0.050641853, + 0.047463298, + -0.058609713, + -0.029499197, + -0.05100956, + -0.03441709, + -0.06348122, + 0.014784361, + 0.056317374, + -0.10280704, + -0.04008354, + -0.018926824, + 0.08832836, + 0.124804, + -0.047645308, + -0.07122146, + -9.886527E-4, + 0.03850324, + 0.048501793, + 0.07072816, + 0.06566776, + -0.013678872, + 0.010010848, + 0.06483413, + -0.030036367, + -0.029748922, + -0.007482364, + -0.05180385, + 0.03698522, + -0.045453787, + 0.056604166, + 0.029394176, + 0.028589265, + -0.012185886, + -0.06919616, + 0.0711641, + -0.034055933, + -0.053101335, + 0.062319, + 0.021600349, + -0.038718067, + 0.060814686, + 0.05087301, + -0.020297311, + 0.016493896, + 0.032162152, + 0.046740912, + 0.05461355, + -0.07024665, + 0.025609337, + -0.02504801, + 0.06765588, + -0.032994855, + -0.037897404, + -0.045783922, + -0.05689299, + -0.040437017, + -0.07904339, + -0.031415287, + -0.029216278, + 0.017395392, + 0.03449264, + -0.025653394, + -0.06283088, + 0.049027324, + 0.016229525, + -0.00985347, + -0.053974394, + -0.030257035, + 0.04325515, + -0.012293731, + -0.002446129, + -0.05567076, + 0.06374684, + -0.03153897, + -0.04475149, + 0.018582936, + 0.025716115, + -0.061778374, + 0.04196277, + -0.04134671, + -0.07396272, + 0.05846184, + 0.006558759, + -0.09745666, + 0.07587805, + 0.0137483915, + -0.100933895, + 0.032008193, + 0.04293283, + 0.017870268, + 0.032806385, + -0.0635923, + -0.019672254, + 0.022225974, + 0.04304554, + -0.06043949, + -0.0285274, + 0.050868835, + 0.057003833, + 0.05740866, + 0.020068677, + -0.034312245, + -0.021671802, + 0.014769731, + -0.07328285, + -0.009586734, + 0.036420938, + -0.022188472, + -0.008200541, + -0.010765854, + -0.06949713, + -0.07555878, + 0.045306854, + -0.05424466, + -0.03647476, + 0.06266633, + 0.08346125, + 0.060288202, + 0.0548457 + ], + "expected_by_score": [ + "The ancient oppidum that corresponds to the modern city of Paris was first mentioned in the mid-1st century BC by Julius Caesar as Luteciam Parisiorum ('Lutetia of the Parisii') and is later attested as Parision in the 5th century AD, then as Paris in 1265. During the Roman period, it was commonly known as Lutetia or Lutecia in Latin, and as Leukotekía in Greek, which is interpreted as either stemming from the Celtic root *lukot- ('mouse'), or from *luto- ('marsh, swamp').\n\n\nThe name Paris is derived from its early inhabitants, the Parisii, a Gallic tribe from the Iron Age and the Roman period. The meaning of the Gaulish ethnonym remains debated. According to Xavier Delamarre, it may derive from the Celtic root pario- ('cauldron'). Alfred Holder interpreted the name as 'the makers' or 'the commanders', by comparing it to the Welsh peryff ('lord, commander'), both possibly descending from a Proto-Celtic form reconstructed as *kwar-is-io-. Alternatively, Pierre-Yves Lambert proposed to translate Parisii as the 'spear people', by connecting the first element to the Old Irish carr ('spear'), derived from an earlier *kwar-sā. In any case, the city's name is not related to the Paris of Greek mythology.\n\n\nResidents of the city are known in English as Parisians and in French as Parisiens ( ⓘ). They are also pejoratively called Parigots ( ⓘ).\n\n\nHistory\n\nOrigins\n\n", + "After the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. In 1137, a new city marketplace (today's Les Halles) replaced the two smaller ones on the Île de la Cité and Place de Grève (Place de l'Hôtel de Ville). The latter location housed the headquarters of Paris's river trade corporation, an organisation that later became, unofficially (although formally in later years), Paris's first municipal government.\n\n\nIn the late 12th century, Philip Augustus extended the Louvre fortress to defend the city against river invasions from the west, gave the city its first walls between 1190 and 1215, rebuilt its bridges to either side of its central island, and paved its main thoroughfares. In 1190, he transformed Paris's former cathedral school into a student-teacher corporation that would become the University of Paris and would draw students from all of Europe.\n\n\nWith 200,000 inhabitants in 1328, Paris, then already the capital of France, was the most populous city of Europe. By comparison, London in 1300 had 80,000 inhabitants. By the early fourteenth century, so much filth had collected inside urban Europe that French and Italian cities were naming streets after human waste. In medieval Paris, several street names were inspired by merde, the French word for \"shit\".\n\n\n", + "In March 2001, Bertrand Delanoë became the first socialist mayor. He was re-elected in March 2008. In 2007, in an effort to reduce car traffic, he introduced the Vélib', a system which rents bicycles. Bertrand Delanoë also transformed a section of the highway along the Left Bank of the Seine into an urban promenade and park, the Promenade des Berges de la Seine, which he inaugurated in June 2013.\n\n\nIn 2007, President Nicolas Sarkozy launched the Grand Paris project, to integrate Paris more closely with the towns in the region around it. After many modifications, the new area, named the Metropolis of Grand Paris, with a population of 6.7 million, was created on 1 January 2016. In 2011, the City of Paris and the national government approved the plans for the Grand Paris Express, totalling 205 km (127 mi) of automated metro lines to connect Paris, the innermost three departments around Paris, airports and high-speed rail (TGV) stations, at an estimated cost of €35 billion. The system is scheduled to be completed by 2030.\n\n\nIn January 2015, Al-Qaeda in the Arabian Peninsula claimed attacks across the Paris region. 1.5 million people marched in Paris in a show of solidarity against terrorism and in support of freedom of speech. In November of the same year, terrorist attacks, claimed by ISIL, killed 130 people and injured more than 350.\n\n\n", + "\nParis (.mw-parser-output .IPA-label-small{font-size:85%}.mw-parser-output .references .IPA-label-small,.mw-parser-output .infobox .IPA-label-small,.mw-parser-output .navbox .IPA-label-small{font-size:100%}French pronunciation: ⓘ) is the capital and largest city of France. With an estimated population of 2,102,650 residents in January 2023 in an area of more than 105 km2 (41 sq mi), Paris is the fourth-largest city in the European Union and the 30th most densely populated city in the world in 2022. Since the 17th century, Paris has been one of the world's major centres of finance, diplomacy, commerce, culture, fashion, and gastronomy. Because of its leading role in the arts and sciences and its early adaptation of extensive street lighting, it became known as the City of Light in the 19th century.\n\n\nThe City of Paris is the centre of the Île-de-France region, or Paris Region, with an official estimated population of 12,271,794 inhabitants in January 2023, or about 19% of the population of France. The Paris Region had a nominal GDP of €765 billion (US$1.064 trillion when adjusted for PPP) in 2021, the highest in the European Union. According to the Economist Intelligence Unit Worldwide Cost of Living Survey, in 2022, Paris was the city with the ninth-highest cost of living in the world.\n\n\n", + "Bal-musette is a style of French music and dance that first became popular in Paris in the 1870s and 1880s; by 1880 Paris had some 150 dance halls. Patrons danced the bourrée to the accompaniment of the cabrette (a bellows-blown bagpipe locally called a \"musette\") and often the vielle à roue (hurdy-gurdy) in the cafés and bars of the city. Parisian and Italian musicians who played the accordion adopted the style and established themselves in Auvergnat bars, and Paris became a major centre for jazz and still attracts jazz musicians from all around the world to its clubs and cafés.\n\n\nParis is the spiritual home of gypsy jazz in particular, and many of the Parisian jazzmen who developed in the first half of the 20th century began by playing Bal-musette in the city. Django Reinhardt rose to fame in Paris, having moved to the 18th arrondissement in a caravan as a young boy, and performed with violinist Stéphane Grappelli and their Quintette du Hot Club de France in the 1930s and 1940s.\n\n\nImmediately after the War the Saint-Germain-des-Pres quarter and the nearby Saint-Michel quarter became home to many small jazz clubs, including the Caveau des Lorientais, the Club Saint-Germain, the Rose Rouge, the Vieux-Colombier, and the most famous, Le Tabou. They introduced Parisians to the music of Claude Luter, Boris Vian, Sydney Bechet, Mezz Mezzrow, and Henri Salvador. " + ], + "expected_by_offset": [ + "\nParis (.mw-parser-output .IPA-label-small{font-size:85%}.mw-parser-output .references .IPA-label-small,.mw-parser-output .infobox .IPA-label-small,.mw-parser-output .navbox .IPA-label-small{font-size:100%}French pronunciation: ⓘ) is the capital and largest city of France. With an estimated population of 2,102,650 residents in January 2023 in an area of more than 105 km2 (41 sq mi), Paris is the fourth-largest city in the European Union and the 30th most densely populated city in the world in 2022. Since the 17th century, Paris has been one of the world's major centres of finance, diplomacy, commerce, culture, fashion, and gastronomy. Because of its leading role in the arts and sciences and its early adaptation of extensive street lighting, it became known as the City of Light in the 19th century.\n\n\nThe City of Paris is the centre of the Île-de-France region, or Paris Region, with an official estimated population of 12,271,794 inhabitants in January 2023, or about 19% of the population of France. The Paris Region had a nominal GDP of €765 billion (US$1.064 trillion when adjusted for PPP) in 2021, the highest in the European Union. According to the Economist Intelligence Unit Worldwide Cost of Living Survey, in 2022, Paris was the city with the ninth-highest cost of living in the world.\n\n\n", + "The ancient oppidum that corresponds to the modern city of Paris was first mentioned in the mid-1st century BC by Julius Caesar as Luteciam Parisiorum ('Lutetia of the Parisii') and is later attested as Parision in the 5th century AD, then as Paris in 1265. During the Roman period, it was commonly known as Lutetia or Lutecia in Latin, and as Leukotekía in Greek, which is interpreted as either stemming from the Celtic root *lukot- ('mouse'), or from *luto- ('marsh, swamp').\n\n\nThe name Paris is derived from its early inhabitants, the Parisii, a Gallic tribe from the Iron Age and the Roman period. The meaning of the Gaulish ethnonym remains debated. According to Xavier Delamarre, it may derive from the Celtic root pario- ('cauldron'). Alfred Holder interpreted the name as 'the makers' or 'the commanders', by comparing it to the Welsh peryff ('lord, commander'), both possibly descending from a Proto-Celtic form reconstructed as *kwar-is-io-. Alternatively, Pierre-Yves Lambert proposed to translate Parisii as the 'spear people', by connecting the first element to the Old Irish carr ('spear'), derived from an earlier *kwar-sā. In any case, the city's name is not related to the Paris of Greek mythology.\n\n\nResidents of the city are known in English as Parisians and in French as Parisiens ( ⓘ). They are also pejoratively called Parigots ( ⓘ).\n\n\nHistory\n\nOrigins\n\n", + "After the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. In 1137, a new city marketplace (today's Les Halles) replaced the two smaller ones on the Île de la Cité and Place de Grève (Place de l'Hôtel de Ville). The latter location housed the headquarters of Paris's river trade corporation, an organisation that later became, unofficially (although formally in later years), Paris's first municipal government.\n\n\nIn the late 12th century, Philip Augustus extended the Louvre fortress to defend the city against river invasions from the west, gave the city its first walls between 1190 and 1215, rebuilt its bridges to either side of its central island, and paved its main thoroughfares. In 1190, he transformed Paris's former cathedral school into a student-teacher corporation that would become the University of Paris and would draw students from all of Europe.\n\n\nWith 200,000 inhabitants in 1328, Paris, then already the capital of France, was the most populous city of Europe. By comparison, London in 1300 had 80,000 inhabitants. By the early fourteenth century, so much filth had collected inside urban Europe that French and Italian cities were naming streets after human waste. In medieval Paris, several street names were inspired by merde, the French word for \"shit\".\n\n\n", + "In March 2001, Bertrand Delanoë became the first socialist mayor. He was re-elected in March 2008. In 2007, in an effort to reduce car traffic, he introduced the Vélib', a system which rents bicycles. Bertrand Delanoë also transformed a section of the highway along the Left Bank of the Seine into an urban promenade and park, the Promenade des Berges de la Seine, which he inaugurated in June 2013.\n\n\nIn 2007, President Nicolas Sarkozy launched the Grand Paris project, to integrate Paris more closely with the towns in the region around it. After many modifications, the new area, named the Metropolis of Grand Paris, with a population of 6.7 million, was created on 1 January 2016. In 2011, the City of Paris and the national government approved the plans for the Grand Paris Express, totalling 205 km (127 mi) of automated metro lines to connect Paris, the innermost three departments around Paris, airports and high-speed rail (TGV) stations, at an estimated cost of €35 billion. The system is scheduled to be completed by 2030.\n\n\nIn January 2015, Al-Qaeda in the Arabian Peninsula claimed attacks across the Paris region. 1.5 million people marched in Paris in a show of solidarity against terrorism and in support of freedom of speech. In November of the same year, terrorist attacks, claimed by ISIL, killed 130 people and injured more than 350.\n\n\n", + "Bal-musette is a style of French music and dance that first became popular in Paris in the 1870s and 1880s; by 1880 Paris had some 150 dance halls. Patrons danced the bourrée to the accompaniment of the cabrette (a bellows-blown bagpipe locally called a \"musette\") and often the vielle à roue (hurdy-gurdy) in the cafés and bars of the city. Parisian and Italian musicians who played the accordion adopted the style and established themselves in Auvergnat bars, and Paris became a major centre for jazz and still attracts jazz musicians from all around the world to its clubs and cafés.\n\n\nParis is the spiritual home of gypsy jazz in particular, and many of the Parisian jazzmen who developed in the first half of the 20th century began by playing Bal-musette in the city. Django Reinhardt rose to fame in Paris, having moved to the 18th arrondissement in a caravan as a young boy, and performed with violinist Stéphane Grappelli and their Quintette du Hot Club de France in the 1930s and 1940s.\n\n\nImmediately after the War the Saint-Germain-des-Pres quarter and the nearby Saint-Michel quarter became home to many small jazz clubs, including the Caveau des Lorientais, the Club Saint-Germain, the Rose Rouge, the Vieux-Colombier, and the most famous, Le Tabou. They introduced Parisians to the music of Claude Luter, Boris Vian, Sydney Bechet, Mezz Mezzrow, and Henri Salvador. " + ] + }, + "sparse_vector_1": { + "embeddings": { + "paris": 2.9709616, + "date": 2.1960778, + "founded": 2.0555024, + "foundation": 1.412623, + "early": 1.2162757, + "founder": 1.1271698, + "french": 0.9213378, + "france": 0.86253893, + "city": 0.82978916, + "founding": 0.79722786, + "established": 0.7967043, + "ancient": 0.7392465, + "when": 0.71705, + "built": 0.6977878, + "treaty": 0.6846069, + "created": 0.68127465, + "century": 0.58926934, + "for": 0.55019474, + "was": 0.52475905, + "origin": 0.48785052, + "expedition": 0.48757303, + "history": 0.47960007, + "mint": 0.47878903, + "historical": 0.4714338, + "capital": 0.42984143, + "timeline": 0.4222377, + "colony": 0.3876187, + "tower": 0.3474891, + "medieval": 0.3272666, + "geography": 0.32456368, + "colonial": 0.30613664, + "location": 0.29013386, + "francisco": 0.22840048, + "orleans": 0.21971667, + "earlier": 0.20318772, + "jackson": 0.18424438, + "exact": 0.17109296, + "rome": 0.16320735, + "civilization": 0.15931238, + "spanish": 0.12759624, + "museum": 0.113024555, + "latin": 0.11201205, + "european": 0.10277243, + "architect": 0.0796932, + "united": 0.031233707 + }, + "expected_by_score": [ + "Clovis the Frank, the first king of the Merovingian dynasty, made the city his capital from 508. As the Frankish domination of Gaul began, there was a gradual immigration by the Franks to Paris and the Parisian Francien dialects were born. Fortification of the Île de la Cité failed to avert sacking by Vikings in 845, but Paris's strategic importance—with its bridges preventing ships from passing—was established by successful defence in the Siege of Paris (885–886), for which the then Count of Paris (comte de Paris), Odo of France, was elected king of West Francia. From the Capetian dynasty that began with the 987 election of Hugh Capet, Count of Paris and Duke of the Franks (duc des Francs), as king of a unified West Francia, Paris gradually became the largest and most prosperous city in France.\n\n\nHigh and Late Middle Ages to Louis XIV\n\nBy the end of the 12th century, Paris had become the political, economic, religious, and cultural capital of France. The Palais de la Cité, the royal residence, was located at the western end of the Île de la Cité. In 1163, during the reign of Louis VII, Maurice de Sully, bishop of Paris, undertook the construction of the Notre Dame Cathedral at its eastern extremity.\n\n\nAfter the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. ", + "\nThe Parisii, a sub-tribe of the Celtic Senones, inhabited the Paris area from around the middle of the 3rd century BC. One of the area's major north–south trade routes crossed the Seine on the Île de la Cité, which gradually became an important trading centre. The Parisii traded with many river towns (some as far away as the Iberian Peninsula) and minted their own coins.\n\n\nThe Romans conquered the Paris Basin in 52 BC and began their settlement on Paris's Left Bank. The Roman town was originally called Lutetia (more fully, Lutetia Parisiorum, \"Lutetia of the Parisii\", modern French Lutèce). It became a prosperous city with a forum, baths, temples, theatres, and an amphitheatre.\n\n\nBy the end of the Western Roman Empire, the town was known as Parisius, a Latin name that would later become Paris in French. Christianity was introduced in the middle of the 3rd century AD by Saint Denis, the first Bishop of Paris: according to legend, when he refused to renounce his faith before the Roman occupiers, he was beheaded on the hill which became known as Mons Martyrum (Latin \"Hill of Martyrs\"), later \"Montmartre\", from where he walked headless to the north of the city; the place where he fell and was buried became an important religious shrine, the Basilica of Saint-Denis, and many French kings are buried there.\n\n\nClovis the Frank, the first king of the Merovingian dynasty, made the city his capital from 508. ", + "\nDuring the Hundred Years' War, Paris was occupied by England-friendly Burgundian forces from 1418, before being occupied outright by the English when Henry V of England entered the French capital in 1420; in spite of a 1429 effort by Joan of Arc to liberate the city, it would remain under English occupation until 1436.\n\n\nIn the late 16th-century French Wars of Religion, Paris was a stronghold of the Catholic League, the organisers of 24 August 1572 St. Bartholomew's Day massacre in which thousands of French Protestants were killed. The conflicts ended when pretender to the throne Henry IV, after converting to Catholicism to gain entry to the capital, entered the city in 1594 to claim the crown of France. This king made several improvements to the capital during his reign: he completed the construction of Paris's first uncovered, sidewalk-lined bridge, the Pont Neuf, built a Louvre extension connecting it to the Tuileries Palace, and created the first Paris residential square, the Place Royale, now Place des Vosges. In spite of Henry IV's efforts to improve city circulation, the narrowness of Paris's streets was a contributing factor in his assassination near Les Halles marketplace in 1610.\n\n\nDuring the 17th century, Cardinal Richelieu, chief minister of Louis XIII, was determined to make Paris the most beautiful city in Europe. He built five new bridges, a new chapel for the College of Sorbonne, and a palace for himself, the Palais-Cardinal. ", + "Diderot and D'Alembert published their Encyclopédie in 1751, before the Montgolfier Brothers launched the first manned flight in a hot air balloon on 21 November 1783. Paris was the financial capital of continental Europe, as well the primary European centre for book publishing, fashion and the manufacture of fine furniture and luxury goods. On 22 October 1797, Paris was also the site of the first parachute jump in history, by Garnerin.\n\n\nIn the summer of 1789, Paris became the centre stage of the French Revolution. On 14 July, a mob seized the arsenal at the Invalides, acquiring thousands of guns, with which it stormed the Bastille, a principal symbol of royal authority. The first independent Paris Commune, or city council, met in the Hôtel de Ville and elected a Mayor, the astronomer Jean Sylvain Bailly, on 15 July.\n\n\nLouis XVI and the royal family were brought to Paris and incarcerated in the Tuileries Palace. In 1793, as the revolution turned increasingly radical, the king, queen and mayor were beheaded by guillotine in the Reign of Terror, along with more than 16,000 others throughout France. The property of the aristocracy and the church was nationalised, and the city's churches were closed, sold or demolished. A succession of revolutionary factions ruled Paris until 9 November 1799 (coup d'état du 18 brumaire), when Napoleon Bonaparte seized power as First Consul.\n\n\n", + "After the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. In 1137, a new city marketplace (today's Les Halles) replaced the two smaller ones on the Île de la Cité and Place de Grève (Place de l'Hôtel de Ville). The latter location housed the headquarters of Paris's river trade corporation, an organisation that later became, unofficially (although formally in later years), Paris's first municipal government.\n\n\nIn the late 12th century, Philip Augustus extended the Louvre fortress to defend the city against river invasions from the west, gave the city its first walls between 1190 and 1215, rebuilt its bridges to either side of its central island, and paved its main thoroughfares. In 1190, he transformed Paris's former cathedral school into a student-teacher corporation that would become the University of Paris and would draw students from all of Europe.\n\n\nWith 200,000 inhabitants in 1328, Paris, then already the capital of France, was the most populous city of Europe. By comparison, London in 1300 had 80,000 inhabitants. By the early fourteenth century, so much filth had collected inside urban Europe that French and Italian cities were naming streets after human waste. In medieval Paris, several street names were inspired by merde, the French word for \"shit\".\n\n\n" + ], + "expected_by_offset": [ + "\nThe Parisii, a sub-tribe of the Celtic Senones, inhabited the Paris area from around the middle of the 3rd century BC. One of the area's major north–south trade routes crossed the Seine on the Île de la Cité, which gradually became an important trading centre. The Parisii traded with many river towns (some as far away as the Iberian Peninsula) and minted their own coins.\n\n\nThe Romans conquered the Paris Basin in 52 BC and began their settlement on Paris's Left Bank. The Roman town was originally called Lutetia (more fully, Lutetia Parisiorum, \"Lutetia of the Parisii\", modern French Lutèce). It became a prosperous city with a forum, baths, temples, theatres, and an amphitheatre.\n\n\nBy the end of the Western Roman Empire, the town was known as Parisius, a Latin name that would later become Paris in French. Christianity was introduced in the middle of the 3rd century AD by Saint Denis, the first Bishop of Paris: according to legend, when he refused to renounce his faith before the Roman occupiers, he was beheaded on the hill which became known as Mons Martyrum (Latin \"Hill of Martyrs\"), later \"Montmartre\", from where he walked headless to the north of the city; the place where he fell and was buried became an important religious shrine, the Basilica of Saint-Denis, and many French kings are buried there.\n\n\nClovis the Frank, the first king of the Merovingian dynasty, made the city his capital from 508. ", + "Clovis the Frank, the first king of the Merovingian dynasty, made the city his capital from 508. As the Frankish domination of Gaul began, there was a gradual immigration by the Franks to Paris and the Parisian Francien dialects were born. Fortification of the Île de la Cité failed to avert sacking by Vikings in 845, but Paris's strategic importance—with its bridges preventing ships from passing—was established by successful defence in the Siege of Paris (885–886), for which the then Count of Paris (comte de Paris), Odo of France, was elected king of West Francia. From the Capetian dynasty that began with the 987 election of Hugh Capet, Count of Paris and Duke of the Franks (duc des Francs), as king of a unified West Francia, Paris gradually became the largest and most prosperous city in France.\n\n\nHigh and Late Middle Ages to Louis XIV\n\nBy the end of the 12th century, Paris had become the political, economic, religious, and cultural capital of France. The Palais de la Cité, the royal residence, was located at the western end of the Île de la Cité. In 1163, during the reign of Louis VII, Maurice de Sully, bishop of Paris, undertook the construction of the Notre Dame Cathedral at its eastern extremity.\n\n\nAfter the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. ", + "After the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. In 1137, a new city marketplace (today's Les Halles) replaced the two smaller ones on the Île de la Cité and Place de Grève (Place de l'Hôtel de Ville). The latter location housed the headquarters of Paris's river trade corporation, an organisation that later became, unofficially (although formally in later years), Paris's first municipal government.\n\n\nIn the late 12th century, Philip Augustus extended the Louvre fortress to defend the city against river invasions from the west, gave the city its first walls between 1190 and 1215, rebuilt its bridges to either side of its central island, and paved its main thoroughfares. In 1190, he transformed Paris's former cathedral school into a student-teacher corporation that would become the University of Paris and would draw students from all of Europe.\n\n\nWith 200,000 inhabitants in 1328, Paris, then already the capital of France, was the most populous city of Europe. By comparison, London in 1300 had 80,000 inhabitants. By the early fourteenth century, so much filth had collected inside urban Europe that French and Italian cities were naming streets after human waste. In medieval Paris, several street names were inspired by merde, the French word for \"shit\".\n\n\n", + "\nDuring the Hundred Years' War, Paris was occupied by England-friendly Burgundian forces from 1418, before being occupied outright by the English when Henry V of England entered the French capital in 1420; in spite of a 1429 effort by Joan of Arc to liberate the city, it would remain under English occupation until 1436.\n\n\nIn the late 16th-century French Wars of Religion, Paris was a stronghold of the Catholic League, the organisers of 24 August 1572 St. Bartholomew's Day massacre in which thousands of French Protestants were killed. The conflicts ended when pretender to the throne Henry IV, after converting to Catholicism to gain entry to the capital, entered the city in 1594 to claim the crown of France. This king made several improvements to the capital during his reign: he completed the construction of Paris's first uncovered, sidewalk-lined bridge, the Pont Neuf, built a Louvre extension connecting it to the Tuileries Palace, and created the first Paris residential square, the Place Royale, now Place des Vosges. In spite of Henry IV's efforts to improve city circulation, the narrowness of Paris's streets was a contributing factor in his assassination near Les Halles marketplace in 1610.\n\n\nDuring the 17th century, Cardinal Richelieu, chief minister of Louis XIII, was determined to make Paris the most beautiful city in Europe. He built five new bridges, a new chapel for the College of Sorbonne, and a palace for himself, the Palais-Cardinal. ", + "Diderot and D'Alembert published their Encyclopédie in 1751, before the Montgolfier Brothers launched the first manned flight in a hot air balloon on 21 November 1783. Paris was the financial capital of continental Europe, as well the primary European centre for book publishing, fashion and the manufacture of fine furniture and luxury goods. On 22 October 1797, Paris was also the site of the first parachute jump in history, by Garnerin.\n\n\nIn the summer of 1789, Paris became the centre stage of the French Revolution. On 14 July, a mob seized the arsenal at the Invalides, acquiring thousands of guns, with which it stormed the Bastille, a principal symbol of royal authority. The first independent Paris Commune, or city council, met in the Hôtel de Ville and elected a Mayor, the astronomer Jean Sylvain Bailly, on 15 July.\n\n\nLouis XVI and the royal family were brought to Paris and incarcerated in the Tuileries Palace. In 1793, as the revolution turned increasingly radical, the king, queen and mayor were beheaded by guillotine in the Reign of Terror, along with more than 16,000 others throughout France. The property of the aristocracy and the church was nationalised, and the city's churches were closed, sold or demolished. A succession of revolutionary factions ruled Paris until 9 November 1799 (coup d'état du 18 brumaire), when Napoleon Bonaparte seized power as First Consul.\n\n\n" + ] + } +} \ No newline at end of file diff --git a/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/sample-doc.json.gz b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/sample-doc.json.gz new file mode 100644 index 0000000000000..881524e46e186 Binary files /dev/null and b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/sample-doc.json.gz differ diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml new file mode 100644 index 0000000000000..25cd1b5aec48a --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml @@ -0,0 +1,242 @@ +setup: + - requires: + cluster_features: "semantic_text.highlighter" + reason: a new highlighter for semantic text field + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "api_key": "abc64", + "similarity": "COSINE" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-sparse-index + body: + mappings: + properties: + body: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.create: + index: test-dense-index + body: + mappings: + properties: + body: + type: semantic_text + inference_id: dense-inference-id + +--- +"Highlighting using a sparse embedding model": + - do: + index: + index: test-sparse-index + id: doc_1 + body: + body: ["ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!"] + refresh: true + + - match: { result: created } + + - do: + search: + index: test-sparse-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 1 } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: test-sparse-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 2 } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + + - do: + search: + index: test-sparse-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + order: "score" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 1 } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: test-sparse-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + order: "score" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 2 } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + +--- +"Highlighting using a dense embedding model": + - do: + index: + index: test-dense-index + id: doc_1 + body: + body: ["ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!"] + refresh: true + + - match: { result: created } + + - do: + search: + index: test-dense-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 1 } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + + - do: + search: + index: test-dense-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 2 } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + + - do: + search: + index: test-dense-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + order: "score" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 1 } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + + - do: + search: + index: test-dense-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + order: "score" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 2 } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + diff --git a/x-pack/plugin/logsdb/src/javaRestTest/java/org/elasticsearch/xpack/logsdb/LogsdbRestIT.java b/x-pack/plugin/logsdb/src/javaRestTest/java/org/elasticsearch/xpack/logsdb/LogsdbRestIT.java index 2bf8b00cf551c..ef9480681f559 100644 --- a/x-pack/plugin/logsdb/src/javaRestTest/java/org/elasticsearch/xpack/logsdb/LogsdbRestIT.java +++ b/x-pack/plugin/logsdb/src/javaRestTest/java/org/elasticsearch/xpack/logsdb/LogsdbRestIT.java @@ -10,6 +10,8 @@ import org.elasticsearch.client.Request; import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.time.DateFormatter; +import org.elasticsearch.common.time.FormatNames; import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; @@ -17,6 +19,7 @@ import org.junit.ClassRule; import java.io.IOException; +import java.time.Instant; import java.util.List; import java.util.Map; @@ -108,4 +111,118 @@ public void testLogsdbSourceModeForLogsIndex() throws IOException { assertNull(settings.get("index.mapping.source.mode")); } + public void testEsqlRuntimeFields() throws IOException { + String mappings = """ + { + "runtime": { + "message_length": { + "type": "long" + }, + "log.offset": { + "type": "long" + } + }, + "dynamic": false, + "properties": { + "@timestamp": { + "type": "date" + }, + "log" : { + "properties": { + "level": { + "type": "keyword" + }, + "file": { + "type": "keyword" + } + } + } + } + } + """; + String indexName = "test-foo"; + createIndex(indexName, Settings.builder().put("index.mode", "logsdb").build(), mappings); + + int numDocs = 500; + var sb = new StringBuilder(); + var now = Instant.now(); + + var expectedMinTimestamp = now; + for (int i = 0; i < numDocs; i++) { + String level = randomBoolean() ? "info" : randomBoolean() ? "warning" : randomBoolean() ? "error" : "fatal"; + String msg = randomAlphaOfLength(20); + String path = randomAlphaOfLength(8); + String messageLength = Integer.toString(msg.length()); + String offset = Integer.toString(randomNonNegativeInt()); + sb.append("{ \"create\": {} }").append('\n'); + if (randomBoolean()) { + sb.append( + """ + {"@timestamp":"$now","message":"$msg","message_length":$l,"file":{"level":"$level","offset":5,"file":"$path"}} + """.replace("$now", formatInstant(now)) + .replace("$level", level) + .replace("$msg", msg) + .replace("$path", path) + .replace("$l", messageLength) + .replace("$o", offset) + ); + } else { + sb.append(""" + {"@timestamp": "$now", "message": "$msg", "message_length": $l} + """.replace("$now", formatInstant(now)).replace("$msg", msg).replace("$l", messageLength)); + } + sb.append('\n'); + if (i != numDocs - 1) { + now = now.plusSeconds(1); + } + } + var expectedMaxTimestamp = now; + + var bulkRequest = new Request("POST", "/" + indexName + "/_bulk"); + bulkRequest.setJsonEntity(sb.toString()); + bulkRequest.addParameter("refresh", "true"); + var bulkResponse = client().performRequest(bulkRequest); + var bulkResponseBody = responseAsMap(bulkResponse); + assertThat(bulkResponseBody, Matchers.hasEntry("errors", false)); + + var forceMergeRequest = new Request("POST", "/" + indexName + "/_forcemerge"); + forceMergeRequest.addParameter("max_num_segments", "1"); + var forceMergeResponse = client().performRequest(forceMergeRequest); + assertOK(forceMergeResponse); + + String query = "FROM test-foo | STATS count(*), min(@timestamp), max(@timestamp), min(message_length), max(message_length)" + + " ,sum(message_length), avg(message_length), min(log.offset), max(log.offset) | LIMIT 1"; + final Request esqlRequest = new Request("POST", "/_query"); + esqlRequest.setJsonEntity(""" + { + "query": "$query" + } + """.replace("$query", query)); + var esqlResponse = client().performRequest(esqlRequest); + assertOK(esqlResponse); + Map esqlResponseBody = responseAsMap(esqlResponse); + + List values = (List) esqlResponseBody.get("values"); + assertThat(values, Matchers.not(Matchers.empty())); + var count = ((List) values.getFirst()).get(0); + assertThat(count, equalTo(numDocs)); + logger.warn("VALUES: {}", values); + + var minTimestamp = ((List) values.getFirst()).get(1); + assertThat(minTimestamp, equalTo(formatInstant(expectedMinTimestamp))); + var maxTimestamp = ((List) values.getFirst()).get(2); + assertThat(maxTimestamp, equalTo(formatInstant(expectedMaxTimestamp))); + + var minLength = ((List) values.getFirst()).get(3); + assertThat(minLength, equalTo(20)); + var maxLength = ((List) values.getFirst()).get(4); + assertThat(maxLength, equalTo(20)); + var sumLength = ((List) values.getFirst()).get(5); + assertThat(sumLength, equalTo(20 * numDocs)); + } + + static String formatInstant(Instant instant) { + return DateFormatter.forPattern(FormatNames.STRICT_DATE_OPTIONAL_TIME.getName()).format(instant); + } + } diff --git a/x-pack/plugin/logsdb/src/main/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseService.java b/x-pack/plugin/logsdb/src/main/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseService.java index 26a672fb1c903..e629f9b3998bb 100644 --- a/x-pack/plugin/logsdb/src/main/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseService.java +++ b/x-pack/plugin/logsdb/src/main/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseService.java @@ -29,7 +29,7 @@ final class SyntheticSourceLicenseService { // You can only override this property if you received explicit approval from Elastic. static final String CUTOFF_DATE_SYS_PROP_NAME = "es.mapping.synthetic_source_fallback_to_stored_source.cutoff_date_restricted_override"; private static final Logger LOGGER = LogManager.getLogger(SyntheticSourceLicenseService.class); - static final long DEFAULT_CUTOFF_DATE = LocalDateTime.of(2025, 2, 1, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); + static final long DEFAULT_CUTOFF_DATE = LocalDateTime.of(2025, 2, 4, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); /** * A setting that determines whether source mode should always be stored source. Regardless of licence. diff --git a/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/LegacyLicenceIntegrationTests.java b/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/LegacyLicenceIntegrationTests.java index 890bc464a2579..f8f307b572f33 100644 --- a/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/LegacyLicenceIntegrationTests.java +++ b/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/LegacyLicenceIntegrationTests.java @@ -69,7 +69,8 @@ public void testSyntheticSourceUsageWithLegacyLicense() { } public void testSyntheticSourceUsageWithLegacyLicensePastCutoff() throws Exception { - long startPastCutoff = LocalDateTime.of(2025, 11, 12, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); + // One day after default cutoff date + long startPastCutoff = LocalDateTime.of(2025, 2, 5, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); putLicense(createGoldOrPlatinumLicense(startPastCutoff)); ensureGreen(); diff --git a/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceIndexSettingsProviderLegacyLicenseTests.java b/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceIndexSettingsProviderLegacyLicenseTests.java index eda0d87868745..c871a7d0216ed 100644 --- a/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceIndexSettingsProviderLegacyLicenseTests.java +++ b/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceIndexSettingsProviderLegacyLicenseTests.java @@ -98,7 +98,7 @@ public void testGetAdditionalIndexSettingsTsdb() throws IOException { } public void testGetAdditionalIndexSettingsTsdbAfterCutoffDate() throws Exception { - long start = LocalDateTime.of(2025, 2, 2, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); + long start = LocalDateTime.of(2025, 2, 5, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); License license = createGoldOrPlatinumLicense(start); long time = LocalDateTime.of(2024, 12, 31, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); var licenseState = new XPackLicenseState(() -> time, new XPackLicenseStatus(license.operationMode(), true, null)); diff --git a/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseServiceTests.java b/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseServiceTests.java index 90a13b16c028e..0eb0d21ff2e78 100644 --- a/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseServiceTests.java +++ b/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseServiceTests.java @@ -41,6 +41,7 @@ public void setup() throws Exception { public void testLicenseAllowsSyntheticSource() { MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(true); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); @@ -53,6 +54,7 @@ public void testLicenseAllowsSyntheticSource() { public void testLicenseAllowsSyntheticSourceTemplateValidation() { MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(true); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); @@ -65,6 +67,7 @@ public void testLicenseAllowsSyntheticSourceTemplateValidation() { public void testDefaultDisallow() { MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(false); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); @@ -77,6 +80,7 @@ public void testDefaultDisallow() { public void testFallback() { MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(true); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); @@ -95,6 +99,7 @@ public void testGoldOrPlatinumLicense() throws Exception { when(mockLicenseService.getLicense()).thenReturn(license); MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.getOperationMode()).thenReturn(license.operationMode()); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE_LEGACY))).thenReturn(true); licenseService.setLicenseState(licenseState); @@ -103,6 +108,8 @@ public void testGoldOrPlatinumLicense() throws Exception { "legacy licensed usage is allowed, so not fallback to stored source", licenseService.fallbackToStoredSource(false, true) ); + Mockito.verify(licenseState, Mockito.times(1)).isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE)); + Mockito.verify(licenseState, Mockito.times(1)).isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE_LEGACY)); Mockito.verify(licenseState, Mockito.times(1)).featureUsed(any()); } @@ -112,6 +119,7 @@ public void testGoldOrPlatinumLicenseLegacyLicenseNotAllowed() throws Exception when(mockLicenseService.getLicense()).thenReturn(license); MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.getOperationMode()).thenReturn(license.operationMode()); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(false); licenseService.setLicenseState(licenseState); @@ -125,14 +133,16 @@ public void testGoldOrPlatinumLicenseLegacyLicenseNotAllowed() throws Exception } public void testGoldOrPlatinumLicenseBeyondCutoffDate() throws Exception { - long start = LocalDateTime.of(2025, 1, 1, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); + long start = LocalDateTime.of(2025, 2, 5, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); License license = createGoldOrPlatinumLicense(start); mockLicenseService = mock(LicenseService.class); when(mockLicenseService.getLicense()).thenReturn(license); MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.getOperationMode()).thenReturn(license.operationMode()); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(false); + when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE_LEGACY))).thenReturn(true); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); assertTrue("beyond cutoff date, so fallback to stored source", licenseService.fallbackToStoredSource(false, true)); @@ -143,19 +153,21 @@ public void testGoldOrPlatinumLicenseBeyondCutoffDate() throws Exception { public void testGoldOrPlatinumLicenseCustomCutoffDate() throws Exception { licenseService = new SyntheticSourceLicenseService(Settings.EMPTY, "2025-01-02T00:00"); - long start = LocalDateTime.of(2025, 1, 1, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); + long start = LocalDateTime.of(2025, 1, 3, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); License license = createGoldOrPlatinumLicense(start); mockLicenseService = mock(LicenseService.class); when(mockLicenseService.getLicense()).thenReturn(license); MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.getOperationMode()).thenReturn(license.operationMode()); + when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(false); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE_LEGACY))).thenReturn(true); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); - assertFalse("custom cutoff date, so fallback to stored source", licenseService.fallbackToStoredSource(false, true)); - Mockito.verify(licenseState, Mockito.times(1)).featureUsed(any()); - Mockito.verify(licenseState, Mockito.times(1)).isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE_LEGACY)); + assertTrue("custom cutoff date, so fallback to stored source", licenseService.fallbackToStoredSource(false, true)); + Mockito.verify(licenseState, Mockito.times(1)).isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE)); + Mockito.verify(licenseState, Mockito.never()).featureUsed(any()); } static License createEnterpriseLicense() throws Exception { diff --git a/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java b/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java index 3b68fc9995b57..62716e11f1720 100644 --- a/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java +++ b/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java @@ -51,7 +51,10 @@ protected Collection> nodePlugins() { public void testNonExistentDataStream() { String nonExistentDataStreamName = randomAlphaOfLength(50); - ReindexDataStreamRequest reindexDataStreamRequest = new ReindexDataStreamRequest(nonExistentDataStreamName); + ReindexDataStreamRequest reindexDataStreamRequest = new ReindexDataStreamRequest( + ReindexDataStreamAction.Mode.UPGRADE, + nonExistentDataStreamName + ); assertThrows( ResourceNotFoundException.class, () -> client().execute(new ActionType(ReindexDataStreamAction.NAME), reindexDataStreamRequest) @@ -61,7 +64,10 @@ public void testNonExistentDataStream() { public void testAlreadyUpToDateDataStream() throws Exception { String dataStreamName = randomAlphaOfLength(50).toLowerCase(Locale.ROOT); - ReindexDataStreamRequest reindexDataStreamRequest = new ReindexDataStreamRequest(dataStreamName); + ReindexDataStreamRequest reindexDataStreamRequest = new ReindexDataStreamRequest( + ReindexDataStreamAction.Mode.UPGRADE, + dataStreamName + ); createDataStream(dataStreamName); ReindexDataStreamResponse response = client().execute( new ActionType(ReindexDataStreamAction.NAME), diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/MigratePlugin.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/MigratePlugin.java index 118cd69ece4d6..ac9e38da07421 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/MigratePlugin.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/MigratePlugin.java @@ -11,21 +11,30 @@ import org.elasticsearch.action.ActionResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.IndexScopedSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.settings.SettingsFilter; import org.elasticsearch.common.settings.SettingsModule; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.persistent.PersistentTaskParams; import org.elasticsearch.persistent.PersistentTaskState; import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestHandler; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction; import org.elasticsearch.xpack.migrate.action.ReindexDataStreamTransportAction; +import org.elasticsearch.xpack.migrate.rest.RestMigrationReindexAction; import org.elasticsearch.xpack.migrate.task.ReindexDataStreamPersistentTaskExecutor; import org.elasticsearch.xpack.migrate.task.ReindexDataStreamPersistentTaskState; import org.elasticsearch.xpack.migrate.task.ReindexDataStreamStatus; @@ -34,47 +43,80 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.REINDEX_DATA_STREAM_FEATURE_FLAG; public class MigratePlugin extends Plugin implements ActionPlugin, PersistentTaskPlugin { + @Override + public List getRestHandlers( + Settings unused, + NamedWriteableRegistry namedWriteableRegistry, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster, + Predicate clusterSupportsFeature + ) { + List handlers = new ArrayList<>(); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + handlers.add(new RestMigrationReindexAction()); + } + return handlers; + } + @Override public List> getActions() { List> actions = new ArrayList<>(); - actions.add(new ActionHandler<>(ReindexDataStreamAction.INSTANCE, ReindexDataStreamTransportAction.class)); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + actions.add(new ActionHandler<>(ReindexDataStreamAction.INSTANCE, ReindexDataStreamTransportAction.class)); + } return actions; } @Override public List getNamedXContent() { - return List.of( - new NamedXContentRegistry.Entry( - PersistentTaskState.class, - new ParseField(ReindexDataStreamPersistentTaskState.NAME), - ReindexDataStreamPersistentTaskState::fromXContent - ), - new NamedXContentRegistry.Entry( - PersistentTaskParams.class, - new ParseField(ReindexDataStreamTaskParams.NAME), - ReindexDataStreamTaskParams::fromXContent - ) - ); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + return List.of( + new NamedXContentRegistry.Entry( + PersistentTaskState.class, + new ParseField(ReindexDataStreamPersistentTaskState.NAME), + ReindexDataStreamPersistentTaskState::fromXContent + ), + new NamedXContentRegistry.Entry( + PersistentTaskParams.class, + new ParseField(ReindexDataStreamTaskParams.NAME), + ReindexDataStreamTaskParams::fromXContent + ) + ); + } else { + return List.of(); + } } @Override public List getNamedWriteables() { - return List.of( - new NamedWriteableRegistry.Entry( - PersistentTaskState.class, - ReindexDataStreamPersistentTaskState.NAME, - ReindexDataStreamPersistentTaskState::new - ), - new NamedWriteableRegistry.Entry( - PersistentTaskParams.class, - ReindexDataStreamTaskParams.NAME, - ReindexDataStreamTaskParams::new - ), - new NamedWriteableRegistry.Entry(Task.Status.class, ReindexDataStreamStatus.NAME, ReindexDataStreamStatus::new) - ); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + return List.of( + new NamedWriteableRegistry.Entry( + PersistentTaskState.class, + ReindexDataStreamPersistentTaskState.NAME, + ReindexDataStreamPersistentTaskState::new + ), + new NamedWriteableRegistry.Entry( + PersistentTaskParams.class, + ReindexDataStreamTaskParams.NAME, + ReindexDataStreamTaskParams::new + ), + new NamedWriteableRegistry.Entry(Task.Status.class, ReindexDataStreamStatus.NAME, ReindexDataStreamStatus::new) + ); + } else { + return List.of(); + } } @Override @@ -85,6 +127,12 @@ public List> getPersistentTasksExecutor( SettingsModule settingsModule, IndexNameExpressionResolver expressionResolver ) { - return List.of(new ReindexDataStreamPersistentTaskExecutor(client, clusterService, ReindexDataStreamTask.TASK_NAME, threadPool)); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + return List.of( + new ReindexDataStreamPersistentTaskExecutor(client, clusterService, ReindexDataStreamTask.TASK_NAME, threadPool) + ); + } else { + return List.of(); + } } } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamAction.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamAction.java index 1785e6971f824..eb7a910df8c0c 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamAction.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamAction.java @@ -11,23 +11,41 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.IndicesRequest; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.FeatureFlag; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.Locale; import java.util.Objects; +import java.util.function.Predicate; public class ReindexDataStreamAction extends ActionType { + public static final FeatureFlag REINDEX_DATA_STREAM_FEATURE_FLAG = new FeatureFlag("reindex_data_stream"); public static final ReindexDataStreamAction INSTANCE = new ReindexDataStreamAction(); public static final String NAME = "indices:admin/data_stream/reindex"; + public static final ParseField MODE_FIELD = new ParseField("mode"); + public static final ParseField SOURCE_FIELD = new ParseField("source"); + public static final ParseField INDEX_FIELD = new ParseField("index"); public ReindexDataStreamAction() { super(NAME); } + public enum Mode { + UPGRADE + } + public static class ReindexDataStreamResponse extends ActionResponse implements ToXContentObject { private final String taskId; @@ -49,7 +67,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field("task", getTaskId()); + builder.field("acknowledged", true); builder.endObject(); return builder; } @@ -70,22 +88,52 @@ public boolean equals(Object other) { } - public static class ReindexDataStreamRequest extends ActionRequest { + public static class ReindexDataStreamRequest extends ActionRequest implements IndicesRequest, ToXContent { + private final Mode mode; private final String sourceDataStream; - public ReindexDataStreamRequest(String sourceDataStream) { - super(); + public ReindexDataStreamRequest(Mode mode, String sourceDataStream) { + this.mode = mode; this.sourceDataStream = sourceDataStream; } public ReindexDataStreamRequest(StreamInput in) throws IOException { super(in); + this.mode = Mode.valueOf(in.readString()); this.sourceDataStream = in.readString(); } + private static final ConstructingObjectParser> PARSER = + new ConstructingObjectParser<>("migration_reindex", objects -> { + Mode mode = Mode.valueOf(((String) objects[0]).toUpperCase(Locale.ROOT)); + String source = (String) objects[1]; + return new ReindexDataStreamRequest(mode, source); + }); + + private static final ConstructingObjectParser SOURCE_PARSER = new ConstructingObjectParser<>( + SOURCE_FIELD.getPreferredName(), + false, + (a, id) -> (String) a[0] + ); + + static { + SOURCE_PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), MODE_FIELD); + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (parser, id) -> SOURCE_PARSER.apply(parser, null), + SOURCE_FIELD + ); + } + + public static ReindexDataStreamRequest fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + out.writeString(mode.name()); out.writeString(sourceDataStream); } @@ -103,15 +151,42 @@ public String getSourceDataStream() { return sourceDataStream; } + public Mode getMode() { + return mode; + } + @Override public int hashCode() { - return Objects.hashCode(sourceDataStream); + return Objects.hash(mode, sourceDataStream); } @Override public boolean equals(Object other) { - return other instanceof ReindexDataStreamRequest - && sourceDataStream.equals(((ReindexDataStreamRequest) other).sourceDataStream); + return other instanceof ReindexDataStreamRequest otherRequest + && mode.equals(otherRequest.mode) + && sourceDataStream.equals(otherRequest.sourceDataStream); + } + + @Override + public String[] indices() { + return new String[] { sourceDataStream }; + } + + @Override + public IndicesOptions indicesOptions() { + return IndicesOptions.strictSingleIndexNoExpandForbidClosed(); + } + + /* + * This only exists for the sake of testing the xcontent parser + */ + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MODE_FIELD.getPreferredName(), mode); + builder.startObject(SOURCE_FIELD.getPreferredName()); + builder.field(INDEX_FIELD.getPreferredName(), sourceDataStream); + builder.endObject(); + return builder; } } } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportAction.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportAction.java index d532b001f5aaa..7f68007f821ba 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportAction.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.ReindexDataStreamRequest; import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.ReindexDataStreamResponse; import org.elasticsearch.xpack.migrate.task.ReindexDataStreamTask; @@ -72,7 +73,8 @@ protected void doExecute(Task task, ReindexDataStreamRequest request, ActionList sourceDataStreamName, transportService.getThreadPool().absoluteTimeInMillis(), totalIndices, - totalIndicesToBeUpgraded + totalIndicesToBeUpgraded, + ClientHelper.getPersistableSafeSecurityHeaders(transportService.getThreadPool().getThreadContext(), clusterService.state()) ); String persistentTaskId = getPersistentTaskId(sourceDataStreamName); persistentTasksService.sendStartRequest( diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/rest/RestMigrationReindexAction.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/rest/RestMigrationReindexAction.java new file mode 100644 index 0000000000000..a7f630d68234d --- /dev/null +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/rest/RestMigrationReindexAction.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.migrate.rest; + +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.rest.action.RestBuilderListener; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction; +import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.ReindexDataStreamResponse; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; + +public class RestMigrationReindexAction extends BaseRestHandler { + + @Override + public String getName() { + return "migration_reindex"; + } + + @Override + public List routes() { + return List.of(new Route(POST, "/_migration/reindex")); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + ReindexDataStreamAction.ReindexDataStreamRequest reindexRequest; + try (XContentParser parser = request.contentParser()) { + reindexRequest = ReindexDataStreamAction.ReindexDataStreamRequest.fromXContent(parser); + } + return channel -> client.execute( + ReindexDataStreamAction.INSTANCE, + reindexRequest, + new ReindexDataStreamRestToXContentListener(channel) + ); + } + + static class ReindexDataStreamRestToXContentListener extends RestBuilderListener { + + ReindexDataStreamRestToXContentListener(RestChannel channel) { + super(channel); + } + + @Override + public RestResponse buildResponse(ReindexDataStreamResponse response, XContentBuilder builder) throws Exception { + response.toXContent(builder, channel.request()); + return new RestResponse(RestStatus.OK, builder); + } + } +} diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ExecuteWithHeadersClient.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ExecuteWithHeadersClient.java new file mode 100644 index 0000000000000..a8962f56468bc --- /dev/null +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ExecuteWithHeadersClient.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.migrate.task; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.support.AbstractClient; +import org.elasticsearch.xpack.core.ClientHelper; + +import java.util.Map; + +public class ExecuteWithHeadersClient extends AbstractClient { + + private final Client client; + private final Map headers; + + public ExecuteWithHeadersClient(Client client, Map headers) { + super(client.settings(), client.threadPool()); + this.client = client; + this.headers = headers; + } + + @Override + protected void doExecute( + ActionType action, + Request request, + ActionListener listener + ) { + ClientHelper.executeWithHeadersAsync(headers, null, client, action, request, listener); + } + +} diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskExecutor.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskExecutor.java index e2a41ea186643..fc471cfa89f26 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskExecutor.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamPersistentTaskExecutor.java @@ -51,7 +51,6 @@ protected ReindexDataStreamTask createTask( params.startTime(), params.totalIndices(), params.totalIndicesToBeUpgraded(), - threadPool, id, type, action, @@ -67,16 +66,19 @@ protected void nodeOperation(AllocatedPersistentTask task, ReindexDataStreamTask GetDataStreamAction.Request request = new GetDataStreamAction.Request(TimeValue.MAX_VALUE, new String[] { sourceDataStream }); assert task instanceof ReindexDataStreamTask; final ReindexDataStreamTask reindexDataStreamTask = (ReindexDataStreamTask) task; - client.execute(GetDataStreamAction.INSTANCE, request, ActionListener.wrap(response -> { + ExecuteWithHeadersClient reindexClient = new ExecuteWithHeadersClient(client, params.headers()); + reindexClient.execute(GetDataStreamAction.INSTANCE, request, ActionListener.wrap(response -> { List dataStreamInfos = response.getDataStreams(); if (dataStreamInfos.size() == 1) { List indices = dataStreamInfos.getFirst().getDataStream().getIndices(); List indicesToBeReindexed = indices.stream() .filter(index -> clusterService.state().getMetadata().index(index).getCreationVersion().isLegacyIndexVersion()) .toList(); - reindexDataStreamTask.setPendingIndices(indicesToBeReindexed.stream().map(Index::getName).toList()); + reindexDataStreamTask.setPendingIndicesCount(indicesToBeReindexed.size()); for (Index index : indicesToBeReindexed) { + reindexDataStreamTask.incrementInProgressIndicesCount(); // TODO This is just a placeholder. This is where the real data stream reindex logic will go + reindexDataStreamTask.reindexSucceeded(); } completeSuccessfulPersistentTask(reindexDataStreamTask); @@ -87,12 +89,12 @@ protected void nodeOperation(AllocatedPersistentTask task, ReindexDataStreamTask } private void completeSuccessfulPersistentTask(ReindexDataStreamTask persistentTask) { - persistentTask.reindexSucceeded(); + persistentTask.allReindexesCompleted(); threadPool.schedule(persistentTask::markAsCompleted, getTimeToLive(persistentTask), threadPool.generic()); } private void completeFailedPersistentTask(ReindexDataStreamTask persistentTask, Exception e) { - persistentTask.reindexFailed(e); + persistentTask.taskFailed(e); threadPool.schedule(() -> persistentTask.markAsFailed(e), getTimeToLive(persistentTask), threadPool.generic()); } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTask.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTask.java index 722b30d9970db..72ddb87e9dea5 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTask.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTask.java @@ -10,29 +10,27 @@ import org.elasticsearch.core.Tuple; import org.elasticsearch.persistent.AllocatedPersistentTask; import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.threadpool.ThreadPool; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; public class ReindexDataStreamTask extends AllocatedPersistentTask { public static final String TASK_NAME = "reindex-data-stream"; private final long persistentTaskStartTime; private final int totalIndices; private final int totalIndicesToBeUpgraded; - private final ThreadPool threadPool; private boolean complete = false; private Exception exception; - private List inProgress = new ArrayList<>(); - private List pending = List.of(); + private AtomicInteger inProgress = new AtomicInteger(0); + private AtomicInteger pending = new AtomicInteger(); private List> errors = new ArrayList<>(); public ReindexDataStreamTask( long persistentTaskStartTime, int totalIndices, int totalIndicesToBeUpgraded, - ThreadPool threadPool, long id, String type, String action, @@ -44,7 +42,6 @@ public ReindexDataStreamTask( this.persistentTaskStartTime = persistentTaskStartTime; this.totalIndices = totalIndices; this.totalIndicesToBeUpgraded = totalIndicesToBeUpgraded; - this.threadPool = threadPool; } @Override @@ -55,30 +52,36 @@ public ReindexDataStreamStatus getStatus() { totalIndicesToBeUpgraded, complete, exception, - inProgress.size(), - pending.size(), + inProgress.get(), + pending.get(), errors ); } - public void reindexSucceeded() { + public void allReindexesCompleted() { this.complete = true; } - public void reindexFailed(Exception e) { + public void taskFailed(Exception e) { this.complete = true; this.exception = e; } - public void setInProgressIndices(List inProgressIndices) { - this.inProgress = inProgressIndices; + public void reindexSucceeded() { + inProgress.decrementAndGet(); + } + + public void reindexFailed(String index, Exception error) { + this.errors.add(Tuple.tuple(index, error)); + inProgress.decrementAndGet(); } - public void setPendingIndices(List pendingIndices) { - this.pending = pendingIndices; + public void incrementInProgressIndicesCount() { + inProgress.incrementAndGet(); + pending.decrementAndGet(); } - public void addErrorIndex(String index, Exception error) { - this.errors.add(Tuple.tuple(index, error)); + public void setPendingIndicesCount(int size) { + pending.set(size); } } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTaskParams.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTaskParams.java index 0f26713a75184..7c4b0007bb632 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTaskParams.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTaskParams.java @@ -9,41 +9,65 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.persistent.PersistentTaskParams; import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.Map; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; -public record ReindexDataStreamTaskParams(String sourceDataStream, long startTime, int totalIndices, int totalIndicesToBeUpgraded) - implements - PersistentTaskParams { +public record ReindexDataStreamTaskParams( + String sourceDataStream, + long startTime, + int totalIndices, + int totalIndicesToBeUpgraded, + Map headers +) implements PersistentTaskParams { + + private static final String API_CONTEXT = Metadata.XContentContext.API.toString(); public static final String NAME = ReindexDataStreamTask.TASK_NAME; private static final String SOURCE_DATA_STREAM_FIELD = "source_data_stream"; private static final String START_TIME_FIELD = "start_time"; private static final String TOTAL_INDICES_FIELD = "total_indices"; private static final String TOTAL_INDICES_TO_BE_UPGRADED_FIELD = "total_indices_to_be_upgraded"; + private static final String HEADERS_FIELD = "headers"; + @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( NAME, true, - args -> new ReindexDataStreamTaskParams((String) args[0], (long) args[1], (int) args[2], (int) args[3]) + args -> new ReindexDataStreamTaskParams( + (String) args[0], + (long) args[1], + (int) args[2], + (int) args[3], + args[4] == null ? Map.of() : (Map) args[4] + ) ); static { PARSER.declareString(constructorArg(), new ParseField(SOURCE_DATA_STREAM_FIELD)); PARSER.declareLong(constructorArg(), new ParseField(START_TIME_FIELD)); PARSER.declareInt(constructorArg(), new ParseField(TOTAL_INDICES_FIELD)); PARSER.declareInt(constructorArg(), new ParseField(TOTAL_INDICES_TO_BE_UPGRADED_FIELD)); + PARSER.declareField( + ConstructingObjectParser.optionalConstructorArg(), + XContentParser::mapStrings, + new ParseField(HEADERS_FIELD), + ObjectParser.ValueType.OBJECT + ); } + @SuppressWarnings("unchecked") public ReindexDataStreamTaskParams(StreamInput in) throws IOException { - this(in.readString(), in.readLong(), in.readInt(), in.readInt()); + this(in.readString(), in.readLong(), in.readInt(), in.readInt(), (Map) in.readGenericValue()); } @Override @@ -62,16 +86,22 @@ public void writeTo(StreamOutput out) throws IOException { out.writeLong(startTime); out.writeInt(totalIndices); out.writeInt(totalIndicesToBeUpgraded); + out.writeGenericValue(headers); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.startObject() + builder.startObject() .field(SOURCE_DATA_STREAM_FIELD, sourceDataStream) .field(START_TIME_FIELD, startTime) .field(TOTAL_INDICES_FIELD, totalIndices) - .field(TOTAL_INDICES_TO_BE_UPGRADED_FIELD, totalIndicesToBeUpgraded) - .endObject(); + .field(TOTAL_INDICES_TO_BE_UPGRADED_FIELD, totalIndicesToBeUpgraded); + if (API_CONTEXT.equals(params.param(Metadata.CONTEXT_MODE_PARAM, API_CONTEXT)) == false) { + // This makes sure that we don't return the headers to an api request, like _cluster/state + builder.stringStringMap(HEADERS_FIELD, headers); + } + builder.endObject(); + return builder; } public String getSourceDataStream() { @@ -81,4 +111,8 @@ public String getSourceDataStream() { public static ReindexDataStreamTaskParams fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } + + public Map getHeaders() { + return headers; + } } diff --git a/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamRequestTests.java b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamRequestTests.java new file mode 100644 index 0000000000000..9c7bf87b6cff0 --- /dev/null +++ b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamRequestTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.migrate.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractXContentSerializingTestCase; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.ReindexDataStreamRequest; + +import java.io.IOException; + +public class ReindexDataStreamRequestTests extends AbstractXContentSerializingTestCase { + + @Override + protected ReindexDataStreamRequest createTestInstance() { + return new ReindexDataStreamRequest(ReindexDataStreamAction.Mode.UPGRADE, randomAlphaOfLength(40)); + } + + @Override + protected ReindexDataStreamRequest mutateInstance(ReindexDataStreamRequest instance) { + // There is currently only one possible value for mode, so we can't change it + return new ReindexDataStreamRequest(instance.getMode(), randomAlphaOfLength(50)); + } + + @Override + protected ReindexDataStreamRequest doParseInstance(XContentParser parser) throws IOException { + return ReindexDataStreamRequest.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return ReindexDataStreamRequest::new; + } +} diff --git a/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamResponseTests.java b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamResponseTests.java index 06844577c4e36..d886fc660d7a8 100644 --- a/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamResponseTests.java +++ b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamResponseTests.java @@ -43,7 +43,7 @@ public void testToXContent() throws IOException { builder.humanReadable(true); response.toXContent(builder, EMPTY_PARAMS); try (XContentParser parser = createParser(JsonXContent.jsonXContent, BytesReference.bytes(builder))) { - assertThat(parser.map(), equalTo(Map.of("task", response.getTaskId()))); + assertThat(parser.map(), equalTo(Map.of("acknowledged", true))); } } } diff --git a/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTaskParamsTests.java b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTaskParamsTests.java index fc39b5d8cb703..67ade297f27ad 100644 --- a/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTaskParamsTests.java +++ b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/task/ReindexDataStreamTaskParamsTests.java @@ -7,11 +7,14 @@ package org.elasticsearch.xpack.migrate.task; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractXContentSerializingTestCase; +import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; @@ -29,7 +32,26 @@ protected Writeable.Reader instanceReader() { @Override protected ReindexDataStreamTaskParams createTestInstance() { - return new ReindexDataStreamTaskParams(randomAlphaOfLength(50), randomLong(), randomNonNegativeInt(), randomNonNegativeInt()); + return createTestInstance(randomBoolean()); + } + + @Override + protected ReindexDataStreamTaskParams createXContextTestInstance(XContentType xContentType) { + /* + * Since we filter out headers from xcontent in some cases, we can't use them in the standard xcontent round trip testing. + * Headers are covered in testToXContentContextMode + */ + return createTestInstance(false); + } + + private ReindexDataStreamTaskParams createTestInstance(boolean withHeaders) { + return new ReindexDataStreamTaskParams( + randomAlphaOfLength(50), + randomLong(), + randomNonNegativeInt(), + randomNonNegativeInt(), + getTestHeaders(withHeaders) + ); } @Override @@ -38,14 +60,16 @@ protected ReindexDataStreamTaskParams mutateInstance(ReindexDataStreamTaskParams long startTime = instance.startTime(); int totalIndices = instance.totalIndices(); int totalIndicesToBeUpgraded = instance.totalIndicesToBeUpgraded(); - switch (randomIntBetween(0, 3)) { + Map headers = instance.headers(); + switch (randomIntBetween(0, 4)) { case 0 -> sourceDataStream = randomAlphaOfLength(50); case 1 -> startTime = randomLong(); case 2 -> totalIndices = totalIndices + 1; case 3 -> totalIndices = totalIndicesToBeUpgraded + 1; + case 4 -> headers = headers.isEmpty() ? getTestHeaders(true) : getTestHeaders(); default -> throw new UnsupportedOperationException(); } - return new ReindexDataStreamTaskParams(sourceDataStream, startTime, totalIndices, totalIndicesToBeUpgraded); + return new ReindexDataStreamTaskParams(sourceDataStream, startTime, totalIndices, totalIndicesToBeUpgraded, headers); } @Override @@ -53,6 +77,18 @@ protected ReindexDataStreamTaskParams doParseInstance(XContentParser parser) { return ReindexDataStreamTaskParams.fromXContent(parser); } + private Map getTestHeaders() { + return getTestHeaders(randomBoolean()); + } + + private Map getTestHeaders(boolean nonEmpty) { + if (nonEmpty) { + return Map.of(randomAlphaOfLength(20), randomAlphaOfLength(30)); + } else { + return Map.of(); + } + } + public void testToXContent() throws IOException { ReindexDataStreamTaskParams params = createTestInstance(); try (XContentBuilder builder = XContentBuilder.builder(JsonXContent.jsonXContent)) { @@ -65,4 +101,41 @@ public void testToXContent() throws IOException { } } } + + public void testToXContentContextMode() throws IOException { + ReindexDataStreamTaskParams params = createTestInstance(true); + + // We do not expect to get headers if the "content_mode" is "api" + try (XContentBuilder builder = XContentBuilder.builder(JsonXContent.jsonXContent)) { + builder.humanReadable(true); + ToXContent.Params xContentParams = new ToXContent.MapParams( + Map.of(Metadata.CONTEXT_MODE_PARAM, Metadata.XContentContext.API.toString()) + ); + params.toXContent(builder, xContentParams); + try (XContentParser parser = createParser(JsonXContent.jsonXContent, BytesReference.bytes(builder))) { + Map parserMap = parser.map(); + assertThat(parserMap.get("source_data_stream"), equalTo(params.sourceDataStream())); + assertThat(((Number) parserMap.get("start_time")).longValue(), equalTo(params.startTime())); + assertThat(parserMap.containsKey("headers"), equalTo(false)); + } + } + + // We do expect to get headers if the "content_mode" is anything but "api" + try (XContentBuilder builder = XContentBuilder.builder(JsonXContent.jsonXContent)) { + builder.humanReadable(true); + ToXContent.Params xContentParams = new ToXContent.MapParams( + Map.of( + Metadata.CONTEXT_MODE_PARAM, + randomFrom(Metadata.XContentContext.GATEWAY.toString(), Metadata.XContentContext.SNAPSHOT.toString()) + ) + ); + params.toXContent(builder, xContentParams); + try (XContentParser parser = createParser(JsonXContent.jsonXContent, BytesReference.bytes(builder))) { + Map parserMap = parser.map(); + assertThat(parserMap.get("source_data_stream"), equalTo(params.sourceDataStream())); + assertThat(((Number) parserMap.get("start_time")).longValue(), equalTo(params.startTime())); + assertThat(parserMap.get("headers"), equalTo(params.getHeaders())); + } + } + } } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DatafeedCcsIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DatafeedCcsIT.java index 139d1b074c7b2..e437c91c8e50e 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DatafeedCcsIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/DatafeedCcsIT.java @@ -94,7 +94,7 @@ protected Collection> nodePlugins(String clusterAlias) { } @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 9187969fc25a4..c6f1ebcc10780 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -631,12 +631,15 @@ synchronized void forcefullyStopProcess() { logger.debug(() -> format("[%s] Forcefully stopping process", task.getDeploymentId())); prepareInternalStateForShutdown(); - if (priorityProcessWorker.isShutdown()) { - // most likely there was a crash or exception that caused the - // thread to stop. Notify any waiting requests in the work queue - handleAlreadyShuttingDownWorker(); - } else { - priorityProcessWorker.shutdown(); + priorityProcessWorker.shutdownNow(); + try { + // wait for any currently executing work to finish + if (priorityProcessWorker.awaitTermination(10L, TimeUnit.SECONDS)) { + priorityProcessWorker.notifyQueueRunnables(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.info(Strings.format("[%s] Interrupted waiting for process worker after shutdownNow", PROCESS_NAME)); } killProcessIfPresent(); @@ -649,12 +652,6 @@ private void prepareInternalStateForShutdown() { stateStreamer.cancel(); } - private void handleAlreadyShuttingDownWorker() { - logger.debug(() -> format("[%s] Process worker was already marked for shutdown", task.getDeploymentId())); - - priorityProcessWorker.notifyQueueRunnables(); - } - private void killProcessIfPresent() { try { if (process.get() == null) { @@ -675,15 +672,7 @@ private void closeNlpTaskProcessor() { private synchronized void stopProcessAfterCompletingPendingWork() { logger.debug(() -> format("[%s] Stopping process after completing its pending work", task.getDeploymentId())); prepareInternalStateForShutdown(); - - if (priorityProcessWorker.isShutdown()) { - // most likely there was a crash or exception that caused the - // thread to stop. Notify any waiting requests in the work queue - handleAlreadyShuttingDownWorker(); - } else { - signalAndWaitForWorkerTermination(); - } - + signalAndWaitForWorkerTermination(); stopProcessGracefully(); closeNlpTaskProcessor(); } @@ -707,6 +696,8 @@ private void awaitTerminationAfterCompletingWork() throws TimeoutException { throw new TimeoutException( Strings.format("Timed out waiting for process worker to complete for process %s", PROCESS_NAME) ); + } else { + priorityProcessWorker.notifyQueueRunnables(); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java index 46edcf1f63c01..b59ef0c40e4f9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java @@ -304,7 +304,7 @@ public String getWriteableName() { @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.LTR_SERVERLESS_RELEASE; + return TransportVersions.V_8_16_0; } @Override diff --git a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/type/DataTypes.java b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/type/DataTypes.java index 6aa47f7c817a7..c67d943b11e22 100644 --- a/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/type/DataTypes.java +++ b/x-pack/plugin/ql/src/main/java/org/elasticsearch/xpack/ql/type/DataTypes.java @@ -112,41 +112,21 @@ public static DataType fromEs(String name) { } public static DataType fromJava(Object value) { - if (value == null) { - return NULL; - } - if (value instanceof Integer) { - return INTEGER; - } - if (value instanceof Long) { - return LONG; - } - if (value instanceof BigInteger) { - return UNSIGNED_LONG; - } - if (value instanceof Boolean) { - return BOOLEAN; - } - if (value instanceof Double) { - return DOUBLE; - } - if (value instanceof Float) { - return FLOAT; - } - if (value instanceof Byte) { - return BYTE; - } - if (value instanceof Short) { - return SHORT; - } - if (value instanceof ZonedDateTime) { - return DATETIME; - } - if (value instanceof String || value instanceof Character) { - return KEYWORD; - } - - return null; + return switch (value) { + case null -> NULL; + case Integer i -> INTEGER; + case Long l -> LONG; + case BigInteger bigInteger -> UNSIGNED_LONG; + case Boolean b -> BOOLEAN; + case Double v -> DOUBLE; + case Float v -> FLOAT; + case Byte b -> BYTE; + case Short s -> SHORT; + case ZonedDateTime zonedDateTime -> DATETIME; + case String s -> KEYWORD; + case Character c -> KEYWORD; + default -> null; + }; } public static boolean isUnsupported(DataType from) { diff --git a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java index 37e1807d138aa..ae35153b6f39f 100644 --- a/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java +++ b/x-pack/plugin/rank-rrf/src/internalClusterTest/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderIT.java @@ -33,6 +33,7 @@ import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; import org.elasticsearch.search.vectors.QueryVectorBuilder; +import org.elasticsearch.search.vectors.TestQueryVectorBuilderPlugin; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.hamcrest.ElasticsearchAssertions; import org.elasticsearch.xcontent.XContentBuilder; @@ -57,7 +58,6 @@ public class RRFRetrieverBuilderIT extends ESIntegTestCase { protected static String INDEX = "test_index"; - protected static final String ID_FIELD = "_id"; protected static final String DOC_FIELD = "doc"; protected static final String TEXT_FIELD = "text"; protected static final String VECTOR_FIELD = "vector"; @@ -743,6 +743,42 @@ public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder expectThrows(UnsupportedOperationException.class, () -> client().prepareSearch(INDEX).setSource(source).get()); } + public void testRRFFiltersPropagatedToKnnQueryVectorBuilder() { + final int rankWindowSize = 100; + final int rankConstant = 10; + SearchSourceBuilder source = new SearchSourceBuilder(); + // this will retriever all but 7 only due to top-level filter + StandardRetrieverBuilder standardRetriever = new StandardRetrieverBuilder(QueryBuilders.matchAllQuery()); + // this will too retrieve just doc 7 + KnnRetrieverBuilder knnRetriever = new KnnRetrieverBuilder( + "vector", + null, + new TestQueryVectorBuilderPlugin.TestQueryVectorBuilder(new float[] { 3 }), + 10, + 10, + null + ); + source.retriever( + new RRFRetrieverBuilder( + Arrays.asList( + new CompoundRetrieverBuilder.RetrieverSource(standardRetriever, null), + new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null) + ), + rankWindowSize, + rankConstant + ) + ); + source.retriever().getPreFilterQueryBuilders().add(QueryBuilders.boolQuery().must(QueryBuilders.termQuery(DOC_FIELD, "doc_7"))); + source.size(10); + SearchRequestBuilder req = client().prepareSearch(INDEX).setSource(source); + ElasticsearchAssertions.assertResponse(req, resp -> { + assertNull(resp.pointInTimeId()); + assertNotNull(resp.getHits().getTotalHits()); + assertThat(resp.getHits().getTotalHits().value(), equalTo(1L)); + assertThat(resp.getHits().getHits()[0].getId(), equalTo("doc_7")); + }); + } + public void testRewriteOnce() { final float[] vector = new float[] { 1 }; AtomicInteger numAsyncCalls = new AtomicInteger(); diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java index bbc0f622724a3..bb61fa951948d 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFFeatures.java @@ -12,6 +12,7 @@ import java.util.Set; +import static org.elasticsearch.search.retriever.CompoundRetrieverBuilder.INNER_RETRIEVERS_FILTER_SUPPORT; import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverBuilder.RRF_RETRIEVER_COMPOSITION_SUPPORTED; /** @@ -23,4 +24,9 @@ public class RRFFeatures implements FeatureSpecification { public Set getFeatures() { return Set.of(RRFRetrieverBuilder.RRF_RETRIEVER_SUPPORTED, RRF_RETRIEVER_COMPOSITION_SUPPORTED); } + + @Override + public Set getTestFeatures() { + return Set.of(INNER_RETRIEVERS_FILTER_SUPPORT); + } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java index 4cd10801b298c..84961f8442163 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRankDoc.java @@ -62,7 +62,7 @@ public RRFRankDoc(StreamInput in) throws IOException { rank = in.readVInt(); positions = in.readIntArray(); scores = in.readFloatArray(); - if (in.getTransportVersion().onOrAfter(TransportVersions.RRF_QUERY_REWRITE)) { + if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { this.rankConstant = in.readVInt(); } else { this.rankConstant = DEFAULT_RANK_CONSTANT; @@ -119,7 +119,7 @@ public void doWriteTo(StreamOutput out) throws IOException { out.writeVInt(rank); out.writeIntArray(positions); out.writeFloatArray(scores); - if (out.getTransportVersion().onOrAfter(TransportVersions.RRF_QUERY_REWRITE)) { + if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_16_0)) { out.writeVInt(rankConstant); } } @@ -173,6 +173,6 @@ protected void doToXContent(XContentBuilder builder, Params params) throws IOExc @Override public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.RRF_QUERY_REWRITE; + return TransportVersions.V_8_16_0; } } diff --git a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java index 792ff4eac3893..f1171b74f7468 100644 --- a/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java +++ b/x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.util.Maps; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.license.LicenseUtils; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; @@ -108,8 +109,10 @@ public String getName() { } @Override - protected RRFRetrieverBuilder clone(List newRetrievers) { - return new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant); + protected RRFRetrieverBuilder clone(List newRetrievers, List newPreFilterQueryBuilders) { + RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.rankWindowSize, this.rankConstant); + clone.preFilterQueryBuilders = newPreFilterQueryBuilders; + return clone; } @Override diff --git a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml index 42c01f0b9636c..cb30542d80003 100644 --- a/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml +++ b/x-pack/plugin/rank-rrf/src/yamlRestTest/resources/rest-api-spec/test/rrf/700_rrf_retriever_search_api_compatibility.yml @@ -1071,3 +1071,77 @@ setup: - match: { hits.hits.2.inner_hits.nested_data_field.hits.total.value: 0 } - match: { hits.hits.2.inner_hits.nested_vector_field.hits.total.value: 0 } + + +--- +"rrf retriever with filters to be passed to nested rrf retrievers": + - requires: + cluster_features: 'inner_retrievers_filter_support' + reason: 'requires fix for properly propagating filters to nested sub-retrievers' + + - do: + search: + _source: false + index: test + body: + retriever: + { + rrf: + { + filter: { + term: { + keyword: "technology" + } + }, + retrievers: [ + { + rrf: { + retrievers: [ + { + # this should only return docs 3 and 5 due to top level filter + standard: { + query: { + knn: { + field: vector, + query_vector: [ 4.0 ], + k: 3 + } + } + } }, + { + # this should return no docs as no docs match both biology and technology + standard: { + query: { + term: { + keyword: "biology" + } + } + } + } + ], + rank_window_size: 10, + rank_constant: 10 + } + }, + # this should only return doc 5 + { + standard: { + query: { + term: { + text: "term5" + } + } + } + } + ], + rank_window_size: 10, + rank_constant: 10 + } + } + size: 10 + + - match: { hits.total.value: 2 } + - match: { hits.hits.0._id: "5" } + - match: { hits.hits.1._id: "3" } + + diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsCanMatchOnCoordinatorIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsCanMatchOnCoordinatorIntegTests.java index 21b24db6ce8d5..23e414c0dc1bf 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsCanMatchOnCoordinatorIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/SearchableSnapshotsCanMatchOnCoordinatorIntegTests.java @@ -42,7 +42,6 @@ import org.elasticsearch.snapshots.SnapshotId; import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.NodeRoles; -import org.elasticsearch.test.junit.annotations.TestIssueLogging; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xpack.core.searchablesnapshots.MountSearchableSnapshotAction; @@ -379,7 +378,7 @@ public void testSearchableSnapshotShardsAreSkippedBySearchRequestWithoutQuerying } if (searchShardsResponse != null) { for (SearchShardsGroup group : searchShardsResponse.getGroups()) { - assertFalse("no shard should be marked as skipped", group.skipped()); + assertTrue("the shard is skipped because index value is outside the query time range", group.skipped()); } } } @@ -788,11 +787,6 @@ public void testQueryPhaseIsExecutedInAnAvailableNodeWhenAllShardsCanBeSkipped() * Can match against searchable snapshots is tested via both the Search API and the SearchShards (transport-only) API. * The latter is a way to do only a can-match rather than all search phases. */ - @TestIssueLogging( - issueUrl = "https://github.com/elastic/elasticsearch/issues/97878", - value = "org.elasticsearch.snapshots:DEBUG,org.elasticsearch.indices.recovery:DEBUG,org.elasticsearch.action.search:DEBUG" - ) - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105339") public void testSearchableSnapshotShardsThatHaveMatchingDataAreNotSkippedOnTheCoordinatingNode() throws Exception { internalCluster().startMasterOnlyNode(); internalCluster().startCoordinatingOnlyNode(Settings.EMPTY); diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 8df10037affdb..c91314716cf9e 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -386,6 +386,7 @@ public class Constants { "cluster:monitor/xpack/esql/stats/dist", "cluster:monitor/xpack/inference", "cluster:monitor/xpack/inference/get", + "cluster:monitor/xpack/inference/unified", "cluster:monitor/xpack/inference/diagnostics/get", "cluster:monitor/xpack/inference/services/get", "cluster:monitor/xpack/info", diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java index fef1a98ca67e9..b56ea7ae3e456 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java @@ -327,8 +327,8 @@ public void testInvalidateNotValidAccessTokens() throws Exception { ResponseException.class, () -> invalidateAccessToken( tokenService.prependVersionAndEncodeAccessToken( - TransportVersions.V_7_3_2, - tokenService.getRandomTokenBytes(TransportVersions.V_7_3_2, randomBoolean()).v1() + TransportVersions.MINIMUM_COMPATIBLE, + tokenService.getRandomTokenBytes(TransportVersions.MINIMUM_COMPATIBLE, randomBoolean()).v1() ) ) ); @@ -347,7 +347,7 @@ public void testInvalidateNotValidAccessTokens() throws Exception { byte[] longerAccessToken = new byte[randomIntBetween(17, 24)]; random().nextBytes(longerAccessToken); invalidateResponse = invalidateAccessToken( - tokenService.prependVersionAndEncodeAccessToken(TransportVersions.V_7_3_2, longerAccessToken) + tokenService.prependVersionAndEncodeAccessToken(TransportVersions.MINIMUM_COMPATIBLE, longerAccessToken) ); assertThat(invalidateResponse.invalidated(), equalTo(0)); assertThat(invalidateResponse.previouslyInvalidated(), equalTo(0)); @@ -365,7 +365,7 @@ public void testInvalidateNotValidAccessTokens() throws Exception { byte[] shorterAccessToken = new byte[randomIntBetween(12, 15)]; random().nextBytes(shorterAccessToken); invalidateResponse = invalidateAccessToken( - tokenService.prependVersionAndEncodeAccessToken(TransportVersions.V_7_3_2, shorterAccessToken) + tokenService.prependVersionAndEncodeAccessToken(TransportVersions.MINIMUM_COMPATIBLE, shorterAccessToken) ); assertThat(invalidateResponse.invalidated(), equalTo(0)); assertThat(invalidateResponse.previouslyInvalidated(), equalTo(0)); @@ -394,8 +394,8 @@ public void testInvalidateNotValidAccessTokens() throws Exception { invalidateResponse = invalidateAccessToken( tokenService.prependVersionAndEncodeAccessToken( - TransportVersions.V_7_3_2, - tokenService.getRandomTokenBytes(TransportVersions.V_7_3_2, randomBoolean()).v1() + TransportVersions.MINIMUM_COMPATIBLE, + tokenService.getRandomTokenBytes(TransportVersions.MINIMUM_COMPATIBLE, randomBoolean()).v1() ) ); assertThat(invalidateResponse.invalidated(), equalTo(0)); @@ -420,8 +420,8 @@ public void testInvalidateNotValidRefreshTokens() throws Exception { ResponseException.class, () -> invalidateRefreshToken( TokenService.prependVersionAndEncodeRefreshToken( - TransportVersions.V_7_3_2, - tokenService.getRandomTokenBytes(TransportVersions.V_7_3_2, true).v2() + TransportVersions.MINIMUM_COMPATIBLE, + tokenService.getRandomTokenBytes(TransportVersions.MINIMUM_COMPATIBLE, true).v2() ) ) ); @@ -441,7 +441,7 @@ public void testInvalidateNotValidRefreshTokens() throws Exception { byte[] longerRefreshToken = new byte[randomIntBetween(17, 24)]; random().nextBytes(longerRefreshToken); invalidateResponse = invalidateRefreshToken( - TokenService.prependVersionAndEncodeRefreshToken(TransportVersions.V_7_3_2, longerRefreshToken) + TokenService.prependVersionAndEncodeRefreshToken(TransportVersions.MINIMUM_COMPATIBLE, longerRefreshToken) ); assertThat(invalidateResponse.invalidated(), equalTo(0)); assertThat(invalidateResponse.previouslyInvalidated(), equalTo(0)); @@ -459,7 +459,7 @@ public void testInvalidateNotValidRefreshTokens() throws Exception { byte[] shorterRefreshToken = new byte[randomIntBetween(12, 15)]; random().nextBytes(shorterRefreshToken); invalidateResponse = invalidateRefreshToken( - TokenService.prependVersionAndEncodeRefreshToken(TransportVersions.V_7_3_2, shorterRefreshToken) + TokenService.prependVersionAndEncodeRefreshToken(TransportVersions.MINIMUM_COMPATIBLE, shorterRefreshToken) ); assertThat(invalidateResponse.invalidated(), equalTo(0)); assertThat(invalidateResponse.previouslyInvalidated(), equalTo(0)); @@ -488,8 +488,8 @@ public void testInvalidateNotValidRefreshTokens() throws Exception { invalidateResponse = invalidateRefreshToken( TokenService.prependVersionAndEncodeRefreshToken( - TransportVersions.V_7_3_2, - tokenService.getRandomTokenBytes(TransportVersions.V_7_3_2, true).v2() + TransportVersions.MINIMUM_COMPATIBLE, + tokenService.getRandomTokenBytes(TransportVersions.MINIMUM_COMPATIBLE, true).v2() ) ); assertThat(invalidateResponse.invalidated(), equalTo(0)); @@ -758,18 +758,11 @@ public void testAuthenticateWithWrongToken() throws Exception { assertAuthenticateWithToken(response.accessToken(), TEST_USER_NAME); // Now attempt to authenticate with an invalid access token string assertUnauthorizedToken(randomAlphaOfLengthBetween(0, 128)); - // Now attempt to authenticate with an invalid access token with valid structure (pre 7.2) + // Now attempt to authenticate with an invalid access token with valid structure (after 8.0 pre 8.10) assertUnauthorizedToken( tokenService.prependVersionAndEncodeAccessToken( - TransportVersions.V_7_1_0, - tokenService.getRandomTokenBytes(TransportVersions.V_7_1_0, randomBoolean()).v1() - ) - ); - // Now attempt to authenticate with an invalid access token with valid structure (after 7.2 pre 8.10) - assertUnauthorizedToken( - tokenService.prependVersionAndEncodeAccessToken( - TransportVersions.V_7_4_0, - tokenService.getRandomTokenBytes(TransportVersions.V_7_4_0, randomBoolean()).v1() + TransportVersions.V_8_0_0, + tokenService.getRandomTokenBytes(TransportVersions.V_8_0_0, randomBoolean()).v1() ) ); // Now attempt to authenticate with an invalid access token with valid structure (current version) diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/profile/ProfileIntegTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/profile/ProfileIntegTests.java index 437fb76351176..3b55295c1efce 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/profile/ProfileIntegTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/profile/ProfileIntegTests.java @@ -557,8 +557,11 @@ public void testSuggestProfilesWithHint() throws IOException { equalTo(profileHits4.subList(2, profileHits4.size())) ); + // Exclude profile for "*" space since that can match _all_ profiles, if the full name is a substring of "user" or the name of + // another profile + final List nonWildcardProfiles = profiles.stream().filter(p -> false == p.user().fullName().endsWith("*")).toList(); // A record will not be included if name does not match even when it has matching hint - final Profile hintedProfile5 = randomFrom(profiles); + final Profile hintedProfile5 = randomFrom(nonWildcardProfiles); final List profileHits5 = Arrays.stream( doSuggest( Set.of(), diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ApiKeyService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ApiKeyService.java index 03558e72fdca3..c1be25b27c51e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ApiKeyService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/ApiKeyService.java @@ -14,6 +14,7 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.DocWriteRequest; @@ -138,7 +139,6 @@ import java.util.function.Supplier; import java.util.stream.Collectors; -import static org.elasticsearch.TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.search.SearchService.DEFAULT_KEEPALIVE_SETTING; import static org.elasticsearch.transport.RemoteClusterPortSettings.TRANSPORT_VERSION_ADVANCED_REMOTE_CLUSTER_SECURITY; @@ -430,17 +430,17 @@ private boolean validateRoleDescriptorsForMixedCluster( listener.onFailure( new IllegalArgumentException( "all nodes must have version [" - + ROLE_REMOTE_CLUSTER_PRIVS + + ROLE_REMOTE_CLUSTER_PRIVS.toReleaseVersion() + "] or higher to support remote cluster privileges for API keys" ) ); return false; } - if (transportVersion.before(ADD_MANAGE_ROLES_PRIVILEGE) && hasGlobalManageRolesPrivilege(roleDescriptors)) { + if (transportVersion.before(TransportVersions.V_8_16_0) && hasGlobalManageRolesPrivilege(roleDescriptors)) { listener.onFailure( new IllegalArgumentException( "all nodes must have version [" - + ADD_MANAGE_ROLES_PRIVILEGE + + TransportVersions.V_8_16_0.toReleaseVersion() + "] or higher to support the manage roles privilege for API keys" ) ); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index 4f7ba7808b823..900436a1fd874 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -48,9 +48,7 @@ import org.elasticsearch.common.cache.CacheBuilder; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.InputStreamStreamInput; -import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; @@ -59,7 +57,6 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.Nullable; -import org.elasticsearch.core.Streams; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; @@ -93,10 +90,8 @@ import org.elasticsearch.xpack.security.support.SecurityIndexManager; import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.IOException; -import java.io.OutputStream; import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; @@ -132,7 +127,6 @@ import javax.crypto.Cipher; import javax.crypto.CipherInputStream; -import javax.crypto.CipherOutputStream; import javax.crypto.NoSuchPaddingException; import javax.crypto.SecretKey; import javax.crypto.SecretKeyFactory; @@ -201,14 +195,8 @@ public class TokenService { // UUIDs are 16 bytes encoded base64 without padding, therefore the length is (16 / 3) * 4 + ((16 % 3) * 8 + 5) / 6 chars private static final int TOKEN_LENGTH = 22; private static final String TOKEN_DOC_ID_PREFIX = TOKEN_DOC_TYPE + "_"; - static final int LEGACY_MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1; static final int MINIMUM_BYTES = VERSION_BYTES + TOKEN_LENGTH + 1; - static final int LEGACY_MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * LEGACY_MINIMUM_BYTES) / 3)).intValue(); public static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); - static final TransportVersion VERSION_HASHED_TOKENS = TransportVersions.V_7_2_0; - static final TransportVersion VERSION_TOKENS_INDEX_INTRODUCED = TransportVersions.V_7_2_0; - static final TransportVersion VERSION_ACCESS_TOKENS_AS_UUIDS = TransportVersions.V_7_2_0; - static final TransportVersion VERSION_MULTIPLE_CONCURRENT_REFRESHES = TransportVersions.V_7_2_0; static final TransportVersion VERSION_CLIENT_AUTH_FOR_REFRESH = TransportVersions.V_8_2_0; static final TransportVersion VERSION_GET_TOKEN_DOC_FOR_REFRESH = TransportVersions.V_8_10_X; @@ -273,8 +261,7 @@ public TokenService( /** * Creates an access token and optionally a refresh token as well, based on the provided authentication and metadata with - * auto-generated values. The created tokens are stored in the security index for versions up to - * {@link #VERSION_TOKENS_INDEX_INTRODUCED} and to a specific security tokens index for later versions. + * auto-generated values. The created tokens are stored a specific security tokens index. */ public void createOAuth2Tokens( Authentication authentication, @@ -291,8 +278,7 @@ public void createOAuth2Tokens( /** * Creates an access token and optionally a refresh token as well from predefined values, based on the provided authentication and - * metadata. The created tokens are stored in the security index for versions up to {@link #VERSION_TOKENS_INDEX_INTRODUCED} and to a - * specific security tokens index for later versions. + * metadata. The created tokens are stored in a specific security tokens index. */ // public for testing public void createOAuth2Tokens( @@ -314,21 +300,15 @@ public void createOAuth2Tokens( * * @param accessTokenBytes The predefined seed value for the access token. This will then be *
    - *
  • Encrypted before stored for versions before {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • - *
  • Hashed before stored for versions after {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • - *
  • Stored in the security index for versions up to {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • - *
  • Stored in a specific security tokens index for versions after - * {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • + *
  • Hashed before stored
  • + *
  • Stored in a specific security tokens index
  • *
  • Prepended with a version ID and Base64 encoded before returned to the caller of the APIs
  • *
* @param refreshTokenBytes The predefined seed value for the access token. This will then be *
    - *
  • Hashed before stored for versions after {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • - *
  • Stored in the security index for versions up to {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • - *
  • Stored in a specific security tokens index for versions after - * {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • - *
  • Prepended with a version ID and encoded with Base64 before returned to the caller of the APIs - * for versions after {@link #VERSION_TOKENS_INDEX_INTRODUCED}
  • + *
  • Hashed before stored
  • + *
  • Stored in a specific security tokens index
  • + *
  • Prepended with a version ID and Base64 encoded before returned to the caller of the APIs
  • *
* @param tokenVersion The version of the nodes with which these tokens will be compatible. * @param authentication The authentication object representing the user for which the tokens are created @@ -384,7 +364,7 @@ private void createOAuth2Tokens( } else { refreshTokenToStore = refreshTokenToReturn = null; } - } else if (tokenVersion.onOrAfter(VERSION_HASHED_TOKENS)) { + } else { assert accessTokenBytes.length == RAW_TOKEN_BYTES_LENGTH; userTokenId = hashTokenString(Strings.BASE_64_NO_PADDING_URL_ENCODER.encodeToString(accessTokenBytes)); accessTokenToStore = null; @@ -395,18 +375,6 @@ private void createOAuth2Tokens( } else { refreshTokenToStore = refreshTokenToReturn = null; } - } else { - assert accessTokenBytes.length == RAW_TOKEN_BYTES_LENGTH; - userTokenId = Strings.BASE_64_NO_PADDING_URL_ENCODER.encodeToString(accessTokenBytes); - accessTokenToStore = null; - if (refreshTokenBytes != null) { - assert refreshTokenBytes.length == RAW_TOKEN_BYTES_LENGTH; - refreshTokenToStore = refreshTokenToReturn = Strings.BASE_64_NO_PADDING_URL_ENCODER.encodeToString( - refreshTokenBytes - ); - } else { - refreshTokenToStore = refreshTokenToReturn = null; - } } UserToken userToken = new UserToken(userTokenId, tokenVersion, tokenAuth, getExpirationTime(), metadata); tokenDocument = createTokenDocument(userToken, accessTokenToStore, refreshTokenToStore, originatingClientAuth); @@ -419,23 +387,22 @@ private void createOAuth2Tokens( final RefreshPolicy tokenCreationRefreshPolicy = tokenVersion.onOrAfter(VERSION_GET_TOKEN_DOC_FOR_REFRESH) ? RefreshPolicy.NONE : RefreshPolicy.WAIT_UNTIL; - final SecurityIndexManager tokensIndex = getTokensIndexForVersion(tokenVersion); logger.debug( () -> format( "Using refresh policy [%s] when creating token doc [%s] in the security index [%s]", tokenCreationRefreshPolicy, documentId, - tokensIndex.aliasName() + securityTokensIndex.aliasName() ) ); - final IndexRequest indexTokenRequest = client.prepareIndex(tokensIndex.aliasName()) + final IndexRequest indexTokenRequest = client.prepareIndex(securityTokensIndex.aliasName()) .setId(documentId) .setOpType(OpType.CREATE) .setSource(tokenDocument, XContentType.JSON) .setRefreshPolicy(tokenCreationRefreshPolicy) .request(); - tokensIndex.prepareIndexIfNeededThenExecute( - ex -> listener.onFailure(traceLog("prepare tokens index [" + tokensIndex.aliasName() + "]", documentId, ex)), + securityTokensIndex.prepareIndexIfNeededThenExecute( + ex -> listener.onFailure(traceLog("prepare tokens index [" + securityTokensIndex.aliasName() + "]", documentId, ex)), () -> executeAsyncWithOrigin( client, SECURITY_ORIGIN, @@ -554,17 +521,16 @@ private void getTokenDocById( @Nullable String storedRefreshToken, ActionListener listener ) { - final SecurityIndexManager tokensIndex = getTokensIndexForVersion(tokenVersion); - final SecurityIndexManager frozenTokensIndex = tokensIndex.defensiveCopy(); + final SecurityIndexManager frozenTokensIndex = securityTokensIndex.defensiveCopy(); if (frozenTokensIndex.isAvailable(PRIMARY_SHARDS) == false) { - logger.warn("failed to get access token [{}] because index [{}] is not available", tokenId, tokensIndex.aliasName()); + logger.warn("failed to get access token [{}] because index [{}] is not available", tokenId, securityTokensIndex.aliasName()); listener.onFailure(frozenTokensIndex.getUnavailableReason(PRIMARY_SHARDS)); return; } - final GetRequest getRequest = client.prepareGet(tokensIndex.aliasName(), getTokenDocumentId(tokenId)).request(); + final GetRequest getRequest = client.prepareGet(securityTokensIndex.aliasName(), getTokenDocumentId(tokenId)).request(); final Consumer onFailure = ex -> listener.onFailure(traceLog("get token from id", tokenId, ex)); - tokensIndex.checkIndexVersionThenExecute( - ex -> listener.onFailure(traceLog("prepare tokens index [" + tokensIndex.aliasName() + "]", tokenId, ex)), + securityTokensIndex.checkIndexVersionThenExecute( + ex -> listener.onFailure(traceLog("prepare tokens index [" + securityTokensIndex.aliasName() + "]", tokenId, ex)), () -> executeAsyncWithOrigin( client.threadPool().getThreadContext(), SECURITY_ORIGIN, @@ -610,7 +576,11 @@ private void getTokenDocById( // if the index or the shard is not there / available we assume that // the token is not valid if (isShardNotAvailableException(e)) { - logger.warn("failed to get token doc [{}] because index [{}] is not available", tokenId, tokensIndex.aliasName()); + logger.warn( + "failed to get token doc [{}] because index [{}] is not available", + tokenId, + securityTokensIndex.aliasName() + ); } else { logger.error(() -> "failed to get token doc [" + tokenId + "]", e); } @@ -650,7 +620,7 @@ void decodeToken(String token, boolean validateUserToken, ActionListener VERSION_ACCESS_TOKENS_UUIDS cluster if (in.available() < MINIMUM_BYTES) { logger.debug("invalid token, smaller than [{}] bytes", MINIMUM_BYTES); @@ -660,41 +630,6 @@ void decodeToken(String token, boolean validateUserToken, ActionListener { - if (decodeKey != null) { - try { - final Cipher cipher = getDecryptionCipher(iv, decodeKey, version, decodedSalt); - final String tokenId = decryptTokenId(encryptedTokenId, cipher, version); - getAndValidateUserToken(tokenId, version, null, validateUserToken, listener); - } catch (IOException | GeneralSecurityException e) { - // could happen with a token that is not ours - logger.warn("invalid token", e); - listener.onResponse(null); - } - } else { - // could happen with a token that is not ours - listener.onResponse(null); - } - }, listener::onFailure)); - } else { - logger.debug(() -> format("invalid key %s key: %s", passphraseHash, keyCache.cache.keySet())); - listener.onResponse(null); - } } } catch (Exception e) { // could happen with a token that is not ours @@ -852,11 +787,7 @@ private void indexInvalidation( final Set idsOfOlderTokens = new HashSet<>(); boolean anyOlderTokensBeforeRefreshViaGet = false; for (UserToken userToken : userTokens) { - if (userToken.getTransportVersion().onOrAfter(VERSION_TOKENS_INDEX_INTRODUCED)) { - idsOfRecentTokens.add(userToken.getId()); - } else { - idsOfOlderTokens.add(userToken.getId()); - } + idsOfRecentTokens.add(userToken.getId()); anyOlderTokensBeforeRefreshViaGet |= userToken.getTransportVersion().before(VERSION_GET_TOKEN_DOC_FOR_REFRESH); } final RefreshPolicy tokensInvalidationRefreshPolicy = anyOlderTokensBeforeRefreshViaGet @@ -1124,7 +1055,7 @@ private void findTokenFromRefreshToken(String refreshToken, Iterator ); getTokenDocById(userTokenId, version, null, storedRefreshToken, listener); } - } else if (version.onOrAfter(VERSION_HASHED_TOKENS)) { + } else { final String unencodedRefreshToken = in.readString(); if (unencodedRefreshToken.length() != TOKEN_LENGTH) { logger.debug("Decoded refresh token [{}] with version [{}] is invalid.", unencodedRefreshToken, version); @@ -1133,9 +1064,6 @@ private void findTokenFromRefreshToken(String refreshToken, Iterator final String hashedRefreshToken = hashTokenString(unencodedRefreshToken); findTokenFromRefreshToken(hashedRefreshToken, securityTokensIndex, backoff, listener); } - } else { - logger.debug("Unrecognized refresh token version [{}].", version); - listener.onResponse(null); } } catch (IOException e) { logger.debug(() -> "Could not decode refresh token [" + refreshToken + "].", e); @@ -1250,7 +1178,6 @@ private void innerRefresh( return; } final RefreshTokenStatus refreshTokenStatus = checkRefreshResult.v1(); - final SecurityIndexManager refreshedTokenIndex = getTokensIndexForVersion(refreshTokenStatus.getTransportVersion()); if (refreshTokenStatus.isRefreshed()) { logger.debug( "Token document [{}] was recently refreshed, when a new token document was generated. Reusing that result.", @@ -1258,31 +1185,29 @@ private void innerRefresh( ); final Tuple parsedTokens = parseTokensFromDocument(tokenDoc.sourceAsMap(), null); Authentication authentication = parsedTokens.v1().getAuthentication(); - decryptAndReturnSupersedingTokens(refreshToken, refreshTokenStatus, refreshedTokenIndex, authentication, listener); + decryptAndReturnSupersedingTokens(refreshToken, refreshTokenStatus, securityTokensIndex, authentication, listener); } else { final TransportVersion newTokenVersion = getTokenVersionCompatibility(); final Tuple newTokenBytes = getRandomTokenBytes(newTokenVersion, true); final Map updateMap = new HashMap<>(); updateMap.put("refreshed", true); - if (newTokenVersion.onOrAfter(VERSION_MULTIPLE_CONCURRENT_REFRESHES)) { - updateMap.put("refresh_time", clock.instant().toEpochMilli()); - try { - final byte[] iv = getRandomBytes(IV_BYTES); - final byte[] salt = getRandomBytes(SALT_BYTES); - String encryptedAccessAndRefreshToken = encryptSupersedingTokens( - newTokenBytes.v1(), - newTokenBytes.v2(), - refreshToken, - iv, - salt - ); - updateMap.put("superseding.encrypted_tokens", encryptedAccessAndRefreshToken); - updateMap.put("superseding.encryption_iv", Base64.getEncoder().encodeToString(iv)); - updateMap.put("superseding.encryption_salt", Base64.getEncoder().encodeToString(salt)); - } catch (GeneralSecurityException e) { - logger.warn("could not encrypt access token and refresh token string", e); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } + updateMap.put("refresh_time", clock.instant().toEpochMilli()); + try { + final byte[] iv = getRandomBytes(IV_BYTES); + final byte[] salt = getRandomBytes(SALT_BYTES); + String encryptedAccessAndRefreshToken = encryptSupersedingTokens( + newTokenBytes.v1(), + newTokenBytes.v2(), + refreshToken, + iv, + salt + ); + updateMap.put("superseding.encrypted_tokens", encryptedAccessAndRefreshToken); + updateMap.put("superseding.encryption_iv", Base64.getEncoder().encodeToString(iv)); + updateMap.put("superseding.encryption_salt", Base64.getEncoder().encodeToString(salt)); + } catch (GeneralSecurityException e) { + logger.warn("could not encrypt access token and refresh token string", e); + onFailure.accept(invalidGrantException("could not refresh the requested token")); } assert tokenDoc.seqNo() != SequenceNumbers.UNASSIGNED_SEQ_NO : "expected an assigned sequence number"; assert tokenDoc.primaryTerm() != SequenceNumbers.UNASSIGNED_PRIMARY_TERM : "expected an assigned primary term"; @@ -1293,17 +1218,17 @@ private void innerRefresh( "Using refresh policy [%s] when updating token doc [%s] for refresh in the security index [%s]", tokenRefreshUpdateRefreshPolicy, tokenDoc.id(), - refreshedTokenIndex.aliasName() + securityTokensIndex.aliasName() ) ); - final UpdateRequestBuilder updateRequest = client.prepareUpdate(refreshedTokenIndex.aliasName(), tokenDoc.id()) + final UpdateRequestBuilder updateRequest = client.prepareUpdate(securityTokensIndex.aliasName(), tokenDoc.id()) .setDoc("refresh_token", updateMap) .setFetchSource(logger.isDebugEnabled()) .setRefreshPolicy(tokenRefreshUpdateRefreshPolicy) .setIfSeqNo(tokenDoc.seqNo()) .setIfPrimaryTerm(tokenDoc.primaryTerm()); - refreshedTokenIndex.prepareIndexIfNeededThenExecute( - ex -> listener.onFailure(traceLog("prepare index [" + refreshedTokenIndex.aliasName() + "]", ex)), + securityTokensIndex.prepareIndexIfNeededThenExecute( + ex -> listener.onFailure(traceLog("prepare index [" + securityTokensIndex.aliasName() + "]", ex)), () -> executeAsyncWithOrigin( client.threadPool().getThreadContext(), SECURITY_ORIGIN, @@ -1349,7 +1274,7 @@ private void innerRefresh( if (cause instanceof VersionConflictEngineException) { // The document has been updated by another thread, get it again. logger.debug("version conflict while updating document [{}], attempting to get it again", tokenDoc.id()); - getTokenDocAsync(tokenDoc.id(), refreshedTokenIndex, true, new ActionListener<>() { + getTokenDocAsync(tokenDoc.id(), securityTokensIndex, true, new ActionListener<>() { @Override public void onResponse(GetResponse response) { if (response.isExists()) { @@ -1368,7 +1293,7 @@ public void onFailure(Exception e) { logger.info("could not get token document [{}] for refresh, retrying", tokenDoc.id()); client.threadPool() .schedule( - () -> getTokenDocAsync(tokenDoc.id(), refreshedTokenIndex, true, this), + () -> getTokenDocAsync(tokenDoc.id(), securityTokensIndex, true, this), backoff.next(), client.threadPool().generic() ); @@ -1689,17 +1614,13 @@ private static Optional checkMultipleRefreshes( RefreshTokenStatus refreshTokenStatus ) { if (refreshTokenStatus.isRefreshed()) { - if (refreshTokenStatus.getTransportVersion().onOrAfter(VERSION_MULTIPLE_CONCURRENT_REFRESHES)) { - if (refreshRequested.isAfter(refreshTokenStatus.getRefreshInstant().plus(30L, ChronoUnit.SECONDS))) { - return Optional.of(invalidGrantException("token has already been refreshed more than 30 seconds in the past")); - } - if (refreshRequested.isBefore(refreshTokenStatus.getRefreshInstant().minus(30L, ChronoUnit.SECONDS))) { - return Optional.of( - invalidGrantException("token has been refreshed more than 30 seconds in the future, clock skew too great") - ); - } - } else { - return Optional.of(invalidGrantException("token has already been refreshed")); + if (refreshRequested.isAfter(refreshTokenStatus.getRefreshInstant().plus(30L, ChronoUnit.SECONDS))) { + return Optional.of(invalidGrantException("token has already been refreshed more than 30 seconds in the past")); + } + if (refreshRequested.isBefore(refreshTokenStatus.getRefreshInstant().minus(30L, ChronoUnit.SECONDS))) { + return Optional.of( + invalidGrantException("token has been refreshed more than 30 seconds in the future, clock skew too great") + ); } } return Optional.empty(); @@ -1979,21 +1900,6 @@ private void ensureEnabled() { } } - /** - * In version {@code #VERSION_TOKENS_INDEX_INTRODUCED} security tokens were moved into a separate index, away from the other entities in - * the main security index, due to their ephemeral nature. They moved "seamlessly" - without manual user intervention. In this way, new - * tokens are created in the new index, while the existing ones were left in place - to be accessed from the old index - and due to be - * removed automatically by the {@code ExpiredTokenRemover} periodic job. Therefore, in general, when searching for a token we need to - * consider both the new and the old indices. - */ - private SecurityIndexManager getTokensIndexForVersion(TransportVersion version) { - if (version.onOrAfter(VERSION_TOKENS_INDEX_INTRODUCED)) { - return securityTokensIndex; - } else { - return securityMainIndex; - } - } - public TimeValue getExpirationDelay() { return expirationDelay; } @@ -2022,41 +1928,13 @@ public String prependVersionAndEncodeAccessToken(TransportVersion version, byte[ out.writeByteArray(accessTokenBytes); return Base64.getEncoder().encodeToString(out.bytes().toBytesRef().bytes); } - } else if (version.onOrAfter(VERSION_ACCESS_TOKENS_AS_UUIDS)) { + } else { try (BytesStreamOutput out = new BytesStreamOutput(MINIMUM_BASE64_BYTES)) { out.setTransportVersion(version); TransportVersion.writeVersion(version, out); out.writeString(Strings.BASE_64_NO_PADDING_URL_ENCODER.encodeToString(accessTokenBytes)); return Base64.getEncoder().encodeToString(out.bytes().toBytesRef().bytes); } - } else { - // we know that the minimum length is larger than the default of the ByteArrayOutputStream so set the size to this explicitly - try ( - ByteArrayOutputStream os = new ByteArrayOutputStream(LEGACY_MINIMUM_BASE64_BYTES); - OutputStream base64 = Base64.getEncoder().wrap(os); - StreamOutput out = new OutputStreamStreamOutput(base64) - ) { - out.setTransportVersion(version); - KeyAndCache keyAndCache = keyCache.activeKeyCache; - TransportVersion.writeVersion(version, out); - out.writeByteArray(keyAndCache.getSalt().bytes); - out.writeByteArray(keyAndCache.getKeyHash().bytes); - final byte[] initializationVector = getRandomBytes(IV_BYTES); - out.writeByteArray(initializationVector); - try ( - CipherOutputStream encryptedOutput = new CipherOutputStream( - out, - getEncryptionCipher(initializationVector, keyAndCache, version) - ); - StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput) - ) { - encryptedStreamOutput.setTransportVersion(version); - encryptedStreamOutput.writeString(Strings.BASE_64_NO_PADDING_URL_ENCODER.encodeToString(accessTokenBytes)); - // StreamOutput needs to be closed explicitly because it wraps CipherOutputStream - encryptedStreamOutput.close(); - return new String(os.toByteArray(), StandardCharsets.UTF_8); - } - } } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java index 4ae17a679d205..23a1fc188e4a0 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java @@ -481,10 +481,10 @@ private Exception validateRoleDescriptor(RoleDescriptor role) { ); } else if (Arrays.stream(role.getConditionalClusterPrivileges()) .anyMatch(privilege -> privilege instanceof ConfigurableClusterPrivileges.ManageRolesPrivilege) - && clusterService.state().getMinTransportVersion().before(TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE)) { + && clusterService.state().getMinTransportVersion().before(TransportVersions.V_8_16_0)) { return new IllegalStateException( "all nodes must have version [" - + TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE.toReleaseVersion() + + TransportVersions.V_8_16_0.toReleaseVersion() + "] or higher to support the manage roles privilege" ); } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index 75c2507a1dc5f..702af75141093 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -126,7 +126,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; @@ -148,7 +147,6 @@ public class TokenServiceTests extends ESTestCase { private SecurityIndexManager securityMainIndex; private SecurityIndexManager securityTokensIndex; private ClusterService clusterService; - private DiscoveryNode pre72OldNode; private DiscoveryNode pre8500040OldNode; private Settings tokenServiceEnabledSettings = Settings.builder() .put(XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.getKey(), true) @@ -228,31 +226,12 @@ public void setupClient() { licenseState = mock(MockLicenseState.class); when(licenseState.isAllowed(Security.TOKEN_SERVICE_FEATURE)).thenReturn(true); - if (randomBoolean()) { - // version 7.2 was an "inflection" point in the Token Service development (access_tokens as UUIDS, multiple concurrent - // refreshes, - // tokens docs on a separate index) - pre72OldNode = addAnother7071DataNode(this.clusterService); - } if (randomBoolean()) { // before refresh tokens used GET, i.e. TokenService#VERSION_GET_TOKEN_DOC_FOR_REFRESH pre8500040OldNode = addAnotherPre8500DataNode(this.clusterService); } } - private static DiscoveryNode addAnother7071DataNode(ClusterService clusterService) { - Version version; - TransportVersion transportVersion; - if (randomBoolean()) { - version = Version.V_7_0_0; - transportVersion = TransportVersions.V_7_0_0; - } else { - version = Version.V_7_1_0; - transportVersion = TransportVersions.V_7_1_0; - } - return addAnotherDataNodeWithVersion(clusterService, version, transportVersion); - } - private static DiscoveryNode addAnotherPre8500DataNode(ClusterService clusterService) { Version version; TransportVersion transportVersion; @@ -301,53 +280,6 @@ public static void shutdownThreadpool() { threadPool = null; } - public void testAttachAndGetToken() throws Exception { - TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); - // This test only makes sense in mixed clusters with pre v7.2.0 nodes where the Token Service Key is used (to encrypt tokens) - if (null == pre72OldNode) { - pre72OldNode = addAnother7071DataNode(this.clusterService); - } - Authentication authentication = AuthenticationTestHelper.builder() - .user(new User("joe", "admin")) - .realmRef(new RealmRef("native_realm", "native", "node1")) - .build(false); - PlainActionFuture tokenFuture = new PlainActionFuture<>(); - Tuple newTokenBytes = tokenService.getRandomTokenBytes(randomBoolean()); - tokenService.createOAuth2Tokens( - newTokenBytes.v1(), - newTokenBytes.v2(), - authentication, - authentication, - Collections.emptyMap(), - tokenFuture - ); - final String accessToken = tokenFuture.get().getAccessToken(); - assertNotNull(accessToken); - mockGetTokenFromAccessTokenBytes(tokenService, newTokenBytes.v1(), authentication, false, null); - - ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + accessToken); - - try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) { - PlainActionFuture future = new PlainActionFuture<>(); - final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext); - tokenService.tryAuthenticateToken(bearerToken, future); - UserToken serialized = future.get(); - assertAuthentication(authentication, serialized.getAuthentication()); - } - - try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) { - // verify a second separate token service with its own salt can also verify - TokenService anotherService = createTokenService(tokenServiceEnabledSettings, systemUTC()); - anotherService.refreshMetadata(tokenService.getTokenMetadata()); - PlainActionFuture future = new PlainActionFuture<>(); - final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext); - anotherService.tryAuthenticateToken(bearerToken, future); - UserToken fromOtherService = future.get(); - assertAuthentication(authentication, fromOtherService.getAuthentication()); - } - } - public void testInvalidAuthorizationHeader() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); @@ -364,89 +296,6 @@ public void testInvalidAuthorizationHeader() throws Exception { } } - public void testPassphraseWorks() throws Exception { - TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); - // This test only makes sense in mixed clusters with pre v7.1.0 nodes where the Key is actually used - if (null == pre72OldNode) { - pre72OldNode = addAnother7071DataNode(this.clusterService); - } - Authentication authentication = AuthenticationTestHelper.builder() - .user(new User("joe", "admin")) - .realmRef(new RealmRef("native_realm", "native", "node1")) - .build(false); - PlainActionFuture tokenFuture = new PlainActionFuture<>(); - Tuple newTokenBytes = tokenService.getRandomTokenBytes(randomBoolean()); - tokenService.createOAuth2Tokens( - newTokenBytes.v1(), - newTokenBytes.v2(), - authentication, - authentication, - Collections.emptyMap(), - tokenFuture - ); - final String accessToken = tokenFuture.get().getAccessToken(); - assertNotNull(accessToken); - mockGetTokenFromAccessTokenBytes(tokenService, newTokenBytes.v1(), authentication, false, null); - - ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader(requestContext, accessToken); - - try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) { - PlainActionFuture future = new PlainActionFuture<>(); - final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext); - tokenService.tryAuthenticateToken(bearerToken, future); - UserToken serialized = future.get(); - assertAuthentication(authentication, serialized.getAuthentication()); - } - - try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) { - // verify a second separate token service with its own passphrase cannot verify - TokenService anotherService = createTokenService(tokenServiceEnabledSettings, systemUTC()); - PlainActionFuture future = new PlainActionFuture<>(); - final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext); - anotherService.tryAuthenticateToken(bearerToken, future); - assertNull(future.get()); - } - } - - public void testGetTokenWhenKeyCacheHasExpired() throws Exception { - TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); - // This test only makes sense in mixed clusters with pre v7.1.0 nodes where the Key is actually used - if (null == pre72OldNode) { - pre72OldNode = addAnother7071DataNode(this.clusterService); - } - Authentication authentication = AuthenticationTestHelper.builder() - .user(new User("joe", "admin")) - .realmRef(new RealmRef("native_realm", "native", "node1")) - .build(false); - - PlainActionFuture tokenFuture = new PlainActionFuture<>(); - Tuple newTokenBytes = tokenService.getRandomTokenBytes(randomBoolean()); - tokenService.createOAuth2Tokens( - newTokenBytes.v1(), - newTokenBytes.v2(), - authentication, - authentication, - Collections.emptyMap(), - tokenFuture - ); - String accessToken = tokenFuture.get().getAccessToken(); - assertThat(accessToken, notNullValue()); - - tokenService.clearActiveKeyCache(); - - tokenService.createOAuth2Tokens( - newTokenBytes.v1(), - newTokenBytes.v2(), - authentication, - authentication, - Collections.emptyMap(), - tokenFuture - ); - accessToken = tokenFuture.get().getAccessToken(); - assertThat(accessToken, notNullValue()); - } - public void testAuthnWithInvalidatedToken() throws Exception { when(securityMainIndex.indexExists()).thenReturn(true); TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); @@ -820,57 +669,6 @@ public void testMalformedRefreshTokens() throws Exception { } } - public void testNonExistingPre72Token() throws Exception { - TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); - // mock another random token so that we don't find a token in TokenService#getUserTokenFromId - Authentication authentication = AuthenticationTestHelper.builder() - .user(new User("joe", "admin")) - .realmRef(new RealmRef("native_realm", "native", "node1")) - .build(false); - mockGetTokenFromAccessTokenBytes(tokenService, tokenService.getRandomTokenBytes(randomBoolean()).v1(), authentication, false, null); - ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - storeTokenHeader( - requestContext, - tokenService.prependVersionAndEncodeAccessToken( - TransportVersions.V_7_1_0, - tokenService.getRandomTokenBytes(TransportVersions.V_7_1_0, randomBoolean()).v1() - ) - ); - - try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) { - PlainActionFuture future = new PlainActionFuture<>(); - final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext); - tokenService.tryAuthenticateToken(bearerToken, future); - assertNull(future.get()); - } - } - - public void testNonExistingUUIDToken() throws Exception { - TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); - // mock another random token so that we don't find a token in TokenService#getUserTokenFromId - Authentication authentication = AuthenticationTestHelper.builder() - .user(new User("joe", "admin")) - .realmRef(new RealmRef("native_realm", "native", "node1")) - .build(false); - mockGetTokenFromAccessTokenBytes(tokenService, tokenService.getRandomTokenBytes(randomBoolean()).v1(), authentication, false, null); - ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - TransportVersion uuidTokenVersion = randomFrom(TransportVersions.V_7_2_0, TransportVersions.V_7_3_2); - storeTokenHeader( - requestContext, - tokenService.prependVersionAndEncodeAccessToken( - uuidTokenVersion, - tokenService.getRandomTokenBytes(uuidTokenVersion, randomBoolean()).v1() - ) - ); - - try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) { - PlainActionFuture future = new PlainActionFuture<>(); - final SecureString bearerToken = Authenticator.extractBearerTokenFromHeader(requestContext); - tokenService.tryAuthenticateToken(bearerToken, future); - assertNull(future.get()); - } - } - public void testNonExistingLatestTokenVersion() throws Exception { TokenService tokenService = createTokenService(tokenServiceEnabledSettings, systemUTC()); // mock another random token so that we don't find a token in TokenService#getUserTokenFromId @@ -925,18 +723,11 @@ public void testIndexNotAvailable() throws Exception { return Void.TYPE; }).when(client).get(any(GetRequest.class), anyActionListener()); - final SecurityIndexManager tokensIndex; - if (pre72OldNode != null) { - tokensIndex = securityMainIndex; - when(securityTokensIndex.isAvailable(SecurityIndexManager.Availability.PRIMARY_SHARDS)).thenReturn(false); - when(securityTokensIndex.indexExists()).thenReturn(false); - when(securityTokensIndex.defensiveCopy()).thenReturn(securityTokensIndex); - } else { - tokensIndex = securityTokensIndex; - when(securityMainIndex.isAvailable(SecurityIndexManager.Availability.PRIMARY_SHARDS)).thenReturn(false); - when(securityMainIndex.indexExists()).thenReturn(false); - when(securityMainIndex.defensiveCopy()).thenReturn(securityMainIndex); - } + final SecurityIndexManager tokensIndex = securityTokensIndex; + when(securityMainIndex.isAvailable(SecurityIndexManager.Availability.PRIMARY_SHARDS)).thenReturn(false); + when(securityMainIndex.indexExists()).thenReturn(false); + when(securityMainIndex.defensiveCopy()).thenReturn(securityMainIndex); + try (ThreadContext.StoredContext ignore = requestContext.newStoredContextPreservingResponseHeaders()) { PlainActionFuture future = new PlainActionFuture<>(); final SecureString bearerToken3 = Authenticator.extractBearerTokenFromHeader(requestContext); @@ -988,7 +779,6 @@ public void testGetAuthenticationWorksWithExpiredUserToken() throws Exception { } public void testSupersedingTokenEncryption() throws Exception { - assumeTrue("Superseding tokens are only created in post 7.2 clusters", pre72OldNode == null); TokenService tokenService = createTokenService(tokenServiceEnabledSettings, Clock.systemUTC()); Authentication authentication = AuthenticationTests.randomAuthentication(null, null); PlainActionFuture tokenFuture = new PlainActionFuture<>(); @@ -1023,13 +813,11 @@ public void testSupersedingTokenEncryption() throws Exception { authentication, tokenFuture ); - if (version.onOrAfter(TokenService.VERSION_ACCESS_TOKENS_AS_UUIDS)) { - // previous versions serialized the access token encrypted and the cipher text was different each time (due to different IVs) - assertThat( - tokenService.prependVersionAndEncodeAccessToken(version, newTokenBytes.v1()), - equalTo(tokenFuture.get().getAccessToken()) - ); - } + + assertThat( + tokenService.prependVersionAndEncodeAccessToken(version, newTokenBytes.v1()), + equalTo(tokenFuture.get().getAccessToken()) + ); assertThat( TokenService.prependVersionAndEncodeRefreshToken(version, newTokenBytes.v2()), equalTo(tokenFuture.get().getRefreshToken()) @@ -1158,10 +946,8 @@ public static String tokenDocIdFromAccessTokenBytes(byte[] accessTokenBytes, Tra MessageDigest userTokenIdDigest = sha256(); userTokenIdDigest.update(accessTokenBytes, RAW_TOKEN_BYTES_LENGTH, RAW_TOKEN_DOC_ID_BYTES_LENGTH); return Base64.getUrlEncoder().withoutPadding().encodeToString(userTokenIdDigest.digest()); - } else if (tokenVersion.onOrAfter(TokenService.VERSION_ACCESS_TOKENS_AS_UUIDS)) { - return TokenService.hashTokenString(Base64.getUrlEncoder().withoutPadding().encodeToString(accessTokenBytes)); } else { - return Base64.getUrlEncoder().withoutPadding().encodeToString(accessTokenBytes); + return TokenService.hashTokenString(Base64.getUrlEncoder().withoutPadding().encodeToString(accessTokenBytes)); } } @@ -1178,12 +964,9 @@ private void mockTokenForRefreshToken( if (userToken.getTransportVersion().onOrAfter(VERSION_GET_TOKEN_DOC_FOR_REFRESH)) { storedAccessToken = Base64.getUrlEncoder().withoutPadding().encodeToString(sha256().digest(accessTokenBytes)); storedRefreshToken = Base64.getUrlEncoder().withoutPadding().encodeToString(sha256().digest(refreshTokenBytes)); - } else if (userToken.getTransportVersion().onOrAfter(TokenService.VERSION_HASHED_TOKENS)) { - storedAccessToken = null; - storedRefreshToken = TokenService.hashTokenString(Base64.getUrlEncoder().withoutPadding().encodeToString(refreshTokenBytes)); } else { storedAccessToken = null; - storedRefreshToken = Base64.getUrlEncoder().withoutPadding().encodeToString(refreshTokenBytes); + storedRefreshToken = TokenService.hashTokenString(Base64.getUrlEncoder().withoutPadding().encodeToString(refreshTokenBytes)); } final RealmRef realmRef = new RealmRef( refreshTokenStatus == null ? randomAlphaOfLength(6) : refreshTokenStatus.getAssociatedRealm(), diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml index 26e3c8ed0ef47..81f65668722fc 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml @@ -92,7 +92,7 @@ setup: - gt: {esql.functions.to_long: $functions_to_long} - match: {esql.functions.coalesce: $functions_coalesce} # Testing for the entire function set isn't feasbile, so we just check that we return the correct count as an approximation. - - length: {esql.functions: 127} # check the "sister" test below for a likely update to the same esql.functions length check + - length: {esql.functions: 128} # check the "sister" test below for a likely update to the same esql.functions length check --- "Basic ESQL usage output (telemetry) non-snapshot version": @@ -163,4 +163,4 @@ setup: - match: {esql.functions.cos: $functions_cos} - gt: {esql.functions.to_long: $functions_to_long} - match: {esql.functions.coalesce: $functions_coalesce} - - length: {esql.functions: 123} # check the "sister" test above for a likely update to the same esql.functions length check + - length: {esql.functions: 124} # check the "sister" test above for a likely update to the same esql.functions length check diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/10_reindex.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/10_reindex.yml new file mode 100644 index 0000000000000..01a41b3aa8c94 --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/10_reindex.yml @@ -0,0 +1,89 @@ +--- +setup: + - do: + cluster.health: + wait_for_status: yellow + +--- +"Test Reindex With Unsupported Mode": + - do: + catch: /illegal_argument_exception/ + migrate.reindex: + body: | + { + "mode": "unsupported_mode", + "source": { + "index": "my-data-stream" + } + } + +--- +"Test Reindex With Nonexistent Data Stream": + - do: + catch: /resource_not_found_exception/ + migrate.reindex: + body: | + { + "mode": "upgrade", + "source": { + "index": "my-data-stream" + } + } + + - do: + catch: /resource_not_found_exception/ + migrate.reindex: + body: | + { + "mode": "upgrade", + "source": { + "index": "my-data-stream1,my-data-stream2" + } + } + + +--- +"Test Reindex With Bad Data Stream Name": + - do: + catch: /illegal_argument_exception/ + migrate.reindex: + body: | + { + "mode": "upgrade", + "source": { + "index": "my-data-stream*" + } + } + +--- +"Test Reindex With Existing Data Stream": + - do: + indices.put_index_template: + name: my-template1 + body: + index_patterns: [my-data-stream*] + template: + mappings: + properties: + '@timestamp': + type: date + 'foo': + type: keyword + data_stream: {} + + - do: + indices.create_data_stream: + name: my-data-stream + - is_true: acknowledged + +# Uncomment once the cancel API is in place +# - do: +# migrate.reindex: +# body: | +# { +# "mode": "upgrade", +# "source": { +# "index": "my-data-stream" +# } +# } +# - match: { task: "reindex-data-stream-my-data-stream" } diff --git a/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCCSCanMatchIT.java b/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCCSCanMatchIT.java index 208da4177fd4c..e4e577299d0d7 100644 --- a/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCCSCanMatchIT.java +++ b/x-pack/plugin/transform/src/internalClusterTest/java/org/elasticsearch/xpack/transform/checkpoint/TransformCCSCanMatchIT.java @@ -385,7 +385,7 @@ protected NamedXContentRegistry xContentRegistry() { } @Override - protected Collection remoteClusterAlias() { + protected List remoteClusterAlias() { return List.of(REMOTE_CLUSTER); } diff --git a/x-pack/plugin/watcher/src/internalClusterTest/java/org/elasticsearch/xpack/watcher/test/integration/HistoryIntegrationTests.java b/x-pack/plugin/watcher/src/internalClusterTest/java/org/elasticsearch/xpack/watcher/test/integration/HistoryIntegrationTests.java index 0070554d99d27..1bcdd060994ce 100644 --- a/x-pack/plugin/watcher/src/internalClusterTest/java/org/elasticsearch/xpack/watcher/test/integration/HistoryIntegrationTests.java +++ b/x-pack/plugin/watcher/src/internalClusterTest/java/org/elasticsearch/xpack/watcher/test/integration/HistoryIntegrationTests.java @@ -130,7 +130,7 @@ public void testFailedInputResultWithDotsInFieldNameGetsStored() throws Exceptio String chainedPath = SINGLE_MAPPING_NAME + ".properties.result.properties.input.properties.chain.properties.chained.properties.search" + ".properties.request.properties.body.enabled"; - assertThat(source.getValue(chainedPath), is(false)); + assertThat(source.getValue(chainedPath), nullValue()); } else { String path = SINGLE_MAPPING_NAME + ".properties.result.properties.input.properties.search.properties.request.properties.body.enabled"; @@ -168,11 +168,11 @@ public void testPayloadInputWithDotsInFieldNameWorks() throws Exception { XContentType.JSON ); - // lets make sure the body fields are disabled + // let's make sure the body fields are disabled or, in the case of chained, the whole object is not indexed if (useChained) { String path = SINGLE_MAPPING_NAME + ".properties.result.properties.input.properties.chain.properties.chained.properties.payload.enabled"; - assertThat(source.getValue(path), is(false)); + assertThat(source.getValue(path), nullValue()); } else { String path = SINGLE_MAPPING_NAME + ".properties.result.properties.input.properties.payload.enabled"; assertThat(source.getValue(path), is(false)); diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/RolesBackwardsCompatibilityIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/RolesBackwardsCompatibilityIT.java index ea1b2cdac5a1f..54b7ff6fa484c 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/RolesBackwardsCompatibilityIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/RolesBackwardsCompatibilityIT.java @@ -158,8 +158,8 @@ public void testRolesWithDescription() throws Exception { public void testRolesWithManageRoles() throws Exception { assumeTrue( - "The manage roles privilege is supported after transport version: " + TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE, - minimumTransportVersion().before(TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE) + "The manage roles privilege is supported after transport version: " + TransportVersions.V_8_16_0, + minimumTransportVersion().before(TransportVersions.V_8_16_0) ); switch (CLUSTER_TYPE) { case OLD -> { @@ -190,7 +190,7 @@ public void testRolesWithManageRoles() throws Exception { } case MIXED -> { try { - this.createClientsByVersion(TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE); + this.createClientsByVersion(TransportVersions.V_8_16_0); // succeed when role manage roles is not provided final String initialRole = randomRoleDescriptorSerialized(); createRole(client(), "my-valid-mixed-role", initialRole); @@ -232,7 +232,7 @@ public void testRolesWithManageRoles() throws Exception { e.getMessage(), containsString( "all nodes must have version [" - + TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE.toReleaseVersion() + + TransportVersions.V_8_16_0.toReleaseVersion() + "] or higher to support the manage roles privilege" ) ); @@ -246,7 +246,7 @@ public void testRolesWithManageRoles() throws Exception { e.getMessage(), containsString( "all nodes must have version [" - + TransportVersions.ADD_MANAGE_ROLES_PRIVILEGE.toReleaseVersion() + + TransportVersions.V_8_16_0.toReleaseVersion() + "] or higher to support the manage roles privilege" ) );