diff --git a/realhf/base/topology.py b/realhf/base/topology.py index 7d008d1c..fd77fd2d 100755 --- a/realhf/base/topology.py +++ b/realhf/base/topology.py @@ -1,6 +1,6 @@ # Modified from https://github.com/microsoft/DeepSpeed/blob/aed599b4422b1cdf7397abb05a58c3726523a333/deepspeed/runtime/pipe/topology.py# -from itertools import product as cartesian_product +from itertools import product as cartesian_product, permutations from typing import Dict, List, NamedTuple, Optional, Tuple import torch.distributed as dist @@ -43,7 +43,7 @@ def decompose_to_three_factors(n: int) -> List[Tuple[int, int, int]]: for j in range(i, int((n // i) ** (1 / 2)) + 1): if (n // i) % j == 0: k = (n // i) // j - factors += list(set(itertools.permutations([i, j, k]))) + factors += list(set(permutations([i, j, k]))) return factors