Skip to content

Commit

Permalink
fix docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
Fridge003 committed Jul 21, 2023
1 parent 9d449cf commit d6dd38b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int,
num_stages: int) -> Tuple[List[int], int]:
"""
Distribute t5 layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder (if decoder exists).
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""

# number of encoder layers must be a positive integer
Expand All @@ -189,7 +190,7 @@ def distribute_t5_layers(num_encoder_layers: int, num_decoder_layers: int,
if num_encoder_layers + num_decoder_layers < num_stages:
raise ValueError("The total number of layers can't be smaller than number of stages.")

# int the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages

Expand Down

0 comments on commit d6dd38b

Please sign in to comment.