Skip to content

Commit

Permalink
Add comments to DP proposer
Browse files Browse the repository at this point in the history
Summary:
# context
* added comments and changed some variable name for better readability

# explicit assumptions
* constraint is the total_hbm
* objective function is the **sum** of the perf metric over all tables

# implicit assumptions
* each table is evenly sharded to all the devices, or can be treated as evenly sharded, so we don't need to consider bottleneck effect
* device doesn't need to be identical, i.e., could have different hbm storage
* each hbm_bin is distributed over all the devices (equivalently)

# variables
* `proposal`: includes a sharding option for each table
* `proposal_list`: includes a list of `proposal`s, each of which has the best perf under a given total_hbm constraint.
* `hbm_by_fqn`: memory constraint lookup table: [table_id][sharding_option_id]
* `perf_by_fqn`: performance metrics lookup table: [table_id][sharding_option_id]
NOTE: hbm is measured in unit of `bin`, such as 0.4 bin, 1.6 bin, etc.

# dp table
* dimensions: `table_count` x `bin_count` x `case`, where `case` is a tuple of (`perf`, `hbm`)
* dp table caches the best **case** that has [0 - table_i] tables, under given `hbm`.
NOTE: **memory complexity** is `table_count` x `bin_count`, assuming `bin_count` >> `option_count`, otherwise `table_count` x `option_count`.
**time complexity** is `table_count` x `bin_count` x `option_count`
* firstly, `table_i` loops over the tables to add each table one by one
* secondly, `option_j` loops over the options of the current `table_i`
* thirdly, `hbm` loops over the bin_count for each hbm constraint to find the minimal

# usage
* dp algorithm will only be called once when `self._inited == False`
* `propose` will return a proposal from the `proposal_list`
* each time calling the `feedback` will move the current proposal to the next (from a higher bhm constraint to a lower one)

Differential Revision: D61565731
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Aug 20, 2024
1 parent b6380be commit 313d4e7
Showing 1 changed file with 68 additions and 47 deletions.
115 changes: 68 additions & 47 deletions torchrec/distributed/planner/proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def __init__(self, hbm_bins_per_device: int = 100) -> None:
self._sharding_options_by_fqn: OrderedDict[str, List[ShardingOption]] = (
OrderedDict()
)
self._proposal_indices: List[List[int]] = []
# list of proposals with different total_hbm, a proposal is a list of indices of sharding_options
self._proposal_list: List[List[int]] = []
self._current_proposal: int = -1

def load(
Expand All @@ -340,6 +341,7 @@ def load(
) -> None:
"""Load search space."""
self._reset()
# order the sharding_option by total_storage.hbm from low to high
for sharding_option in sorted(search_space, key=lambda x: x.total_storage.hbm):
fqn = sharding_option.fqn
if fqn not in self._sharding_options_by_fqn:
Expand All @@ -348,7 +350,7 @@ def load(

def _reset(self) -> None:
self._sharding_options_by_fqn = OrderedDict()
self._proposal_indices = []
self._proposal_list = []
self._current_proposal = -1

def propose(self) -> Optional[List[ShardingOption]]:
Expand All @@ -359,7 +361,7 @@ def propose(self) -> Optional[List[ShardingOption]]:
for sharding_options in self._sharding_options_by_fqn.values()
]
elif self._current_proposal >= 0:
proposal_index = self._proposal_indices[self._current_proposal]
proposal_index = self._proposal_list[self._current_proposal]
return [
self._sharding_options_by_fqn[fqn][index]
for fqn, index in zip(
Expand All @@ -379,62 +381,81 @@ def feedback(
"""Feedback last proposed plan."""
if not self._inited:
self._inited = True
M = len(self._sharding_options_by_fqn)
N = max([len(x) for x in self._sharding_options_by_fqn.values()])
table_count = len(self._sharding_options_by_fqn)
option_count = max([len(x) for x in self._sharding_options_by_fqn.values()])

assert storage_constraint is not None
# are we assuming the table will be evenly sharded on all devices?
hbm_total = sum([x.storage.hbm for x in storage_constraint.devices])
K = self._hbm_bins_per_device * len(storage_constraint.devices)
bin_size = float(hbm_total) / K

dp = [[(float("inf"), float("inf"))] * K for _ in range(M)]
backtrack = [[(-1, -1)] * K for _ in range(M)]

hbm_by_fqn = [[float("inf") for _ in range(N)] for _ in range(M)]
perf_by_fqn = [[float("inf") for _ in range(N)] for _ in range(M)]
for m, sharding_options in enumerate(
bin_count = self._hbm_bins_per_device * len(storage_constraint.devices)
bin_size = float(hbm_total) / bin_count

dp = [
[(float("inf"), float("inf"))] * bin_count for _ in range(table_count)
] # [table_id][hbm_bin][perf, hbm]

backtrack = [
[(-1, -1)] * bin_count for _ in range(table_count)
] # [table_id][hbm_bin][opt_id, prev_hbm_bin]

hbm_by_fqn = [
[float("inf") for _ in range(option_count)] for _ in range(table_count)
] # memory constraint lookup table: [table_id][sharding_option_id]
perf_by_fqn = [
[float("inf") for _ in range(option_count)] for _ in range(table_count)
] # performance metrics lookup table: [table_id][sharding_option_id]

# populate hbm and perf for each sharding option and table: A[table_id][sharding_option_id]
for table_id, sharding_options in enumerate(
self._sharding_options_by_fqn.values()
):
for n, sharding_option in enumerate(sharding_options):
hbm_by_fqn[m][n] = _bytes_to_float_bin(
for opt_id, sharding_option in enumerate(sharding_options):
hbm_by_fqn[table_id][opt_id] = _bytes_to_float_bin(
sharding_option.total_storage.hbm, bin_size
)
perf_by_fqn[m][n] = sharding_option.total_perf

for j in range(N):
if hbm_by_fqn[0][j] < K:
hbm_i = int(hbm_by_fqn[0][j])
if dp[0][hbm_i][0] > perf_by_fqn[0][j]:
dp[0][hbm_i] = (perf_by_fqn[0][j], hbm_by_fqn[0][j])
backtrack[0][hbm_i] = (j, -1)

for i in range(1, M):
for j in range(N):
for c in range(K):
prev_perf, perv_hbm = dp[i - 1][c]
perf_by_fqn[table_id][opt_id] = sharding_option.total_perf

table_0 = 0
for opt_j in range(option_count):
if hbm_by_fqn[0][opt_j] < bin_count:
hbm_i = int(hbm_by_fqn[0][opt_j])
# options are ordered in increasing order of hbm, we only want to consider
# a sharding option that has higher hbm and better perf (the smaller the better)
if dp[table_0][hbm_i][0] > perf_by_fqn[table_0][opt_j]:
dp[table_0][hbm_i] = (
perf_by_fqn[table_0][opt_j],
hbm_by_fqn[table_0][opt_j],
)
backtrack[table_0][hbm_i] = (opt_j, -1)

# dp: table_count x option_count x bin_count
for table_i in range(1, table_count):
for opt_j in range(option_count):
for hbm in range(bin_count):
prev_perf, perv_hbm = dp[table_i - 1][hbm]
if prev_perf < float("inf"):
new_hbm = perv_hbm + hbm_by_fqn[i][j]
if new_hbm < K:
new_hbm = perv_hbm + hbm_by_fqn[table_i][opt_j]
if new_hbm < bin_count:
new_hbm_i = int(new_hbm)
new_perf = prev_perf + perf_by_fqn[i][j]
if dp[i][new_hbm_i][0] > new_perf:
dp[i][new_hbm_i] = (new_perf, new_hbm)
backtrack[i][new_hbm_i] = (j, c)

self._proposal_indices = []
for c in range(K - 1, -1, -1):
cur_col_idx, cur_hbm_idx = backtrack[M - 1][c]
if cur_col_idx >= 0:
column_indices = [-1] * M
column_indices[M - 1] = cur_col_idx
for i in range(M - 2, -1, -1):
column_indices[i], cur_hbm_idx = backtrack[i][cur_hbm_idx]
self._proposal_indices.append(column_indices)
if len(self._proposal_indices) > 0:
new_perf = prev_perf + perf_by_fqn[table_i][opt_j]
if dp[table_i][new_hbm_i][0] > new_perf:
dp[table_i][new_hbm_i] = (new_perf, new_hbm)
backtrack[table_i][new_hbm_i] = (opt_j, hbm)
self._proposal_list = []
# fill in all the proposals, starting from highest hbm to lowest hbm
for c in range(bin_count - 1, -1, -1):
cur_opt_idx, cur_hbm_idx = backtrack[table_count - 1][c]
if cur_opt_idx >= 0:
proposal_indices = [-1] * table_count
proposal_indices[table_count - 1] = cur_opt_idx
for i in range(table_count - 2, -1, -1):
proposal_indices[i], cur_hbm_idx = backtrack[i][cur_hbm_idx]
self._proposal_list.append(proposal_indices)
if len(self._proposal_list) > 0:
self._current_proposal = 0
else:
self._current_proposal += 1
if self._current_proposal >= len(self._proposal_indices):
if self._current_proposal >= len(self._proposal_list):
self._current_proposal = -1


Expand Down

0 comments on commit 313d4e7

Please sign in to comment.