Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: rohitrango <[email protected]>
  • Loading branch information
rohitrango committed Jul 10, 2024
1 parent 508198f commit 2254bf9
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def __init__(

time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim),
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)

self.input_blocks = nn.ModuleList(
Expand Down Expand Up @@ -505,24 +507,26 @@ def __init__(
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
)
if not use_spatial_transformer
else SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
)
),
ResBlock(
ch,
Expand Down Expand Up @@ -693,7 +697,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):
# handle asynchronous grad reduction
no_sync_func = None
if not forward_only and self.with_distributed_adam:
no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,)
no_sync_func = partial(
self._optimizer.no_sync,
greedy_grad_copy=self.megatron_amp_O2,
)

# pipeline schedules will get these from self.model.config
for module in self.get_module_list():
Expand Down Expand Up @@ -737,12 +744,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only):

def training_step(self, dataloader_iter):
"""
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Batch should be a list of microbatches and those microbatches should on CPU.
Microbatches are then moved to GPU during the pipeline.
The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
Our dataloaders produce a micro-batch and then we fetch
a number of microbatches depending on the global batch size and model parallel size
from the dataloader to produce a list of microbatches.
Batch should be a list of microbatches and those microbatches should on CPU.
Microbatches are then moved to GPU during the pipeline.
The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions.
"""
# we zero grads here because we also call backward in the apex fwd/bwd functions
self._optimizer.zero_grad()
Expand Down Expand Up @@ -786,20 +793,20 @@ def training_step(self, dataloader_iter):
return loss_mean

def backward(self, *args, **kwargs):
""" LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from apex.
No need to call it here.
"""LightningModule hook to do backward.
We want this to do nothing since we run backward in the fwd/bwd functions from apex.
No need to call it here.
"""
pass

def optimizer_zero_grad(self, *args, **kwargs):
""" LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""LightningModule hook to zero grad.
We want this to do nothing as we are zeroing grads during the training_step.
"""
pass

def _append_sequence_parallel_module_grads(self, module, grads):
""" Helper method for allreduce_sequence_parallel_gradients"""
"""Helper method for allreduce_sequence_parallel_gradients"""

for param in module.parameters():
sequence_parallel_param = getattr(param, 'sequence_parallel', False)
Expand All @@ -812,8 +819,8 @@ def _append_sequence_parallel_module_grads(self, module, grads):

def get_forward_output_and_loss_func(self):
def process_batch(batch):
""" Prepares the global batch for apex fwd/bwd functions.
Global batch is a list of micro batches.
"""Prepares the global batch for apex fwd/bwd functions.
Global batch is a list of micro batches.
"""
# noise_map, condition
batch[self.cfg.first_stage_key] = batch[self.cfg.first_stage_key].cuda(non_blocking=True)
Expand All @@ -823,7 +830,8 @@ def process_batch(batch):

# SD has more dedicated structure for encoding, so we enable autocasting here as well
with torch.cuda.amp.autocast(
self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype,
self.autocast_dtype in (torch.half, torch.bfloat16),
dtype=self.autocast_dtype,
):
x, c = self.model.get_input(batch, self.cfg.first_stage_key)

Expand Down Expand Up @@ -890,7 +898,7 @@ def validation_step(self, batch, batch_idx):
self.log_dict(val_loss_dict, prog_bar=False, logger=True, on_step=False, on_epoch=True)

def setup(self, stage=None):
""" PTL hook that is executed after DDP spawns.
"""PTL hook that is executed after DDP spawns.
We setup datasets here as megatron datasets require DDP to instantiate.
See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information.
Args:
Expand Down Expand Up @@ -944,7 +952,8 @@ def build_train_valid_test_datasets(self):

if self.cfg.first_stage_key.endswith("encoded"):
self._train_ds, self._validation_ds = build_train_valid_precached_datasets(
model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0),
model_cfg=self.cfg,
consumed_samples=self.compute_consumed_samples(0),
)
else:
self._train_ds, self._validation_ds = build_train_valid_datasets(
Expand Down Expand Up @@ -998,20 +1007,23 @@ def setup_test_data(self, cfg):
f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}'
)
self._test_dl = torch.utils.data.DataLoader(
self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True,
self._test_ds,
batch_size=self._micro_batch_size,
num_workers=cfg.num_workers,
pin_memory=True,
)

def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any:
""" PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device
When using pipeline parallelism, we need the global batch to remain on the CPU,
since the memory overhead will be too high when using a large number of microbatches.
Microbatches are transferred from CPU to GPU inside the pipeline.
"""
return batch

def _validate_trainer(self):
""" Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""Certain trainer configurations can break training.
Here we try to catch them and raise an error.
"""
if self.trainer.accumulate_grad_batches > 1:
raise ValueError(
Expand Down
Loading

0 comments on commit 2254bf9

Please sign in to comment.