Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch: engine update #6

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions animus/torch/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,16 @@
import os

from accelerate import Accelerator
from accelerate.state import DistributedType
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from animus.torch import IS_TORCH_XLA_AVAILABLE

if IS_TORCH_XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp


def _ddp_sum_reduce(tensor: torch.Tensor) -> torch.Tensor:
cloned = tensor.clone()
dist.all_reduce(cloned, dist.ReduceOp.SUM)
return cloned


def _ddp_mean_reduce(tensor: torch.Tensor, world_size: int) -> torch.Tensor:
reduced = _ddp_sum_reduce(tensor) / world_size
return reduced


class Engine(Accelerator):
def spawn(self, fn: Callable, *args, **kwargs):
return fn(*args, **kwargs)
Expand All @@ -39,27 +25,6 @@ def setup(self, local_rank: int, world_size: int):
def cleanup(self):
pass

def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict:
if self.state.distributed_type in [
DistributedType.MULTI_CPU,
DistributedType.MULTI_GPU,
]:
metrics = {
k: _ddp_mean_reduce(
torch.tensor(v, device=self.device),
world_size=self.state.num_processes,
)
for k, v in metrics.items()
}
elif self.state.distributed_type == DistributedType.TPU:
metrics = {
k: xm.mesh_reduce(
k, v.item() if isinstance(v, torch.Tensor) else v, np.mean
)
for k, v in metrics.items()
}
return metrics


class CPUEngine(Engine):
def __init__(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -124,16 +89,6 @@ def setup(self, local_rank: int, world_size: int):
def cleanup(self):
dist.destroy_process_group()

def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict:
metrics = {
k: _ddp_mean_reduce(
torch.tensor(v, device=self.device),
world_size=self.state.num_processes,
)
for k, v in metrics.items()
}
return metrics


class XLAEngine(Engine):
def __init__(self, *args, **kwargs):
Expand All @@ -147,12 +102,5 @@ def spawn(self, fn: Callable, *args, **kwargs):
def setup(self, local_rank: int, world_size: int):
super().__init__(self, *self._args, **self._kwargs)

def mean_reduce_ddp_metrics(self, metrics: Dict) -> Dict:
metrics = {
k: xm.mesh_reduce(k, v.item() if isinstance(v, torch.Tensor) else v, np.mean)
for k, v in metrics.items()
}
return metrics


__all__ = [Engine, CPUEngine, GPUEngine, DPEngine, DDPEngine, XLAEngine]
2 changes: 1 addition & 1 deletion bin/check_torch_dl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
set -eo pipefail -v

# pip install animus
# pip install accelerate packaging torch torchvision tqdm
# pip install accelerate>=0.7.0 packaging torch torchvision tqdm

python examples/torch_dl/torch_run.py --engine="cpu"

Expand Down
1 change: 1 addition & 0 deletions examples/notebooks/XLA_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"outputs": [],
"source": [
"from jax import __version__\n",
"\n",
"print(__version__)"
]
},
Expand Down
1 change: 1 addition & 0 deletions examples/notebooks/XLA_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
],
"source": [
"from torch import __version__\n",
"\n",
"print(__version__)"
]
},
Expand Down
6 changes: 3 additions & 3 deletions examples/torch_dl/torch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ def run_dataset(self) -> None:
output = self.model(data)
loss = self.criterion(output, target)
pred = output.argmax(dim=1, keepdim=True)
total_loss += loss.sum().item()
total_accuracy += pred.eq(target.view_as(pred)).sum().item()
total_loss += loss.sum().detach()
total_accuracy += pred.eq(target.view_as(pred)).sum().detach()
if self.is_train_dataset:
self.engine.backward(loss)
self.optimizer.step()

total_loss /= self.dataset_batch_step
total_accuracy /= self.dataset_batch_step * self.batch_size
self.dataset_metrics = {"loss": total_loss, "accuracy": total_accuracy}
self.dataset_metrics = self.engine.mean_reduce_ddp_metrics(self.dataset_metrics)
self.engine.reduce(self.dataset_metrics, reduction="mean")
self.dataset_metrics = {k: float(v) for k, v in self.dataset_metrics.items()}

def on_epoch_end(self, exp: "IExperiment") -> None:
Expand Down
3 changes: 2 additions & 1 deletion requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ catalyst-codestyle==21.09.2
black==21.8b0
catalyst-sphinx-theme==1.2.0
tomlkit==0.7.2
pre-commit==2.13.0
pre-commit==2.13.0
click
2 changes: 1 addition & 1 deletion requirements/requirements_torch_dl.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
accelerate
accelerate>=0.7.0
packaging
torch
torchvision
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

setup(
name=NAME,
version="0.0.2",
version="0.0.3",
url=URL,
download_url=URL,
description=DESCRIPTION,
Expand Down