From b9be652096f286314c99757000a06c7fc80a5251 Mon Sep 17 00:00:00 2001 From: Marcin Januszkiewicz Date: Thu, 29 Aug 2024 15:27:06 +0200 Subject: [PATCH] Apply google java formatting --- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 113 +++--- .../query/TraveltimeQueryBuilder.java | 351 +++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 279 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 113 +++--- .../query/TraveltimeQueryBuilder.java | 351 +++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 279 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 329 ++++++++-------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 179 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 241 ++++++------ .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 329 ++++++++-------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 279 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 329 ++++++++-------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 279 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 329 ++++++++-------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 279 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 329 ++++++++-------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 279 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 329 ++++++++-------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 279 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 329 ++++++++-------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 329 ++++++++-------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 166 ++++++--- .../query/TraveltimeFetchPhase.java | 122 +++--- .../query/TraveltimeQueryBuilder.java | 348 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 168 ++++++--- .../query/TraveltimeFetchPhase.java | 122 +++--- .../query/TraveltimeQueryBuilder.java | 347 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 138 ++++--- .../query/TraveltimeFetchPhase.java | 122 +++--- .../query/TraveltimeQueryBuilder.java | 348 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 138 ++++--- .../query/TraveltimeFetchPhase.java | 122 +++--- .../query/TraveltimeQueryBuilder.java | 348 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 339 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 339 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 147 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 339 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 149 +++++--- .../query/TraveltimeFetchPhase.java | 112 +++--- .../query/TraveltimeQueryBuilder.java | 339 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 163 +++++--- .../query/TraveltimeFetchPhase.java | 122 +++--- .../query/TraveltimeQueryBuilder.java | 349 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 163 +++++--- .../query/TraveltimeFetchPhase.java | 122 +++--- .../query/TraveltimeQueryBuilder.java | 349 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 163 +++++--- .../query/TraveltimeFetchPhase.java | 122 +++--- .../query/TraveltimeQueryBuilder.java | 349 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- .../elasticsearch/TraveltimePlugin.java | 163 +++++--- .../query/TraveltimeFetchPhase.java | 122 +++--- .../query/TraveltimeQueryBuilder.java | 350 ++++++++--------- .../query/TraveltimeQueryParser.java | 106 +++--- .../elasticsearch/query/TraveltimeScorer.java | 178 ++++----- .../query/TraveltimeSearchQuery.java | 72 ++-- .../elasticsearch/query/TraveltimeWeight.java | 270 +++++++------- 154 files changed, 14440 insertions(+), 12713 deletions(-) diff --git a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 690201c..b79cd82 100644 --- a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.watcher.ResourceWatcherService; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_CLEANUP_INTERVAL, CACHE_EXPIRY, CACHE_SIZE); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_CLEANUP_INTERVAL, + CACHE_EXPIRY, + CACHE_SIZE); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index 4865b2d..308b58e 100644 --- a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,71 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.mapperService(), fetchContext.searchLookup(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.mapperService(), + fetchContext.searchLookup(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 7fc0e68..1615fe8 100644 --- a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -18,184 +22,187 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.*; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - if (in.readBoolean()) { - mode = in.readEnum(Transportation.Modes.class); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + if (in.readBoolean()) { + mode = in.readEnum(Transportation.Modes.class); + } else { + mode = null; + } + if (in.readBoolean()) { + String c = in.readString(); + country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + } else { + country = null; + } + if (in.readBoolean()) { + requestType = in.readEnum(RequestType.class); + } else { + mode = null; + } + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeBoolean(mode != null); + if (mode != null) out.writeEnum(mode); + out.writeBoolean(country != null); + if (country != null) out.writeString(country.getValue()); + out.writeBoolean(requestType != null); + if (requestType != null) out.writeEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(QueryShardContext context) throws IOException { + MappedFieldType originMapping = context.fieldMapper(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); } else { - mode = null; + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - if (in.readBoolean()) { - String c = in.readString(); - country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); } else { - country = null; + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - if (in.readBoolean()) { - requestType = in.readEnum(RequestType.class); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); } else { - mode = null; - } - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeBoolean(mode != null); - if (mode != null) out.writeEnum(mode); - out.writeBoolean(country != null); - if (country != null) out.writeString(country.getValue()); - out.writeBoolean(requestType != null); - if(requestType != null) out.writeEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(QueryShardContext context) throws IOException { - MappedFieldType originMapping = context.fieldMapper(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); - } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); - } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } } diff --git a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index a8bdac1..d43c20c 100644 --- a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; @@ -11,57 +14,68 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index c68eab1..45cf72f 100644 --- a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - super.visit(visitor); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + super.visit(visitor); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index efc8e7c..231698c 100644 --- a/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/7.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,10 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -20,159 +24,154 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public void extractTerms(Set terms) { - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public void extractTerms(Set terms) {} + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - public long cost() { - return this.filtered.cost(); - } - } + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - DocIdSetIterator finalIterator; - - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if (hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 690201c..b79cd82 100644 --- a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.watcher.ResourceWatcherService; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_CLEANUP_INTERVAL, CACHE_EXPIRY, CACHE_SIZE); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_CLEANUP_INTERVAL, + CACHE_EXPIRY, + CACHE_SIZE); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index a2590cd..87cffdb 100644 --- a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,71 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getQueryShardContext(), fetchContext.searchLookup(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getQueryShardContext(), + fetchContext.searchLookup(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 96746a3..a74cfbe 100644 --- a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -18,184 +22,187 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.*; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - if (in.readBoolean()) { - mode = in.readEnum(Transportation.Modes.class); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + if (in.readBoolean()) { + mode = in.readEnum(Transportation.Modes.class); + } else { + mode = null; + } + if (in.readBoolean()) { + String c = in.readString(); + country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + } else { + country = null; + } + if (in.readBoolean()) { + requestType = in.readEnum(RequestType.class); + } else { + mode = null; + } + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeBoolean(mode != null); + if (mode != null) out.writeEnum(mode); + out.writeBoolean(country != null); + if (country != null) out.writeString(country.getValue()); + out.writeBoolean(requestType != null); + if (requestType != null) out.writeEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(QueryShardContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); } else { - mode = null; + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - if (in.readBoolean()) { - String c = in.readString(); - country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); } else { - country = null; + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - if (in.readBoolean()) { - requestType = in.readEnum(RequestType.class); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); } else { - mode = null; - } - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeBoolean(mode != null); - if (mode != null) out.writeEnum(mode); - out.writeBoolean(country != null); - if (country != null) out.writeString(country.getValue()); - out.writeBoolean(requestType != null); - if(requestType != null) out.writeEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(QueryShardContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); - } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); - } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } } diff --git a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index a8bdac1..d43c20c 100644 --- a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; @@ -11,57 +14,68 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index c68eab1..45cf72f 100644 --- a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - super.visit(visitor); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + super.visit(visitor); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index efc8e7c..231698c 100644 --- a/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/7.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,10 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -20,159 +24,154 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public void extractTerms(Set terms) { - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public void extractTerms(Set terms) {} + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - public long cost() { - return this.filtered.cost(); - } - } + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - DocIdSetIterator finalIterator; - - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if (hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 690201c..b79cd82 100644 --- a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.watcher.ResourceWatcherService; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_CLEANUP_INTERVAL, CACHE_EXPIRY, CACHE_SIZE); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_CLEANUP_INTERVAL, + CACHE_EXPIRY, + CACHE_SIZE); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 8aa58a8..5265e8f 100644 --- a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -18,170 +22,173 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.*; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } } diff --git a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index a8bdac1..d43c20c 100644 --- a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; @@ -11,57 +14,68 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index b98e209..3958877 100644 --- a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,101 +1,104 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - @Getter - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + @Getter private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index c68eab1..45cf72f 100644 --- a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - super.visit(visitor); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + super.visit(visitor); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 3a73f7d..b7e8541 100644 --- a/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/7.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,10 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -20,132 +24,127 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public void extractTerms(Set terms) { - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public void extractTerms(Set terms) {} + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } - - public int docID() { - return this.filtered.docID(); - } - - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } - - public int advance(int target) throws IOException { - return this.filtered.advance(target); + } + + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); + + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); } + } - public long cost() { - return this.filtered.cost(); - } - } - - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); - - DocIdSetIterator finalIterator; - - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; - } - - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } - } - - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if(results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 690201c..b79cd82 100644 --- a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.watcher.ResourceWatcherService; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_CLEANUP_INTERVAL, CACHE_EXPIRY, CACHE_SIZE); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_CLEANUP_INTERVAL, + CACHE_EXPIRY, + CACHE_SIZE); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 8aa58a8..5265e8f 100644 --- a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -18,170 +22,173 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.*; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } } diff --git a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index a8bdac1..d43c20c 100644 --- a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; @@ -11,57 +14,68 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index c68eab1..45cf72f 100644 --- a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - super.visit(visitor); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + super.visit(visitor); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 5358ac9..231698c 100644 --- a/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/7.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,10 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -20,159 +24,154 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public void extractTerms(Set terms) { - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public void extractTerms(Set terms) {} + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - public long cost() { - return this.filtered.cost(); - } - } + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - DocIdSetIterator finalIterator; - - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 25439cb..99dd76a 100644 --- a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.watcher.ResourceWatcherService; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_CLEANUP_INTERVAL, CACHE_EXPIRY, CACHE_SIZE); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_CLEANUP_INTERVAL, + CACHE_EXPIRY, + CACHE_SIZE); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 8aa58a8..5265e8f 100644 --- a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -18,170 +22,173 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.*; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } } diff --git a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 7e7125d..6912e63 100644 --- a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.common.xcontent.ContextParser; @@ -11,57 +14,68 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index c68eab1..45cf72f 100644 --- a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - super.visit(visitor); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + super.visit(visitor); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 5358ac9..231698c 100644 --- a/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/7.14/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,10 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -20,159 +24,154 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public void extractTerms(Set terms) { - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public void extractTerms(Set terms) {} + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - public long cost() { - return this.filtered.cost(); - } - } + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - DocIdSetIterator finalIterator; - - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 25439cb..99dd76a 100644 --- a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.watcher.ResourceWatcherService; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_CLEANUP_INTERVAL, CACHE_EXPIRY, CACHE_SIZE); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_CLEANUP_INTERVAL, + CACHE_EXPIRY, + CACHE_SIZE); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 8aa58a8..5265e8f 100644 --- a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -18,170 +22,173 @@ import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.query.*; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } } diff --git a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 7e7125d..6912e63 100644 --- a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.common.xcontent.ContextParser; @@ -11,57 +14,68 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index c68eab1..45cf72f 100644 --- a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - super.visit(visitor); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + super.visit(visitor); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 5358ac9..231698c 100644 --- a/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/7.15/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,10 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -20,159 +24,154 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public void extractTerms(Set terms) { - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public void extractTerms(Set terms) {} + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - public long cost() { - return this.filtered.cost(); - } - } + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - DocIdSetIterator finalIterator; - - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index c71c5d2..fa90b48 100644 --- a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 84237a3..4c467a9 100644 --- a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -18,170 +22,173 @@ import org.elasticsearch.index.query.*; import org.elasticsearch.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } } diff --git a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 4df2ba9..2cb4ff3 100644 --- a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -11,57 +14,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index c68eab1..45cf72f 100644 --- a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - super.visit(visitor); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + super.visit(visitor); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 5358ac9..231698c 100644 --- a/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/7.16/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,10 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -20,159 +24,154 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public void extractTerms(Set terms) { - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public void extractTerms(Set terms) {} + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - public long cost() { - return this.filtered.cost(); - } - } + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - DocIdSetIterator finalIterator; - - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index c71c5d2..fa90b48 100644 --- a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 84237a3..4c467a9 100644 --- a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -18,170 +22,173 @@ import org.elasticsearch.index.query.*; import org.elasticsearch.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } } diff --git a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 4df2ba9..2cb4ff3 100644 --- a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -11,57 +14,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index c68eab1..45cf72f 100644 --- a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - super.visit(visitor); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + super.visit(visitor); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 5358ac9..231698c 100644 --- a/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/7.17/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,10 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -20,159 +24,154 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public void extractTerms(Set terms) { - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public void extractTerms(Set terms) {} + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - public long cost() { - return this.filtered.cost(); - } - } + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - DocIdSetIterator finalIterator; - - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionDISI.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index c71c5d2..fa90b48 100644 --- a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 84237a3..4c467a9 100644 --- a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -18,170 +22,173 @@ import org.elasticsearch.index.query.*; import org.elasticsearch.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } } diff --git a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 4df2ba9..2cb4ff3 100644 --- a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -11,57 +14,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 7ec036d..99c1267 100644 --- a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.0/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index e598a70..9ce2ced 100644 --- a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 84237a3..4c467a9 100644 --- a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -18,170 +22,173 @@ import org.elasticsearch.index.query.*; import org.elasticsearch.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } } diff --git a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 4df2ba9..2cb4ff3 100644 --- a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -11,57 +14,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 7ec036d..99c1267 100644 --- a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.1/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index c26ff2d..c78a845 100644 --- a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.routing.allocation.AllocationService; @@ -28,73 +33,114 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents( - Client client, - ClusterService clusterService, - ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, - ScriptService scriptService, - NamedXContentRegistry xContentRegistry, - Environment environment, - NodeEnvironment nodeEnvironment, - NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier, - Tracer tracer, - AllocationService allocationService, - IndicesService indicesService - ) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier, + Tracer tracer, + AllocationService allocationService, + IndicesService indicesService) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier, tracer, allocationService, indicesService); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier, + tracer, + allocationService, + indicesService); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index 2230825..7956338 100644 --- a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,10 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -13,75 +17,75 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if (!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if (!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } - @Override - public StoredFieldsSpec storedFieldsSpec() { - return new StoredFieldsSpec(false, false, Set.of(params.getField())); - } - }; - } + @Override + public StoredFieldsSpec storedFieldsSpec() { + return new StoredFieldsSpec(false, false, Set.of(params.getField())); + } + }; + } } diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 3225ea4..6e3539a 100644 --- a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -20,181 +24,181 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); - } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.MINIMUM_COMPATIBLE; - } - - public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { - return AbstractQueryBuilder.parseInnerQueryBuilder(parser); - } - - + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.MINIMUM_COMPATIBLE; + } + + public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { + return AbstractQueryBuilder.parseInnerQueryBuilder(parser); + } } diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 6f1deee..961c908 100644 --- a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.QueryBuilder; @@ -10,57 +13,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 0bc37e5..be0b0da 100644 --- a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,51 +1,53 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexSearcher reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexSearcher reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.10/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index c26cc3e..c5dfa90 100644 --- a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.routing.allocation.AllocationService; @@ -28,73 +33,116 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, threadPool.generic(), () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents( - Client client, - ClusterService clusterService, - ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, - ScriptService scriptService, - NamedXContentRegistry xContentRegistry, - Environment environment, - NodeEnvironment nodeEnvironment, - NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier, - TelemetryProvider telemetryProvider, - AllocationService allocationService, - IndicesService indicesService - ) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, + threadPool.generic(), + () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier, + TelemetryProvider telemetryProvider, + AllocationService allocationService, + IndicesService indicesService) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier, telemetryProvider, allocationService, indicesService); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier, + telemetryProvider, + allocationService, + indicesService); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index 2230825..7956338 100644 --- a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,10 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -13,75 +17,75 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if (!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if (!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } - @Override - public StoredFieldsSpec storedFieldsSpec() { - return new StoredFieldsSpec(false, false, Set.of(params.getField())); - } - }; - } + @Override + public StoredFieldsSpec storedFieldsSpec() { + return new StoredFieldsSpec(false, false, Set.of(params.getField())); + } + }; + } } diff --git a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 7d73359..5332be0 100644 --- a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -21,181 +25,180 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); - } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.MINIMUM_COMPATIBLE; - } - - public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { - return AbstractQueryBuilder.parseInnerQueryBuilder(parser); - } - - + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.MINIMUM_COMPATIBLE; + } + + public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { + return AbstractQueryBuilder.parseInnerQueryBuilder(parser); + } } diff --git a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 6f1deee..961c908 100644 --- a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.QueryBuilder; @@ -10,57 +13,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 0bc37e5..be0b0da 100644 --- a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,51 +1,53 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexSearcher reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexSearcher reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.11/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 1153fc8..b88fb83 100644 --- a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,77 +7,102 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.routing.allocation.AllocationService; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.env.Environment; -import org.elasticsearch.env.NodeEnvironment; -import org.elasticsearch.indices.IndicesService; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SearchPlugin; -import org.elasticsearch.repositories.RepositoriesService; -import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.fetch.FetchSubPhase; -import org.elasticsearch.telemetry.TelemetryProvider; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.watcher.ResourceWatcherService; -import org.elasticsearch.xcontent.NamedXContentRegistry; - -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, threadPool.generic(), () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, + threadPool.generic(), + () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - @Override - public Collection createComponents(PluginServices pluginServices) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(pluginServices.environment().settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(pluginServices.environment().settings())); - Integer cacheSize = CACHE_SIZE.get(pluginServices.environment().settings()); + @Override + public Collection createComponents(PluginServices pluginServices) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds( + CACHE_CLEANUP_INTERVAL.get(pluginServices.environment().settings())); + Duration cacheExpiry = + Duration.ofSeconds(CACHE_EXPIRY.get(pluginServices.environment().settings())); + Integer cacheSize = CACHE_SIZE.get(pluginServices.environment().settings()); - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(pluginServices.threadPool(), cleanupSeconds); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(pluginServices.threadPool(), cleanupSeconds); - return super.createComponents(pluginServices); - } + return super.createComponents(pluginServices); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index 2230825..7956338 100644 --- a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,10 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -13,75 +17,75 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if (!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if (!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } - @Override - public StoredFieldsSpec storedFieldsSpec() { - return new StoredFieldsSpec(false, false, Set.of(params.getField())); - } - }; - } + @Override + public StoredFieldsSpec storedFieldsSpec() { + return new StoredFieldsSpec(false, false, Set.of(params.getField())); + } + }; + } } diff --git a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index d6f78b2..9eb7314 100644 --- a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -21,181 +25,181 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); - } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.MINIMUM_COMPATIBLE; - } - - public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { - return AbstractQueryBuilder.parseInnerQueryBuilder(parser); - } - - + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.MINIMUM_COMPATIBLE; + } + + public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { + return AbstractQueryBuilder.parseInnerQueryBuilder(parser); + } } diff --git a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 6f1deee..961c908 100644 --- a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.QueryBuilder; @@ -10,57 +13,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 0bc37e5..be0b0da 100644 --- a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,51 +1,53 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexSearcher reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexSearcher reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.12/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 1153fc8..b88fb83 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,77 +7,102 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.routing.allocation.AllocationService; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.env.Environment; -import org.elasticsearch.env.NodeEnvironment; -import org.elasticsearch.indices.IndicesService; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SearchPlugin; -import org.elasticsearch.repositories.RepositoriesService; -import org.elasticsearch.script.ScriptService; import org.elasticsearch.search.fetch.FetchSubPhase; -import org.elasticsearch.telemetry.TelemetryProvider; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.watcher.ResourceWatcherService; -import org.elasticsearch.xcontent.NamedXContentRegistry; - -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, threadPool.generic(), () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, + threadPool.generic(), + () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - @Override - public Collection createComponents(PluginServices pluginServices) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(pluginServices.environment().settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(pluginServices.environment().settings())); - Integer cacheSize = CACHE_SIZE.get(pluginServices.environment().settings()); + @Override + public Collection createComponents(PluginServices pluginServices) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds( + CACHE_CLEANUP_INTERVAL.get(pluginServices.environment().settings())); + Duration cacheExpiry = + Duration.ofSeconds(CACHE_EXPIRY.get(pluginServices.environment().settings())); + Integer cacheSize = CACHE_SIZE.get(pluginServices.environment().settings()); - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(pluginServices.threadPool(), cleanupSeconds); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(pluginServices.threadPool(), cleanupSeconds); - return super.createComponents(pluginServices); - } + return super.createComponents(pluginServices); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index 2230825..7956338 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,10 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -13,75 +17,75 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if (!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if (!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } - @Override - public StoredFieldsSpec storedFieldsSpec() { - return new StoredFieldsSpec(false, false, Set.of(params.getField())); - } - }; - } + @Override + public StoredFieldsSpec storedFieldsSpec() { + return new StoredFieldsSpec(false, false, Set.of(params.getField())); + } + }; + } } diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index d6f78b2..9eb7314 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -21,181 +25,181 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); - } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.MINIMUM_COMPATIBLE; - } - - public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { - return AbstractQueryBuilder.parseInnerQueryBuilder(parser); - } - - + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.MINIMUM_COMPATIBLE; + } + + public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { + return AbstractQueryBuilder.parseInnerQueryBuilder(parser); + } } diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 6f1deee..961c908 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.QueryBuilder; @@ -10,57 +13,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 0bc37e5..be0b0da 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,51 +1,53 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexSearcher reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexSearcher reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index e598a70..9ce2ced 100644 --- a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 6000a45..e5447eb 100644 --- a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -19,175 +23,178 @@ import org.elasticsearch.index.query.*; import org.elasticsearch.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public Version getMinimalSupportedVersion() { - return Version.V_8_2_0; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_8_2_0; + } } diff --git a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 4df2ba9..2cb4ff3 100644 --- a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -11,57 +14,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 7ec036d..99c1267 100644 --- a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.2/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index e598a70..9ce2ced 100644 --- a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -25,60 +30,108 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of( - new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser()) - ); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 6000a45..e5447eb 100644 --- a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -19,175 +23,178 @@ import org.elasticsearch.index.query.*; import org.elasticsearch.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public Version getMinimalSupportedVersion() { - return Version.V_8_2_0; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_8_2_0; + } } diff --git a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 4df2ba9..2cb4ff3 100644 --- a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -11,57 +14,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 7ec036d..99c1267 100644 --- a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.3/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 780ae25..6c72f0f 100644 --- a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.service.ClusterService; @@ -26,58 +31,110 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier, Tracer tracer) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier, + Tracer tracer) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier, tracer); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier, + tracer); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 6000a45..e5447eb 100644 --- a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -19,175 +23,178 @@ import org.elasticsearch.index.query.*; import org.elasticsearch.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public Version getMinimalSupportedVersion() { - return Version.V_8_2_0; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_8_2_0; + } } diff --git a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 4df2ba9..2cb4ff3 100644 --- a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -11,57 +14,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 7ec036d..99c1267 100644 --- a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.4/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index df539c7..7f0a0d9 100644 --- a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders; @@ -27,58 +32,112 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, ResourceWatcherService resourceWatcherService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, Environment environment, NodeEnvironment nodeEnvironment, NamedWriteableRegistry namedWriteableRegistry, IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier, Tracer tracer, AllocationDeciders allocationDeciders) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier, + Tracer tracer, + AllocationDeciders allocationDeciders) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier, tracer, allocationDeciders); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier, + tracer, + allocationDeciders); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index aab65ac..2e6163b 100644 --- a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -12,69 +15,70 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if(!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if(!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } - }; - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } + }; + } } diff --git a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index efd97ea..90013eb 100644 --- a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -19,175 +23,178 @@ import org.elasticsearch.index.query.*; import org.elasticsearch.xcontent.XContentBuilder; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public Version getMinimalSupportedVersion() { - return Version.V_8_2_0; - } + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_8_2_0; + } } diff --git a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 4df2ba9..2cb4ff3 100644 --- a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.AbstractQueryBuilder; @@ -11,57 +14,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> AbstractQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 7ec036d..99c1267 100644 --- a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.5/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index e11294e..b37eaab 100644 --- a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders; @@ -27,72 +32,112 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents( - Client client, - ClusterService clusterService, - ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, - ScriptService scriptService, - NamedXContentRegistry xContentRegistry, - Environment environment, - NodeEnvironment nodeEnvironment, - NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier, - Tracer tracer, - AllocationDeciders allocationService - ) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier, + Tracer tracer, + AllocationDeciders allocationService) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier, tracer, allocationService); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier, + tracer, + allocationService); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index 2230825..7956338 100644 --- a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,10 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -13,75 +17,75 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if (!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if (!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } - @Override - public StoredFieldsSpec storedFieldsSpec() { - return new StoredFieldsSpec(false, false, Set.of(params.getField())); - } - }; - } + @Override + public StoredFieldsSpec storedFieldsSpec() { + return new StoredFieldsSpec(false, false, Set.of(params.getField())); + } + }; + } } diff --git a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index fa14f86..fad1c77 100644 --- a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -20,181 +24,182 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); - } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public Version getMinimalSupportedVersion() { - return Version.V_8_2_0; - } - - public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { - return AbstractQueryBuilder.parseInnerQueryBuilder(parser); - } - - + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_8_2_0; + } + + public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { + return AbstractQueryBuilder.parseInnerQueryBuilder(parser); + } } diff --git a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 6f1deee..961c908 100644 --- a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.QueryBuilder; @@ -10,57 +13,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 7ec036d..99c1267 100644 --- a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.6/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 826e2f0..33ff423 100644 --- a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.routing.allocation.AllocationService; @@ -27,72 +32,112 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents( - Client client, - ClusterService clusterService, - ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, - ScriptService scriptService, - NamedXContentRegistry xContentRegistry, - Environment environment, - NodeEnvironment nodeEnvironment, - NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier, - Tracer tracer, - AllocationService allocationService - ) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier, + Tracer tracer, + AllocationService allocationService) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier, tracer, allocationService); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier, + tracer, + allocationService); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index 2230825..7956338 100644 --- a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,10 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -13,75 +17,75 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if (!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if (!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } - @Override - public StoredFieldsSpec storedFieldsSpec() { - return new StoredFieldsSpec(false, false, Set.of(params.getField())); - } - }; - } + @Override + public StoredFieldsSpec storedFieldsSpec() { + return new StoredFieldsSpec(false, false, Set.of(params.getField())); + } + }; + } } diff --git a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index d4b6e9e..5fee598 100644 --- a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -20,181 +24,182 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); - } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.MINIMUM_COMPATIBLE; - } - - public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { - return AbstractQueryBuilder.parseInnerQueryBuilder(parser); - } - - + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.MINIMUM_COMPATIBLE; + } + + public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { + return AbstractQueryBuilder.parseInnerQueryBuilder(parser); + } } diff --git a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 6f1deee..961c908 100644 --- a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.QueryBuilder; @@ -10,57 +13,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 7ec036d..99c1267 100644 --- a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.7/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 826e2f0..33ff423 100644 --- a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.routing.allocation.AllocationService; @@ -27,72 +32,112 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents( - Client client, - ClusterService clusterService, - ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, - ScriptService scriptService, - NamedXContentRegistry xContentRegistry, - Environment environment, - NodeEnvironment nodeEnvironment, - NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier, - Tracer tracer, - AllocationService allocationService - ) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier, + Tracer tracer, + AllocationService allocationService) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier, tracer, allocationService); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier, + tracer, + allocationService); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index 2230825..7956338 100644 --- a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,10 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -13,75 +17,75 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if (!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if (!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } - @Override - public StoredFieldsSpec storedFieldsSpec() { - return new StoredFieldsSpec(false, false, Set.of(params.getField())); - } - }; - } + @Override + public StoredFieldsSpec storedFieldsSpec() { + return new StoredFieldsSpec(false, false, Set.of(params.getField())); + } + }; + } } diff --git a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index d4b6e9e..5fee598 100644 --- a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -20,181 +24,182 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); - } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } - } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); - } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } - } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.MINIMUM_COMPATIBLE; - } - - public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { - return AbstractQueryBuilder.parseInnerQueryBuilder(parser); - } - - + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.MINIMUM_COMPATIBLE; + } + + public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { + return AbstractQueryBuilder.parseInnerQueryBuilder(parser); + } } diff --git a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 6f1deee..961c908 100644 --- a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.QueryBuilder; @@ -10,57 +13,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 7ec036d..99c1267 100644 --- a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,52 +1,54 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexReader reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexReader reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.8/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } } diff --git a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index 826e2f0..33ff423 100644 --- a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -1,6 +1,5 @@ package com.traveltime.plugin.elasticsearch; - import com.traveltime.plugin.elasticsearch.query.TraveltimeFetchPhase; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryBuilder; import com.traveltime.plugin.elasticsearch.query.TraveltimeQueryParser; @@ -8,6 +7,12 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.net.URI; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.routing.allocation.AllocationService; @@ -27,72 +32,112 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xcontent.NamedXContentRegistry; -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - public class TraveltimePlugin extends Plugin implements SearchPlugin { - public static final Setting APP_ID = Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); - public static final Setting API_KEY = Setting.simpleString("traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); - public static final Setting> DEFAULT_MODE = new Setting<>("traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); - public static final Setting> DEFAULT_COUNTRY = new Setting<>("traveltime.default.country", s -> "", Util::findCountryByName, Setting.Property.NodeScope); - - public static final Setting> DEFAULT_REQUEST_TYPE = new Setting<>("traveltime.default.request_type", s -> RequestType.ONE_TO_MANY.name(), Util::findRequestTypeByName, Setting.Property.NodeScope); - public static final Setting API_URI = new Setting<>("traveltime.api.uri", s -> "https://proto.api.traveltimeapp.com/api/v2/", URI::create, Setting.Property.NodeScope); + public static final Setting APP_ID = + Setting.simpleString("traveltime.app.id", Setting.Property.NodeScope); + public static final Setting API_KEY = + Setting.simpleString( + "traveltime.api.key", Setting.Property.NodeScope, Setting.Property.Filtered); + public static final Setting> DEFAULT_MODE = + new Setting<>( + "traveltime.default.mode", s -> "", Util::findModeByName, Setting.Property.NodeScope); + public static final Setting> DEFAULT_COUNTRY = + new Setting<>( + "traveltime.default.country", + s -> "", + Util::findCountryByName, + Setting.Property.NodeScope); - private static final Setting CACHE_CLEANUP_INTERVAL = Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); - private static final Setting CACHE_EXPIRY = Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); - private static final Setting CACHE_SIZE = Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); + public static final Setting> DEFAULT_REQUEST_TYPE = + new Setting<>( + "traveltime.default.request_type", + s -> RequestType.ONE_TO_MANY.name(), + Util::findRequestTypeByName, + Setting.Property.NodeScope); + public static final Setting API_URI = + new Setting<>( + "traveltime.api.uri", + s -> "https://proto.api.traveltimeapp.com/api/v2/", + URI::create, + Setting.Property.NodeScope); - private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { - TraveltimeCache.INSTANCE.cleanUp(); - TraveltimeCache.DISTANCE.cleanUp(); - threadPool.scheduleUnlessShuttingDown(cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); - } + private static final Setting CACHE_CLEANUP_INTERVAL = + Setting.intSetting("traveltime.cache.cleanup.interval", 120, 0, Setting.Property.NodeScope); + private static final Setting CACHE_EXPIRY = + Setting.intSetting("traveltime.cache.expiry", 60, 0, Setting.Property.NodeScope); + private static final Setting CACHE_SIZE = + Setting.intSetting("traveltime.cache.size", 50, 0, Setting.Property.NodeScope); - @Override - public Collection createComponents( - Client client, - ClusterService clusterService, - ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, - ScriptService scriptService, - NamedXContentRegistry xContentRegistry, - Environment environment, - NodeEnvironment nodeEnvironment, - NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier, - Tracer tracer, - AllocationService allocationService - ) { - TimeValue cleanupSeconds = TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); - Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); - Integer cacheSize = CACHE_SIZE.get(environment.settings()); + private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { + TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); + threadPool.scheduleUnlessShuttingDown( + cleanupSeconds, "generic", () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); + } - TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); - TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); - cleanUpAndReschedule(threadPool, cleanupSeconds); + @Override + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier, + Tracer tracer, + AllocationService allocationService) { + TimeValue cleanupSeconds = + TimeValue.timeValueSeconds(CACHE_CLEANUP_INTERVAL.get(environment.settings())); + Duration cacheExpiry = Duration.ofSeconds(CACHE_EXPIRY.get(environment.settings())); + Integer cacheSize = CACHE_SIZE.get(environment.settings()); - return super.createComponents(client, clusterService, threadPool, resourceWatcherService, scriptService, xContentRegistry, environment, nodeEnvironment, namedWriteableRegistry, indexNameExpressionResolver, repositoriesServiceSupplier, tracer, allocationService); + TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); + cleanUpAndReschedule(threadPool, cleanupSeconds); - } + return super.createComponents( + client, + clusterService, + threadPool, + resourceWatcherService, + scriptService, + xContentRegistry, + environment, + nodeEnvironment, + namedWriteableRegistry, + indexNameExpressionResolver, + repositoriesServiceSupplier, + tracer, + allocationService); + } - @Override - public List> getSettings() { - return List.of(APP_ID, API_KEY, DEFAULT_MODE, DEFAULT_COUNTRY, DEFAULT_REQUEST_TYPE, API_URI, CACHE_SIZE, CACHE_EXPIRY, CACHE_CLEANUP_INTERVAL); - } + @Override + public List> getSettings() { + return List.of( + APP_ID, + API_KEY, + DEFAULT_MODE, + DEFAULT_COUNTRY, + DEFAULT_REQUEST_TYPE, + API_URI, + CACHE_SIZE, + CACHE_EXPIRY, + CACHE_CLEANUP_INTERVAL); + } - @Override - public List> getQueries() { - return List.of(new QuerySpec<>(TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); - } + @Override + public List> getQueries() { + return List.of( + new QuerySpec<>( + TraveltimeQueryParser.NAME, TraveltimeQueryBuilder::new, new TraveltimeQueryParser())); + } - @Override - public List getFetchSubPhases(FetchPhaseConstructionContext context) { - return List.of(new TraveltimeFetchPhase()); - } + @Override + public List getFetchSubPhases(FetchPhaseConstructionContext context) { + return List.of(new TraveltimeFetchPhase()); + } } diff --git a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index 2230825..7956338 100644 --- a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -1,6 +1,10 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.TraveltimeCache; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; import lombok.val; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; @@ -13,75 +17,75 @@ import org.elasticsearch.search.fetch.subphase.FieldAndFormat; import org.elasticsearch.search.fetch.subphase.FieldFetcher; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - public class TraveltimeFetchPhase implements FetchSubPhase { - private static class ParamFinder extends QueryVisitor { - private final List paramList = new ArrayList<>(); + private static class ParamFinder extends QueryVisitor { + private final List paramList = new ArrayList<>(); - @Override - public void visitLeaf(Query query) { - if (query instanceof TraveltimeSearchQuery) { - if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { - paramList.add(((TraveltimeSearchQuery) query)); - } - } + @Override + public void visitLeaf(Query query) { + if (query instanceof TraveltimeSearchQuery) { + if (!((TraveltimeSearchQuery) query).getOutput().isEmpty()) { + paramList.add(((TraveltimeSearchQuery) query)); + } } + } - public TraveltimeSearchQuery getQuery() { - if (paramList.size() == 1) return paramList.get(0); - else return null; - } - } + public TraveltimeSearchQuery getQuery() { + if (paramList.size() == 1) return paramList.get(0); + else return null; + } + } - @Override - public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { - Query query = fetchContext.query(); - val finder = new ParamFinder(); - query.visit(finder); - TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); - if (traveltimeQuery == null) return null; - TraveltimeQueryParameters params = traveltimeQuery.getParams(); - final String output = traveltimeQuery.getOutput(); - final String distanceOutput = traveltimeQuery.getDistanceOutput(); + @Override + public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { + Query query = fetchContext.query(); + val finder = new ParamFinder(); + query.visit(finder); + TraveltimeSearchQuery traveltimeQuery = finder.getQuery(); + if (traveltimeQuery == null) return null; + TraveltimeQueryParameters params = traveltimeQuery.getParams(); + final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); - FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); + FieldFetcher fieldFetcher = + FieldFetcher.create( + fetchContext.getSearchExecutionContext(), + List.of(new FieldAndFormat(params.getField(), null))); - return new FetchSubPhaseProcessor() { + return new FetchSubPhaseProcessor() { - @Override - public void setNextReader(LeafReaderContext readerContext) { - fieldFetcher.setNextReader(readerContext); - } + @Override + public void setNextReader(LeafReaderContext readerContext) { + fieldFetcher.setNextReader(readerContext); + } - @Override - public void process(HitContext hitContext) throws IOException { - val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); - docValues.advance(hitContext.docId()); - val point = docValues.nextValue(); - if (!output.isEmpty()) { - Integer tt = TraveltimeCache.INSTANCE.get(params, point); - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); - } - } + @Override + public void process(HitContext hitContext) throws IOException { + val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); + docValues.advance(hitContext.docId()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if (!distanceOutput.isEmpty()) { - Integer td = TraveltimeCache.DISTANCE.get(params, point); - if (td >= 0) { - hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); - } - } - } + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext + .hit() + .setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } + } + } - @Override - public StoredFieldsSpec storedFieldsSpec() { - return new StoredFieldsSpec(false, false, Set.of(params.getField())); - } - }; - } + @Override + public StoredFieldsSpec storedFieldsSpec() { + return new StoredFieldsSpec(false, false, Set.of(params.getField())); + } + }; + } } diff --git a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index dfac8d5..ac3466f 100644 --- a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -6,6 +6,10 @@ import com.traveltime.sdk.dto.requests.proto.Country; import com.traveltime.sdk.dto.requests.proto.RequestType; import com.traveltime.sdk.dto.requests.proto.Transportation; +import java.io.IOException; +import java.net.URI; +import java.util.Objects; +import java.util.Optional; import lombok.NonNull; import lombok.Setter; import org.apache.lucene.search.Query; @@ -20,182 +24,182 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.net.URI; -import java.util.Objects; -import java.util.Optional; - @Setter public class TraveltimeQueryBuilder extends AbstractQueryBuilder { - @NonNull - private String field; - @NonNull - private GeoPoint origin; - private int limit; - private Transportation.Modes mode; - private Country country; - private RequestType requestType; - private QueryBuilder prefilter; - @NonNull - private String output = ""; - @NonNull - private String distanceOutput = ""; - - public TraveltimeQueryBuilder() { - } - - public TraveltimeQueryBuilder(StreamInput in) throws IOException { - super(in); - field = in.readString(); - origin = in.readGeoPoint(); - limit = in.readInt(); - mode = in.readOptionalEnum(Transportation.Modes.class); - String c = in.readOptionalString(); - if(c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); - requestType = in.readOptionalEnum(RequestType.class); - prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); - output = in.readString(); - distanceOutput = in.readString(); - } - - @Override - protected void doWriteTo(StreamOutput out) throws IOException { - out.writeString(field); - out.writeGeoPoint(origin); - out.writeInt(limit); - out.writeOptionalEnum(mode); - out.writeOptionalString(country == null ? null : country.getValue()); - out.writeOptionalEnum(requestType); - out.writeOptionalNamedWriteable(prefilter); - out.writeString(output); - out.writeString(distanceOutput); - } - - @Override - protected void doXContent(XContentBuilder builder, Params params) throws IOException { - builder.field("field", field); - builder.field("origin", origin); - builder.field("limit", limit); - builder.field("mode", mode == null ? null : mode.getValue()); - builder.field("country", country == null ? null : country.getValue()); - builder.field("requestType", requestType == null ? null : requestType.name()); - builder.field("prefilter", prefilter); - builder.field("output", output); - builder.field("distanceOutput", distanceOutput); - } - - @Override - protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { - if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); - return super.doRewrite(queryRewriteContext); - } - - @Override - protected Query doToQuery(SearchExecutionContext context) throws IOException { - MappedFieldType originMapping = context.getFieldType(field); - if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { - throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); - } - - GeoUtils.normalizePoint(origin); - if (!GeoUtils.isValidLatitude(origin.getLat())) { - throw new QueryShardException(context, "latitude invalid for origin " + origin); - } - if (!GeoUtils.isValidLongitude(origin.getLon())) { - throw new QueryShardException(context, "longitude invalid for origin " + origin); - } - - URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); - String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); - String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); - if (appId.isEmpty()) { - throw new IllegalStateException("Traveltime app id must be set in the config"); - } - if (apiKey.isEmpty()) { - throw new IllegalStateException("Traveltime api key must be set in the config"); - } - - Optional defaultMode = TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); - Optional defaultCountry = TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); - Optional defaultRequestType = TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); - - Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); - - boolean includeDistance = !distanceOutput.isEmpty(); - - TraveltimeQueryParameters params = new TraveltimeQueryParameters(field, originCoord, limit, mode, country, requestType, includeDistance); - if (params.getMode() == null) { - if (defaultMode.isPresent()) { - params = params.withMode(defaultMode.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); - } + @NonNull private String field; + @NonNull private GeoPoint origin; + private int limit; + private Transportation.Modes mode; + private Country country; + private RequestType requestType; + private QueryBuilder prefilter; + @NonNull private String output = ""; + @NonNull private String distanceOutput = ""; + + public TraveltimeQueryBuilder() {} + + public TraveltimeQueryBuilder(StreamInput in) throws IOException { + super(in); + field = in.readString(); + origin = in.readGeoPoint(); + limit = in.readInt(); + mode = in.readOptionalEnum(Transportation.Modes.class); + String c = in.readOptionalString(); + if (c != null) country = Util.findCountryByName(c).orElseGet(() -> new Country.Custom(c)); + requestType = in.readOptionalEnum(RequestType.class); + prefilter = in.readOptionalNamedWriteable(QueryBuilder.class); + output = in.readString(); + distanceOutput = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(field); + out.writeGeoPoint(origin); + out.writeInt(limit); + out.writeOptionalEnum(mode); + out.writeOptionalString(country == null ? null : country.getValue()); + out.writeOptionalEnum(requestType); + out.writeOptionalNamedWriteable(prefilter); + out.writeString(output); + out.writeString(distanceOutput); + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) throws IOException { + builder.field("field", field); + builder.field("origin", origin); + builder.field("limit", limit); + builder.field("mode", mode == null ? null : mode.getValue()); + builder.field("country", country == null ? null : country.getValue()); + builder.field("requestType", requestType == null ? null : requestType.name()); + builder.field("prefilter", prefilter); + builder.field("output", output); + builder.field("distanceOutput", distanceOutput); + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + if (this.prefilter != null) this.prefilter = this.prefilter.rewrite(queryRewriteContext); + return super.doRewrite(queryRewriteContext); + } + + @Override + protected Query doToQuery(SearchExecutionContext context) throws IOException { + MappedFieldType originMapping = context.getFieldType(field); + if (!(originMapping instanceof GeoPointFieldMapper.GeoPointFieldType)) { + throw new QueryShardException(context, "field [" + field + "] is not a geo_point field"); + } + + GeoUtils.normalizePoint(origin); + if (!GeoUtils.isValidLatitude(origin.getLat())) { + throw new QueryShardException(context, "latitude invalid for origin " + origin); + } + if (!GeoUtils.isValidLongitude(origin.getLon())) { + throw new QueryShardException(context, "longitude invalid for origin " + origin); + } + + URI appUri = TraveltimePlugin.API_URI.get(context.getIndexSettings().getSettings()); + String appId = TraveltimePlugin.APP_ID.get(context.getIndexSettings().getSettings()); + String apiKey = TraveltimePlugin.API_KEY.get(context.getIndexSettings().getSettings()); + if (appId.isEmpty()) { + throw new IllegalStateException("Traveltime app id must be set in the config"); + } + if (apiKey.isEmpty()) { + throw new IllegalStateException("Traveltime api key must be set in the config"); + } + + Optional defaultMode = + TraveltimePlugin.DEFAULT_MODE.get(context.getIndexSettings().getSettings()); + Optional defaultCountry = + TraveltimePlugin.DEFAULT_COUNTRY.get(context.getIndexSettings().getSettings()); + Optional defaultRequestType = + TraveltimePlugin.DEFAULT_REQUEST_TYPE.get(context.getIndexSettings().getSettings()); + + Coordinates originCoord = Coordinates.builder().lat(origin.lat()).lng(origin.getLon()).build(); + + boolean includeDistance = !distanceOutput.isEmpty(); + + TraveltimeQueryParameters params = + new TraveltimeQueryParameters( + field, originCoord, limit, mode, country, requestType, includeDistance); + if (params.getMode() == null) { + if (defaultMode.isPresent()) { + params = params.withMode(defaultMode.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'mode' field to be present or a default mode to be" + + " set in the config"); } - if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { - throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException( + "Traveltime query with distance output cannot be used with public transportation mode"); + } + if (params.getCountry() == null) { + if (defaultCountry.isPresent()) { + params = params.withCountry(defaultCountry.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'country' field to be present or a default country to" + + " be set in the config"); } - if (params.getCountry() == null) { - if (defaultCountry.isPresent()) { - params = params.withCountry(defaultCountry.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'country' field to be present or a default country to be set in the config"); - } + } + if (params.getRequestType() == null) { + if (defaultRequestType.isPresent()) { + params = params.withRequestType(defaultRequestType.get()); + } else { + throw new IllegalStateException( + "Traveltime query requires either 'requestType' field to be present or a default" + + " request type to be set in the config"); } - if(params.getRequestType() == null) { - if(defaultRequestType.isPresent()) { - params = params.withRequestType(defaultRequestType.get()); - } else { - throw new IllegalStateException("Traveltime query requires either 'requestType' field to be present or a default request type to be set in the config"); - } - - } - if (params.getLimit() <= 0) { - throw new IllegalStateException("Traveltime limit must be greater than zero"); - } - - Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - - return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); - } - - @Override - protected boolean doEquals(TraveltimeQueryBuilder other) { - if (!Objects.equals(this.field, other.field)) return false; - if (!Objects.equals(this.origin, other.origin)) return false; - if (!Objects.equals(this.mode, other.mode)) return false; - if (!Objects.equals(this.country, other.country)) return false; - if (!Objects.equals(this.prefilter, other.prefilter)) return false; - if (!Objects.equals(this.output, other.output)) return false; - return this.limit == other.limit; - } - - @Override - protected int doHashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + this.field.hashCode(); - result = result * PRIME + this.origin.hashCode(); - result = result * PRIME + Objects.hashCode(this.mode); - result = result * PRIME + Objects.hashCode(this.country); - result = result * PRIME + Objects.hashCode(this.prefilter); - result = result * PRIME + Objects.hashCode(this.output); - result = result * PRIME + this.limit; - return result; - } - - @Override - public String getWriteableName() { - return TraveltimeQueryParser.NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersion.MINIMUM_COMPATIBLE; - } - - public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { - return AbstractQueryBuilder.parseInnerQueryBuilder(parser); - } - - + } + if (params.getLimit() <= 0) { + throw new IllegalStateException("Traveltime limit must be greater than zero"); + } + + Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; + + return new TraveltimeSearchQuery( + params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); + } + + @Override + protected boolean doEquals(TraveltimeQueryBuilder other) { + if (!Objects.equals(this.field, other.field)) return false; + if (!Objects.equals(this.origin, other.origin)) return false; + if (!Objects.equals(this.mode, other.mode)) return false; + if (!Objects.equals(this.country, other.country)) return false; + if (!Objects.equals(this.prefilter, other.prefilter)) return false; + if (!Objects.equals(this.output, other.output)) return false; + return this.limit == other.limit; + } + + @Override + protected int doHashCode() { + final int PRIME = 59; + int result = 1; + result = result * PRIME + this.field.hashCode(); + result = result * PRIME + this.origin.hashCode(); + result = result * PRIME + Objects.hashCode(this.mode); + result = result * PRIME + Objects.hashCode(this.country); + result = result * PRIME + Objects.hashCode(this.prefilter); + result = result * PRIME + Objects.hashCode(this.output); + result = result * PRIME + this.limit; + return result; + } + + @Override + public String getWriteableName() { + return TraveltimeQueryParser.NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.MINIMUM_COMPATIBLE; + } + + public static QueryBuilder parseInnerQueryBuilder(XContentParser parser) throws IOException { + return AbstractQueryBuilder.parseInnerQueryBuilder(parser); + } } diff --git a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 6f1deee..961c908 100644 --- a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -1,6 +1,9 @@ package com.traveltime.plugin.elasticsearch.query; import com.traveltime.plugin.elasticsearch.util.Util; +import java.io.IOException; +import java.util.Optional; +import java.util.function.Function; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.geo.GeoUtils; import org.elasticsearch.index.query.QueryBuilder; @@ -10,57 +13,68 @@ import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentParser; -import java.io.IOException; -import java.util.Optional; -import java.util.function.Function; - public class TraveltimeQueryParser implements QueryParser { - public static String NAME = "traveltime"; - private final ParseField field = new ParseField("field"); - private final ParseField origin = new ParseField("origin"); - private final ParseField limit = new ParseField("limit"); - private final ParseField mode = new ParseField("mode"); - private final ParseField country = new ParseField("country"); - private final ParseField requestType = new ParseField("requestType"); - private final ParseField prefilter = new ParseField("prefilter"); - private final ParseField output = new ParseField("output"); - private final ParseField distanceOutput = new ParseField("distanceOutput"); + public static String NAME = "traveltime"; + private final ParseField field = new ParseField("field"); + private final ParseField origin = new ParseField("origin"); + private final ParseField limit = new ParseField("limit"); + private final ParseField mode = new ParseField("mode"); + private final ParseField country = new ParseField("country"); + private final ParseField requestType = new ParseField("requestType"); + private final ParseField prefilter = new ParseField("prefilter"); + private final ParseField output = new ParseField("output"); + private final ParseField distanceOutput = new ParseField("distanceOutput"); - private final ContextParser prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); + private final ContextParser prefilterParser = + (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); - private final ObjectParser queryParser = new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); + private final ObjectParser queryParser = + new ObjectParser<>(NAME, TraveltimeQueryBuilder::new); - { - queryParser.declareString(TraveltimeQueryBuilder::setField, field); - queryParser.declareField(TraveltimeQueryBuilder::setOrigin, (parser, c) -> GeoUtils.parseGeoPoint(parser), origin, ObjectParser.ValueType.VALUE_OBJECT_ARRAY); - queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); - queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); - queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); - queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); - queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); - queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); + { + queryParser.declareString(TraveltimeQueryBuilder::setField, field); + queryParser.declareField( + TraveltimeQueryBuilder::setOrigin, + (parser, c) -> GeoUtils.parseGeoPoint(parser), + origin, + ObjectParser.ValueType.VALUE_OBJECT_ARRAY); + queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit); + queryParser.declareString( + (qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), + mode); + queryParser.declareString( + (qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), + country); + queryParser.declareString( + (qb, s) -> + qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), + requestType); + queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); + queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); - queryParser.declareRequiredFieldSet(field.toString()); - queryParser.declareRequiredFieldSet(origin.toString()); - queryParser.declareRequiredFieldSet(limit.toString()); - } + queryParser.declareRequiredFieldSet(field.toString()); + queryParser.declareRequiredFieldSet(origin.toString()); + queryParser.declareRequiredFieldSet(limit.toString()); + } - private static T findByNameOrError(String what, String name, Function> finder) { - Optional result = finder.apply(name); - if (result.isEmpty()) { - throw new IllegalArgumentException(String.format("Couldn't find a %s with the name %s", what, name)); - } else { - return result.get(); - } - } + private static T findByNameOrError( + String what, String name, Function> finder) { + Optional result = finder.apply(name); + if (result.isEmpty()) { + throw new IllegalArgumentException( + String.format("Couldn't find a %s with the name %s", what, name)); + } else { + return result.get(); + } + } - @Override - public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { - try { - return queryParser.parse(parser, null); - } catch (IllegalArgumentException iae) { - throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); - } - } + @Override + public TraveltimeQueryBuilder fromXContent(XContentParser parser) throws IOException { + try { + return queryParser.parse(parser, null); + } catch (IllegalArgumentException iae) { + throw new ParsingException(parser.getTokenLocation(), iae.getMessage(), iae); + } + } } diff --git a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java index 530f5af..c55b3dc 100644 --- a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java +++ b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeScorer.java @@ -1,99 +1,103 @@ package com.traveltime.plugin.elasticsearch.query; import it.unimi.dsi.fastutil.longs.Long2IntMap; +import java.io.IOException; import lombok.RequiredArgsConstructor; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Scorer; -import java.io.IOException; - public class TraveltimeScorer extends Scorer { - protected final TraveltimeWeight weight; - private final Long2IntMap pointToTime; - private final TraveltimeFilteredDocs docs; - private final float boost; - - @RequiredArgsConstructor - private class TraveltimeFilteredDocs extends DocIdSetIterator { - private final TraveltimeWeight.FilteredIterator backing; - - private long currentValue = 0; - private boolean currentValueDirty = true; - private void invalidateCurrentValue() { - currentValueDirty = true; - } - private void advanceValue() throws IOException { - if(currentValueDirty) { - currentValue = backing.nextValue(); - currentValueDirty = false; - } - } - - public long nextValue() throws IOException { - advanceValue(); - return currentValue; + protected final TraveltimeWeight weight; + private final Long2IntMap pointToTime; + private final TraveltimeFilteredDocs docs; + private final float boost; + + @RequiredArgsConstructor + private class TraveltimeFilteredDocs extends DocIdSetIterator { + private final TraveltimeWeight.FilteredIterator backing; + + private long currentValue = 0; + private boolean currentValueDirty = true; + + private void invalidateCurrentValue() { + currentValueDirty = true; + } + + private void advanceValue() throws IOException { + if (currentValueDirty) { + currentValue = backing.nextValue(); + currentValueDirty = false; } - - @Override - public int docID() { - return backing.docID(); - } - - @Override - public int nextDoc() throws IOException { - int id = backing.nextDoc(); - invalidateCurrentValue(); - while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = backing.nextDoc(); - invalidateCurrentValue(); - } - return id; + } + + public long nextValue() throws IOException { + advanceValue(); + return currentValue; + } + + @Override + public int docID() { + return backing.docID(); + } + + @Override + public int nextDoc() throws IOException { + int id = backing.nextDoc(); + invalidateCurrentValue(); + while (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = backing.nextDoc(); + invalidateCurrentValue(); } - - @Override - public int advance(int target) throws IOException { - int id = backing.advance(target); - invalidateCurrentValue(); - if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { - id = nextDoc(); - } - return id; - } - - @Override - public long cost() { - return backing.cost() * 1000; + return id; + } + + @Override + public int advance(int target) throws IOException { + int id = backing.advance(target); + invalidateCurrentValue(); + if (id != DocIdSetIterator.NO_MORE_DOCS && !pointToTime.containsKey(nextValue())) { + id = nextDoc(); } - } - - public TraveltimeScorer(TraveltimeWeight w, Long2IntMap coordToTime, TraveltimeWeight.FilteredIterator docs, float boost) { - super(w); - this.weight = w; - this.pointToTime = coordToTime; - this.docs = new TraveltimeFilteredDocs(docs); - this.boost = boost; - } - - @Override - public DocIdSetIterator iterator() { - return docs; - } - - @Override - public float getMaxScore(int upTo) { - return 1; - } - - @Override - public float score() throws IOException { - int limit = weight.getTtQuery().getParams().getLimit(); - int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); - return (boost * (limit - tt + 1)) / (limit + 1); - - } - - @Override - public int docID() { - return docs.docID(); - } + return id; + } + + @Override + public long cost() { + return backing.cost() * 1000; + } + } + + public TraveltimeScorer( + TraveltimeWeight w, + Long2IntMap coordToTime, + TraveltimeWeight.FilteredIterator docs, + float boost) { + super(w); + this.weight = w; + this.pointToTime = coordToTime; + this.docs = new TraveltimeFilteredDocs(docs); + this.boost = boost; + } + + @Override + public DocIdSetIterator iterator() { + return docs; + } + + @Override + public float getMaxScore(int upTo) { + return 1; + } + + @Override + public float score() throws IOException { + int limit = weight.getTtQuery().getParams().getLimit(); + int tt = pointToTime.getOrDefault(docs.nextValue(), limit + 1); + return (boost * (limit - tt + 1)) / (limit + 1); + } + + @Override + public int docID() { + return docs.docID(); + } } diff --git a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 0bc37e5..be0b0da 100644 --- a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -1,51 +1,53 @@ package com.traveltime.plugin.elasticsearch.query; +import java.io.IOException; +import java.net.URI; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import lombok.Getter; import org.apache.lucene.search.*; -import java.io.IOException; -import java.net.URI; - @AllArgsConstructor @EqualsAndHashCode(callSuper = false) @Getter public class TraveltimeSearchQuery extends Query { - private final TraveltimeQueryParameters params; - private final Query prefilter; - private final String output; - private final String distanceOutput; - private final URI appUri; - private final String appId; - private final String apiKey; + private final TraveltimeQueryParameters params; + private final Query prefilter; + private final String output; + private final String distanceOutput; + private final URI appUri; + private final String appId; + private final String apiKey; - @Override - public void visit(QueryVisitor visitor) { - if (prefilter != null) { - prefilter.visit(visitor); - } - visitor.visitLeaf(this); - } + @Override + public void visit(QueryVisitor visitor) { + if (prefilter != null) { + prefilter.visit(visitor); + } + visitor.visitLeaf(this); + } - @Override - public String toString(String field) { - return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); - } + @Override + public String toString(String field) { + return String.format("TraveltimeSearchQuery(params = %s, prefilter = %s)", params, prefilter); + } - @Override - public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - Weight prefilterWeight = prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; - return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); - } + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) + throws IOException { + Weight prefilterWeight = + prefilter != null ? prefilter.createWeight(searcher, scoreMode, boost) : null; + return new TraveltimeWeight(this, prefilterWeight, !output.isEmpty(), boost); + } - @Override - public Query rewrite(IndexSearcher reader) throws IOException { - Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; - if (newPrefilter == prefilter) { - return super.rewrite(reader); - } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); - } - } + @Override + public Query rewrite(IndexSearcher reader) throws IOException { + Query newPrefilter = prefilter != null ? prefilter.rewrite(reader) : null; + if (newPrefilter == prefilter) { + return super.rewrite(reader); + } else { + return new TraveltimeSearchQuery( + params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); + } + } } diff --git a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index 7f365e8..37ca206 100644 --- a/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.9/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -8,6 +8,9 @@ import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.LongArrayList; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; @@ -19,154 +22,151 @@ import org.apache.lucene.search.*; import org.elasticsearch.SpecialPermission; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - @EqualsAndHashCode(callSuper = false) public class TraveltimeWeight extends Weight { - @Getter - private final TraveltimeSearchQuery ttQuery; - - private final Weight prefilter; - - private final boolean hasOutput; - - private final float boost; - - private final Logger log = LogManager.getLogger(); - - @EqualsAndHashCode.Exclude - private final ProtoFetcher protoFetcher; - - public TraveltimeWeight(TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { - super(q); - ttQuery = q; - this.prefilter = prefilter; - this.hasOutput = hasOutput; - this.boost = boost; - protoFetcher = FetcherSingleton.INSTANCE.getFetcher(q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); - } - - @Override - public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.noMatch("Cannot provide explanation for traveltime matches"); - } - - @RequiredArgsConstructor - public static class FilteredIterator { - private final SortedNumericDocValues values; - private final DocIdSetIterator filtered; - - public long nextValue() throws IOException { - return this.values.nextValue(); + @Getter private final TraveltimeSearchQuery ttQuery; + + private final Weight prefilter; + + private final boolean hasOutput; + + private final float boost; + + private final Logger log = LogManager.getLogger(); + + @EqualsAndHashCode.Exclude private final ProtoFetcher protoFetcher; + + public TraveltimeWeight( + TraveltimeSearchQuery q, Weight prefilter, boolean hasOutput, float boost) { + super(q); + ttQuery = q; + this.prefilter = prefilter; + this.hasOutput = hasOutput; + this.boost = boost; + protoFetcher = + FetcherSingleton.INSTANCE.getFetcher( + q.getAppUri(), q.getAppId(), q.getApiKey(), SpecialPermission::new); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) { + return Explanation.noMatch("Cannot provide explanation for traveltime matches"); + } + + @RequiredArgsConstructor + public static class FilteredIterator { + private final SortedNumericDocValues values; + private final DocIdSetIterator filtered; + + public long nextValue() throws IOException { + return this.values.nextValue(); + } + + public int docID() { + return this.filtered.docID(); + } + + public int nextDoc() throws IOException { + return this.filtered.nextDoc(); + } + + public int advance(int target) throws IOException { + return this.filtered.advance(target); + } + + public long cost() { + return this.filtered.cost(); + } + } + + private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { + val reader = context.reader(); + val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + + DocIdSetIterator finalIterator; + + if (prefilter != null) { + val preScorer = prefilter.scorer(context); + if (preScorer == null) return null; + val prefilterIterator = preScorer.iterator(); + finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); + } else { + finalIterator = backing; + } + + return new FilteredIterator(backing, finalIterator); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + val backing = filteredValues(context); + if (backing == null) return null; + + val valueArray = new LongArrayList(); + val decodedArray = new ArrayList(); + val valueSet = new LongOpenHashSet(); + + while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + long encodedCoords = backing.nextValue(); + if (valueSet.add(encodedCoords)) { + valueArray.add(encodedCoords); + decodedArray.add(Util.decode(encodedCoords)); } + } - public int docID() { - return this.filtered.docID(); - } + val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - public int nextDoc() throws IOException { - return this.filtered.nextDoc(); - } + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - public int advance(int target) throws IOException { - return this.filtered.advance(target); - } - - public long cost() { - return this.filtered.cost(); - } - } + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - private FilteredIterator filteredValues(LeafReaderContext context) throws IOException { - val reader = context.reader(); - val backing = reader.getSortedNumericDocValues(ttQuery.getParams().getField()); + val timeDistance = + protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); - DocIdSetIterator finalIterator; + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); - if (prefilter != null) { - val preScorer = prefilter.scorer(context); - if(preScorer == null) return null; - val prefilterIterator = preScorer.iterator(); - finalIterator = ConjunctionUtils.intersectIterators(List.of(prefilterIterator, backing)); - } else { - finalIterator = backing; + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } } - return new FilteredIterator(backing, finalIterator); - } - - @Override - public Scorer scorer(LeafReaderContext context) throws IOException { - val backing = filteredValues(context); - if (backing == null) return null; - - val valueArray = new LongArrayList(); - val decodedArray = new ArrayList(); - val valueSet = new LongOpenHashSet(); - - while (backing.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { - long encodedCoords = backing.nextValue(); - if(valueSet.add(encodedCoords)) { - valueArray.add(encodedCoords); - decodedArray.add(Util.decode(encodedCoords)); - } + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = + protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType()); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } + } - val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - - if (ttQuery.getParams().isIncludeDistance()) { - val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); - - val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); - - val timeDistance = protoFetcher.getTimesAndDistances( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - mode, - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - val times = timeDistance.getLeft(); - val distances = timeDistance.getRight(); - - for (int index = 0; index < times.size(); index++) { - if (times.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); - pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); - } - } - - TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); - } else { - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if (results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); - } - } - } - - if(hasOutput) { - TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); - } + if (hasOutput) { + TraveltimeCache.INSTANCE.add(ttQuery.getParams(), pointToTime); + } - return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); - } + return new TraveltimeScorer(this, pointToTime, filteredValues(context), boost); + } - @Override - public boolean isCacheable(LeafReaderContext ctx) { - return true; - } + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return true; + } }