diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java index 7cef2bed04ce3..c5fa5d0975483 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlTasks.java @@ -229,6 +229,10 @@ public static SnapshotUpgradeState getSnapshotUpgradeState(@Nullable PersistentT public static DatafeedState getDatafeedState(String datafeedId, @Nullable PersistentTasksCustomMetadata tasks) { PersistentTasksCustomMetadata.PersistentTask task = getDatafeedTask(datafeedId, tasks); + return getDatafeedState(task); + } + + public static DatafeedState getDatafeedState(PersistentTasksCustomMetadata.PersistentTask task) { if (task == null) { // If we haven't started a datafeed then there will be no persistent task, // which is the same as if the datafeed was't started diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfo.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfo.java index fd2f3627e3fb1..826b0785aa563 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfo.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfo.java @@ -86,6 +86,10 @@ public int getTargetAllocations() { return targetAllocations; } + public int getFailedAllocations() { + return state == RoutingState.FAILED ? targetAllocations : 0; + } + public RoutingState getState() { return state; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index d27d325a5c596..8147dabda7b48 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -287,6 +287,10 @@ public int totalTargetAllocations() { return nodeRoutingTable.values().stream().mapToInt(RoutingInfo::getTargetAllocations).sum(); } + public int totalFailedAllocations() { + return nodeRoutingTable.values().stream().mapToInt(RoutingInfo::getFailedAllocations).sum(); + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoTests.java index 28ebf8b2445c5..830f7dde7c7d8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/RoutingInfoTests.java @@ -69,4 +69,17 @@ public void testIsRoutable_GivenStartedWithNonZeroAllocations() { RoutingInfo routingInfo = new RoutingInfo(randomIntBetween(1, 10), 1, RoutingState.STARTED, ""); assertThat(routingInfo.isRoutable(), is(true)); } + + public void testGetFailedAllocations() { + int targetAllocations = randomIntBetween(1, 10); + RoutingInfo routingInfo = new RoutingInfo( + randomIntBetween(0, targetAllocations), + targetAllocations, + randomFrom(RoutingState.STARTING, RoutingState.STARTED, RoutingState.STOPPING), + "" + ); + assertThat(routingInfo.getFailedAllocations(), is(0)); + routingInfo = new RoutingInfo(randomIntBetween(0, targetAllocations), targetAllocations, RoutingState.FAILED, ""); + assertThat(routingInfo.getFailedAllocations(), is(targetAllocations)); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 538f02b3f9092..77622523362f9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -72,6 +72,7 @@ import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.telemetry.TelemetryProvider; import org.elasticsearch.threadpool.ExecutorBuilder; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.threadpool.ThreadPool; @@ -884,6 +885,7 @@ public Collection createComponents(PluginServices services) { Environment environment = services.environment(); NamedXContentRegistry xContentRegistry = services.xContentRegistry(); IndexNameExpressionResolver indexNameExpressionResolver = services.indexNameExpressionResolver(); + TelemetryProvider telemetryProvider = services.telemetryProvider(); if (enabled == false) { // Holders for @link(MachineLearningFeatureSetUsage) which needs access to job manager and ML extension, @@ -1220,6 +1222,14 @@ public Collection createComponents(PluginServices services) { machineLearningExtension.get().isNlpEnabled() ); + MlMetrics mlMetrics = new MlMetrics( + telemetryProvider.getMeterRegistry(), + clusterService, + settings, + autodetectProcessManager, + dataFrameAnalyticsManager + ); + return List.of( mlLifeCycleService, new MlControllerHolder(mlController), @@ -1251,7 +1261,8 @@ public Collection createComponents(PluginServices services) { trainedModelAllocationClusterServiceSetOnce.get(), deploymentManager.get(), nodeAvailabilityZoneMapper, - new MachineLearningExtensionHolder(machineLearningExtension.get()) + new MachineLearningExtensionHolder(machineLearningExtension.get()), + mlMetrics ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlMetrics.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlMetrics.java new file mode 100644 index 0000000000000..2eeb5947ff591 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlMetrics.java @@ -0,0 +1,424 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml; + +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.cluster.node.DiscoveryNodeRole; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.gateway.GatewayService; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.telemetry.metric.LongWithAttributes; +import org.elasticsearch.telemetry.metric.MeterRegistry; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsManager; +import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata; +import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager; +import org.elasticsearch.xpack.ml.utils.NativeMemoryCalculator; + +import java.util.Map; +import java.util.Optional; + +import static org.elasticsearch.xpack.core.ml.MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT; +import static org.elasticsearch.xpack.core.ml.MlTasks.DATAFEED_TASK_NAME; +import static org.elasticsearch.xpack.core.ml.MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME; +import static org.elasticsearch.xpack.core.ml.MlTasks.JOB_SNAPSHOT_UPGRADE_TASK_NAME; +import static org.elasticsearch.xpack.core.ml.MlTasks.JOB_TASK_NAME; +import static org.elasticsearch.xpack.ml.MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD; + +public class MlMetrics implements ClusterStateListener { + + private final ClusterService clusterService; + private final AutodetectProcessManager autodetectProcessManager; + private final DataFrameAnalyticsManager dataFrameAnalyticsManager; + private final boolean hasMasterRole; + private final boolean hasMlRole; + + private static final Map MASTER_TRUE_MAP = Map.of("is_master", Boolean.TRUE); + private static final Map MASTER_FALSE_MAP = Map.of("is_master", Boolean.FALSE); + private volatile Map isMasterMap = MASTER_FALSE_MAP; + private volatile boolean firstTime = true; + + private volatile MlTaskStatusCounts mlTaskStatusCounts = MlTaskStatusCounts.EMPTY; + private volatile TrainedModelAllocationCounts trainedModelAllocationCounts = TrainedModelAllocationCounts.EMPTY; + + private volatile long nativeMemLimit; + private volatile long nativeMemAdUsage; + private volatile long nativeMemDfaUsage; + private volatile long nativeMemTrainedModelUsage; + private volatile long nativeMemFree; + + public MlMetrics( + MeterRegistry meterRegistry, + ClusterService clusterService, + Settings settings, + AutodetectProcessManager autodetectProcessManager, + DataFrameAnalyticsManager dataFrameAnalyticsManager + ) { + this.clusterService = clusterService; + this.autodetectProcessManager = autodetectProcessManager; + this.dataFrameAnalyticsManager = dataFrameAnalyticsManager; + hasMasterRole = DiscoveryNode.hasRole(settings, DiscoveryNodeRole.MASTER_ROLE); + if (hasMasterRole) { + registerMasterNodeMetrics(meterRegistry); + } + hasMlRole = DiscoveryNode.hasRole(settings, DiscoveryNodeRole.ML_ROLE); + if (hasMlRole) { + registerMlNodeMetrics(meterRegistry); + } + if (hasMasterRole || hasMlRole) { + clusterService.addListener(this); + } + } + + private void registerMlNodeMetrics(MeterRegistry meterRegistry) { + // Ignore the AutoCloseable warnings here - the registry is responsible for closing these gauges + meterRegistry.registerLongGauge( + "es.ml.native_memory.limit", + "ML native memory limit on this node.", + "bytes", + () -> new LongWithAttributes(nativeMemLimit, Map.of()) + ); + meterRegistry.registerLongGauge( + "es.ml.native_memory.usage.anomaly_detectors", + "ML native memory used by anomaly detection jobs on this node.", + "bytes", + () -> new LongWithAttributes(nativeMemAdUsage, Map.of()) + ); + meterRegistry.registerLongGauge( + "es.ml.native_memory.usage.data_frame_analytics", + "ML native memory used by data frame analytics jobs on this node.", + "bytes", + () -> new LongWithAttributes(nativeMemDfaUsage, Map.of()) + ); + meterRegistry.registerLongGauge( + "es.ml.native_memory.usage.trained_models", + "ML native memory used by trained models on this node.", + "bytes", + () -> new LongWithAttributes(nativeMemTrainedModelUsage, Map.of()) + ); + meterRegistry.registerLongGauge( + "es.ml.native_memory.free", + "Free ML native memory on this node.", + "bytes", + () -> new LongWithAttributes(nativeMemFree, Map.of()) + ); + } + + private void registerMasterNodeMetrics(MeterRegistry meterRegistry) { + // Ignore the AutoCloseable warnings here - the registry is responsible for closing these gauges + meterRegistry.registerLongGauge( + "es.ml.anomaly_detectors.opening.count", + "Count of anomaly detection jobs in the opening state cluster-wide.", + "jobs", + () -> new LongWithAttributes(mlTaskStatusCounts.adOpeningCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.anomaly_detectors.opened.count", + "Count of anomaly detection jobs in the opened state cluster-wide.", + "jobs", + () -> new LongWithAttributes(mlTaskStatusCounts.adOpenedCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.anomaly_detectors.closing.count", + "Count of anomaly detection jobs in the closing state cluster-wide.", + "jobs", + () -> new LongWithAttributes(mlTaskStatusCounts.adClosingCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.anomaly_detectors.failed.count", + "Count of anomaly detection jobs in the failed state cluster-wide.", + "jobs", + () -> new LongWithAttributes(mlTaskStatusCounts.adFailedCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.anomaly_detectors.starting.count", + "Count of datafeeds in the starting state cluster-wide.", + "datafeeds", + () -> new LongWithAttributes(mlTaskStatusCounts.datafeedStartingCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.anomaly_detectors.started.count", + "Count of datafeeds in the started state cluster-wide.", + "datafeeds", + () -> new LongWithAttributes(mlTaskStatusCounts.datafeedStartedCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.anomaly_detectors.stopping.count", + "Count of datafeeds in the stopping state cluster-wide.", + "datafeeds", + () -> new LongWithAttributes(mlTaskStatusCounts.datafeedStoppingCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.data_frame_analytics.starting.count", + "Count of data frame analytics jobs in the starting state cluster-wide.", + "jobs", + () -> new LongWithAttributes(mlTaskStatusCounts.dfaStartingCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.data_frame_analytics.started.count", + "Count of data frame analytics jobs in the started state cluster-wide.", + "jobs", + () -> new LongWithAttributes(mlTaskStatusCounts.dfaStartedCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.data_frame_analytics.reindexing.count", + "Count of data frame analytics jobs in the reindexing state cluster-wide.", + "jobs", + () -> new LongWithAttributes(mlTaskStatusCounts.dfaReindexingCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.data_frame_analytics.analyzing.count", + "Count of data frame analytics jobs in the analyzing state cluster-wide.", + "jobs", + () -> new LongWithAttributes(mlTaskStatusCounts.dfaAnalyzingCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.data_frame_analytics.stopping.count", + "Count of data frame analytics jobs in the stopping state cluster-wide.", + "jobs", + () -> new LongWithAttributes(mlTaskStatusCounts.dfaStoppingCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.data_frame_analytics.failed.count", + "Count of data frame analytics jobs in the failed state cluster-wide.", + "jobs", + () -> new LongWithAttributes(mlTaskStatusCounts.dfaFailedCount, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.trained_models.deployment.target_allocations.count", + "Sum of target trained model allocations across all deployments cluster-wide.", + "allocations", + () -> new LongWithAttributes(trainedModelAllocationCounts.trainedModelsTargetAllocations, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.trained_models.deployment.current_allocations.count", + "Sum of current trained model allocations across all deployments cluster-wide.", + "allocations", + () -> new LongWithAttributes(trainedModelAllocationCounts.trainedModelsCurrentAllocations, isMasterMap) + ); + meterRegistry.registerLongGauge( + "es.ml.trained_models.deployment.failed_allocations.count", + "Sum of failed trained model allocations across all deployments cluster-wide.", + "allocations", + () -> new LongWithAttributes(trainedModelAllocationCounts.trainedModelsFailedAllocations, isMasterMap) + ); + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + isMasterMap = event.localNodeMaster() ? MASTER_TRUE_MAP : MASTER_FALSE_MAP; + + if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) { + // Wait until the gateway has recovered from disk. + return; + } + + boolean recalculateFreeMem = false; + + final ClusterState currentState = event.state(); + final ClusterState previousState = event.previousState(); + + if (firstTime || event.metadataChanged()) { + final PersistentTasksCustomMetadata tasks = currentState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE); + final PersistentTasksCustomMetadata oldTasks = previousState.getMetadata().custom(PersistentTasksCustomMetadata.TYPE); + if (tasks != null && tasks.equals(oldTasks) == false) { + if (hasMasterRole) { + mlTaskStatusCounts = findTaskStatuses(tasks); + } + if (hasMlRole) { + nativeMemAdUsage = findAdMemoryUsage(autodetectProcessManager); + nativeMemDfaUsage = findDfaMemoryUsage(dataFrameAnalyticsManager, tasks); + recalculateFreeMem = true; + } + } + } + + final TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.fromState(currentState); + final TrainedModelAssignmentMetadata previousMetadata = TrainedModelAssignmentMetadata.fromState(previousState); + if (currentMetadata != null && currentMetadata.equals(previousMetadata) == false) { + if (hasMasterRole) { + trainedModelAllocationCounts = findTrainedModelAllocationCounts(currentMetadata); + } + if (hasMlRole) { + nativeMemTrainedModelUsage = findTrainedModelMemoryUsage(currentMetadata, currentState.nodes().getLocalNode().getId()); + recalculateFreeMem = true; + } + } + + if (firstTime) { + firstTime = false; + nativeMemLimit = findNativeMemoryLimit(currentState.nodes().getLocalNode(), clusterService.getClusterSettings()); + recalculateFreeMem = true; + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(USE_AUTO_MACHINE_MEMORY_PERCENT, s -> memoryLimitClusterSettingUpdated()); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, s -> memoryLimitClusterSettingUpdated()); + } + + if (recalculateFreeMem) { + nativeMemFree = findNativeMemoryFree(nativeMemLimit, nativeMemAdUsage, nativeMemDfaUsage, nativeMemTrainedModelUsage); + } + } + + private void memoryLimitClusterSettingUpdated() { + nativeMemLimit = findNativeMemoryLimit(clusterService.localNode(), clusterService.getClusterSettings()); + nativeMemFree = findNativeMemoryFree(nativeMemLimit, nativeMemAdUsage, nativeMemDfaUsage, nativeMemTrainedModelUsage); + } + + static MlTaskStatusCounts findTaskStatuses(PersistentTasksCustomMetadata tasks) { + + int adOpeningCount = 0; + int adOpenedCount = 0; + int adClosingCount = 0; + int adFailedCount = 0; + int datafeedStartingCount = 0; + int datafeedStartedCount = 0; + int datafeedStoppingCount = 0; + int dfaStartingCount = 0; + int dfaStartedCount = 0; + int dfaReindexingCount = 0; + int dfaAnalyzingCount = 0; + int dfaStoppingCount = 0; + int dfaFailedCount = 0; + + for (PersistentTasksCustomMetadata.PersistentTask task : tasks.tasks()) { + switch (task.getTaskName()) { + case JOB_TASK_NAME: + switch (MlTasks.getJobStateModifiedForReassignments(task)) { + case OPENING -> ++adOpeningCount; + case OPENED -> ++adOpenedCount; + case CLOSING -> ++adClosingCount; + case FAILED -> ++adFailedCount; + } + break; + case DATAFEED_TASK_NAME: + switch (MlTasks.getDatafeedState(task)) { + case STARTING -> ++datafeedStartingCount; + case STARTED -> ++datafeedStartedCount; + case STOPPING -> ++datafeedStoppingCount; + } + break; + case DATA_FRAME_ANALYTICS_TASK_NAME: + switch (MlTasks.getDataFrameAnalyticsState(task)) { + case STARTING -> ++dfaStartingCount; + case STARTED -> ++dfaStartedCount; + case REINDEXING -> ++dfaReindexingCount; + case ANALYZING -> ++dfaAnalyzingCount; + case STOPPING -> ++dfaStoppingCount; + case FAILED -> ++dfaFailedCount; + } + break; + case JOB_SNAPSHOT_UPGRADE_TASK_NAME: + // Not currently tracked + // TODO: consider in the future, especially when we're at the stage of needing to upgrade serverless model snapshots + break; + } + } + + return new MlTaskStatusCounts( + adOpeningCount, + adOpenedCount, + adClosingCount, + adFailedCount, + datafeedStartingCount, + datafeedStartedCount, + datafeedStoppingCount, + dfaStartingCount, + dfaStartedCount, + dfaReindexingCount, + dfaAnalyzingCount, + dfaStoppingCount, + dfaFailedCount + ); + } + + static long findAdMemoryUsage(AutodetectProcessManager autodetectProcessManager) { + return autodetectProcessManager.getOpenProcessMemoryUsage().getBytes(); + } + + static long findDfaMemoryUsage(DataFrameAnalyticsManager dataFrameAnalyticsManager, PersistentTasksCustomMetadata tasks) { + return dataFrameAnalyticsManager.getActiveTaskMemoryUsage(tasks).getBytes(); + } + + static TrainedModelAllocationCounts findTrainedModelAllocationCounts(TrainedModelAssignmentMetadata metadata) { + int trainedModelsTargetAllocations = 0; + int trainedModelsCurrentAllocations = 0; + int trainedModelsFailedAllocations = 0; + + for (TrainedModelAssignment trainedModelAssignment : metadata.allAssignments().values()) { + trainedModelsTargetAllocations += trainedModelAssignment.totalCurrentAllocations(); + trainedModelsCurrentAllocations += trainedModelAssignment.totalTargetAllocations(); + trainedModelsFailedAllocations += trainedModelAssignment.totalFailedAllocations(); + } + + return new TrainedModelAllocationCounts( + trainedModelsTargetAllocations, + trainedModelsCurrentAllocations, + trainedModelsFailedAllocations + ); + } + + static long findTrainedModelMemoryUsage(TrainedModelAssignmentMetadata metadata, String localNodeId) { + long trainedModelMemoryUsageBytes = 0; + for (TrainedModelAssignment assignment : metadata.allAssignments().values()) { + if (Optional.ofNullable(assignment.getNodeRoutingTable().get(localNodeId)) + .map(RoutingInfo::getState) + .orElse(RoutingState.STOPPED) + .consumesMemory()) { + trainedModelMemoryUsageBytes += assignment.getTaskParams().estimateMemoryUsageBytes(); + } + } + return trainedModelMemoryUsageBytes; + } + + static long findNativeMemoryLimit(DiscoveryNode localNode, ClusterSettings settings) { + return NativeMemoryCalculator.allowedBytesForMl(localNode, settings).orElse(0L); + } + + static long findNativeMemoryFree(long nativeMemLimit, long nativeMemAdUsage, long nativeMemDfaUsage, long nativeMemTrainedModelUsage) { + long totalUsage = nativeMemAdUsage - nativeMemDfaUsage - nativeMemTrainedModelUsage; + if (totalUsage > 0) { + totalUsage += NATIVE_EXECUTABLE_CODE_OVERHEAD.getBytes(); + } + return nativeMemLimit - totalUsage; + } + + record MlTaskStatusCounts( + int adOpeningCount, + int adOpenedCount, + int adClosingCount, + int adFailedCount, + int datafeedStartingCount, + int datafeedStartedCount, + int datafeedStoppingCount, + int dfaStartingCount, + int dfaStartedCount, + int dfaReindexingCount, + int dfaAnalyzingCount, + int dfaStoppingCount, + int dfaFailedCount + ) { + static final MlTaskStatusCounts EMPTY = new MlTaskStatusCounts(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + } + + record TrainedModelAllocationCounts( + int trainedModelsTargetAllocations, + int trainedModelsCurrentAllocations, + int trainedModelsFailedAllocations + ) { + static final TrainedModelAllocationCounts EMPTY = new TrainedModelAllocationCounts(0, 0, 0); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java index 8ad7cd92a8e73..b1eba9a513568 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java @@ -20,12 +20,16 @@ import org.elasticsearch.cluster.metadata.MappingMetadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.IndexNotFoundException; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.ml.MlStatsIndex; +import org.elasticsearch.xpack.core.ml.MlTasks; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -45,7 +49,10 @@ import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService; +import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.core.Strings.format; @@ -72,6 +79,8 @@ public class DataFrameAnalyticsManager { /** Indicates whether the node is shutting down. */ private final AtomicBoolean nodeShuttingDown = new AtomicBoolean(); + private final Map memoryLimitById = new ConcurrentHashMap<>(); + public DataFrameAnalyticsManager( Settings settings, NodeClient client, @@ -101,6 +110,7 @@ public DataFrameAnalyticsManager( public void execute(DataFrameAnalyticsTask task, ClusterState clusterState, TimeValue masterNodeTimeout) { // With config in hand, determine action to take ActionListener configListener = ActionListener.wrap(config -> { + memoryLimitById.put(config.getId(), config.getModelMemoryLimit()); // Check if existing destination index is incompatible. // If it is, we delete it and start from reindexing. IndexMetadata destIndex = clusterState.getMetadata().index(config.getDest().getIndex()); @@ -224,6 +234,7 @@ private void executeStep(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig c case FINAL -> { LOGGER.info("[{}] Marking task completed", config.getId()); task.markAsCompleted(); + memoryLimitById.remove(config.getId()); } default -> task.markAsFailed(ExceptionsHelper.serverError("Unknown step [{}]", step)); } @@ -291,4 +302,34 @@ public boolean isNodeShuttingDown() { public void markNodeAsShuttingDown() { nodeShuttingDown.set(true); } + + /** + * Get the memory limit for a data frame analytics job if known. + * The memory limit will only be known if it is running on the + * current node, or has been very recently. + * @param id Data frame analytics job ID. + * @return The {@link ByteSizeValue} representing the memory limit, if known, otherwise {@link Optional#empty}. + */ + public Optional getMemoryLimitIfKnown(String id) { + return Optional.ofNullable(memoryLimitById.get(id)); + } + + /** + * Finds the memory used by data frame analytics jobs that are active on the current node. + * This includes jobs that are in the reindexing state, even though they don't have a running + * process, because we want to ensure that when they get as far as needing to run a process + * there'll be space for it. + * @param tasks Persistent tasks metadata. + * @return Memory used by data frame analytics jobs that are active on the current node. + */ + public ByteSizeValue getActiveTaskMemoryUsage(PersistentTasksCustomMetadata tasks) { + long memoryUsedBytes = 0; + for (Map.Entry entry : memoryLimitById.entrySet()) { + DataFrameAnalyticsState state = MlTasks.getDataFrameAnalyticsState(entry.getKey(), tasks); + if (state.consumesMemory()) { + memoryUsedBytes += entry.getValue().getBytes() + DataFrameAnalyticsConfig.PROCESS_MEMORY_OVERHEAD.getBytes(); + } + } + return ByteSizeValue.ofBytes(memoryUsedBytes); + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java index 8deac327c065e..f685afc788adc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java @@ -1062,4 +1062,23 @@ public void clusterChanged(ClusterChangedEvent event) { resetInProgress = MlMetadata.getMlMetadata(event.state()).isResetMode(); } + /** + * Finds the memory used by open autodetect processes on the current node. + * @return Memory used by open autodetect processes on the current node. + */ + public ByteSizeValue getOpenProcessMemoryUsage() { + long memoryUsedBytes = 0; + for (ProcessContext processContext : processByAllocation.values()) { + if (processContext.getState() == ProcessContext.ProcessStateName.RUNNING) { + ModelSizeStats modelSizeStats = processContext.getAutodetectCommunicator().getModelSizeStats(); + memoryUsedBytes += switch (modelSizeStats.getAssignmentMemoryBasis()) { + case MODEL_MEMORY_LIMIT -> modelSizeStats.getModelBytesMemoryLimit(); + case CURRENT_MODEL_BYTES -> modelSizeStats.getModelBytes(); + case PEAK_MODEL_BYTES -> Optional.ofNullable(modelSizeStats.getPeakModelBytes()).orElse(modelSizeStats.getModelBytes()); + }; + memoryUsedBytes += Job.PROCESS_MEMORY_OVERHEAD.getBytes(); + } + } + return ByteSizeValue.ofBytes(memoryUsedBytes); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java index 36828423ce8e9..ea6b6fb5a4d65 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -52,6 +53,7 @@ import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats; +import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats.AssignmentMemoryBasis; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSnapshot; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.Quantiles; import org.elasticsearch.xpack.ml.MachineLearning; @@ -814,6 +816,35 @@ public void testCreate_givenNonZeroCountsAndNoModelSnapshotNorQuantiles() { verifyNoMoreInteractions(auditor); } + public void testGetOpenProcessMemoryUsage() { + modelSnapshot = null; + quantiles = null; + dataCounts = new DataCounts("foo"); + dataCounts.setLatestRecordTimeStamp(new Date(0L)); + dataCounts.incrementProcessedRecordCount(42L); + long modelMemoryLimitBytes = ByteSizeValue.ofMb(randomIntBetween(10, 1000)).getBytes(); + long peakModelBytes = randomLongBetween(100000, modelMemoryLimitBytes - 1); + long modelBytes = randomLongBetween(1, peakModelBytes - 1); + AssignmentMemoryBasis assignmentMemoryBasis = randomFrom(AssignmentMemoryBasis.values()); + modelSizeStats = new ModelSizeStats.Builder("foo").setModelBytesMemoryLimit(modelMemoryLimitBytes) + .setPeakModelBytes(peakModelBytes) + .setModelBytes(modelBytes) + .setAssignmentMemoryBasis(assignmentMemoryBasis) + .build(); + when(autodetectCommunicator.getModelSizeStats()).thenReturn(modelSizeStats); + AutodetectProcessManager manager = createSpyManager(); + JobTask jobTask = mock(JobTask.class); + when(jobTask.getJobId()).thenReturn("foo"); + manager.openJob(jobTask, clusterState, DEFAULT_MASTER_NODE_TIMEOUT, (e, b) -> {}); + + long expectedSizeBytes = Job.PROCESS_MEMORY_OVERHEAD.getBytes() + switch (assignmentMemoryBasis) { + case MODEL_MEMORY_LIMIT -> modelMemoryLimitBytes; + case CURRENT_MODEL_BYTES -> modelBytes; + case PEAK_MODEL_BYTES -> peakModelBytes; + }; + assertThat(manager.getOpenProcessMemoryUsage(), equalTo(ByteSizeValue.ofBytes(expectedSizeBytes))); + } + private AutodetectProcessManager createNonSpyManager(String jobId) { ExecutorService executorService = mock(ExecutorService.class); when(threadPool.executor(anyString())).thenReturn(executorService);