From 553f247d4de272b6814ef95cd8fc75eb49c352fb Mon Sep 17 00:00:00 2001 From: Cathal O'Brien Date: Thu, 21 Nov 2024 17:44:56 +0100 Subject: [PATCH 1/8] added red gpu & increased green gpu monitoring (#147) * added red gpu monitoring, increased green gpu monitoring and refactored monitors into their own files * applied feedback * added changelog entry * fixed Changelog entry --- CHANGELOG.md | 1 + .../training/diagnostics/mlflow/logger.py | 53 ++------ .../mlflow/system_metrics/cpu_monitor.py | 41 ++++++ .../mlflow/system_metrics/gpu_monitor.py | 127 ++++++++++++++++++ 4 files changed, 183 insertions(+), 39 deletions(-) create mode 100644 src/anemoi/training/diagnostics/mlflow/system_metrics/cpu_monitor.py create mode 100644 src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 85ec45b3..3abfe1b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ Keep it human-readable, your future self will thank you! - Feat: Save a gif for longer rollouts in validation [#65](https://github.com/ecmwf/anemoi-training/pull/65) - New limited area config file added, limited_area.yaml. [#134](https://github.com/ecmwf/anemoi-training/pull/134/) - New stretched grid config added, stretched_grid.yaml [#133](https://github.com/ecmwf/anemoi-training/pull/133) +- Custom System monitor for Nvidia and AMD GPUs [#147](https://github.com/ecmwf/anemoi-training/pull/147) ### Changed diff --git a/src/anemoi/training/diagnostics/mlflow/logger.py b/src/anemoi/training/diagnostics/mlflow/logger.py index 7c482ce1..71d5c475 100644 --- a/src/anemoi/training/diagnostics/mlflow/logger.py +++ b/src/anemoi/training/diagnostics/mlflow/logger.py @@ -433,56 +433,31 @@ def experiment(self) -> MLFlowLogger.experiment: def log_system_metrics(self) -> None: """Log system metrics (CPU, GPU, etc).""" import mlflow - import psutil - from mlflow.system_metrics.metrics.base_metrics_monitor import BaseMetricsMonitor from mlflow.system_metrics.metrics.disk_monitor import DiskMonitor - from mlflow.system_metrics.metrics.gpu_monitor import GPUMonitor from mlflow.system_metrics.metrics.network_monitor import NetworkMonitor from mlflow.system_metrics.system_metrics_monitor import SystemMetricsMonitor - class CustomCPUMonitor(BaseMetricsMonitor): - """Class for monitoring CPU stats. - - Extends default CPUMonitor, to also measure total \ - memory and a different formula for calculating used memory. - - """ - - def collect_metrics(self) -> None: - # Get CPU metrics. - cpu_percent = psutil.cpu_percent() - self._metrics["cpu_utilization_percentage"].append(cpu_percent) - - system_memory = psutil.virtual_memory() - # Change the formula for measuring CPU memory usage - # By default Mlflow uses psutil.virtual_memory().used - # Tests have shown that "used" underreports memory usage by as much as a factor of 2, - # "used" also misses increased memory usage from using a higher prefetch factor - self._metrics["system_memory_usage_megabytes"].append( - (system_memory.total - system_memory.available) / 1e6, - ) - self._metrics["system_memory_usage_percentage"].append(system_memory.percent) - - # QOL: report the total system memory in raw numbers - self._metrics["system_memory_total_megabytes"].append(system_memory.total / 1e6) - - def aggregate_metrics(self) -> dict[str, int]: - return {k: round(sum(v) / len(v), 1) for k, v in self._metrics.items()} + from anemoi.training.diagnostics.mlflow.system_metrics.cpu_monitor import CPUMonitor + from anemoi.training.diagnostics.mlflow.system_metrics.gpu_monitor import GreenGPUMonitor + from anemoi.training.diagnostics.mlflow.system_metrics.gpu_monitor import RedGPUMonitor class CustomSystemMetricsMonitor(SystemMetricsMonitor): def __init__(self, run_id: str, resume_logging: bool = False): super().__init__(run_id, resume_logging=resume_logging) - # Replace the CPUMonitor with custom implementation - self.monitors = [CustomCPUMonitor(), DiskMonitor(), NetworkMonitor()] + self.monitors = [CPUMonitor(), DiskMonitor(), NetworkMonitor()] + + # Try init both and catch the error when one init fails try: - gpu_monitor = GPUMonitor() + gpu_monitor = GreenGPUMonitor() self.monitors.append(gpu_monitor) - except ImportError: - LOGGER.warning( - "`pynvml` is not installed, to log GPU metrics please run `pip install pynvml` \ - to install it", - ) + except (ImportError, RuntimeError) as e: + LOGGER.warning("Failed to init Nvidia GPU Monitor: %s", e) + try: + gpu_monitor = RedGPUMonitor() + self.monitors.append(gpu_monitor) + except (ImportError, RuntimeError) as e: + LOGGER.warning("Failed to init AMD GPU Monitor: %s", e) mlflow.enable_system_metrics_logging() system_monitor = CustomSystemMetricsMonitor( diff --git a/src/anemoi/training/diagnostics/mlflow/system_metrics/cpu_monitor.py b/src/anemoi/training/diagnostics/mlflow/system_metrics/cpu_monitor.py new file mode 100644 index 00000000..fbf6b3e5 --- /dev/null +++ b/src/anemoi/training/diagnostics/mlflow/system_metrics/cpu_monitor.py @@ -0,0 +1,41 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import psutil +from mlflow.system_metrics.metrics.base_metrics_monitor import BaseMetricsMonitor + + +class CPUMonitor(BaseMetricsMonitor): + """Class for monitoring CPU stats. + + Extends default CPUMonitor, to also measure total \ + memory and a different formula for calculating used memory. + + """ + + def collect_metrics(self) -> None: + # Get CPU metrics. + cpu_percent = psutil.cpu_percent() + self._metrics["cpu_utilization_percentage"].append(cpu_percent) + + system_memory = psutil.virtual_memory() + # Change the formula for measuring CPU memory usage + # By default Mlflow uses psutil.virtual_memory().used + # Tests have shown that "used" underreports memory usage by as much as a factor of 2, + # "used" also misses increased memory usage from using a higher prefetch factor + self._metrics["system_memory_usage_megabytes"].append( + (system_memory.total - system_memory.available) / 1e6, + ) + self._metrics["system_memory_usage_percentage"].append(system_memory.percent) + + # QOL: report the total system memory in raw numbers + self._metrics["system_memory_total_megabytes"].append(system_memory.total / 1e6) + + def aggregate_metrics(self) -> dict[str, int]: + return {k: round(sum(v) / len(v), 1) for k, v in self._metrics.items()} diff --git a/src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py b/src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py new file mode 100644 index 00000000..b5a2c132 --- /dev/null +++ b/src/anemoi/training/diagnostics/mlflow/system_metrics/gpu_monitor.py @@ -0,0 +1,127 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import contextlib +import sys + +from mlflow.system_metrics.metrics.base_metrics_monitor import BaseMetricsMonitor + +with contextlib.suppress(ImportError): + import pynvml +with contextlib.suppress(ImportError): + from pyrsmi import rocml + + +class GreenGPUMonitor(BaseMetricsMonitor): + """Class for monitoring Nvidia GPU stats. + + Requires pynvml to be installed. + Extends default GPUMonitor, to also measure total \ + memory + + """ + + def __init__(self): + if "pynvml" not in sys.modules: + # Only instantiate if `pynvml` is installed. + import_error_msg = "`pynvml` is not installed, if you are running on an Nvidia GPU \ + and want to log GPU metrics please run `pip install pynvml`." + raise ImportError(import_error_msg) + try: + # `nvmlInit()` will fail if no GPU is found. + pynvml.nvmlInit() + except pynvml.NVMLError as e: + runtime_error_msg = "Failed to initalize Nvidia GPU monitor: " + raise RuntimeError(runtime_error_msg) from e + + super().__init__() + self.num_gpus = pynvml.nvmlDeviceGetCount() + self.gpu_handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in range(self.num_gpus)] + + def collect_metrics(self) -> None: + # Get GPU metrics. + for i, handle in enumerate(self.gpu_handles): + memory = pynvml.nvmlDeviceGetMemoryInfo(handle) + self._metrics[f"gpu_{i}_memory_usage_percentage"].append( + round(memory.used / memory.total * 100, 1), + ) + self._metrics[f"gpu_{i}_memory_usage_megabytes"].append(memory.used / 1e6) + + # Only record total device memory on GPU 0 to prevent spam + # Unlikely for GPUs on the same node to have different total memory + if i == 0: + self._metrics["gpu_memory_total_megabytes"].append(memory.total / 1e6) + + # Monitor PCIe usage + tx_kilobytes = pynvml.nvmlDeviceGetPcieThroughput(handle, pynvml.NVML_PCIE_UTIL_TX_BYTES) + rx_kilobytes = pynvml.nvmlDeviceGetPcieThroughput(handle, pynvml.NVML_PCIE_UTIL_RX_BYTES) + self._metrics[f"gpu_{i}_pcie_tx_megabytes"].append(tx_kilobytes / 1e3) + self._metrics[f"gpu_{i}_pcie_rx_megabytes"].append(rx_kilobytes / 1e3) + + device_utilization = pynvml.nvmlDeviceGetUtilizationRates(handle) + self._metrics[f"gpu_{i}_utilization_percentage"].append(device_utilization.gpu) + + power_milliwatts = pynvml.nvmlDeviceGetPowerUsage(handle) + power_capacity_milliwatts = pynvml.nvmlDeviceGetEnforcedPowerLimit(handle) + self._metrics[f"gpu_{i}_power_usage_watts"].append(power_milliwatts / 1000) + self._metrics[f"gpu_{i}_power_usage_percentage"].append( + (power_milliwatts / power_capacity_milliwatts) * 100, + ) + + def aggregate_metrics(self) -> dict[str, int]: + return {k: round(sum(v) / len(v), 1) for k, v in self._metrics.items()} + + +class RedGPUMonitor(BaseMetricsMonitor): + """Class for monitoring AMD GPU stats. + + Requires that pyrsmi is installed + Logs utilization and memory usage. + + """ + + def __init__(self): + if "pyrsmi" not in sys.modules: + import_error_msg = "`pyrsmi` is not installed, if you are running on an AMD GPU \ + and want to log GPU metrics please run `pip install pyrsmi`." + # Only instantiate if `pyrsmi` is installed. + raise ImportError(import_error_msg) + try: + # `rocml.smi_initialize()()` will fail if no GPU is found. + rocml.smi_initialize() + except RuntimeError as e: + runtime_error_msg = "Failed to initalize AMD GPU monitor: " + raise RuntimeError(runtime_error_msg) from e + + super().__init__() + self.num_gpus = rocml.smi_get_device_count() + + def collect_metrics(self) -> None: + # Get GPU metrics. + for device in range(self.num_gpus): + memory_used = rocml.smi_get_device_memory_used(device) + memory_total = rocml.smi_get_device_memory_total(device) + memory_busy = rocml.smi_get_device_memory_busy(device) + self._metrics[f"gpu_{device}_memory_usage_percentage"].append( + round(memory_used / memory_total * 100, 1), + ) + self._metrics[f"gpu_{device}_memory_usage_megabytes"].append(memory_used / 1e6) + + self._metrics[f"gpu_{device}_memory_busy_percentage"].append(memory_busy) + + # Only record total device memory on GPU 0 to prevent spam + # Unlikely for GPUs on the same node to have different total memory + if device == 0: + self._metrics["gpu_memory_total_megabytes"].append(memory_total / 1e6) + + utilization = rocml.smi_get_device_utilization(device) + self._metrics[f"gpu_{device}_utilization_percentage"].append(utilization) + + def aggregate_metrics(self) -> dict[str, int]: + return {k: round(sum(v) / len(v), 1) for k, v in self._metrics.items()} From 25abf5e143a29d5931ccb4ac42a5f83c5cd26851 Mon Sep 17 00:00:00 2001 From: Sara Hahner <44293258+sahahner@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:24:17 +0100 Subject: [PATCH 2/8] Feature/mask NaNs in training loss function (#72) * feat: mask NaNs in training loss function --------- Co-authored-by: Jakob Schloer Co-authored-by: Harrison Cook --- CHANGELOG.md | 1 + .../training/config/training/default.yaml | 4 ++- src/anemoi/training/train/forecaster.py | 27 ++++++++++++++++++- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3abfe1b5..205b92d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -113,6 +113,7 @@ Keep it human-readable, your future self will thank you! - Feature: Support training for datasets with missing time steps [#48](https://github.com/ecmwf/anemoi-training/pulls/48) - Feature: `AnemoiMlflowClient`, an mlflow client with authentication support [#86](https://github.com/ecmwf/anemoi-training/pull/86) - Long Rollout Plots +- Mask NaN values in training loss function [#72](https://github.com/ecmwf/anemoi-training/pull/72) and [#271](https://github.com/ecmwf-lab/aifs-mono/issues/271) ### Fixed diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index b471034e..1c103827 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -48,7 +48,9 @@ training_loss: # Scalars to include in loss calculation # Available scalars include: # - 'variable': See `variable_loss_scaling` for more information - scalars: ['variable'] + # - 'loss_weights_mask': Giving imputed NaNs a zero weight in the loss function + scalars: ['variable', 'loss_weights_mask'] + ignore_nans: False loss_gradient_scaling: False diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 80459b8f..5c9f5e84 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -103,7 +103,10 @@ def __init__( # Kwargs to pass to the loss function loss_kwargs = {"node_weights": self.node_weights} # Scalars to include in the loss function, must be of form (dim, scalar) - scalars = {"variable": (-1, variable_scaling)} + # Add mask multiplying NaN locations with zero. At this stage at [[1]]. + # Filled after first application of preprocessor. dimension=[-2, -1] (latlon, n_outputs). + scalars = {"variable": (-1, variable_scaling), "loss_weights_mask": ((-2, -1), torch.ones((1, 1)))} + self.updated_loss_mask = False self.loss = self.get_loss_function(config.training.training_loss, scalars=scalars, **loss_kwargs) @@ -217,6 +220,24 @@ def get_loss_function( return loss_function + def training_weights_for_imputed_variables( + self, + batch: torch.Tensor, + ) -> None: + """Update the loss weights mask for imputed variables.""" + if "loss_weights_mask" in self.loss.scalar: + loss_weights_mask = torch.ones((1, 1), device=batch.device) + # iterate over all pre-processors and check if they have a loss_mask_training attribute + for pre_processor in self.model.pre_processors.processors.values(): + if hasattr(pre_processor, "loss_mask_training"): + loss_weights_mask = loss_weights_mask * pre_processor.loss_mask_training + # if transform_loss_mask function exists for preprocessor apply it + if hasattr(pre_processor, "transform_loss_mask"): + loss_weights_mask = pre_processor.transform_loss_mask(loss_weights_mask) + # update scaler with loss_weights_mask retrieved from preprocessors + self.loss.update_scalar(scalar=loss_weights_mask.cpu(), name="loss_weights_mask") + self.updated_loss_mask = True + @staticmethod def get_val_metric_ranges(config: DictConfig, data_indices: IndexCollection) -> tuple[dict, dict]: @@ -361,6 +382,10 @@ def rollout_step( # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) + if not self.updated_loss_mask: + # update loss scalar after first application and initialization of preprocessors + self.training_weights_for_imputed_variables(batch) + # start rollout of preprocessed batch x = batch[ :, From 11ded7db6dfbb84f3654426aff62d98cc376b2f8 Mon Sep 17 00:00:00 2001 From: gabrieloks <116646686+gabrieloks@users.noreply.github.com> Date: Fri, 22 Nov 2024 13:48:50 +0000 Subject: [PATCH 3/8] [FIX] Power spectra bug on n320 (LAM?) (#149) * Update plots.py To plot the power spectra, we need to create a regular grid (n_pix_lat x n_pix_lon) and interpolate the data on it. The way n_pix_lat and n_pix_lon were previously defined is not robust and might lead to errors. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update plots.py * Update CHANGELOG.md --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 2 ++ src/anemoi/training/diagnostics/plots.py | 8 +++----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 205b92d5..953f2b8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD) ### Fixed +Fixed bug in power spectra plotting for the n320 resolution. + ### Added ### Changed diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index e0b44d1c..6d1d1c8c 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -171,15 +171,13 @@ def plot_power_spectrum( pc_lon = np.array(pc_lon) pc_lat = np.array(pc_lat) - # Calculate delta_lon and delta_lat on the projected grid - delta_lon = abs(np.diff(pc_lon)) - non_zero_delta_lon = delta_lon[delta_lon != 0] + # Calculate delta_lat on the projected grid delta_lat = abs(np.diff(pc_lat)) non_zero_delta_lat = delta_lat[delta_lat != 0] # Define a regular grid for interpolation - n_pix_lon = int(np.floor(abs(pc_lon.max() - pc_lon.min()) / abs(np.min(non_zero_delta_lon)))) # around 400 for O96 - n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / abs(np.min(non_zero_delta_lat)))) # around 192 for O96 + n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / abs(np.min(non_zero_delta_lat)))) + n_pix_lon = (n_pix_lat - 1) * 2 + 1 # 2*lmax + 1 regular_pc_lon = np.linspace(pc_lon.min(), pc_lon.max(), n_pix_lon) regular_pc_lat = np.linspace(pc_lat.min(), pc_lat.max(), n_pix_lat) grid_pc_lon, grid_pc_lat = np.meshgrid(regular_pc_lon, regular_pc_lat) From cf53a6e3eea81a789a2ffc3b9d867afc1ce132d6 Mon Sep 17 00:00:00 2001 From: Jan Polster Date: Fri, 22 Nov 2024 17:11:32 +0100 Subject: [PATCH 4/8] Feature/Improve Dataloader Memory with Read Groups (#76) * feat: improve dataloader memory - Add reader groups to support sharded reading of batches - Add dataloader.read_group_size in config to control read behaviour - Add GraphForecaster.allgather_batch() to reconstruct full batch from shards - Refactor callbacks to call allgather on batches as needed * docs: update docstring with instructions on reader group usage * refactor: rank computations via SLURM_PROCID - Pass model/reader group information from DDPGroupStrategy instead --------- Co-authored-by: gabrieloks <116646686+gabrieloks@users.noreply.github.com> Co-authored-by: Harrison Cook Co-authored-by: sahahner --- CHANGELOG.md | 2 + docs/user-guide/distributed.rst | 4 + .../config/dataloader/native_grid.yaml | 11 ++ src/anemoi/training/data/datamodule.py | 29 ---- src/anemoi/training/data/dataset.py | 82 +++++++++-- .../diagnostics/callbacks/evaluation.py | 4 +- .../training/diagnostics/callbacks/plot.py | 16 ++- src/anemoi/training/distributed/strategy.py | 130 +++++++++++++++--- src/anemoi/training/train/forecaster.py | 81 +++++++++-- src/anemoi/training/train/train.py | 4 +- 10 files changed, 283 insertions(+), 80 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 953f2b8a..88488369 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ Fixed bug in power spectra plotting for the n320 resolution. ### Added +- Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) + ### Changed ## [0.3.0 - Loss & Callback Refactors](https://github.com/ecmwf/anemoi-training/compare/0.2.2...0.3.0) - 2024-11-14 diff --git a/docs/user-guide/distributed.rst b/docs/user-guide/distributed.rst index 40ee4d65..68d7697a 100644 --- a/docs/user-guide/distributed.rst +++ b/docs/user-guide/distributed.rst @@ -45,6 +45,10 @@ number of GPUs you wish to shard the model across. It is recommended to only shard if the model does not fit in GPU memory, as data distribution is a much more efficient way to parallelise the training. +When using model sharding, ``config.dataloader.read_group_size`` allows +for sharded data loading in subgroups. This should be set to the number +of GPUs per model for optimal performance. + ********* Example ********* diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index d7aa4f6d..9513ecc7 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -1,6 +1,17 @@ prefetch_factor: 2 pin_memory: True +# ============ +# read_group_size: +# Form subgroups of model comm groups that read data together. +# Each reader in the group only reads 1/read_group_size of the data +# which is then all-gathered between the group. +# This can reduce CPU memory usage as well as increase dataloader throughput. +# The number of GPUs per model must be divisible by read_group_size. +# To disable, set to 1. +# ============ +read_group_size: ${hardware.num_gpus_per_model} + num_workers: training: 8 validation: 8 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 303266fc..6d8e6da0 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -9,7 +9,6 @@ import logging -import os from functools import cached_property from typing import Callable @@ -43,31 +42,6 @@ def __init__(self, config: DictConfig) -> None: self.config = config - self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) # global rank - self.model_comm_group_id = ( - self.global_rank // self.config.hardware.num_gpus_per_model - ) # id of the model communication group the rank is participating in - self.model_comm_group_rank = ( - self.global_rank % self.config.hardware.num_gpus_per_model - ) # rank within one model communication group - total_gpus = self.config.hardware.num_gpus_per_node * self.config.hardware.num_nodes - assert ( - total_gpus - ) % self.config.hardware.num_gpus_per_model == 0, ( - f"GPUs per model {self.config.hardware.num_gpus_per_model} does not divide total GPUs {total_gpus}" - ) - self.model_comm_num_groups = ( - self.config.hardware.num_gpus_per_node - * self.config.hardware.num_nodes - // self.config.hardware.num_gpus_per_model - ) # number of model communication groups - LOGGER.debug( - "Rank %d model communication group number %d, with local model communication group rank %d", - self.global_rank, - self.model_comm_group_id, - self.model_comm_group_rank, - ) - # Set the maximum rollout to be expected self.rollout = ( self.config.training.rollout.max @@ -182,9 +156,6 @@ def _get_dataset( rollout=r, multistep=self.config.training.multistep_input, timeincrement=self.timeincrement, - model_comm_group_rank=self.model_comm_group_rank, - model_comm_group_id=self.model_comm_group_id, - model_comm_num_groups=self.model_comm_num_groups, shuffle=shuffle, label=label, ) diff --git a/src/anemoi/training/data/dataset.py b/src/anemoi/training/data/dataset.py index 9e368f9c..40065e06 100644 --- a/src/anemoi/training/data/dataset.py +++ b/src/anemoi/training/data/dataset.py @@ -36,9 +36,6 @@ def __init__( rollout: int = 1, multistep: int = 1, timeincrement: int = 1, - model_comm_group_rank: int = 0, - model_comm_group_id: int = 0, - model_comm_num_groups: int = 1, shuffle: bool = True, label: str = "generic", ) -> None: @@ -54,12 +51,6 @@ def __init__( time increment between samples, by default 1 multistep : int, optional collate (t-1, ... t - multistep) into the input state vector, by default 1 - model_comm_group_rank : int, optional - process rank in the torch.distributed group (important when running on multiple GPUs), by default 0 - model_comm_group_id: int, optional - device group ID, default 0 - model_comm_num_groups : int, optional - total number of device groups, by default 1 shuffle : bool, optional Shuffle batches, by default True label : str, optional @@ -77,11 +68,14 @@ def __init__( self.n_samples_per_epoch_total: int = 0 self.n_samples_per_epoch_per_worker: int = 0 - # DDP-relevant info - self.model_comm_group_rank = model_comm_group_rank - self.model_comm_num_groups = model_comm_num_groups - self.model_comm_group_id = model_comm_group_id - self.global_rank = int(os.environ.get("SLURM_PROCID", "0")) + # lazy init model and reader group info, will be set by the DDPGroupStrategy: + self.model_comm_group_rank = 0 + self.model_comm_num_groups = 1 + self.model_comm_group_id = 0 + self.global_rank = 0 + + self.reader_group_rank = 0 + self.reader_group_size = 1 # additional state vars (lazy init) self.n_samples_per_worker = 0 @@ -93,6 +87,8 @@ def __init__( assert self.multi_step > 0, "Multistep value must be greater than zero." self.ensemble_dim: int = 2 self.ensemble_size = self.data.shape[self.ensemble_dim] + self.grid_dim: int = -1 + self.grid_size = self.data.shape[self.grid_dim] @cached_property def statistics(self) -> dict: @@ -128,6 +124,58 @@ def valid_date_indices(self) -> np.ndarray: """ return get_usable_indices(self.data.missing, len(self.data), self.rollout, self.multi_step, self.timeincrement) + def set_comm_group_info( + self, + global_rank: int, + model_comm_group_id: int, + model_comm_group_rank: int, + model_comm_num_groups: int, + reader_group_rank: int, + reader_group_size: int, + ) -> None: + """Set model and reader communication group information (called by DDPGroupStrategy). + + Parameters + ---------- + global_rank : int + Global rank + model_comm_group_id : int + Model communication group ID + model_comm_group_rank : int + Model communication group rank + model_comm_num_groups : int + Number of model communication groups + reader_group_rank : int + Reader group rank + reader_group_size : int + Reader group size + """ + self.global_rank = global_rank + self.model_comm_group_id = model_comm_group_id + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.reader_group_rank = reader_group_rank + self.reader_group_size = reader_group_size + + if self.reader_group_size > 1: + # get the grid shard size and start/end indices + grid_shard_size = self.grid_size // self.reader_group_size + self.grid_start = self.reader_group_rank * grid_shard_size + if self.reader_group_rank == self.reader_group_size - 1: + self.grid_end = self.grid_size + else: + self.grid_end = (self.reader_group_rank + 1) * grid_shard_size + + LOGGER.debug( + "NativeGridDataset.set_group_info(): global_rank %d, model_comm_group_id %d, " + "model_comm_group_rank %d, model_comm_num_groups %d, reader_group_rank %d", + global_rank, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + reader_group_rank, + ) + def per_worker_init(self, n_workers: int, worker_id: int) -> None: """Called by worker_init_func on each copy of dataset. @@ -233,7 +281,11 @@ def __iter__(self) -> torch.Tensor: start = i - (self.multi_step - 1) * self.timeincrement end = i + (self.rollout + 1) * self.timeincrement - x = self.data[start : end : self.timeincrement] + if self.reader_group_size > 1: # read only a subset of the grid + x = self.data[start : end : self.timeincrement, :, :, self.grid_start : self.grid_end] + else: # read the full grid + x = self.data[start : end : self.timeincrement, :, :, :] + x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables") self.ensemble_dim = 1 diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index fc812121..cbc929d6 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -15,7 +15,6 @@ import torch from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_only if TYPE_CHECKING: import pytorch_lightning as pl @@ -103,7 +102,6 @@ def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, rank_zero_only=True, ) - @rank_zero_only def on_validation_batch_end( self, trainer: pl.Trainer, @@ -114,6 +112,8 @@ def on_validation_batch_end( ) -> None: del outputs # outputs are not used if batch_idx % self.every_n_batches == 0: + batch = pl_module.allgather_batch(batch) + precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index b13d8727..08f9d28b 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -241,7 +241,6 @@ def __init__(self, config: OmegaConf, every_n_batches: int | None = None): super().__init__(config) self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch - @rank_zero_only def on_validation_batch_end( self, trainer: pl.Trainer, @@ -251,7 +250,16 @@ def on_validation_batch_end( batch_idx: int, **kwargs, ) -> None: + if ( + self.config.diagnostics.plot.asynchronous + and self.config.dataloader.read_group_size > 1 + and pl_module.local_rank == 0 + ): + LOGGER.warning("Asynchronous plotting can result in NCCL timeouts with reader_group_size > 1.") + if batch_idx % self.every_n_batches == 0: + batch = pl_module.allgather_batch(batch) + self.plot( trainer, pl_module, @@ -383,7 +391,6 @@ def __init__( every_n_epochs, ) - @rank_zero_only def _plot( self, trainer: pl.Trainer, @@ -480,6 +487,7 @@ def _plot( LOGGER.info("Time taken to plot/animate samples for longer rollout: %d seconds", int(time.time() - start_time)) + @rank_zero_only def _plot_rollout_step( self, pl_module: pl.LightningModule, @@ -539,6 +547,7 @@ def _store_video_frame_data( vmax[:] = np.maximum(vmax, np.nanmax(data_over_time[-1], axis=1)) return data_over_time, vmin, vmax + @rank_zero_only def _generate_video_rollout( self, data_0: np.ndarray, @@ -595,7 +604,6 @@ def _generate_video_rollout( tag=f"gnn_pred_val_animation_{variable_name}_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", ) - @rank_zero_only def on_validation_batch_end( self, trainer: pl.Trainer, @@ -605,6 +613,8 @@ def on_validation_batch_end( batch_idx: int, ) -> None: if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.every_n_epochs == 0: + batch = pl_module.allgather_batch(batch) + precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, diff --git a/src/anemoi/training/distributed/strategy.py b/src/anemoi/training/distributed/strategy.py index c6509795..32c96dc6 100644 --- a/src/anemoi/training/distributed/strategy.py +++ b/src/anemoi/training/distributed/strategy.py @@ -9,7 +9,6 @@ import logging -import os import numpy as np import pytorch_lightning as pl @@ -27,19 +26,22 @@ class DDPGroupStrategy(DDPStrategy): """Distributed Data Parallel strategy with group communication.""" - def __init__(self, num_gpus_per_model: int, **kwargs: dict) -> None: + def __init__(self, num_gpus_per_model: int, read_group_size: int, **kwargs: dict) -> None: """Initialize the distributed strategy. Parameters ---------- num_gpus_per_model : int Number of GPUs per model to shard over. + read_group_size : int + Number of GPUs per reader group. **kwargs : dict Additional keyword arguments. """ super().__init__(**kwargs) self.model_comm_group_size = num_gpus_per_model + self.read_group_size = read_group_size def setup(self, trainer: pl.Trainer) -> None: assert self.accelerator is not None, "Accelerator is not initialized for distributed strategy" @@ -60,18 +62,56 @@ def setup(self, trainer: pl.Trainer) -> None: torch.distributed.new_group(x) for x in model_comm_group_ranks ] # every rank has to create all of these - model_comm_group_id, model_comm_group_nr, model_comm_group_rank = self.get_my_model_comm_group( + model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( self.model_comm_group_size, ) model_comm_group = model_comm_groups[model_comm_group_id] - self.model.set_model_comm_group(model_comm_group) + self.model.set_model_comm_group( + model_comm_group, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + self.model_comm_group_size, + ) + + # set up reader groups by further splitting model_comm_group_ranks with read_group_size: + + assert self.model_comm_group_size % self.read_group_size == 0, ( + f"Number of GPUs per model ({self.model_comm_group_size}) must be divisible by read_group_size " + f"({self.read_group_size})." + ) + + reader_group_ranks = np.array( + [ + np.split(group_ranks, int(self.model_comm_group_size / self.read_group_size)) + for group_ranks in model_comm_group_ranks + ], + ) # Shape: (num_model_comm_groups, model_comm_grp_size/read_group_size, read_group_size) + reader_groups = [[torch.distributed.new_group(x) for x in group_ranks] for group_ranks in reader_group_ranks] + reader_group_id, reader_group_rank, reader_group_size, reader_group_root = self.get_my_reader_group( + model_comm_group_rank, + self.read_group_size, + ) + # get all reader groups of the current model group + model_reader_groups = reader_groups[model_comm_group_id] + self.model.set_reader_groups( + model_reader_groups, + reader_group_id, + reader_group_rank, + reader_group_size, + ) + LOGGER.debug( - "Rank %d model_comm_group is %s, group number %d, with local group rank %d and comms_group_ranks %s", + "Rank %d model_comm_group_id: %d model_comm_group: %s model_comm_group_rank: %d " + "reader_group_id: %d reader_group: %s reader_group_rank: %d reader_group_root (global): %d", self.global_rank, - str(model_comm_group_nr), model_comm_group_id, - model_comm_group_rank, str(model_comm_group_ranks[model_comm_group_id]), + model_comm_group_rank, + reader_group_id, + reader_group_ranks[model_comm_group_id, reader_group_id], + reader_group_rank, + reader_group_root, ) # register hooks for correct gradient reduction @@ -109,7 +149,7 @@ def setup(self, trainer: pl.Trainer) -> None: # seed ranks self.seed_rnd(model_comm_group_id) - def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndarray, int]: + def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, int, int]: """Determine tasks that work together and from a model group. Parameters @@ -119,19 +159,69 @@ def get_my_model_comm_group(self, num_gpus_per_model: int) -> tuple[int, np.ndar Returns ------- - tuple[int, np.ndarray, int] - Model_comm_group id, Model_comm_group Nr, Model_comm_group rank + tuple[int, int, int] + Model_comm_group id, Model_comm_group rank, Number of model_comm_groups + """ + model_comm_group_id = self.global_rank // num_gpus_per_model + model_comm_group_rank = self.global_rank % num_gpus_per_model + model_comm_num_groups = self.world_size // num_gpus_per_model + + return model_comm_group_id, model_comm_group_rank, model_comm_num_groups + + def get_my_reader_group(self, model_comm_group_rank: int, read_group_size: int) -> tuple[int, int, int]: + """Determine tasks that work together and from a reader group. + + Parameters + ---------- + model_comm_group_rank : int + Rank within the model communication group. + read_group_size : int + Number of dataloader readers per model group. + + Returns + ------- + tuple[int, int, int] + Reader_group id, Reader_group rank, Reader_group root (global rank) """ - model_comm_groups = np.arange(0, self.world_size, dtype=np.int32) - model_comm_groups = np.split(model_comm_groups, self.world_size / num_gpus_per_model) + reader_group_id = model_comm_group_rank // read_group_size + reader_group_rank = model_comm_group_rank % read_group_size + reader_group_size = read_group_size + reader_group_root = (self.global_rank // read_group_size) * read_group_size + + return reader_group_id, reader_group_rank, reader_group_size, reader_group_root - model_comm_group_id = None - for i, model_comm_group in enumerate(model_comm_groups): - if self.global_rank in model_comm_group: - model_comm_group_id = i - model_comm_group_nr = model_comm_group - model_comm_group_rank = np.ravel(np.asarray(model_comm_group == self.global_rank).nonzero())[0] - return model_comm_group_id, model_comm_group_nr, model_comm_group_rank + def process_dataloader(self, dataloader: torch.utils.data.DataLoader) -> torch.utils.data.DataLoader: + """Pass communication group information to the dataloader for distributed training. + + Parameters + ---------- + dataloader : torch.utils.data.DataLoader + Dataloader to process. + + Returns + ------- + torch.utils.data.DataLoader + Processed dataloader. + + """ + dataloader = super().process_dataloader(dataloader) + + # pass model and reader group information to the dataloaders dataset + model_comm_group_id, model_comm_group_rank, model_comm_num_groups = self.get_my_model_comm_group( + self.model_comm_group_size, + ) + _, reader_group_rank, _, _ = self.get_my_reader_group(model_comm_group_rank, self.read_group_size) + + dataloader.dataset.set_comm_group_info( + self.global_rank, + model_comm_group_id, + model_comm_group_rank, + model_comm_num_groups, + reader_group_rank, + self.read_group_size, + ) + + return dataloader def seed_rnd(self, model_comm_group_id: int) -> None: """Seed the random number generators for the rank.""" @@ -145,7 +235,7 @@ def seed_rnd(self, model_comm_group_id: int) -> None: "Strategy: Rank %d, model comm group id %d, base seed %d, seeded with %d, " "running with random seed: %d, sanity rnd: %s" ), - int(os.environ.get("SLURM_PROCID", "0")), + self.global_rank, model_comm_group_id, base_seed, initial_seed, diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 5c9f5e84..659c906c 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -9,8 +9,6 @@ import logging -import math -import os from collections import defaultdict from collections.abc import Generator from collections.abc import Mapping @@ -138,17 +136,20 @@ def __init__( self.use_zero_optimizer = config.training.zero_optimizer self.model_comm_group = None + self.reader_groups = None LOGGER.debug("Rollout window length: %d", self.rollout) LOGGER.debug("Rollout increase every : %d epochs", self.rollout_epoch_increment) LOGGER.debug("Rollout max : %d", self.rollout_max) LOGGER.debug("Multistep: %d", self.multi_step) - self.model_comm_group_id = int(os.environ.get("SLURM_PROCID", "0")) // config.hardware.num_gpus_per_model - self.model_comm_group_rank = int(os.environ.get("SLURM_PROCID", "0")) % config.hardware.num_gpus_per_model - self.model_comm_num_groups = math.ceil( - config.hardware.num_gpus_per_node * config.hardware.num_nodes / config.hardware.num_gpus_per_model, - ) + # lazy init model and reader group info, will be set by the DDPGroupStrategy: + self.model_comm_group_id = 0 + self.model_comm_group_rank = 0 + self.model_comm_num_groups = 1 + + self.reader_group_id = 0 + self.reader_group_rank = 0 def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x, self.model_comm_group) @@ -313,9 +314,31 @@ def get_variable_scaling( return torch.from_numpy(variable_loss_scaling) - def set_model_comm_group(self, model_comm_group: ProcessGroup) -> None: - LOGGER.debug("set_model_comm_group: %s", model_comm_group) + def set_model_comm_group( + self, + model_comm_group: ProcessGroup, + model_comm_group_id: int, + model_comm_group_rank: int, + model_comm_num_groups: int, + model_comm_group_size: int, + ) -> None: self.model_comm_group = model_comm_group + self.model_comm_group_id = model_comm_group_id + self.model_comm_group_rank = model_comm_group_rank + self.model_comm_num_groups = model_comm_num_groups + self.model_comm_group_size = model_comm_group_size + + def set_reader_groups( + self, + reader_groups: list[ProcessGroup], + reader_group_id: int, + reader_group_rank: int, + reader_group_size: int, + ) -> None: + self.reader_groups = reader_groups + self.reader_group_id = reader_group_id + self.reader_group_rank = reader_group_rank + self.reader_group_size = reader_group_size def advance_input( self, @@ -425,6 +448,8 @@ def _step( validation_mode: bool = False, ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: del batch_idx + batch = self.allgather_batch(batch) + loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) metrics = {} y_preds = [] @@ -442,6 +467,44 @@ def _step( loss *= 1.0 / self.rollout return loss, metrics, y_preds + def allgather_batch(self, batch: torch.Tensor) -> torch.Tensor: + """Allgather the batch-shards across the reader group. + + Parameters + ---------- + batch : torch.Tensor + Batch-shard of current reader rank + + Returns + ------- + torch.Tensor + Allgathered (full) batch + """ + grid_size = self.model.metadata["dataset"]["shape"][-1] + + if grid_size == batch.shape[-2]: + return batch # already have the full grid + + grid_shard_size = grid_size // self.reader_group_size + last_grid_shard_size = grid_size - (grid_shard_size * (self.reader_group_size - 1)) + + # prepare tensor list with correct shapes for all_gather + shard_shape = list(batch.shape) + shard_shape[-2] = grid_shard_size + last_shard_shape = list(batch.shape) + last_shard_shape[-2] = last_grid_shard_size + + tensor_list = [torch.empty(tuple(shard_shape), device=self.device) for _ in range(self.reader_group_size - 1)] + tensor_list.append(torch.empty(last_shard_shape, device=self.device)) + + torch.distributed.all_gather( + tensor_list, + batch, + group=self.reader_groups[self.reader_group_id], + ) + + return torch.cat(tensor_list, dim=-2) + def calculate_val_metrics( self, y_pred: torch.Tensor, diff --git a/src/anemoi/training/train/train.py b/src/anemoi/training/train/train.py index 553114f5..80fc70d3 100644 --- a/src/anemoi/training/train/train.py +++ b/src/anemoi/training/train/train.py @@ -12,7 +12,6 @@ import datetime import logging -import os from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING @@ -106,7 +105,7 @@ def initial_seed(self) -> int: (torch.rand(1), np_rng.random()) LOGGER.debug( "Initial seed: Rank %d, initial seed %d, running with random seed: %d", - int(os.environ.get("SLURM_PROCID", "0")), + self.strategy.global_rank, initial_seed, rnd_seed, ) @@ -345,6 +344,7 @@ def strategy(self) -> DDPGroupStrategy: """Training strategy.""" return DDPGroupStrategy( self.config.hardware.num_gpus_per_model, + self.config.dataloader.get("read_group_size", self.config.hardware.num_gpus_per_model), static_graph=not self.config.training.accum_grad_batches > 1, ) From e3fe023553230c57d6d1b40720188f9cd28d1f3b Mon Sep 17 00:00:00 2001 From: gabrieloks <116646686+gabrieloks@users.noreply.github.com> Date: Mon, 25 Nov 2024 11:35:37 +0000 Subject: [PATCH 5/8] Validation and Training dataset dates assertion (#154) * Change training end date/validation start date assertion to a warning to allow flexibility. Co-authored-by: Harrison Cook --------- Co-authored-by: Harrison Cook --- src/anemoi/training/data/datamodule.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 6d8e6da0..ba9ff0c3 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -116,10 +116,12 @@ def ds_train(self) -> NativeGridDataset: def ds_valid(self) -> NativeGridDataset: r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1)) - assert self.config.dataloader.training.end < self.config.dataloader.validation.start, ( - f"Training end date {self.config.dataloader.training.end} is not before" - f"validation start date {self.config.dataloader.validation.start}" - ) + if not self.config.dataloader.training.end < self.config.dataloader.validation.start: + LOGGER.warning( + "Training end date %s is not before validation start date %s.", + self.config.dataloader.training.end, + self.config.dataloader.validation.start, + ) return self._get_dataset( open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)), shuffle=False, From 0608f21abb78d425d965cd79ea040aaf3a66b5f8 Mon Sep 17 00:00:00 2001 From: Ana Prieto Nemesio <91897203+anaprietonem@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:25:17 +0100 Subject: [PATCH 6/8] warmup config for reproducibility of aifs v0.3 (#155) * warmup config for reproducibility of aifs v0.3 * add entry to changelog * update docs --- CHANGELOG.md | 2 ++ docs/user-guide/training.rst | 15 ++++++++++----- src/anemoi/training/config/training/default.yaml | 1 + src/anemoi/training/train/forecaster.py | 3 ++- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 88488369..f50ee7f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ Keep it human-readable, your future self will thank you! Fixed bug in power spectra plotting for the n320 resolution. ### Added +- Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155) + - Add reader groups to reduce CPU memory usage and increase dataloader throughput [#76](https://github.com/ecmwf/anemoi-training/pull/76) diff --git a/docs/user-guide/training.rst b/docs/user-guide/training.rst index 5be08222..588b34d9 100644 --- a/docs/user-guide/training.rst +++ b/docs/user-guide/training.rst @@ -188,10 +188,11 @@ level has a weighting less than 0.2). *************** Anemoi training uses the ``CosineLRScheduler`` from PyTorch as it's -learning rate scheduler. The user can configure the maximum learning -rate by setting ``config.training.lr.rate``. Note that this learning -rate is scaled by the number of GPUs where for the `data parallelism -`_. +learning rate scheduler. Docs for this scheduler can be found here +https://github.com/huggingface/pytorch-image-models/blob/main/timm/scheduler/cosine_lr.py +The user can configure the maximum learning rate by setting +``config.training.lr.rate``. Note that this learning rate is scaled by +the number of GPUs where for the `data parallelism `_. .. code:: yaml @@ -201,7 +202,11 @@ The user can also control the rate at which the learning rate decreases by setting the total number of iterations through ``config.training.lr.iterations`` and the minimum learning rate reached through ``config.training.lr.min``. Note that the minimum learning rate -is not scaled by the number of GPUs. +is not scaled by the number of GPUs. The user can also control the +warmup period by setting ``config.training.lr.warmup_t``. If the warmup +period is set to 0, the learning rate will start at the maximum learning +rate. If no warmup period is defined, a default warmup period of 1000 +iterations is used. ********* Rollout diff --git a/src/anemoi/training/config/training/default.yaml b/src/anemoi/training/config/training/default.yaml index 1c103827..af168ecc 100644 --- a/src/anemoi/training/config/training/default.yaml +++ b/src/anemoi/training/config/training/default.yaml @@ -83,6 +83,7 @@ lr: rate: 0.625e-4 #local_lr iterations: ${training.max_steps} # NOTE: When max_epochs < max_steps, scheduler will run for max_steps min: 3e-7 #Not scaled by #GPU + warmup_t: 1000 # Changes in per-gpu batch_size should come with a rescaling of the local_lr # in order to keep a constant global_lr diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 659c906c..a3abd59c 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -127,6 +127,7 @@ def __init__( * config.training.lr.rate / config.hardware.num_gpus_per_model ) + self.warmup_t = getattr(config.training.lr, "warmup_t", 1000) self.lr_iterations = config.training.lr.iterations self.lr_min = config.training.lr.min self.rollout = config.training.rollout.start @@ -638,6 +639,6 @@ def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]] optimizer, lr_min=self.lr_min, t_initial=self.lr_iterations, - warmup_t=1000, + warmup_t=self.warmup_t, ) return [optimizer], [{"scheduler": scheduler, "interval": "step"}] From 1abb65ea50a4fd00d85c60175af791c4f5f48b98 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 25 Nov 2024 17:05:07 +0000 Subject: [PATCH 7/8] hotfix: Expand scalar to prevent index out of bound error (#160) * Disable scalar indices if no variable scalar is used in val_metrics --- src/anemoi/training/losses/weightedloss.py | 2 ++ src/anemoi/training/train/forecaster.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anemoi/training/losses/weightedloss.py b/src/anemoi/training/losses/weightedloss.py index 0deccc9d..7ed97b21 100644 --- a/src/anemoi/training/losses/weightedloss.py +++ b/src/anemoi/training/losses/weightedloss.py @@ -107,6 +107,8 @@ def scale( if scalar_indices is None: return x * scalar + + scalar = scalar.expand_as(x) return x * scalar[scalar_indices] def scale_by_node_weights(self, x: torch.Tensor, squash: bool = True) -> torch.Tensor: diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index a3abd59c..f92050cf 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -547,7 +547,7 @@ def calculate_val_metrics( metrics[f"{metric_name}/{mkey}/{rollout_step + 1}"] = metric( y_pred_postprocessed[..., indices], y_postprocessed[..., indices], - scalar_indices=[..., indices], + scalar_indices=[..., indices] if -1 in metric.scalar else None, ) return metrics From fa430782331f9825eaf986cc5e80ed7a0dd83364 Mon Sep 17 00:00:00 2001 From: Sara Hahner <44293258+sahahner@users.noreply.github.com> Date: Tue, 26 Nov 2024 10:05:39 +0100 Subject: [PATCH 8/8] Callback PlotHistogram breaks if only one variable is specified. (#165) * enable plothistogram and plotspectrum for only one variable --- CHANGELOG.md | 3 ++- src/anemoi/training/diagnostics/plots.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f50ee7f0..80a4bcca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,8 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.3.0...HEAD) ### Fixed -Fixed bug in power spectra plotting for the n320 resolution. +- Fixed bug in power spectra plotting for the n320 resolution. +- Allow histogram and spectrum plot for one variable [#165](https://github.com/ecmwf/anemoi-training/pull/165) ### Added - Introduce variable to configure (Cosine Annealing) optimizer warm up [#155](https://github.com/ecmwf/anemoi-training/pull/155) diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 6d1d1c8c..93e2d324 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -166,6 +166,8 @@ def plot_power_spectrum( figsize = (n_plots_y * 4, n_plots_x * 3) fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) + if n_plots_x == 1: + ax = [ax] pc_lat, pc_lon = equirectangular_projection(latlons) @@ -293,6 +295,8 @@ def plot_histogram( figsize = (n_plots_y * 4, n_plots_x * 3) fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) + if n_plots_x == 1: + ax = [ax] for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()): yt = y_true[..., variable_idx].squeeze()