Skip to content

Commit

Permalink
Enabling the computation of validation loss and other metrics when us…
Browse files Browse the repository at this point in the history
…ing sequence parallelism (mosaicml#3183)

* fix a bug in eval with seq parallelism

* print debug values

* ..

* ..

* ..

* potentially fixing the eval bug

* minor

* minor

* minor

* ..

* fixing is_sampler_distributed

* removing redundant condition
  • Loading branch information
ShashankMosaicML authored Apr 10, 2024
1 parent a471278 commit 4e54004
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(
world_size = dist.get_world_size()
# Check for Distributed Sampler if not using IterableDataset on more than 1 GPU
if world_size > 1 and not isinstance(dataloader.dataset, torch.utils.data.IterableDataset):
is_sampler_distributed = dataloader.sampler and isinstance(dataloader.sampler, DistributedSampler)
is_sampler_distributed = isinstance(dataloader.sampler, DistributedSampler)
is_batch_sampler_distributed = dataloader.batch_sampler is not None and isinstance(
dataloader.batch_sampler,
DistributedSampler,
Expand Down
4 changes: 3 additions & 1 deletion composer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(
self._eval_interval = None
self.eval_interval = eval_interval
self.auto_microbatching = _is_auto_microbatching(device_eval_microbatch_size)
if self.auto_microbatching and hasattr(self.dataloader, 'seq_parallel_world_size'):
raise ValueError('`device_eval_microbatch_size="auto"` is not compatible with sequence parallelism.')
self.device_eval_microbatch_size = _get_initial_device_eval_microbatch_size(
device_eval_microbatch_size,
self.auto_microbatching,
Expand Down Expand Up @@ -177,7 +179,7 @@ def _get_initial_device_eval_microbatch_size(
),
) from e
return batch_size
elif isinstance(device_eval_microbatch_size, int):
elif isinstance(device_eval_microbatch_size, Union[int, float]):
return device_eval_microbatch_size
else:
raise ValueError("device_eval_microbatch_size must be an int or ``'auto'``")
2 changes: 1 addition & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _validate_evaluator(evaluator: Evaluator, device: Device):
if hasattr(
evaluator.dataloader,
'seq_parallel_world_size',
) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.device_eval_batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
Expand Down

0 comments on commit 4e54004

Please sign in to comment.