diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java index d0c837d..1153fc8 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/TraveltimePlugin.java @@ -50,6 +50,7 @@ public class TraveltimePlugin extends Plugin implements SearchPlugin { private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) { TraveltimeCache.INSTANCE.cleanUp(); + TraveltimeCache.DISTANCE.cleanUp(); threadPool.scheduleUnlessShuttingDown(cleanupSeconds, threadPool.generic(), () -> cleanUpAndReschedule(threadPool, cleanupSeconds)); } @@ -60,6 +61,7 @@ public Collection createComponents(PluginServices pluginServices) { Integer cacheSize = CACHE_SIZE.get(pluginServices.environment().settings()); TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry); + TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry); cleanUpAndReschedule(pluginServices.threadPool(), cleanupSeconds); return super.createComponents(pluginServices); diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java index 019a235..2230825 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeFetchPhase.java @@ -47,6 +47,7 @@ public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) { if (traveltimeQuery == null) return null; TraveltimeQueryParameters params = traveltimeQuery.getParams(); final String output = traveltimeQuery.getOutput(); + final String distanceOutput = traveltimeQuery.getDistanceOutput(); FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null))); @@ -61,10 +62,19 @@ public void setNextReader(LeafReaderContext readerContext) { public void process(HitContext hitContext) throws IOException { val docValues = hitContext.reader().getSortedNumericDocValues(params.getField()); docValues.advance(hitContext.docId()); - Integer tt = TraveltimeCache.INSTANCE.get(params, docValues.nextValue()); + val point = docValues.nextValue(); + if (!output.isEmpty()) { + Integer tt = TraveltimeCache.INSTANCE.get(params, point); + if (tt >= 0) { + hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + } + } - if (tt >= 0) { - hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt))); + if (!distanceOutput.isEmpty()) { + Integer td = TraveltimeCache.DISTANCE.get(params, point); + if (td >= 0) { + hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td))); + } } } diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java index 4137a2d..eb28878 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryBuilder.java @@ -131,6 +131,9 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config"); } } + if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) { + throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode"); + } if (params.getCountry() == null) { if (defaultCountry.isPresent()) { params = params.withCountry(defaultCountry.get()); @@ -152,7 +155,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null; - return new TraveltimeSearchQuery(params, prefilterQuery, output, appUri, appId, apiKey); + return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey); } @Override diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java index 1705d39..ac2a6ab 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeQueryParser.java @@ -24,6 +24,7 @@ public class TraveltimeQueryParser implements QueryParser prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p); @@ -35,9 +36,10 @@ public class TraveltimeQueryParser implements QueryParser qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode); queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country); - queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("country", s, Util::findRequestTypeByName)), requestType); + queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType); queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter); queryParser.declareString(TraveltimeQueryBuilder::setOutput, output); + queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput); queryParser.declareRequiredFieldSet(field.toString()); queryParser.declareRequiredFieldSet(origin.toString()); diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java index 0ee3888..0bc37e5 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeSearchQuery.java @@ -15,6 +15,7 @@ public class TraveltimeSearchQuery extends Query { private final TraveltimeQueryParameters params; private final Query prefilter; private final String output; + private final String distanceOutput; private final URI appUri; private final String appId; private final String apiKey; @@ -44,7 +45,7 @@ public Query rewrite(IndexSearcher reader) throws IOException { if (newPrefilter == prefilter) { return super.rewrite(reader); } else { - return new TraveltimeSearchQuery(params, newPrefilter, output, appUri, appId, apiKey); + return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey); } } } diff --git a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java index a990cba..7f365e8 100644 --- a/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java +++ b/8.13/src/main/java/com/traveltime/plugin/elasticsearch/query/TraveltimeWeight.java @@ -116,18 +116,45 @@ public Scorer scorer(LeafReaderContext context) throws IOException { val pointToTime = new Long2IntOpenHashMap(valueArray.size()); - val results = protoFetcher.getTimes( - ttQuery.getParams().getOrigin(), - decodedArray, - ttQuery.getParams().getLimit(), - ttQuery.getParams().getMode(), - ttQuery.getParams().getCountry(), - ttQuery.getParams().getRequestType() - ); - - for (int index = 0; index < results.size(); index++) { - if(results.get(index) >= 0) { - pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + if (ttQuery.getParams().isIncludeDistance()) { + val pointToDistance = new Long2IntOpenHashMap(valueArray.size()); + + val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode()); + + val timeDistance = protoFetcher.getTimesAndDistances( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + mode, + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType() + ); + + val times = timeDistance.getLeft(); + val distances = timeDistance.getRight(); + + for (int index = 0; index < times.size(); index++) { + if (times.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), times.get(index).intValue()); + pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue()); + } + } + + TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance); + } else { + val results = protoFetcher.getTimes( + ttQuery.getParams().getOrigin(), + decodedArray, + ttQuery.getParams().getLimit(), + ttQuery.getParams().getMode(), + ttQuery.getParams().getCountry(), + ttQuery.getParams().getRequestType() + ); + + for (int index = 0; index < results.size(); index++) { + if (results.get(index) >= 0) { + pointToTime.put(valueArray.getLong(index), results.get(index).intValue()); + } } }