Skip to content

Commit

Permalink
stash: fsdp method
Browse files Browse the repository at this point in the history
  • Loading branch information
Secbone committed May 7, 2024
1 parent 19b2bbc commit 047b878
Show file tree
Hide file tree
Showing 15 changed files with 548 additions and 86 deletions.
24 changes: 24 additions & 0 deletions toad/nn/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@


def get_distributed_module(module, backend = None, **kwargs):
from .ddp import DDPModule
from .fsdp import FSDPModule

if backend == 'fsdp':
return FSDPModule(module, **kwargs)

return DDPModule(module, **kwargs)



def prepare(module, backend = None, **kwargs):
from ..module import ModuleMixin

if backend == 'fsdp':
from .fsdp import FSDP
module = FSDP(module, **kwargs)

from .ddp import DDP
module = DDP(module, **kwargs)

return ModuleMixin.mixin(module)
Empty file.
77 changes: 77 additions & 0 deletions toad/nn/distributed/accelerate/accelerator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from dataclasses import dataclass
from .strategy import Strategy, DDPStrategy, FSDPStrategy


@dataclass
class AcceleratorState:
rank: int = -1
size: int = 0
strategy: Strategy = None

@property
def initialized(self):
import torch

return torch.distributed.is_initialized()


class Accelerator:
def __init__(self, rank = None, size = None, strategy = "ddp"):
self.state = AcceleratorState(
rank = rank,
size = size,
strategy = strategy,
)


@property
def rank(self):
return self.state.rank

@property
def size(self):
return self.state.size

@property
def initialized(self):
return self.state.initialized

@property
def strategy(self):
return self.state.strategy

def setup(self):
import torch

if not self.initialized:
# choose a rpc type
rpc = 'nccl' if torch.distributed.is_nccl_available() else 'gloo'

torch.distributed.init_process_group(
rpc,
rank = self.rank,
world_size = self.size,
)


def prepare(self, module, loader, optimizer):
self.setup()

module = self.prepare_module(module)

return module, loader, optimizer


def prepare_model(self, module):
from ...module import ModuleMixin

if isinstance(self.strategy, FSDPStrategy):
from ..fsdp import FSDP
module = FSDP(module, **kwargs)

if isinstance(self.strategy, DDPStrategy):
from ..ddp import DDP
module = DDP(module, **kwargs)

return ModuleMixin.mixin(module)

17 changes: 17 additions & 0 deletions toad/nn/distributed/accelerate/strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dataclasses import dataclass


@dataclass
class Strategy:
method: str = None


@dataclass
class DDPStrategy(Strategy):
method: str = "ddp"


@dataclass
class FSDPStrategy(DDPStrategy):
method: str = "fsdp"
policy: str = None
8 changes: 8 additions & 0 deletions toad/nn/distributed/clusters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@


class Cluster:
def __init__(self):
pass

def spawn(self, func):
pass
4 changes: 4 additions & 0 deletions toad/nn/distributed/clusters/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@


class TorchCluster:
pass
17 changes: 17 additions & 0 deletions toad/nn/distributed/ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from torch.nn.parallel import DistributedDataParallel as DDP


class DDPModule(DDP):
"""distributed module class
"""
def fit(self, *args, **kwargs):
return self.module.fit(*args, **kwargs)

def save(self, *args, **kwargs):
return self.module.save(*args, **kwargs)

def load(self, *args, **kwargs):
return self.module.load(*args, **kwargs)

def log(self, *args, **kwargs):
return self.module.log(*args, **kwargs)
27 changes: 27 additions & 0 deletions toad/nn/distributed/distributor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dataclasses import dataclass


@dataclass
class DistributedState:
size: int = 0
backend: str = 'ddp'


class Distributor:
def __init__(self, size, backend = 'ddp', cluster = 'mp'):
self.state = DistributedState(
size = size,
backend = backend,
)


def init(self):
pass

def prepare(self, module, loader, optimizer, scheduler):
pass


def spawn(self, func, *args):

pass
153 changes: 142 additions & 11 deletions toad/nn/distributed/fsdp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ class TestModel(Module):
def __init__(self, in_feats, out_feats):
super().__init__()

self.linear = nn.Linear(in_feats, out_feats)
self.linear_1 = nn.Linear(in_feats, in_feats)
self.linear_2 = nn.Linear(in_feats, out_feats)

def forward(self, x):
x = self.linear(x)
x = self.linear_1(x)
x = self.linear_2(x)
return F.relu(x)

def fit_step(self, batch):
Expand All @@ -35,9 +37,9 @@ def worker(rank, world):

torch.manual_seed(0)

