diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index b96a4a1c..dd41b8ae 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -239,16 +239,17 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): all_reduced_amax_tensor = all_reduced_amax_tensor.wait() - ( - reduced_fp8_amax_tensor, - reduced_fp8_amax_w_tensor, - reduced_fp8_amax_dL_dY_tensor, - ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) - - for idx, child in enumerate(fp8_layers): - child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) - child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) - child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) + # Split the reduced tensor into single element tensors + # [x1, x2, x3, w1, w2, w3, dL_dY1, dL_dY2, dL_dY3] -> [[x1], [x2], [x3], [w1], [w2], [w3], [dL_dY1], [dL_dY2], [dL_dY3]] + splits = torch.split(all_reduced_amax_tensor, 1) + + # Then foreach_copy the split tensors back into the original tensors + torch._foreach_copy_( + fp8_amax_x_tensor_list + + fp8_amax_w_tensor_list + + fp8_amax_dL_dY_tensor_list, + splits, + ) # We create two stacked tensor groups, one for the amax history and one for the current scales fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)