Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Forecaster class #920

Merged
merged 3 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ protected void runAdJob(
String user = userInfo.getName();
List<String> roles = userInfo.getRoles();

String resultIndex = jobParameter.getResultIndex();
String resultIndex = jobParameter.getCustomResultIndex();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it always be a custom result index in this path? Or all result indexes now custom?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. When this function returns null, we uses default result index.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

if (resultIndex == null) {
runAnomalyDetectionJob(
jobParameter,
Expand Down Expand Up @@ -536,7 +536,7 @@ private void stopAdJob(String detectorId, AnomalyDetectorFunction function) {
Instant.now(),
job.getLockDurationSeconds(),
job.getUser(),
job.getResultIndex()
job.getCustomResultIndex()
);
IndexRequest indexRequest = new IndexRequest(CommonName.JOB_INDEX)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.forecast.model.Forecaster;
import org.opensearch.jobscheduler.spi.JobSchedulerExtension;
import org.opensearch.jobscheduler.spi.ScheduledJobParser;
import org.opensearch.jobscheduler.spi.ScheduledJobRunner;
Expand Down Expand Up @@ -955,7 +956,8 @@ public List<NamedXContentRegistry.Entry> getNamedXContent() {
AnomalyDetector.XCONTENT_REGISTRY,
AnomalyResult.XCONTENT_REGISTRY,
DetectorInternalState.XCONTENT_REGISTRY,
AnomalyDetectorJob.XCONTENT_REGISTRY
AnomalyDetectorJob.XCONTENT_REGISTRY,
Forecaster.XCONTENT_REGISTRY
);
}

Expand Down
38 changes: 19 additions & 19 deletions src/main/java/org/opensearch/ad/AnomalyDetectorProfileRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ private void prepareProfile(
ActionListener<DetectorProfile> listener,
Set<DetectorProfileName> profilesToCollect
) {
String detectorId = detector.getDetectorId();
String detectorId = detector.getId();
GetRequest getRequest = new GetRequest(CommonName.JOB_INDEX, detectorId);
client.get(getRequest, ActionListener.wrap(getResponse -> {
if (getResponse != null && getResponse.isExists()) {
Expand All @@ -162,7 +162,7 @@ private void prepareProfile(
AnomalyDetectorJob job = AnomalyDetectorJob.parse(parser);
long enabledTimeMs = job.getEnabledTime().toEpochMilli();

boolean isMultiEntityDetector = detector.isMultientityDetector();
boolean isMultiEntityDetector = detector.isHC();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: expand to "isHighCardinality"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!


int totalResponsesToWait = 0;
if (profilesToCollect.contains(DetectorProfileName.ERROR)) {
Expand Down Expand Up @@ -284,8 +284,8 @@ private void prepareProfile(
}

private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorProfile> listener, AnomalyDetector detector) {
List<String> categoryField = detector.getCategoryField();
if (!detector.isMultientityDetector() || categoryField.size() > ADNumericSetting.maxCategoricalFields()) {
List<String> categoryField = detector.getCategoryFields();
if (!detector.isHC() || categoryField.size() > ADNumericSetting.maxCategoricalFields()) {
listener.onResponse(new DetectorProfile.Builder().build());
} else {
if (categoryField.size() == 1) {
Expand All @@ -304,7 +304,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
DetectorProfile profile = profileBuilder.totalEntities(value).build();
listener.onResponse(profile);
}, searchException -> {
logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getDetectorId());
logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getId());
listener.onFailure(searchException);
});
// using the original context in listener as user roles have no permissions for internal operations like fetching a
Expand All @@ -313,7 +313,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
.<SearchRequest, SearchResponse>asyncRequestWithInjectedSecurity(
request,
client::search,
detector.getDetectorId(),
detector.getId(),
client,
searchResponseListener
);
Expand All @@ -322,7 +322,11 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
AggregationBuilder bucketAggs = AggregationBuilders
.composite(
ADCommonName.TOTAL_ENTITIES,
detector.getCategoryField().stream().map(f -> new TermsValuesSourceBuilder(f).field(f)).collect(Collectors.toList())
detector
.getCategoryFields()
.stream()
.map(f -> new TermsValuesSourceBuilder(f).field(f))
.collect(Collectors.toList())
)
.size(maxTotalEntitiesToTrack);
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().aggregation(bucketAggs).trackTotalHits(false).size(0);
Expand Down Expand Up @@ -353,7 +357,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
DetectorProfile profile = profileBuilder.totalEntities(Long.valueOf(compositeAgg.getBuckets().size())).build();
listener.onResponse(profile);
}, searchException -> {
logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getDetectorId());
logger.warn(ADCommonMessages.FAIL_TO_GET_TOTAL_ENTITIES + detector.getId());
listener.onFailure(searchException);
});
// using the original context in listener as user roles have no permissions for internal operations like fetching a
Expand All @@ -362,7 +366,7 @@ private void profileEntityStats(MultiResponsesDelegateActionListener<DetectorPro
.<SearchRequest, SearchResponse>asyncRequestWithInjectedSecurity(
searchRequest,
client::search,
detector.getDetectorId(),
detector.getId(),
client,
searchResponseListener
);
Expand Down Expand Up @@ -400,7 +404,7 @@ private void profileStateRelated(
Set<DetectorProfileName> profilesToCollect
) {
if (enabled) {
RCFPollingRequest request = new RCFPollingRequest(detector.getDetectorId());
RCFPollingRequest request = new RCFPollingRequest(detector.getId());
client.execute(RCFPollingAction.INSTANCE, request, onPollRCFUpdates(detector, profilesToCollect, listener));
} else {
DetectorProfile.Builder builder = new DetectorProfile.Builder();
Expand All @@ -419,7 +423,7 @@ private void profileModels(
MultiResponsesDelegateActionListener<DetectorProfile> listener
) {
DiscoveryNode[] dataNodes = nodeFilter.getEligibleDataNodes();
ProfileRequest profileRequest = new ProfileRequest(detector.getDetectorId(), profiles, forMultiEntityDetector, dataNodes);
ProfileRequest profileRequest = new ProfileRequest(detector.getId(), profiles, forMultiEntityDetector, dataNodes);
client.execute(ProfileAction.INSTANCE, profileRequest, onModelResponse(detector, profiles, job, listener));// get init progress
}

Expand All @@ -429,7 +433,7 @@ private ActionListener<ProfileResponse> onModelResponse(
AnomalyDetectorJob job,
MultiResponsesDelegateActionListener<DetectorProfile> listener
) {
boolean isMultientityDetector = detector.isMultientityDetector();
boolean isMultientityDetector = detector.isHC();
return ActionListener.wrap(profileResponse -> {
DetectorProfile.Builder profile = new DetectorProfile.Builder();
if (profilesToCollect.contains(DetectorProfileName.COORDINATING_NODE)) {
Expand Down Expand Up @@ -516,7 +520,7 @@ private ActionListener<SearchResponse> onInittedEver(
logger
.error(
"Fail to find any anomaly result with anomaly score larger than 0 after AD job enabled time for detector {}",
detector.getDetectorId()
detector.getId()
);
listener.onFailure(exception);
}
Expand Down Expand Up @@ -565,11 +569,7 @@ private ActionListener<RCFPollingResponse> onPollRCFUpdates(
// data exists.
processInitResponse(detector, profilesToCollect, 0L, true, new DetectorProfile.Builder(), listener);
} else {
logger
.error(
new ParameterizedMessage("Fail to get init progress through messaging for {}", detector.getDetectorId()),
exception
);
logger.error(new ParameterizedMessage("Fail to get init progress through messaging for {}", detector.getId()), exception);
listener.onFailure(exception);
}
});
Expand Down Expand Up @@ -603,7 +603,7 @@ private void processInitResponse(
InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, 0);
builder.initProgress(initProgress);
} else {
long intervalMins = ((IntervalTimeConfiguration) detector.getDetectionInterval()).toDuration().toMinutes();
long intervalMins = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMinutes();
InitProgressProfile initProgress = computeInitProgressProfile(totalUpdates, intervalMins);
builder.initProgress(initProgress);
}
Expand Down
14 changes: 7 additions & 7 deletions src/main/java/org/opensearch/ad/AnomalyDetectorRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public void executeDetector(
ActionListener<List<AnomalyResult>> listener
) throws IOException {
context.restore();
List<String> categoryField = detector.getCategoryField();
List<String> categoryField = detector.getCategoryFields();
if (categoryField != null && !categoryField.isEmpty()) {
featureManager.getPreviewEntities(detector, startTime.toEpochMilli(), endTime.toEpochMilli(), ActionListener.wrap(entities -> {

Expand All @@ -86,13 +86,13 @@ public void executeDetector(
ActionListener<EntityAnomalyResult> entityAnomalyResultListener = ActionListener
.wrap(
entityAnomalyResult -> { listener.onResponse(entityAnomalyResult.getAnomalyResults()); },
e -> onFailure(e, listener, detector.getDetectorId())
e -> onFailure(e, listener, detector.getId())
);
MultiResponsesDelegateActionListener<EntityAnomalyResult> multiEntitiesResponseListener =
new MultiResponsesDelegateActionListener<EntityAnomalyResult>(
entityAnomalyResultListener,
entities.size(),
String.format(Locale.ROOT, "Fail to get preview result for multi entity detector %s", detector.getDetectorId()),
String.format(Locale.ROOT, "Fail to get preview result for multi entity detector %s", detector.getId()),
true
);
for (Entity entity : entities) {
Expand All @@ -113,17 +113,17 @@ public void executeDetector(
}, e -> multiEntitiesResponseListener.onFailure(e))
);
}
}, e -> onFailure(e, listener, detector.getDetectorId())));
}, e -> onFailure(e, listener, detector.getId())));
} else {
featureManager.getPreviewFeatures(detector, startTime.toEpochMilli(), endTime.toEpochMilli(), ActionListener.wrap(features -> {
try {
List<ThresholdingResult> results = modelManager
.getPreviewResults(features.getProcessedFeatures(), detector.getShingleSize());
listener.onResponse(sample(parsePreviewResult(detector, features, results, null), maxPreviewResults));
} catch (Exception e) {
onFailure(e, listener, detector.getDetectorId());
onFailure(e, listener, detector.getId());
}
}, e -> onFailure(e, listener, detector.getDetectorId())));
}, e -> onFailure(e, listener, detector.getId())));
}
}

Expand Down Expand Up @@ -184,7 +184,7 @@ private List<AnomalyResult> parsePreviewResult(
);
} else {
result = new AnomalyResult(
detector.getDetectorId(),
detector.getId(),
null,
featureDatas,
Instant.ofEpochMilli(timeRange.getKey()),
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/org/opensearch/ad/EntityProfileRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public void profile(
) {
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
AnomalyDetector detector = AnomalyDetector.parse(parser, detectorId);
List<String> categoryFields = detector.getCategoryField();
List<String> categoryFields = detector.getCategoryFields();
int maxCategoryFields = ADNumericSetting.maxCategoricalFields();
if (categoryFields == null || categoryFields.size() == 0) {
listener.onFailure(new IllegalArgumentException(NOT_HC_DETECTOR_ERR_MSG));
Expand Down Expand Up @@ -186,7 +186,7 @@ private void validateEntity(
.<SearchRequest, SearchResponse>asyncRequestWithInjectedSecurity(
searchRequest,
client::search,
detector.getDetectorId(),
detector.getId(),
client,
searchResponseListener
);
Expand Down Expand Up @@ -277,7 +277,7 @@ private void getJob(
detectorId,
enabledTimeMs,
entityValue,
detector.getResultIndex()
detector.getCustomResultIndex()
);

EntityProfile.Builder builder = new EntityProfile.Builder();
Expand Down Expand Up @@ -397,7 +397,7 @@ private void sendInitState(
builder.state(EntityState.INIT);
}
if (profilesToCollect.contains(EntityProfileName.INIT_PROGRESS)) {
long intervalMins = ((IntervalTimeConfiguration) detector.getDetectionInterval()).toDuration().toMinutes();
long intervalMins = ((IntervalTimeConfiguration) detector.getInterval()).toDuration().toMinutes();
InitProgressProfile initProgress = computeInitProgressProfile(updates, intervalMins);
builder.initProgress(initProgress);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public void indexAnomalyResult(
AnomalyResultResponse response,
AnomalyDetector detector
) {
String detectorId = detector.getDetectorId();
String detectorId = detector.getId();
try {
// skipping writing to the result index if not necessary
// For a single-entity detector, the result is not useful if error is null
Expand Down Expand Up @@ -124,7 +124,7 @@ public void indexAnomalyResult(
response.getError()
);

String resultIndex = detector.getResultIndex();
String resultIndex = detector.getCustomResultIndex();
anomalyResultHandler.index(anomalyResult, detectorId, resultIndex);
updateRealtimeTask(response, detectorId);
} catch (EndRunException e) {
Expand Down Expand Up @@ -156,13 +156,7 @@ private void updateRealtimeTask(AnomalyResultResponse response, String detectorI
Runnable profileHCInitProgress = () -> {
client.execute(ProfileAction.INSTANCE, profileRequest, ActionListener.wrap(r -> {
log.debug("Update latest realtime task for HC detector {}, total updates: {}", detectorId, r.getTotalUpdates());
updateLatestRealtimeTask(
detectorId,
null,
r.getTotalUpdates(),
response.getDetectorIntervalInMinutes(),
response.getError()
);
updateLatestRealtimeTask(detectorId, null, r.getTotalUpdates(), response.getIntervalInMinutes(), response.getError());
}, e -> { log.error("Failed to update latest realtime task for " + detectorId, e); }));
};
if (!adTaskManager.isHCRealtimeTaskStartInitializing(detectorId)) {
Expand All @@ -181,13 +175,7 @@ private void updateRealtimeTask(AnomalyResultResponse response, String detectorI
detectorId,
response.getRcfTotalUpdates()
);
updateLatestRealtimeTask(
detectorId,
null,
response.getRcfTotalUpdates(),
response.getDetectorIntervalInMinutes(),
response.getError()
);
updateLatestRealtimeTask(detectorId, null, response.getRcfTotalUpdates(), response.getIntervalInMinutes(), response.getError());
}
}

Expand Down Expand Up @@ -278,7 +266,7 @@ public void indexAnomalyResultException(
String taskState,
AnomalyDetector detector
) {
String detectorId = detector.getDetectorId();
String detectorId = detector.getId();
try {
IntervalTimeConfiguration windowDelay = (IntervalTimeConfiguration) detector.getWindowDelay();
Instant dataStartTime = detectionStartTime.minus(windowDelay.getInterval(), windowDelay.getUnit());
Expand All @@ -299,15 +287,15 @@ public void indexAnomalyResultException(
anomalyDetectionIndices.getSchemaVersion(ADIndex.RESULT),
null // no model id
);
String resultIndex = detector.getResultIndex();
String resultIndex = detector.getCustomResultIndex();
if (resultIndex != null && !anomalyDetectionIndices.doesIndexExist(resultIndex)) {
// Set result index as null, will write exception to default result index.
anomalyResultHandler.index(anomalyResult, detectorId, null);
} else {
anomalyResultHandler.index(anomalyResult, detectorId, resultIndex);
}

if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !detector.isMultiCategoryDetector()) {
if (errorMessage.contains(ADCommonMessages.NO_MODEL_ERR_MSG) && !detector.isHC()) {
// single stream detector raises ResourceNotFoundException containing CommonErrorMessages.NO_CHECKPOINT_ERR_MSG
// when there is no checkpoint.
// Delay real time cache update by one minute so we will have trained models by then and update the state
Expand All @@ -321,7 +309,7 @@ public void indexAnomalyResultException(
detectorId,
taskState,
totalUpdates,
detector.getDetectorIntervalInMinutes(),
detector.getIntervalInMinutes(),
totalUpdates > 0 ? "" : errorMessage
);
}, e -> {
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/ad/NodeState.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public NodeState(String detectorId, Clock clock) {
this.detectorJob = null;
}

public String getDetectorId() {
public String getId() {
return detectorId;
}

Expand Down
Loading