NUM_FEATS = 4096
NUM_FEATS = 1024*2
NUM_CLASSES = 1024
DATASET_SIZE = 10000
DATASET_SIZE = 1000


X = torch.rand(DATASET_SIZE, NUM_FEATS, dtype = torch.float)
Expand All @@ -54,46 +56,175 @@ def worker(rank, world):

model = TestModel(NUM_FEATS, NUM_CLASSES)
# print(next(model.linear.parameters()).shape)

model.load(f"data/origin_model_{rank}.pkl")

model.distributed(rpc = "gloo", rank = rank, world_size = world)

q_model = quantize(model)
# q_model.eval()

model.distributed(backend = "gloo", rank = rank, world_size = world)
peft_model = get_peft_model(q_model)

fdsp_model = FSDPModule(
model,
peft_model,
# use_orig_params = True,
# sync_module_states = True,
# auto_wrap_policy = my_auto_wrap_policy,
# policy = ModuleWrapPolicy([nn.Linear,]),
device_id=torch.device("cpu"),
)

# for p in fdsp_model.parameters():
# print(p, p.shape)

optimizer = optim.Adam(fdsp_model.parameters(), lr = 1e-3)

state_path = f"data/fsdp_model_{rank}.pkl"

fdsp_model.load(state_path)
# fdsp_model.load(state_path)

print('before fit:', fdsp_model(X[0]).sum())

# inputs = torch.rand(10, features_dim)
fdsp_model.fit(loader, epoch = 20, early_stopping = False)
# inputs = torch.rand(10, NUM_FEATS)
# fdsp_model.fit(loader, epoch = 20, early_stopping = False)
train(fdsp_model, loader, epoch = 20)

print('after fit:', fdsp_model(X[0]).sum())

print(fdsp_model)
# print(fdsp_model.flatten_sharded_optim_state_dict())
print("##### fsdp parameters:", get_parameters(fdsp_model).shape)
print("##### fsdp q model flatten:", fdsp_model.linear_2._handle.flat_param)
print("##### q_model parameters:", type(get_parameters(q_model.linear_2)))

# out = fdsp_model(inputs).sum()

# out.backward()

# print("~~~~~", out)

print(model)
# print(model)
model.save(f"data/origin_model_{rank}.pkl")

fdsp_model.save(state_path)



def train(model, loader, **kwargs):
from ..trainer import Trainer
trainer = Trainer(model, loader, early_stopping = False)


@trainer.fit_step
def fit_step(model, batch):
x, y = batch
y_hat = model(x)
# return F.cross_entropy(y_hat, y)
return F.mse_loss(y_hat, y)

trainer.train(**kwargs)

def get_parameters(model):
return next(model.parameters())


def quantize(model):
import copy
from quanto import Calibration, freeze, qfloat8, qint4, qint8, quantize

m_copy = copy.deepcopy(model)

quantize(m_copy, weights=qint4)
freeze(m_copy)

# m_copy = replace_hqq_linear(m_copy)

# m_copy = replace_qlinear(m_copy)

print("### q model linear_2 weight", m_copy.linear_2.weight)
print("### q model linear_2 parameters", get_parameters(m_copy.linear_2).dtype)

return m_copy



def replace_qlinear(model, skip_modules=["lm_head"], **kwargs):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
from ..quantize import QLinear

for name, module in model.named_children():
if name in skip_modules:
continue

if isinstance(module, torch.nn.Linear):
model._modules[name] = QLinear.qcreate(module, **kwargs)
model._modules[name].quantize()
model._modules[name].freeze()

return model



def replace_hqq_linear(model, skip_modules=["lm_head"], **kwargs):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
"""
from hqq.core.quantize import HQQLinear, HQQBackend, BaseQuantizeConfig

quant_config = BaseQuantizeConfig(
nbits=4,
group_size=64,
# quant_zero=True,
# quant_scale=True,
# offload_meta=True,
view_as_float=True
)

for name, module in model.named_children():
if name in skip_modules:
continue

if len(list(module.children())) > 0:
replace_linear(module, HQQLinear, quant_config, skip_modules, **kwargs)

if isinstance(module, torch.nn.Linear):
model._modules[name] = HQQLinear(module, quant_config, **kwargs)

HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP)
return model



def get_peft_model(model):
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType

peft_config = LoraConfig(
# task_type=TaskType.SEQ_2_SEQ_LM,
# task_type=TaskType.FEATURE_EXTRACTION,
target_modules = ['linear_1'],
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)

model = get_peft_model(model, peft_config)

return model



def test_fsdp_model():
import torch.multiprocessing as mp
Expand Down
Loading

0 comments on commit 047b878

Please sign in to comment.