diff --git a/animus/torch/engine.py b/animus/torch/engine.py index 6c74a3a..1e13b8d 100644 --- a/animus/torch/engine.py +++ b/animus/torch/engine.py @@ -5,8 +5,6 @@ import os from accelerate import Accelerator -from accelerate.state import DistributedType -import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -14,21 +12,9 @@ from animus.torch import IS_TORCH_XLA_AVAILABLE if IS_TORCH_XLA_AVAILABLE: - import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp -def _ddp_sum_reduce(tensor: torch.Tensor) -> torch.Tensor: - cloned = tensor.clone() - dist.all_reduce(cloned, dist.ReduceOp.SUM) - return cloned - - -def _ddp_mean_reduce(tensor: torch.Tensor, world_size: int) -> torch.Tensor: - reduced = _ddp_sum_reduce(tensor) / world_size - return reduced - - class Engine(Accelerator): def spawn(self, fn: Callable, *args, **kwargs): return fn(*args, **kwargs) @@ -39,27 +25,6 @@ def setup(self, local_rank: int, world_size: int): def cleanup(self): pass - def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict: - if self.state.distributed_type in [ - DistributedType.MULTI_CPU, - DistributedType.MULTI_GPU, - ]: - metrics = { - k: _ddp_mean_reduce( - torch.tensor(v, device=self.device), - world_size=self.state.num_processes, - ) - for k, v in metrics.items() - } - elif self.state.distributed_type == DistributedType.TPU: - metrics = { - k: xm.mesh_reduce( - k, v.item() if isinstance(v, torch.Tensor) else v, np.mean - ) - for k, v in metrics.items() - } - return metrics - class CPUEngine(Engine): def __init__(self, *args, **kwargs) -> None: @@ -124,16 +89,6 @@ def setup(self, local_rank: int, world_size: int): def cleanup(self): dist.destroy_process_group() - def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict: - metrics = { - k: _ddp_mean_reduce( - torch.tensor(v, device=self.device), - world_size=self.state.num_processes, - ) - for k, v in metrics.items() - } - return metrics - class XLAEngine(Engine): def __init__(self, *args, **kwargs): @@ -147,12 +102,5 @@ def spawn(self, fn: Callable, *args, **kwargs): def setup(self, local_rank: int, world_size: int): super().__init__(self, *self._args, **self._kwargs) - def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict: - metrics = { - k: xm.mesh_reduce(k, v.item() if isinstance(v, torch.Tensor) else v, np.mean) - for k, v in metrics.items() - } - return metrics - __all__ = [Engine, CPUEngine, GPUEngine, DPEngine, DDPEngine, XLAEngine] diff --git a/bin/check_torch_dl.sh b/bin/check_torch_dl.sh index 34e6ada..1e6ef84 100644 --- a/bin/check_torch_dl.sh +++ b/bin/check_torch_dl.sh @@ -4,7 +4,7 @@ set -eo pipefail -v # pip install animus -# pip install accelerate packaging torch torchvision tqdm +# pip install accelerate>=0.7.0 packaging torch torchvision tqdm python examples/torch_dl/torch_run.py --engine="cpu" diff --git a/examples/notebooks/XLA_jax.ipynb b/examples/notebooks/XLA_jax.ipynb index 1adb81f..c801b99 100644 --- a/examples/notebooks/XLA_jax.ipynb +++ b/examples/notebooks/XLA_jax.ipynb @@ -33,6 +33,7 @@ "outputs": [], "source": [ "from jax import __version__\n", + "\n", "print(__version__)" ] }, diff --git a/examples/notebooks/XLA_torch.ipynb b/examples/notebooks/XLA_torch.ipynb index 0234a6d..6ee13af 100644 --- a/examples/notebooks/XLA_torch.ipynb +++ b/examples/notebooks/XLA_torch.ipynb @@ -48,6 +48,7 @@ ], "source": [ "from torch import __version__\n", + "\n", "print(__version__)" ] }, diff --git a/examples/torch_dl/torch_run.py b/examples/torch_dl/torch_run.py index 2d78034..f7af5ed 100644 --- a/examples/torch_dl/torch_run.py +++ b/examples/torch_dl/torch_run.py @@ -96,8 +96,8 @@ def run_dataset(self) -> None: output = self.model(data) loss = self.criterion(output, target) pred = output.argmax(dim=1, keepdim=True) - total_loss += loss.sum().item() - total_accuracy += pred.eq(target.view_as(pred)).sum().item() + total_loss += loss.sum().detach() + total_accuracy += pred.eq(target.view_as(pred)).sum().detach() if self.is_train_dataset: self.engine.backward(loss) self.optimizer.step() @@ -105,7 +105,7 @@ def run_dataset(self) -> None: total_loss /= self.dataset_batch_step total_accuracy /= self.dataset_batch_step * self.batch_size self.dataset_metrics = {"loss": total_loss, "accuracy": total_accuracy} - self.dataset_metrics = self.engine.mean_reduce_ddp_metrics(self.dataset_metrics) + self.engine.reduce(self.dataset_metrics, reduction="mean") self.dataset_metrics = {k: float(v) for k, v in self.dataset_metrics.items()} def on_epoch_end(self, exp: "IExperiment") -> None: diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index b657612..a7a4a13 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -9,4 +9,5 @@ catalyst-codestyle==21.09.2 black==21.8b0 catalyst-sphinx-theme==1.2.0 tomlkit==0.7.2 -pre-commit==2.13.0 \ No newline at end of file +pre-commit==2.13.0 +click \ No newline at end of file diff --git a/requirements/requirements_torch_dl.txt b/requirements/requirements_torch_dl.txt index 9354e31..b825d4c 100644 --- a/requirements/requirements_torch_dl.txt +++ b/requirements/requirements_torch_dl.txt @@ -1,4 +1,4 @@ -accelerate +accelerate>=0.7.0 packaging torch torchvision diff --git a/setup.py b/setup.py index b10d4c7..973130a 100755 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( name=NAME, - version="0.0.2", + version="0.0.3", url=URL, download_url=URL, description=DESCRIPTION,