Skip to content

Commit

Permalink
forward port flaky test fix in PR #1319 and add forecasting security …
Browse files Browse the repository at this point in the history
…tests

Signed-off-by: Kaituo Li <[email protected]>
  • Loading branch information
kaituo committed Oct 1, 2024
1 parent 062db14 commit 3d5b2d3
Show file tree
Hide file tree
Showing 40 changed files with 93,558 additions and 391 deletions.
8 changes: 4 additions & 4 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -339,12 +339,15 @@ integTest {
filter {
includeTestsMatching "org.opensearch.ad.rest.*IT"
includeTestsMatching "org.opensearch.ad.e2e.*IT"
includeTestsMatching "org.opensearch.forecast.rest.*IT"
includeTestsMatching "org.opensearch.forecast.e2e.*IT"
}
}

if (System.getProperty("https") == null || System.getProperty("https") == "false") {
filter {
excludeTestsMatching "org.opensearch.ad.rest.SecureADRestIT"
excludeTestsMatching "org.opensearch.forecast.rest.SecureForecastRestIT"
}
}

Expand Down Expand Up @@ -468,6 +471,7 @@ task integTestRemote(type: RestIntegTestTask) {
if (System.getProperty("https") == null || System.getProperty("https") == "false") {
filter {
excludeTestsMatching "org.opensearch.ad.rest.SecureADRestIT"
excludeTestsMatching "org.opensearch.forecast.rest.SecureForecastRestIT"
}
}
}
Expand Down Expand Up @@ -696,10 +700,7 @@ List<String> jacocoExclusions = [

// TODO: add test coverage (kaituo)
'org.opensearch.forecast.*',
'org.opensearch.timeseries.ml.TimeSeriesSingleStreamCheckpointDao',
'org.opensearch.timeseries.transport.JobRequest',
'org.opensearch.timeseries.transport.handler.ResultBulkIndexingHandler',
'org.opensearch.timeseries.ml.Inferencer',
'org.opensearch.timeseries.transport.SingleStreamResultRequest',
'org.opensearch.timeseries.rest.handler.IndexJobActionHandler.1',
'org.opensearch.timeseries.transport.SuggestConfigParamResponse',
Expand Down Expand Up @@ -727,7 +728,6 @@ List<String> jacocoExclusions = [
'org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker',
'org.opensearch.timeseries.util.TimeUtil',
'org.opensearch.ad.transport.ADHCImputeTransportAction',
'org.opensearch.timeseries.ml.RealTimeInferencer',
]


Expand Down
8 changes: 6 additions & 2 deletions src/main/java/org/opensearch/ad/ml/ADRealTimeInferencer.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.AD_THREAD_POOL_NAME;

import java.time.Clock;

import org.opensearch.ad.caching.ADCacheProvider;
import org.opensearch.ad.caching.ADPriorityCache;
import org.opensearch.ad.indices.ADIndex;
Expand All @@ -32,7 +34,8 @@ public ADRealTimeInferencer(
ADColdStartWorker coldStartWorker,
ADSaveResultStrategy resultWriteWorker,
ADCacheProvider cache,
ThreadPool threadPool
ThreadPool threadPool,
Clock clock
) {
super(
modelManager,
Expand All @@ -43,7 +46,8 @@ public ADRealTimeInferencer(
resultWriteWorker,
cache,
threadPool,
AD_THREAD_POOL_NAME
AD_THREAD_POOL_NAME,
clock
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.model.IntervalTimeConfiguration;
import org.opensearch.timeseries.util.ActionListenerExecutor;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -129,14 +128,12 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
return;
}
Config config = configOptional.get();
long windowDelayMillis = ((IntervalTimeConfiguration) config.getWindowDelay()).toDuration().toMillis();
int featureSize = config.getEnabledFeatureIds().size();
long dataEndMillis = nodeRequest.getRequest().getDataEndMillis();
long dataStartMillis = nodeRequest.getRequest().getDataStartMillis();
long executionEndTime = dataEndMillis + windowDelayMillis;
String taskId = nodeRequest.getRequest().getTaskId();
for (ModelState<ThresholdedRandomCutForest> modelState : cache.get().getAllModels(configId)) {
if (shouldProcessModelState(modelState, executionEndTime, clusterService, hashRing)) {
if (shouldProcessModelState(modelState, dataEndMillis, clusterService, hashRing)) {
double[] nanArray = new double[featureSize];
Arrays.fill(nanArray, Double.NaN);
adInferencer
Expand All @@ -163,8 +160,8 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
* Determines whether the model state should be processed based on various conditions.
*
* Conditions checked:
* - The model's last seen execution end time is not the minimum Instant value.
* - The current execution end time is greater than or equal to the model's last seen execution end time,
* - The model's last seen data end time is not the minimum Instant value. This means the model hasn't been initialized yet.
* - The current data end time is greater than the model's last seen data end time,
* indicating that the model state was updated in previous intervals.
* - The entity associated with the model state is present.
* - The owning node for real-time processing of the entity, with the same local version, is present in the hash ring.
Expand All @@ -175,14 +172,14 @@ protected ADHCImputeNodeResponse nodeOperation(ADHCImputeNodeRequest nodeRequest
* concurrently (e.g., during tests when multiple threads may operate quickly).
*
* @param modelState The current state of the model.
* @param executionEndTime The end time of the current execution interval.
* @param dataEndTime The data end time of current interval.
* @param clusterService The service providing information about the current cluster node.
* @param hashRing The hash ring used to determine the owning node for real-time processing of entities.
* @return true if the model state should be processed; otherwise, false.
*/
private boolean shouldProcessModelState(
ModelState<ThresholdedRandomCutForest> modelState,
long executionEndTime,
long dataEndTime,
ClusterService clusterService,
HashRing hashRing
) {
Expand All @@ -194,8 +191,8 @@ private boolean shouldProcessModelState(
// Check if the model state conditions are met for processing
// We cannot use last used time as it will be updated whenever we update its priority in CacheBuffer.update when there is a
// PriorityCache.get.
return modelState.getLastSeenExecutionEndTime() != Instant.MIN
&& executionEndTime >= modelState.getLastSeenExecutionEndTime().toEpochMilli()
return modelState.getLastSeenDataEndTime() != Instant.MIN
&& dataEndTime > modelState.getLastSeenDataEndTime().toEpochMilli()
&& modelState.getEntity().isPresent()
&& owningNode.isPresent()
&& owningNode.get().getId().equals(clusterService.localNode().getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import static org.opensearch.timeseries.TimeSeriesAnalyticsPlugin.FORECAST_THREAD_POOL_NAME;

import java.time.Clock;

import org.opensearch.forecast.caching.ForecastCacheProvider;
import org.opensearch.forecast.caching.ForecastPriorityCache;
import org.opensearch.forecast.indices.ForecastIndex;
Expand All @@ -32,7 +34,8 @@ public ForecastRealTimeInferencer(
ForecastColdStartWorker coldStartWorker,
ForecastSaveResultStrategy resultWriteWorker,
ForecastCacheProvider cache,
ThreadPool threadPool
ThreadPool threadPool,
Clock clock
) {
super(
modelManager,
Expand All @@ -43,7 +46,8 @@ public ForecastRealTimeInferencer(
resultWriteWorker,
cache,
threadPool,
FORECAST_THREAD_POOL_NAME
FORECAST_THREAD_POOL_NAME,
clock
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
all,
RestHandlerUtils.buildEntity(request, forecasterId)
);

return channel -> client.execute(GetForecasterAction.INSTANCE, getForecasterRequest, new RestToXContentListener<>(channel));
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(Encode.forHtml(e.getMessage()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,6 @@ private QueryBuilder generateBuildInSubFilter(SearchTopForecastResultRequest req
*/
private RangeQueryBuilder generateDateFilter(SearchTopForecastResultRequest request, Forecaster forecaster) {
// forecast from is data end time for forecast
// return QueryBuilders.termQuery(CommonName.DATA_END_TIME_FIELD, request.getForecastFrom().toEpochMilli());
long startInclusive = request.getForecastFrom().toEpochMilli();
long endExclusive = startInclusive + forecaster.getIntervalInMilliseconds();
return QueryBuilders.rangeQuery(CommonName.DATA_END_TIME_FIELD).gte(startInclusive).lt(endExclusive);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,8 @@ public PooledObject<LinkedBuffer> wrap(LinkedBuffer obj) {
adColdstartQueue,
adSaveResultStrategy,
adCacheProvider,
threadPool
threadPool,
getClock()
);

ADCheckpointReadWorker adCheckpointReadQueue = new ADCheckpointReadWorker(
Expand Down Expand Up @@ -1230,7 +1231,8 @@ public PooledObject<LinkedBuffer> wrap(LinkedBuffer obj) {
forecastColdstartQueue,
forecastSaveResultStrategy,
forecastCacheProvider,
threadPool
threadPool,
getClock()
);

ForecastCheckpointReadWorker forecastCheckpointReadQueue = new ForecastCheckpointReadWorker(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ public ModelState<RCFModelType> get(String modelId, Config config) {
// reset every 60 intervals
return new DoorKeeper(
TimeSeriesSettings.DOOR_KEEPER_FOR_CACHE_MAX_INSERTION,
config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ),
config.getIntervalDuration().multipliedBy(TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ),
clock,
TimeSeriesSettings.CACHE_DOOR_KEEPER_COUNT_THRESHOLD
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ private void coldStart(
// reset every 60 intervals
return new DoorKeeper(
TimeSeriesSettings.DOOR_KEEPER_FOR_COLD_STARTER_MAX_INSERTION,
config.getIntervalDuration().multipliedBy(TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ),
config.getIntervalDuration().multipliedBy(TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ),
clock,
TimeSeriesSettings.COLD_START_DOOR_KEEPER_COUNT_THRESHOLD
);
Expand All @@ -251,7 +251,7 @@ private void coldStart(
logger
.info(
"Won't retry real-time cold start within {} intervals for model {}",
TimeSeriesSettings.DOOR_KEEPER_MAINTENANCE_FREQ,
TimeSeriesSettings.EXPIRING_VALUE_MAINTENANCE_FREQ,
modelId
);
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ public <RCFDescriptor extends AnomalyDescriptor> IntermediateResultType score(
throw e;
} finally {
modelState.setLastUsedTime(clock.instant());
modelState.setLastSeenExecutionEndTime(clock.instant());
modelState.setLastSeenDataEndTime(sample.getDataEndTime());
}
return createEmptyResult();
}
Expand Down
12 changes: 6 additions & 6 deletions src/main/java/org/opensearch/timeseries/ml/ModelState.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class ModelState<T> implements org.opensearch.timeseries.ExpiringState {
// time when the ML model was used last time
protected Instant lastUsedTime;
protected Instant lastCheckpointTime;
protected Instant lastSeenExecutionEndTime;
protected Instant lastSeenDataEndTime;
protected Clock clock;
protected float priority;
protected Deque<Sample> samples;
Expand Down Expand Up @@ -75,7 +75,7 @@ public ModelState(
this.priority = priority;
this.entity = entity;
this.samples = samples;
this.lastSeenExecutionEndTime = Instant.MIN;
this.lastSeenDataEndTime = Instant.MIN;
}

/**
Expand Down Expand Up @@ -252,11 +252,11 @@ public Map<String, Object> getModelStateAsMap() {
};
}

public Instant getLastSeenExecutionEndTime() {
return lastSeenExecutionEndTime;
public Instant getLastSeenDataEndTime() {
return lastSeenDataEndTime;
}

public void setLastSeenExecutionEndTime(Instant lastSeenExecutionEndTime) {
this.lastSeenExecutionEndTime = lastSeenExecutionEndTime;
public void setLastSeenDataEndTime(Instant lastSeenExecutionEndTime) {
this.lastSeenDataEndTime = lastSeenExecutionEndTime;
}
}
Loading

0 comments on commit 3d5b2d3

Please sign in to comment.