-
Notifications
You must be signed in to change notification settings - Fork 196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BUG: FSDP2 + stochastic rounding optimizers don't work #1505
Comments
i monkey-patched the stochastic rounding helper to convert the DTensor to an actual tensor and back, but this is not a great solution as there is significant overhead (actually runs slower than AdamW now): from torch.distributed.tensor import DTensor, Replicate
def _fp32_to_bf16_sr(x_f32_orig: Tensor) -> Tensor:
# For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16
# - Round towards zero: [a31, ..., a16, 0, ..., 0]
# - Round away from zero: [a31, ..., a16+1, 0, ..., 0]
# (since the value can be negative, we use round towards/away from zero instead of round up/down)
#
# For stochastic rounding, we round away from zero with the probability of
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
#
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
# handle DTensor case, which does not support .view(dtype)
if isinstance(x_f32_orig, DTensor):
x_f32 = x_f32_orig.to_local()
else:
x_f32 = x_f32_orig
rand_16bit = torch.randint(
0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
)
x_f32_bits = x_f32.view(torch.int32)
x_fraction = x_f32_bits & 0xFFFF # lower 16 bits
x_bf16_towards_zero = x_f32_bits & 0xFFFF0000 # upper 16 bits
x_f32_bits = torch.where(
rand_16bit < x_fraction, # this is True with the probability of p_fraction
x_bf16_towards_zero
+ 0x10000, # this might overflow, which will result in UB due to signed integer
x_bf16_towards_zero,
)
# alternative, slightly faster
# x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000
out = x_f32_bits.view(torch.float32).bfloat16()
if isinstance(x_f32_orig, DTensor):
# convert back to DTensor
out = DTensor.from_local(
out,
device_mesh=x_f32_orig.device_mesh,
placements=x_f32_orig.placements,
)
return out |
Created an issue in pytorch too: pytorch/pytorch#144286 |
I think we just wait for it to be resolved in PyTorch core. DTensor <-> local tensor probably introduces graph breaks, hence bad perf is observed. In the meantime, if you want to unblock, you can try to unwrap DTensor before calling adam step ao/torchao/prototype/low_bit_optim/adam.py Lines 124 to 138 in 270a90f
Unwrapping DTensor before adam step will make sure the compiled adam is 1 graph. You can use |
Thanks! wasn't aware of these internals. Is it the invariant that Also, wouldn't I still need to convert back to DTensor at some point? |
It's not guaranteed that |
I think |
_fp32_to_bf16_sr
uses view(dtype), which does not have a sharding strategy with DTensor. therefore, if you use it in combination with FSDP2, you get the following error:The text was updated successfully, but these errors were encountered: