diff --git a/torchrec/distributed/planner/planners.py b/torchrec/distributed/planner/planners.py index 1543970d5..101240d60 100644 --- a/torchrec/distributed/planner/planners.py +++ b/torchrec/distributed/planner/planners.py @@ -67,7 +67,7 @@ from torchrec.distributed.utils import none_throws -def _to_sharding_plan( +def to_sharding_plan( sharding_options: List[ShardingOption], topology: Topology, ) -> ShardingPlan: @@ -388,7 +388,7 @@ def plan( best_plan = callback(best_plan) self._best_plan = best_plan - sharding_plan = _to_sharding_plan(best_plan, self._topology) + sharding_plan = to_sharding_plan(best_plan, self._topology) end_time = perf_counter() for stats in self._stats: @@ -737,7 +737,7 @@ def plan( best_plan = callback(best_plan) self._best_plan = best_plan - sharding_plan = _to_sharding_plan( + sharding_plan = to_sharding_plan( best_plan, self._topology_groups[group] ) best_plans.append(sharding_plan)