Skip to content

Commit

Permalink
stash: base cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
Secbone committed Sep 3, 2024
1 parent e1b44b6 commit 44dd042
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 21 deletions.
13 changes: 1 addition & 12 deletions toad/nn/distributed/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,7 @@ def prepare_module(self, module):
from ...module import ModuleMixin

if isinstance(self.strategy, FSDPStrategy):
from ..fsdp import FSDP
from torch.distributed.fsdp import CPUOffload

module = FSDP(
module,
sync_module_states = True,
auto_wrap_policy = self.strategy.policy,
device_id = self.device,
param_init_fn = self.strategy.init_fn(rank = self.rank, device = self.device),
cpu_offload = CPUOffload(offload_params = True) if self.device.type == 'cuda' else None,
limit_all_gathers = True,
)
module = self.strategy.prepare_module(module, self.rank)

elif isinstance(self.strategy, DDPStrategy):
from ..ddp import DDP
Expand Down
32 changes: 32 additions & 0 deletions toad/nn/distributed/accelerate/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ def save(self, module):
pass


def prepare_module(self, module, rank):
pass


def prepare_optimizer(self, optimizer, module):
opt_cls = type(optimizer)
params = optimizer.param_groups[0]
params.pop("params")

return opt_cls(module.parameters(), **params)



@dataclass
class DDPStrategy(Strategy):
method: str = "ddp"
Expand Down Expand Up @@ -43,3 +56,22 @@ def fn(module):
return fn

return None


def prepare_module(self, module, rank):
from ..fsdp import FSDP
from torch.distributed.fsdp import CPUOffload

module = FSDP(
module,
sync_module_states = True if self.device.type == 'cuda' else None,
auto_wrap_policy = self.policy,
device_id = self.device,
param_init_fn = self.init_fn(rank = rank, device = self.device),
cpu_offload = CPUOffload(offload_params = True) if self.device.type == 'cuda' else None,
limit_all_gathers = True,
)

return module


5 changes: 4 additions & 1 deletion toad/nn/distributed/clusters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@


def get_cluster(backend = "mp"):
if backend is "mp":
if backend == "mp":
from .torch import TorchCluster
return TorchCluster()
elif backend == "base":
from .base import Cluster
return Cluster()
else:
return None
15 changes: 13 additions & 2 deletions toad/nn/distributed/clusters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,18 @@ class Cluster:
def __init__(self):
pass

def spawn(self, func):
pass
def spawn(self, func, size, **kwargs):
# TODO: use python multiprocess
import torch.multiprocessing as mp

from .executor import Executor, ExecutorContext
context = ExecutorContext(
size = size,
func = func,
params = kwargs,
)

executor = Executor(context)

mp.spawn(executor, nprocs = size, join = True)

28 changes: 25 additions & 3 deletions toad/nn/distributed/clusters/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
class ExecutorContext:
rank: int = -1
size: int = 0
trainer: Trainer = None
func: Callable = None
strategy: Strategy = None
accelerator: Accelerator = None
params: dict = None


class Executor:
Expand All @@ -23,6 +21,30 @@ def __init__(self, context: ExecutorContext):
def rank(self):
return self.context.rank


def run(self, *args, **kwargs):
import torch
torch.manual_seed(self.rank)

res = self.context.func(self.rank, **self.context.params)

return res


def __call__(self, rank, *args, **kwargs):
self.context.rank = rank
return self.run(*args, **kwargs)



@dataclass
class FSDPExecutorContext(ExecutorContext):
trainer: Trainer = None
strategy: Strategy = None
accelerator: Accelerator = None


class FSDPExecutor(Executor):
@property
def accelerator(self):
return self.context.accelerator
Expand Down
6 changes: 3 additions & 3 deletions toad/nn/distributed/clusters/torch.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .base import Cluster
from .executor import Executor, ExecutorContext
from .executor import FSDPExecutor, FSDPExecutorContext


def _wrap(rank, size, func, *args, **kwargs):
from toad.nn.distributed.accelerator import Accelerator
accelerator = Accelerator(rank = rank, size = size)


class TorchExecutor(Executor):
class TorchExecutor(FSDPExecutor):
def __call__(self, rank, *args, **kwargs):
self.context.rank = rank
self.run(*args, **kwargs)
Expand All @@ -17,7 +17,7 @@ class TorchCluster(Cluster):
def spawn(self, func, size, trainer, strategy = None, **kwargs):
import torch.multiprocessing as mp

context = ExecutorContext(
context = FSDPExecutorContext(
trainer = trainer,
size = size,
func = func,
Expand Down

0 comments on commit 44dd042

Please sign in to comment.