Skip to content

Commit

Permalink
Make use of ActionListener#delegateFailureAndWrap in more spots in ML…
Browse files Browse the repository at this point in the history
… codebase (elastic#105882)

Found a bunch more spots where this shortcut helps save both memory and
brainpower for thinking through potential leaks.
=> made use of it and sometimes also inlined a couple local variables
for readability.
  • Loading branch information
original-brownbear authored Mar 5, 2024
1 parent 93fd12d commit ca10472
Show file tree
Hide file tree
Showing 22 changed files with 261 additions and 270 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ public void putDatafeed(
final RoleDescriptor.IndicesPrivileges.Builder indicesPrivilegesBuilder = RoleDescriptor.IndicesPrivileges.builder()
.indices(indices);

ActionListener<HasPrivilegesResponse> privResponseListener = ActionListener.wrap(
r -> handlePrivsResponse(username, request, r, state, threadPool, listener),
listener::onFailure
ActionListener<HasPrivilegesResponse> privResponseListener = listener.delegateFailureAndWrap(
(l, r) -> handlePrivsResponse(username, request, r, state, threadPool, l)
);

ActionListener<GetRollupIndexCapsAction.Response> getRollupIndexCapsActionHandler = ActionListener.wrap(response -> {
Expand Down Expand Up @@ -173,15 +172,14 @@ public void getDatafeeds(
request.getDatafeedId(),
request.allowNoMatch(),
parentTaskId,
ActionListener.wrap(
datafeedBuilders -> listener.onResponse(
listener.delegateFailureAndWrap(
(l, datafeedBuilders) -> l.onResponse(
new QueryPage<>(
datafeedBuilders.stream().map(DatafeedConfig.Builder::build).collect(Collectors.toList()),
datafeedBuilders.size(),
DatafeedConfig.RESULTS_FIELD
)
),
listener::onFailure
)
)
);
}
Expand Down Expand Up @@ -222,10 +220,7 @@ public void updateDatafeed(
request.getUpdate(),
headers,
jobConfigProvider::validateDatafeedJob,
ActionListener.wrap(
updatedConfig -> listener.onResponse(new PutDatafeedAction.Response(updatedConfig)),
listener::onFailure
)
listener.delegateFailureAndWrap((l, updatedConfig) -> l.onResponse(new PutDatafeedAction.Response(updatedConfig)))
);
});

Expand Down Expand Up @@ -254,19 +249,18 @@ public void deleteDatafeed(DeleteDatafeedAction.Request request, ClusterState st

String datafeedId = request.getDatafeedId();

datafeedConfigProvider.getDatafeedConfig(datafeedId, null, ActionListener.wrap(datafeedConfigBuilder -> {
datafeedConfigProvider.getDatafeedConfig(datafeedId, null, listener.delegateFailureAndWrap((delegate, datafeedConfigBuilder) -> {
String jobId = datafeedConfigBuilder.build().getJobId();
JobDataDeleter jobDataDeleter = new JobDataDeleter(client, jobId);
jobDataDeleter.deleteDatafeedTimingStats(
ActionListener.wrap(
unused1 -> datafeedConfigProvider.deleteDatafeedConfig(
delegate.delegateFailureAndWrap(
(l, unused1) -> datafeedConfigProvider.deleteDatafeedConfig(
datafeedId,
ActionListener.wrap(unused2 -> listener.onResponse(AcknowledgedResponse.TRUE), listener::onFailure)
),
listener::onFailure
l.delegateFailureAndWrap((ll, unused2) -> ll.onResponse(AcknowledgedResponse.TRUE))
)
)
);
}, listener::onFailure));
}));

}

Expand Down Expand Up @@ -316,7 +310,7 @@ private void putDatafeed(
CheckedConsumer<Boolean, Exception> mappingsUpdated = ok -> datafeedConfigProvider.putDatafeedConfig(
request.getDatafeed(),
headers,
ActionListener.wrap(response -> listener.onResponse(new PutDatafeedAction.Response(response.v1())), listener::onFailure)
listener.delegateFailureAndWrap((l, response) -> l.onResponse(new PutDatafeedAction.Response(response.v1())))
);

CheckedConsumer<Boolean, Exception> validationOk = ok -> {
Expand Down Expand Up @@ -345,16 +339,19 @@ private void putDatafeed(
}

private void checkJobDoesNotHaveADatafeed(String jobId, ActionListener<Boolean> listener) {
datafeedConfigProvider.findDatafeedIdsForJobIds(Collections.singletonList(jobId), ActionListener.wrap(datafeedIds -> {
if (datafeedIds.isEmpty()) {
listener.onResponse(Boolean.TRUE);
} else {
listener.onFailure(
ExceptionsHelper.conflictStatusException(
"A datafeed [" + datafeedIds.iterator().next() + "] already exists for job [" + jobId + "]"
)
);
}
}, listener::onFailure));
datafeedConfigProvider.findDatafeedIdsForJobIds(
Collections.singletonList(jobId),
listener.delegateFailureAndWrap((delegate, datafeedIds) -> {
if (datafeedIds.isEmpty()) {
delegate.onResponse(Boolean.TRUE);
} else {
delegate.onFailure(
ExceptionsHelper.conflictStatusException(
"A datafeed [" + datafeedIds.iterator().next() + "] already exists for job [" + jobId + "]"
)
);
}
})
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,12 @@ static void create(
) {
final boolean hasAggs = datafeed.hasAggregations();
final boolean isComposite = hasAggs && datafeed.hasCompositeAgg(xContentRegistry);
ActionListener<DataExtractorFactory> factoryHandler = ActionListener.wrap(
factory -> listener.onResponse(
ActionListener<DataExtractorFactory> factoryHandler = listener.delegateFailureAndWrap(
(l, factory) -> l.onResponse(
datafeed.getChunkingConfig().isEnabled()
? new ChunkedDataExtractorFactory(datafeed, job, xContentRegistry, factory)
: factory
),
listener::onFailure
)
);

ActionListener<GetRollupIndexCapsAction.Response> getRollupIndexCapsActionHandler = ActionListener.wrap(response -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ public void findDatafeedIdsForJobIds(Collection<String> jobIds, ActionListener<S
client.threadPool().getThreadContext(),
ML_ORIGIN,
searchRequest,
ActionListener.<SearchResponse>wrap(response -> {
listener.<SearchResponse>delegateFailureAndWrap((delegate, response) -> {
Set<String> datafeedIds = new HashSet<>();
// There cannot be more than one datafeed per job
assert response.getHits().getTotalHits().value <= jobIds.size();
Expand All @@ -233,8 +233,8 @@ public void findDatafeedIdsForJobIds(Collection<String> jobIds, ActionListener<S
datafeedIds.add(hit.field(DatafeedConfig.ID.getPreferredName()).getValue());
}

listener.onResponse(datafeedIds);
}, listener::onFailure),
delegate.onResponse(datafeedIds);
}),
client::search
);
}
Expand All @@ -256,7 +256,7 @@ public void findDatafeedsByJobIds(
client.threadPool().getThreadContext(),
ML_ORIGIN,
searchRequest,
ActionListener.<SearchResponse>wrap(response -> {
listener.<SearchResponse>delegateFailureAndWrap((delegate, response) -> {
Map<String, DatafeedConfig.Builder> datafeedsByJobId = new HashMap<>();
// There cannot be more than one datafeed per job
assert response.getHits().getTotalHits().value <= jobIds.size();
Expand All @@ -265,8 +265,8 @@ public void findDatafeedsByJobIds(
DatafeedConfig.Builder builder = parseLenientlyFromSource(hit.getSourceRef());
datafeedsByJobId.put(builder.getJobId(), builder);
}
listener.onResponse(datafeedsByJobId);
}, listener::onFailure),
delegate.onResponse(datafeedsByJobId);
}),
client::search
);
}
Expand Down Expand Up @@ -440,7 +440,7 @@ public void expandDatafeedIds(
client.threadPool().getThreadContext(),
ML_ORIGIN,
searchRequest,
ActionListener.<SearchResponse>wrap(response -> {
listener.<SearchResponse>delegateFailureAndWrap((delegate, response) -> {
SortedSet<String> datafeedIds = new TreeSet<>();
SearchHit[] hits = response.getHits().getHits();
for (SearchHit hit : hits) {
Expand All @@ -453,12 +453,12 @@ public void expandDatafeedIds(
requiredMatches.filterMatchedIds(datafeedIds);
if (requiredMatches.hasUnmatchedIds()) {
// some required datafeeds were not found
listener.onFailure(ExceptionsHelper.missingDatafeedException(requiredMatches.unmatchedIdsString()));
delegate.onFailure(ExceptionsHelper.missingDatafeedException(requiredMatches.unmatchedIdsString()));
return;
}

listener.onResponse(datafeedIds);
}, listener::onFailure),
delegate.onResponse(datafeedIds);
}),
client::search
);

Expand Down Expand Up @@ -502,7 +502,7 @@ public void expandDatafeedConfigs(
client.threadPool().getThreadContext(),
ML_ORIGIN,
searchRequest,
ActionListener.<SearchResponse>wrap(response -> {
listener.<SearchResponse>delegateFailureAndWrap((delegate, response) -> {
List<DatafeedConfig.Builder> datafeeds = new ArrayList<>();
Set<String> datafeedIds = new HashSet<>();
SearchHit[] hits = response.getHits().getHits();
Expand All @@ -521,12 +521,12 @@ public void expandDatafeedConfigs(
requiredMatches.filterMatchedIds(datafeedIds);
if (requiredMatches.hasUnmatchedIds()) {
// some required datafeeds were not found
listener.onFailure(ExceptionsHelper.missingDatafeedException(requiredMatches.unmatchedIdsString()));
delegate.onFailure(ExceptionsHelper.missingDatafeedException(requiredMatches.unmatchedIdsString()));
return;
}

listener.onResponse(datafeeds);
}, listener::onFailure),
delegate.onResponse(datafeeds);
}),
client::search
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
import org.elasticsearch.xpack.ml.dataframe.inference.InferenceRunner;
import org.elasticsearch.xpack.ml.dataframe.persistence.DataFrameAnalyticsConfigProvider;
Expand Down Expand Up @@ -171,9 +170,8 @@ public void execute(DataFrameAnalyticsTask task, ClusterState clusterState, Time
}, task::setFailed);

// Retrieve configuration
ActionListener<Boolean> statsIndexListener = ActionListener.wrap(
aBoolean -> configProvider.get(task.getParams().getId(), configListener),
configListener::onFailure
ActionListener<Boolean> statsIndexListener = configListener.delegateFailureAndWrap(
(l, aBoolean) -> configProvider.get(task.getParams().getId(), l)
);

// Make sure the stats index and alias exist
Expand Down Expand Up @@ -203,25 +201,22 @@ private void createStatsIndexAndUpdateMappingsIfNecessary(
TimeValue masterNodeTimeout,
ActionListener<Boolean> listener
) {
ActionListener<Boolean> createIndexListener = ActionListener.wrap(
aBoolean -> ElasticsearchMappings.addDocMappingIfMissing(
MlStatsIndex.writeAlias(),
MlStatsIndex::wrappedMapping,
clientToUse,
clusterState,
masterNodeTimeout,
listener,
MlStatsIndex.STATS_INDEX_MAPPINGS_VERSION
),
listener::onFailure
);

MlStatsIndex.createStatsIndexAndAliasIfNecessary(
clientToUse,
clusterState,
expressionResolver,
masterNodeTimeout,
createIndexListener
listener.delegateFailureAndWrap(
(l, aBoolean) -> ElasticsearchMappings.addDocMappingIfMissing(
MlStatsIndex.writeAlias(),
MlStatsIndex::wrappedMapping,
clientToUse,
clusterState,
masterNodeTimeout,
l,
MlStatsIndex.STATS_INDEX_MAPPINGS_VERSION
)
)
);
}

Expand Down Expand Up @@ -306,25 +301,25 @@ private void executeJobInMiddleOfReindexing(DataFrameAnalyticsTask task, DataFra

private void buildInferenceStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, ActionListener<InferenceStep> listener) {
ParentTaskAssigningClient parentTaskClient = new ParentTaskAssigningClient(client, task.getParentTaskId());

ActionListener<ExtractedFieldsDetector> extractedFieldsDetectorListener = ActionListener.wrap(extractedFieldsDetector -> {
ExtractedFields extractedFields = extractedFieldsDetector.detect().v1();
InferenceRunner inferenceRunner = new InferenceRunner(
settings,
parentTaskClient,
modelLoadingService,
resultsPersisterService,
task.getParentTaskId(),
config,
extractedFields,
task.getStatsHolder().getProgressTracker(),
task.getStatsHolder().getDataCountsTracker()
);
InferenceStep inferenceStep = new InferenceStep(client, task, auditor, config, threadPool, inferenceRunner);
listener.onResponse(inferenceStep);
}, listener::onFailure);

new ExtractedFieldsDetectorFactory(parentTaskClient).createFromDest(config, extractedFieldsDetectorListener);
new ExtractedFieldsDetectorFactory(parentTaskClient).createFromDest(
config,
listener.delegateFailureAndWrap((delegate, extractedFieldsDetector) -> {
ExtractedFields extractedFields = extractedFieldsDetector.detect().v1();
InferenceRunner inferenceRunner = new InferenceRunner(
settings,
parentTaskClient,
modelLoadingService,
resultsPersisterService,
task.getParentTaskId(),
config,
extractedFields,
task.getStatsHolder().getProgressTracker(),
task.getStatsHolder().getDataCountsTracker()
);
InferenceStep inferenceStep = new InferenceStep(client, task, auditor, config, threadPool, inferenceRunner);
delegate.onResponse(inferenceStep);
})
);
}

public boolean isNodeShuttingDown() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,11 @@ private static void prepareCreateIndexRequest(
AtomicReference<Settings> settingsHolder = new AtomicReference<>();
AtomicReference<MappingMetadata> mappingsHolder = new AtomicReference<>();

ActionListener<FieldCapabilitiesResponse> fieldCapabilitiesListener = ActionListener.wrap(fieldCapabilitiesResponse -> {
listener.onResponse(createIndexRequest(clock, config, settingsHolder.get(), mappingsHolder.get(), fieldCapabilitiesResponse));
}, listener::onFailure);
ActionListener<FieldCapabilitiesResponse> fieldCapabilitiesListener = listener.delegateFailureAndWrap(
(l, fieldCapabilitiesResponse) -> l.onResponse(
createIndexRequest(clock, config, settingsHolder.get(), mappingsHolder.get(), fieldCapabilitiesResponse)
)
);

ActionListener<MappingMetadata> mappingsListener = ActionListener.wrap(mappings -> {
mappingsHolder.set(mappings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,22 +147,22 @@ public static void createForDestinationIndex(
ActionListener<DataFrameDataExtractorFactory> listener
) {
ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory(client);
extractedFieldsDetectorFactory.createFromDest(config, ActionListener.wrap(extractedFieldsDetector -> {
extractedFieldsDetectorFactory.createFromDest(config, listener.delegateFailureAndWrap((delegate, extractedFieldsDetector) -> {
ExtractedFields extractedFields = extractedFieldsDetector.detect().v1();

DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(
client,
config.getId(),
Collections.singletonList(config.getDest().getIndex()),
config.getSource().getParsedQuery(),
extractedFields,
config.getAnalysis().getRequiredFields(),
config.getHeaders(),
config.getAnalysis().supportsMissingValues(),
createTrainTestSplitterFactory(client, config, extractedFields),
Collections.emptyMap()
delegate.onResponse(
new DataFrameDataExtractorFactory(
client,
config.getId(),
Collections.singletonList(config.getDest().getIndex()),
config.getSource().getParsedQuery(),
extractedFields,
config.getAnalysis().getRequiredFields(),
config.getHeaders(),
config.getAnalysis().supportsMissingValues(),
createTrainTestSplitterFactory(client, config, extractedFields),
Collections.emptyMap()
)
);
listener.onResponse(extractorFactory);
}, listener::onFailure));
}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ private void getCardinalitiesForFieldsWithConstraints(
return;
}

ActionListener<SearchResponse> searchListener = ActionListener.wrap(
searchResponse -> buildFieldCardinalitiesMap(config, searchResponse, listener),
listener::onFailure
);

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0)
.query(config.getSource().getParsedQuery())
.runtimeMappings(config.getSource().getRuntimeMappings());
Expand Down Expand Up @@ -147,7 +142,7 @@ private void getCardinalitiesForFieldsWithConstraints(
client,
TransportSearchAction.TYPE,
searchRequest,
searchListener
listener.delegateFailureAndWrap((l, searchResponse) -> buildFieldCardinalitiesMap(config, searchResponse, l))
);
}

Expand Down
Loading

0 comments on commit ca10472

Please sign in to comment.