From 5c3b9427774ffbf2ed6e4472cfeaaf6d6bfa9871 Mon Sep 17 00:00:00 2001 From: Marcin Januszkiewicz Date: Fri, 29 Sep 2023 09:24:14 +0200 Subject: [PATCH] Add support for version 8.10 --- 8.10/build.gradle | 6 + .../elasticsearch/TraveltimePlugin.java | 95 ++++++++++ .../query/TraveltimeFetchPhase.java | 77 ++++++++ .../query/TraveltimeQueryBuilder.java | 173 ++++++++++++++++++ .../query/TraveltimeQueryParser.java | 62 +++++++ .../elasticsearch/query/TraveltimeScorer.java | 114 ++++++++++++ .../query/TraveltimeSearchQuery.java | 50 +++++ .../elasticsearch/query/TraveltimeWeight.java | 153 ++++++++++++++++ .../universal/plugin-descriptor.properties | 6 + settings.gradle | 1 + 10 files changed, 737 insertions(+) create mode 100644 8.10/build.gradle create mode 100644 8.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java create mode 100644 8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java create mode 100644 8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java create mode 100644 8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java create mode 100644 8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java create mode 100644 8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java create mode 100644 8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java create mode 100644 8.10/src/universal/plugin-descriptor.properties diff --git a/8.10/build.gradle b/8.10/build.gradle new file mode 100644 index 0000000..37f5eb2 --- /dev/null +++ b/8.10/build.gradle @@ -0,0 +1,6 @@ +buildPlugin(this, '8.10', ['0', '1', '2', '3-SNAPSHOT']) + +compileJava { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 +} \ No newline at end of file diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java new file mode 100644 index 0000000..cd76ae5 --- /dev/null +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -0,0 +1,95 @@ +package com.traveltime.plugin.elasticsearch; + + +import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; +import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; +import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; +import com.traveltime.plugin.elasticsearch.util.Util; +import com.traveltime.sdk.dto.requests.proto.Country; +import com.traveltime.sdk.dto.requests.proto.Transportation; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.routing.allocation.AllocationService; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.env.Environment; +import org.elasticsearch.env.NodeEnvironment; +import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.repositories.RepositoriesService; +import org.elasticsearch.script.ScriptService; +import org.elasticsearch.search.fetch.FetchSubPhase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.tracing.Tracer; +import org.elasticsearch.watcher.ResourceWatcherService; +import org.elasticsearch.xcontent.NamedXContentRegistry; + +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; + +public class TraveltimePlugin extends Plugin implements SearchPlugin { + public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); + public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + + private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } + + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier, + Tracer tracer, + AllocationService allocationService, + IndicesService indicesService + ) { + TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); + + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); + + return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier, tracer, allocationService, indicesService); + + } + + @Override + public List> getSettings() { + return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); + } + + @Override + public List> getQueries() { + return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } + + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } +} diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java new file mode 100644 index 0000000..77afd79 --- /dev/null +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -0,0 +1,77 @@ +package com.traveltime.plugin.elasticsearch.query; + +import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import lombok.val; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.elasticsearch.common.document.DocumentField; +import org.elasticsearch.search.fetch.FetchContext; +import org.elasticsearch.search.fetch.FetchSubPhase; +import org.elasticsearch.search.fetch.FetchSubPhaseProcessor; +import org.elasticsearch.search.fetch.StoredFieldsSpec; +import org.elasticsearch.search.fetch.subphase.FieldAndFormat; +import org.elasticsearch.search.fetch.subphase.FieldFetcher; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +public class TraveltimeFetchPhase implements FetchSubPhase { + + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); + + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } + } + } + + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } + + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + + FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + + return new FetchSubPhaseProcessor() { + + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } + + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + Integer tt = TraveltimeCache.INSTANCE.get(params, docValues.nextValue()); + + if (tt > 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } + + @Override + public StoredFieldsSpec storedFieldsSpec() { + return new StoredFieldsSpec(false, false, Set.of(params.getField())); + } + }; + } +} diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java new file mode 100644 index 0000000..a006bbb --- /dev/null +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -0,0 +1,173 @@ +package com.traveltime.plugin.elasticsearch.query; + +import com.traveltime.plugin.elasticsearch.TraveltimePlugin; +import com.traveltime.sdk.dto.common.Coordinates; +import com.traveltime.sdk.dto.requests.proto.Country; +import com.traveltime.sdk.dto.requests.proto.Transportation; +import lombok.NonNull; +import lombok.Setter; +import org.apache.lucene.search.Query; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.geo.GeoPoint; +import org.elasticsearch.common.geo.GeoUtils; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.index.mapper.GeoPointFieldMapper; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.query.*; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; + +@Setter +public class TraveltimeQueryBuilder extends AbstractQueryBuilder { + @NonNull + private String field; + @NonNull + private GeoPoint origin; + private int limit; + private Transportation mode; + private Country country; + private QueryBuilder prefilter; + @NonNull + private String output = ""; + + public TraveltimeQueryBuilder() { + } + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.class); + country = in.readOptionalEnum(Country.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalEnum(country); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("prefilter", prefilter); + builder.field("output", output); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); + } + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); + } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery(params, prefilterQuery, output, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.MINIMUM_COMPATIBLE; + } + + public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { + return AbstractQueryBuilder.parseInnerQueryBuilder(parser); + } + + +} diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java new file mode 100644 index 0000000..7dd3f32 --- /dev/null +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -0,0 +1,62 @@ +package com.traveltime.plugin.elasticsearch.query; + +import com.traveltime.plugin.elasticsearch.util.Util; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.geo.GeoUtils; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryParser; +import org.elasticsearch.xcontent.ContextParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; + +public class TraveltimeQueryParser implements QueryParser { + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + + private final ContextParser prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); + + private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); + queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } + + private static T findByNameOrError(String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } + + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } +} diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java new file mode 100644 index 0000000..58e384b --- /dev/null +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -0,0 +1,114 @@ +package com.traveltime.plugin.elasticsearch.query; + +import it.unimi.dsi.fastutil.longs.Long2IntMap; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Scorer; + +import java.io.IOException; + +public class TraveltimeScorer extends Scorer { + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends SortedNumericDocValues { + private final SortedNumericDocValues backing; + + @Getter + private long currentValue = 0; + private boolean currentValueDirty = true; + private void invalidateCurrentValue() { + currentValueDirty = true; + } + private void advanceValue() throws IOException { + if(currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; + } + } + + @Override + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docValueCount() { + return 1; + } + + @Override + public boolean advanceExact(int target) throws IOException { + invalidateCurrentValue(); + return (target == DocIdSetIterator.NO_MORE_DOCS && backing.advanceExact(target)) || + backing.advanceExact(target) && pointToTime.containsKey(nextValue()); + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); + } + return id; + } + + @Override + public int advance(int target) throws IOException { + if (advanceExact(target)) { + return target; + } else { + return nextDoc(); + } + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, SortedNumericDocValues docs, float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + + } + + @Override + public int docID() { + return docs.docID(); + } +} diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java new file mode 100644 index 0000000..0ee3888 --- /dev/null +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -0,0 +1,50 @@ +package com.traveltime.plugin.elasticsearch.query; + +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import org.apache.lucene.search.*; + +import java.io.IOException; +import java.net.URI; + +@AllArgsConstructor +@EqualsAndHashCode(callSuper = false) +@Getter +public class TraveltimeSearchQuery extends Query { + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final URI appUri; + private final String appId; + private final String apiKey; + + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } + + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } + + @Override + public Query rewrite(IndexSearcher reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery(params, newPrefilter, output, appUri, appId, apiKey); + } + } +} diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java new file mode 100644 index 0000000..7cc8efc --- /dev/null +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -0,0 +1,153 @@ +package com.traveltime.plugin.elasticsearch.query; + +import com.traveltime.plugin.elasticsearch.FetcherSingleton; +import com.traveltime.plugin.elasticsearch.ProtoFetcher; +import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import com.traveltime.plugin.elasticsearch.util.Util; +import com.traveltime.sdk.dto.common.Coordinates; +import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; +import it.unimi.dsi.fastutil.longs.LongArrayList; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.val; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.SortedNumericDocValues; +import org.apache.lucene.search.*; +import org.elasticsearch.SpecialPermission; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +@EqualsAndHashCode(callSuper = false) +public class TraveltimeWeight extends Weight { + @Getter + private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude + private final ProtoFetcher protoFetcher; + + public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator extends SortedNumericDocValues { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docValueCount() { + return this.values.docValueCount(); + } + + public boolean advanceExact(int target) throws IOException { + this.filtered.advance(target); + return this.values.docValueCount() > 0; + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if(preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if(valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); + } + } + + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); + + val results = protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry() + ); + + for (int index = 0; index < results.size(); index++) { + if(results.get(index) > 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } + } + + if(hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } + + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } +} diff --git a/8.10/src/universal/plugin-descriptor.properties b/8.10/src/universal/plugin-descriptor.properties new file mode 100644 index 0000000..4574b21 --- /dev/null +++ b/8.10/src/universal/plugin-descriptor.properties @@ -0,0 +1,6 @@ +description=Traveltime search plugin +version=PLUGIN_VERSION +name=Traveltime +classname=com.traveltime.plugin.elasticsearch.TraveltimePlugin +java.version=1.11 +elasticsearch.version=ES_VERSION diff --git a/settings.gradle b/settings.gradle index b1829cb..1917403 100644 --- a/settings.gradle +++ b/settings.gradle @@ -15,3 +15,4 @@ include '8.2' include '8.3' include '8.4' include '8.5' +include '8.10'