Skip to content

Commit

Permalink
Apply patch to 8.13
Browse files Browse the repository at this point in the history
  • Loading branch information
mjanuszkiewicz-tt committed Aug 29, 2024
1 parent 2207d57 commit ce5fbe0
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class TraveltimePlugin extends Plugin implements SearchPlugin {

private void cleanUpAndReschedule(ThreadPool threadPool, TimeValue cleanupSeconds) {
TraveltimeCache.INSTANCE.cleanUp();
TraveltimeCache.DISTANCE.cleanUp();
threadPool.scheduleUnlessShuttingDown(cleanupSeconds, threadPool.generic(), () -> cleanUpAndReschedule(threadPool, cleanupSeconds));
}

Expand All @@ -60,6 +61,7 @@ public Collection<?> createComponents(PluginServices pluginServices) {
Integer cacheSize = CACHE_SIZE.get(pluginServices.environment().settings());

TraveltimeCache.INSTANCE.setUp(cacheSize, cacheExpiry);
TraveltimeCache.DISTANCE.setUp(cacheSize, cacheExpiry);
cleanUpAndReschedule(pluginServices.threadPool(), cleanupSeconds);

return super.createComponents(pluginServices);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) {
if (traveltimeQuery == null) return null;
TraveltimeQueryParameters params = traveltimeQuery.getParams();
final String output = traveltimeQuery.getOutput();
final String distanceOutput = traveltimeQuery.getDistanceOutput();

FieldFetcher fieldFetcher = FieldFetcher.create(fetchContext.getSearchExecutionContext(), List.of(new FieldAndFormat(params.getField(), null)));

Expand All @@ -61,10 +62,19 @@ public void setNextReader(LeafReaderContext readerContext) {
public void process(HitContext hitContext) throws IOException {
val docValues = hitContext.reader().getSortedNumericDocValues(params.getField());
docValues.advance(hitContext.docId());
Integer tt = TraveltimeCache.INSTANCE.get(params, docValues.nextValue());
val point = docValues.nextValue();
if (!output.isEmpty()) {
Integer tt = TraveltimeCache.INSTANCE.get(params, point);
if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
}
}

if (tt >= 0) {
hitContext.hit().setDocumentField(output, new DocumentField(output, List.of(tt)));
if (!distanceOutput.isEmpty()) {
Integer td = TraveltimeCache.DISTANCE.get(params, point);
if (td >= 0) {
hitContext.hit().setDocumentField(distanceOutput, new DocumentField(distanceOutput, List.of(td)));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {
throw new IllegalStateException("Traveltime query requires either 'mode' field to be present or a default mode to be set in the config");
}
}
if(params.isIncludeDistance() && !Util.canUseDistance(params.getMode())) {
throw new IllegalStateException("Traveltime query with distance output cannot be used with public transportation mode");
}
if (params.getCountry() == null) {
if (defaultCountry.isPresent()) {
params = params.withCountry(defaultCountry.get());
Expand All @@ -152,7 +155,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException {

Query prefilterQuery = prefilter != null ? prefilter.toQuery(context) : null;

return new TraveltimeSearchQuery(params, prefilterQuery, output, appUri, appId, apiKey);
return new TraveltimeSearchQuery(params, prefilterQuery, output, distanceOutput, appUri, appId, apiKey);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class TraveltimeQueryParser implements QueryParser<TraveltimeQueryBuilder
private final ParseField requestType = new ParseField("requestType");
private final ParseField prefilter = new ParseField("prefilter");
private final ParseField output = new ParseField("output");
private final ParseField distanceOutput = new ParseField("distance_output");

private final ContextParser<Void, QueryBuilder> prefilterParser = (p, c) -> TraveltimeQueryBuilder.parseInnerQueryBuilder(p);

Expand All @@ -35,9 +36,10 @@ public class TraveltimeQueryParser implements QueryParser<TraveltimeQueryBuilder
queryParser.declareInt(TraveltimeQueryBuilder::setLimit, limit);
queryParser.declareString((qb, s) -> qb.setMode(findByNameOrError("transportation mode", s, Util::findModeByName)), mode);
queryParser.declareString((qb, s) -> qb.setCountry(findByNameOrError("country", s, Util::findCountryByName)), country);
queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("country", s, Util::findRequestTypeByName)), requestType);
queryParser.declareString((qb, s) -> qb.setRequestType(findByNameOrError("request mode", s, Util::findRequestTypeByName)), requestType);
queryParser.declareObject(TraveltimeQueryBuilder::setPrefilter, prefilterParser, prefilter);
queryParser.declareString(TraveltimeQueryBuilder::setOutput, output);
queryParser.declareString(TraveltimeQueryBuilder::setDistanceOutput, distanceOutput);

queryParser.declareRequiredFieldSet(field.toString());
queryParser.declareRequiredFieldSet(origin.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ public class TraveltimeSearchQuery extends Query {
private final TraveltimeQueryParameters params;
private final Query prefilter;
private final String output;
private final String distanceOutput;
private final URI appUri;
private final String appId;
private final String apiKey;
Expand Down Expand Up @@ -44,7 +45,7 @@ public Query rewrite(IndexSearcher reader) throws IOException {
if (newPrefilter == prefilter) {
return super.rewrite(reader);
} else {
return new TraveltimeSearchQuery(params, newPrefilter, output, appUri, appId, apiKey);
return new TraveltimeSearchQuery(params, newPrefilter, output, distanceOutput, appUri, appId, apiKey);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,45 @@ public Scorer scorer(LeafReaderContext context) throws IOException {

val pointToTime = new Long2IntOpenHashMap(valueArray.size());

val results = protoFetcher.getTimes(
ttQuery.getParams().getOrigin(),
decodedArray,
ttQuery.getParams().getLimit(),
ttQuery.getParams().getMode(),
ttQuery.getParams().getCountry(),
ttQuery.getParams().getRequestType()
);

for (int index = 0; index < results.size(); index++) {
if(results.get(index) >= 0) {
pointToTime.put(valueArray.getLong(index), results.get(index).intValue());
if (ttQuery.getParams().isIncludeDistance()) {
val pointToDistance = new Long2IntOpenHashMap(valueArray.size());

val mode = Util.unsafeCastToDistanceTransportation(ttQuery.getParams().getMode());

val timeDistance = protoFetcher.getTimesAndDistances(
ttQuery.getParams().getOrigin(),
decodedArray,
ttQuery.getParams().getLimit(),
mode,
ttQuery.getParams().getCountry(),
ttQuery.getParams().getRequestType()
);

val times = timeDistance.getLeft();
val distances = timeDistance.getRight();

for (int index = 0; index < times.size(); index++) {
if (times.get(index) >= 0) {
pointToTime.put(valueArray.getLong(index), times.get(index).intValue());
pointToDistance.put(valueArray.getLong(index), distances.get(index).intValue());
}
}

TraveltimeCache.DISTANCE.add(ttQuery.getParams(), pointToDistance);
} else {
val results = protoFetcher.getTimes(
ttQuery.getParams().getOrigin(),
decodedArray,
ttQuery.getParams().getLimit(),
ttQuery.getParams().getMode(),
ttQuery.getParams().getCountry(),
ttQuery.getParams().getRequestType()
);

for (int index = 0; index < results.size(); index++) {
if (results.get(index) >= 0) {
pointToTime.put(valueArray.getLong(index), results.get(index).intValue());
}
}
}

Expand Down

0 comments on commit ce5fbe0

Please sign in to comment.