diff --git a/toad/nn/distributed/__init__.py b/toad/nn/distributed/__init__.py index e69de29..48d6b15 100644 --- a/toad/nn/distributed/__init__.py +++ b/toad/nn/distributed/__init__.py @@ -0,0 +1,24 @@ + + +def get_distributed_module(module, backend = None, **kwargs): + from .ddp import DDPModule + from .fsdp import FSDPModule + + if backend == 'fsdp': + return FSDPModule(module, **kwargs) + + return DDPModule(module, **kwargs) + + + +def prepare(module, backend = None, **kwargs): + from ..module import ModuleMixin + + if backend == 'fsdp': + from .fsdp import FSDP + module = FSDP(module, **kwargs) + + from .ddp import DDP + module = DDP(module, **kwargs) + + return ModuleMixin.mixin(module) diff --git a/toad/nn/distributed/accelerate/__init__.py b/toad/nn/distributed/accelerate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/toad/nn/distributed/accelerate/accelerator.py b/toad/nn/distributed/accelerate/accelerator.py new file mode 100644 index 0000000..c23ac4b --- /dev/null +++ b/toad/nn/distributed/accelerate/accelerator.py @@ -0,0 +1,77 @@ +from dataclasses import dataclass +from .strategy import Strategy, DDPStrategy, FSDPStrategy + + +@dataclass +class AcceleratorState: + rank: int = -1 + size: int = 0 + strategy: Strategy = None + + @property + def initialized(self): + import torch + + return torch.distributed.is_initialized() + + +class Accelerator: + def __init__(self, rank = None, size = None, strategy = "ddp"): + self.state = AcceleratorState( + rank = rank, + size = size, + strategy = strategy, + ) + + + @property + def rank(self): + return self.state.rank + + @property + def size(self): + return self.state.size + + @property + def initialized(self): + return self.state.initialized + + @property + def strategy(self): + return self.state.strategy + + def setup(self): + import torch + + if not self.initialized: + # choose a rpc type + rpc = 'nccl' if torch.distributed.is_nccl_available() else 'gloo' + + torch.distributed.init_process_group( + rpc, + rank = self.rank, + world_size = self.size, + ) + + + def prepare(self, module, loader, optimizer): + self.setup() + + module = self.prepare_module(module) + + return module, loader, optimizer + + + def prepare_model(self, module): + from ...module import ModuleMixin + + if isinstance(self.strategy, FSDPStrategy): + from ..fsdp import FSDP + module = FSDP(module, **kwargs) + + if isinstance(self.strategy, DDPStrategy): + from ..ddp import DDP + module = DDP(module, **kwargs) + + return ModuleMixin.mixin(module) + diff --git a/toad/nn/distributed/accelerate/strategy.py b/toad/nn/distributed/accelerate/strategy.py new file mode 100644 index 0000000..151d0f7 --- /dev/null +++ b/toad/nn/distributed/accelerate/strategy.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + + +@dataclass +class Strategy: + method: str = None + + +@dataclass +class DDPStrategy(Strategy): + method: str = "ddp" + + +@dataclass +class FSDPStrategy(DDPStrategy): + method: str = "fsdp" + policy: str = None diff --git a/toad/nn/distributed/clusters/base.py b/toad/nn/distributed/clusters/base.py new file mode 100644 index 0000000..e1a2d20 --- /dev/null +++ b/toad/nn/distributed/clusters/base.py @@ -0,0 +1,8 @@ + + +class Cluster: + def __init__(self): + pass + + def spawn(self, func): + pass diff --git a/toad/nn/distributed/clusters/torch.py b/toad/nn/distributed/clusters/torch.py new file mode 100644 index 0000000..c30e847 --- /dev/null +++ b/toad/nn/distributed/clusters/torch.py @@ -0,0 +1,4 @@ + + +class TorchCluster: + pass diff --git a/toad/nn/distributed/ddp.py b/toad/nn/distributed/ddp.py new file mode 100644 index 0000000..3b7b4cd --- /dev/null +++ b/toad/nn/distributed/ddp.py @@ -0,0 +1,17 @@ +from torch.nn.parallel import DistributedDataParallel as DDP + + +class DDPModule(DDP): + """distributed module class + """ + def fit(self, *args, **kwargs): + return self.module.fit(*args, **kwargs) + + def save(self, *args, **kwargs): + return self.module.save(*args, **kwargs) + + def load(self, *args, **kwargs): + return self.module.load(*args, **kwargs) + + def log(self, *args, **kwargs): + return self.module.log(*args, **kwargs) diff --git a/toad/nn/distributed/distributor.py b/toad/nn/distributed/distributor.py new file mode 100644 index 0000000..c4e664a --- /dev/null +++ b/toad/nn/distributed/distributor.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass + + +@dataclass +class DistributedState: + size: int = 0 + backend: str = 'ddp' + + +class Distributor: + def __init__(self, size, backend = 'ddp', cluster = 'mp'): + self.state = DistributedState( + size = size, + backend = backend, + ) + + + def init(self): + pass + + def prepare(self, module, loader, optimizer, scheduler): + pass + + + def spawn(self, func, *args): + + pass diff --git a/toad/nn/distributed/fsdp_test.py b/toad/nn/distributed/fsdp_test.py index 8ea7dec..9f45b7c 100644 --- a/toad/nn/distributed/fsdp_test.py +++ b/toad/nn/distributed/fsdp_test.py @@ -15,10 +15,12 @@ class TestModel(Module): def __init__(self, in_feats, out_feats): super().__init__() - self.linear = nn.Linear(in_feats, out_feats) + self.linear_1 = nn.Linear(in_feats, in_feats) + self.linear_2 = nn.Linear(in_feats, out_feats) def forward(self, x): - x = self.linear(x) + x = self.linear_1(x) + x = self.linear_2(x) return F.relu(x) def fit_step(self, batch): @@ -35,9 +37,9 @@ def worker(rank, world): torch.manual_seed(0) - NUM_FEATS = 4096 + NUM_FEATS = 1024*2 NUM_CLASSES = 1024 - DATASET_SIZE = 10000 + DATASET_SIZE = 1000 X = torch.rand(DATASET_SIZE, NUM_FEATS, dtype = torch.float) @@ -54,32 +56,46 @@ def worker(rank, world): model = TestModel(NUM_FEATS, NUM_CLASSES) # print(next(model.linear.parameters()).shape) + + model.load(f"data/origin_model_{rank}.pkl") + + model.distributed(rpc = "gloo", rank = rank, world_size = world) + + q_model = quantize(model) + # q_model.eval() - model.distributed(backend = "gloo", rank = rank, world_size = world) + peft_model = get_peft_model(q_model) fdsp_model = FSDPModule( - model, + peft_model, + # use_orig_params = True, # sync_module_states = True, # auto_wrap_policy = my_auto_wrap_policy, # policy = ModuleWrapPolicy([nn.Linear,]), device_id=torch.device("cpu"), ) + # for p in fdsp_model.parameters(): + # print(p, p.shape) + optimizer = optim.Adam(fdsp_model.parameters(), lr = 1e-3) state_path = f"data/fsdp_model_{rank}.pkl" - fdsp_model.load(state_path) + # fdsp_model.load(state_path) print('before fit:', fdsp_model(X[0]).sum()) - # inputs = torch.rand(10, features_dim) - fdsp_model.fit(loader, epoch = 20, early_stopping = False) + # inputs = torch.rand(10, NUM_FEATS) + # fdsp_model.fit(loader, epoch = 20, early_stopping = False) + train(fdsp_model, loader, epoch = 20) print('after fit:', fdsp_model(X[0]).sum()) print(fdsp_model) - # print(fdsp_model.flatten_sharded_optim_state_dict()) + print("##### fsdp parameters:", get_parameters(fdsp_model).shape) + print("##### fsdp q model flatten:", fdsp_model.linear_2._handle.flat_param) + print("##### q_model parameters:", type(get_parameters(q_model.linear_2))) # out = fdsp_model(inputs).sum() @@ -87,13 +103,128 @@ def worker(rank, world): # print("~~~~~", out) - print(model) + # print(model) model.save(f"data/origin_model_{rank}.pkl") fdsp_model.save(state_path) +def train(model, loader, **kwargs): + from ..trainer import Trainer + trainer = Trainer(model, loader, early_stopping = False) + + + @trainer.fit_step + def fit_step(model, batch): + x, y = batch + y_hat = model(x) + # return F.cross_entropy(y_hat, y) + return F.mse_loss(y_hat, y) + + trainer.train(**kwargs) + +def get_parameters(model): + return next(model.parameters()) + + +def quantize(model): + import copy + from quanto import Calibration, freeze, qfloat8, qint4, qint8, quantize + + m_copy = copy.deepcopy(model) + + quantize(m_copy, weights=qint4) + freeze(m_copy) + + # m_copy = replace_hqq_linear(m_copy) + + # m_copy = replace_qlinear(m_copy) + + print("### q model linear_2 weight", m_copy.linear_2.weight) + print("### q model linear_2 parameters", get_parameters(m_copy.linear_2).dtype) + + return m_copy + + + +def replace_qlinear(model, skip_modules=["lm_head"], **kwargs): + """ + Replace linear modules with a new Linear module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + """ + from ..quantize import QLinear + + for name, module in model.named_children(): + if name in skip_modules: + continue + + if isinstance(module, torch.nn.Linear): + model._modules[name] = QLinear.qcreate(module, **kwargs) + model._modules[name].quantize() + model._modules[name].freeze() + + return model + + + +def replace_hqq_linear(model, skip_modules=["lm_head"], **kwargs): + """ + Replace linear modules with a new Linear module. + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + skip_modules (`List[str]`, *optional*, defaults to `lm_head`): + List of modules names not to convert. Defaults to `lm_head`. + """ + from hqq.core.quantize import HQQLinear, HQQBackend, BaseQuantizeConfig + + quant_config = BaseQuantizeConfig( + nbits=4, + group_size=64, + # quant_zero=True, + # quant_scale=True, + # offload_meta=True, + view_as_float=True + ) + + for name, module in model.named_children(): + if name in skip_modules: + continue + + if len(list(module.children())) > 0: + replace_linear(module, HQQLinear, quant_config, skip_modules, **kwargs) + + if isinstance(module, torch.nn.Linear): + model._modules[name] = HQQLinear(module, quant_config, **kwargs) + + HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP) + return model + + + +def get_peft_model(model): + from peft import get_peft_config, get_peft_model, LoraConfig, TaskType + + peft_config = LoraConfig( + # task_type=TaskType.SEQ_2_SEQ_LM, + # task_type=TaskType.FEATURE_EXTRACTION, + target_modules = ['linear_1'], + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) + + model = get_peft_model(model, peft_config) + + return model + + def test_fsdp_model(): import torch.multiprocessing as mp diff --git a/toad/nn/lora/__init__.py b/toad/nn/lora/__init__.py new file mode 100644 index 0000000..31454a5 --- /dev/null +++ b/toad/nn/lora/__init__.py @@ -0,0 +1,16 @@ +def get_lora_model(model, config = None): + from peft import get_peft_config, get_peft_model, LoraConfig, TaskType + + peft_config = LoraConfig( + # task_type=TaskType.SEQ_2_SEQ_LM, + # task_type=TaskType.FEATURE_EXTRACTION, + target_modules = ['linear_1'], + inference_mode=False, + r=8, + lora_alpha=32, + lora_dropout=0.1, + ) + + model = get_peft_model(model, peft_config) + + return model diff --git a/toad/nn/module.py b/toad/nn/module.py index 5cf12ce..dd504f6 100644 --- a/toad/nn/module.py +++ b/toad/nn/module.py @@ -1,14 +1,16 @@ +from abc import ABC + import torch import numpy as np from torch import nn, optim -from torch.nn.parallel import DistributedDataParallel + from .trainer.history import get_current_history from ..utils.progress import Progress -class Module(nn.Module): +class ModuleMixin(ABC): """base module for every model Examples: @@ -43,12 +45,6 @@ class Module(nn.Module): ... model.fit(train_loader) """ - def __init__(self): - """define model struct - """ - super().__init__() - - @property def device(self): """device of model @@ -133,35 +129,91 @@ def log(self, key, value): return return history.log(key, value) + + + def qunatize(self, **kwargs): + from .quantize import quantize, freeze + quantize(self, **kwargs) + freeze(self) - def distributed(self, backend = None, **kwargs): + return self + + + def lora(self, **kwargs): + from .lora import get_lora_model + return get_lora_model(self, **kwargs) + + + def distributed(self, backend = None, rpc = None, **kwargs): """get distributed model """ if not torch.distributed.is_initialized(): - if backend is None: - # choose a backend - backend = 'nccl' if torch.distributed.is_nccl_available() else 'gloo' + if rpc is None: + # choose a rpc type + rpc = 'nccl' if torch.distributed.is_nccl_available() else 'gloo' - torch.distributed.init_process_group(backend, **kwargs) + torch.distributed.init_process_group(rpc, **kwargs) - return DistModule(self) + + from .distributed import get_distributed_module + return get_distributed_module(self, backend = backend) + + + @classmethod + def mixin(cls, module): + import types + + for name in cls.__dict__: + if name.startswith('__') and name.endswith('__') \ + or not type(cls.__dict__[name])==types.FunctionType \ + or name in module.__dict__: + + continue + + module.__dict__[name] = types.MethodType(cls.__dict__[name], module) -class DistModule(DistributedDataParallel): - """distributed module class +class Module(ModuleMixin, nn.Module): + """base module for every model + + Examples: + >>> from toad.nn import Module + ... from torch import nn + ... + ... class Net(Module): + ... def __init__(self, inputs, hidden, outputs): + ... super().__init__() + ... self.model = nn.Sequential( + ... nn.Linear(inputs, hidden), + ... nn.ReLU(), + ... nn.Linear(hidden, outputs), + ... nn.Sigmoid(), + ... ) + ... + ... def forward(self, x): + ... return self.model(x) + ... + ... def fit_step(self, batch): + ... x, y = batch + ... y_hat = self(x) + ... + ... # log into history + ... self.log('y', y) + ... self.log('y_hat', y_hat) + ... + ... return nn.functional.mse_loss(y_hat, y) + ... + ... model = Net(10, 4, 1) + ... + ... model.fit(train_loader) + """ - def fit(self, *args, **kwargs): - return self.module.fit(*args, **kwargs) - - def save(self, *args, **kwargs): - return self.module.save(*args, **kwargs) - - def load(self, *args, **kwargs): - return self.module.load(*args, **kwargs) - - def log(self, *args, **kwargs): - return self.module.log(*args, **kwargs) + pass + + + + diff --git a/toad/nn/quantize/__init__.py b/toad/nn/quantize/__init__.py new file mode 100644 index 0000000..7696ae4 --- /dev/null +++ b/toad/nn/quantize/__init__.py @@ -0,0 +1,9 @@ +# from .qlinear import QLinear + + + +from quanto import quantize, freeze, qfloat8, qint4, qint8 + + + + diff --git a/toad/nn/quantize/qlinear.py b/toad/nn/quantize/qlinear.py new file mode 100644 index 0000000..44016be --- /dev/null +++ b/toad/nn/quantize/qlinear.py @@ -0,0 +1,40 @@ +import torch +import torch.nn.functional as F +# from quanto import QModuleMixin, register_qmodule, quantize_activation + + + +class QLinear(torch.nn.Linear): + @classmethod + def qcreate( + cls, module, weights = None, activations = None, optimizer = None + ): + return cls( + module.in_features, + module.out_features, + module.bias is not None, + dtype=module.weight.dtype, + device=module.weight.device, + ) + + + def quantize(self): + qweight = torch.randint(256, size=self.weight.shape, dtype=torch.uint8) + self.qweight = qweight + return self.qweight + + + def freeze(self): + self.weight = torch.nn.Parameter(self.pack_weight(self.qweight)) + + def pack_weight(self, weight): + return weight.view(torch.float32) + + def unpack_weight(self, weight): + return weight.view(torch.uint8) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + qweight = self.unpack_weight(self.weight).to(torch.float32) + print("forward qweight shape", qweight.shape) + return F.linear(input, qweight, bias=self.bias) + diff --git a/toad/nn/trainer/trainer.py b/toad/nn/trainer/trainer.py index 5f582ad..33dd872 100644 --- a/toad/nn/trainer/trainer.py +++ b/toad/nn/trainer/trainer.py @@ -1,3 +1,6 @@ +from typing import Callable +from dataclasses import dataclass + import torch import numpy as np from torch import optim @@ -5,6 +8,7 @@ from .history import History from .callback import callback as Callback from .event import Event +from ..distributed.distributor import Distributor from ...utils.progress import Progress @@ -15,6 +19,28 @@ TRAINER_RUNNING = "running" TRAINER_TERMINATED = "terminated" + +@dataclass +class TrainerStatus: + UNSET: str = "unset" + INIT: str = "init" + RUNNING: str = "running" + TERMINATED: str = "terminated" + + +@dataclass +class TrainerState: + module: torch.nn.Module = None + loader: torch.utils.data.DataLoader = None + optimizer: torch.optim.Optimizer = None + scheduler: torch.optim.lr_scheduler.LRScheduler = None + step: Callable = None + histories: [History] = None + status: str = TrainerStatus.UNSET + distributor: Distributor = None + + + class Trainer(Event): """trainer for training models """ @@ -33,17 +59,10 @@ def __init__(self, model, loader = None, optimizer = None, loss = None, keep_his """ super().__init__() - self.set_model(model) - - self.loader = loader - - self._mode = STANDALONE_MODE - self._state = TRAINER_INIT + step = self._get_step(model) if optimizer is None: optimizer = optim.Adam(model.parameters(), lr = 1e-3) - - self.optimizer = optimizer self.loss = loss @@ -56,35 +75,47 @@ def __init__(self, model, loader = None, optimizer = None, loss = None, keep_his self.register("earlystop:check", early_stopping) from collections import deque - self.history = deque(maxlen = keep_history) + histories = deque(maxlen = keep_history) + + self.state = TrainerState( + module = model, + loader = loader, + optimizer = optimizer, + scheduler = None, + step = step, + histories = histories, + status = TrainerStatus.INIT, + ) @property - def state(self): - return self._state + def status(self): + return self.state.status + + + @property + def histories(self): + return self.state.histories def terminate(self): - self._state = TRAINER_TERMINATED + self.state.status = TrainerStatus.TERMINATED def run(self): - self._state = TRAINER_RUNNING + self.state.status = TrainerStatus.RUNNING - - def set_model(self, model): - """setup model - """ + def _get_step(self, module): from ..module import Module - if isinstance(model, Module): - self.fit_step(model.__class__.fit_step) + if isinstance(module, Module): + return module.__class__.fit_step - self.model = model + return None def fit_step(self, func): - self._step = func + self.state.step = func return func @@ -98,13 +129,17 @@ def distributed(self, address = None, workers = 4, gpu = False): workers (int): compute task's resource gpu (Booleans): whether use GPU, "True" or "False" ''' - self._mode = DISTRIBUTED_MODE - self._workers = workers - self._gpu = gpu + # self._mode = DISTRIBUTED_MODE + # self._workers = workers + # self._gpu = gpu - import ray - if not ray.is_initialized(): - ray.init(address = address) + # import ray + # if not ray.is_initialized(): + # ray.init(address = address) + + # TODO: init distributor + distributor = Distributor(size = workers) + self.state.distributor = distributor def _train(self, config: dict): @@ -127,10 +162,10 @@ def _train(self, config: dict): for c in callback: self.register("epoch:end", c) - loader = self.loader - model = self.model + loader = self.state.loader + model = self.state.module - if self._mode == DISTRIBUTED_MODE: + if self.state.distributor is not None: import ray.train as train # TODO prepare loader and model loader = train.torch.prepare_data_loader(loader) @@ -161,9 +196,9 @@ def train(self, loader = None, epoch = 10, **kwargs): Module: the model with best performents """ if loader is not None: - self.loader = loader + self.state.loader = loader - if self.loader is None: + if self.state.loader is None: raise ValueError("loader is not set, please set a loader for trainning!") config = { @@ -172,7 +207,7 @@ def train(self, loader = None, epoch = 10, **kwargs): } # distrubution trainning - if self._mode == DISTRIBUTED_MODE: + if self.state.distributor is not None: from ray.air import ScalingConfig from ray.train.torch import TorchTrainer @@ -190,7 +225,7 @@ def train(self, loader = None, epoch = 10, **kwargs): else: self._train(config = config) - return self.model + return self.state.module @torch.no_grad() @@ -210,7 +245,10 @@ def evaluate(self, loader, callback = None): history = History() - self.model.eval() + model = self.state.module + step = self.state.step + + model.eval() history.start() @@ -218,12 +256,12 @@ def evaluate(self, loader, callback = None): for i, batch in enumerate(p, start = 1): # step fit if self.loss is None: - l = self._step(self.model, batch) + l = step(model, batch) else: - l = self._step(self.model, batch, loss=self.loss) + l = step(model, batch, loss=self.loss) # log loss - self.model.log('loss', l) + history.log('loss', l) loss += (l.item() - loss) / i p.suffix = 'loss:{:.4f}'.format(loss) @@ -235,7 +273,7 @@ def evaluate(self, loader, callback = None): epoch = None, history = history, trainer = self, - model = self.model, + model = model, ) return history @@ -245,6 +283,8 @@ def evaluate(self, loader, callback = None): def train_loop(trainer, model, loader, epoch = 10, start = 0, backward_rounds = 1): # init progress bar p = Progress(loader) + step = trainer.state.step + optimizer = trainer.state.optimizer for ep in range(start, epoch): # set model to train mode @@ -254,7 +294,7 @@ def train_loop(trainer, model, loader, epoch = 10, start = 0, backward_rounds = # setup a new history for model in each epoch history = History() - trainer.history.append(history) + trainer.state.histories.append(history) # setup callback params params = { @@ -277,18 +317,18 @@ def train_loop(trainer, model, loader, epoch = 10, start = 0, backward_rounds = # step fit if trainer.loss is None: - l = trainer._step(model, batch) + l = step(model, batch) else: - l = trainer._step(model, batch, loss=trainer.loss) + l = step(model, batch, loss=trainer.loss) # log loss history.log('loss', l) backward_loss = l + backward_loss if i % backward_rounds == 0 or i == len(p): - trainer.optimizer.zero_grad() + optimizer.zero_grad() backward_loss.backward() - trainer.optimizer.step() + optimizer.step() # reset backward loss backward_loss = 0. @@ -306,5 +346,5 @@ def train_loop(trainer, model, loader, epoch = 10, start = 0, backward_rounds = trainer.emit("earlystop:check", **params) # check if trainer need terminate - if trainer.state == TRAINER_TERMINATED: + if trainer.status == TrainerStatus.TERMINATED: break diff --git a/toad/nn/trainer/trainer_test.py b/toad/nn/trainer/trainer_test.py index ec2e3a8..58aae4b 100644 --- a/toad/nn/trainer/trainer_test.py +++ b/toad/nn/trainer/trainer_test.py @@ -45,7 +45,7 @@ def test_trainer(): model = TestModel(NUM_FEATS, NUM_CLASSES) trainer = Trainer(model, loader) trainer.train(epoch = 2) - assert len(trainer.history) == 2 + assert len(trainer.histories) == 2 def test_trainer_early_stopping(): @@ -57,7 +57,7 @@ def scoring(history): trainer = Trainer(model, loader, early_stopping = scoring) trainer.train(epoch = 200) - assert len(trainer.history) == 4 + assert len(trainer.histories) == 4 def test_trainer_fit_step(): @@ -123,7 +123,7 @@ def test_trainer_loss(): model = TestModel2(NUM_FEATS, NUM_CLASSES) trainer = Trainer(model, loader, loss = F.cross_entropy) trainer.train(epoch = 2) - assert len(trainer.history) == 2 + assert len(trainer.histories) == 2 # def test_trainer_distributed():