Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: FSDP2 + stochastic rounding optimizers don't work #1505

Open
cassanof opened this issue Jan 6, 2025 · 6 comments
Open

BUG: FSDP2 + stochastic rounding optimizers don't work #1505

cassanof opened this issue Jan 6, 2025 · 6 comments

Comments

@cassanof
Copy link

cassanof commented Jan 6, 2025

_fp32_to_bf16_sr uses view(dtype), which does not have a sharding strategy with DTensor. therefore, if you use it in combination with FSDP2, you get the following error:

Operator aten.view.dtype does not have a sharding strategy registered.
@cassanof
Copy link
Author

cassanof commented Jan 6, 2025

i monkey-patched the stochastic rounding helper to convert the DTensor to an actual tensor and back, but this is not a great solution as there is significant overhead (actually runs slower than AdamW now):

from torch.distributed.tensor import DTensor, Replicate
def _fp32_to_bf16_sr(x_f32_orig: Tensor) -> Tensor:
    # For an FP32 number      [a31, ..., a16, a15, ..., a0] to be converted to BF16
    # - Round towards zero:   [a31, ..., a16,   0, ...,  0]
    # - Round away from zero: [a31, ..., a16+1, 0, ...,  0]
    # (since the value can be negative, we use round towards/away from zero instead of round up/down)
    #
    # For stochastic rounding, we round away from zero with the probability of
    # [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
    #
    # we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16

    # handle DTensor case, which does not support .view(dtype)
    if isinstance(x_f32_orig, DTensor):
        x_f32 = x_f32_orig.to_local()
    else:
        x_f32 = x_f32_orig

    rand_16bit = torch.randint(
        0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
    )
    x_f32_bits = x_f32.view(torch.int32)
    x_fraction = x_f32_bits & 0xFFFF  # lower 16 bits
    x_bf16_towards_zero = x_f32_bits & 0xFFFF0000  # upper 16 bits

    x_f32_bits = torch.where(
        rand_16bit < x_fraction,  # this is True with the probability of p_fraction
        x_bf16_towards_zero
        + 0x10000,  # this might overflow, which will result in UB due to signed integer
        x_bf16_towards_zero,
    )
    # alternative, slightly faster
    # x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000
    out = x_f32_bits.view(torch.float32).bfloat16()
    if isinstance(x_f32_orig, DTensor):
        # convert back to DTensor
        out = DTensor.from_local(
            out,
            device_mesh=x_f32_orig.device_mesh,
            placements=x_f32_orig.placements,
        )
    return out

@cassanof
Copy link
Author

cassanof commented Jan 6, 2025

Created an issue in pytorch too: pytorch/pytorch#144286

@gau-nernst
Copy link
Collaborator

I think we just wait for it to be resolved in PyTorch core. DTensor <-> local tensor probably introduces graph breaks, hence bad perf is observed. In the meantime, if you want to unblock, you can try to unwrap DTensor before calling adam step

torch.compile(single_param_adam, fullgraph=True, dynamic=False)(
p.detach(),
grad,
state["step"],
state["exp_avg"],
state["exp_avg_sq"],
state.get("max_exp_avg_sq", None),
group["lr"],
group["betas"][0],
group["betas"][1],
group["weight_decay"],
group["eps"],
self.is_adamw,
self.bf16_stochastic_round and p.dtype is torch.bfloat16,
)

Unwrapping DTensor before adam step will make sure the compiled adam is 1 graph. You can use ._local_tensor to directly access the local tensor, though use it at your own risk... .local_tensor() may introduce a copy + it might not share the same storage.

@cassanof
Copy link
Author

cassanof commented Jan 7, 2025

Thanks! wasn't aware of these internals. Is it the invariant that _local_tensor will be == dtensor.local_tensor() when the DTensor is replicated?

Also, wouldn't I still need to convert back to DTensor at some point?

@gau-nernst
Copy link
Collaborator

._local_tensor is the storage of DTensor I think. Think of .view(), hence, you don't need to convert back to DTensor.

It's not guaranteed that DTensor.local_tensor() == DTensor._local_tensor. Though in simple cases, they should be the same / share the same storage. I don't rmb the exact cases where they are different. So in the end, use it at your own risk.

@awgu
Copy link
Contributor

awgu commented Jan 8, 2025

I think DTensor.to_local() is a differentiable function returning DTensor._local_tensor. Running DTensor.to_local() under no_grad() context is functionally equivalent to accessing DTensor._local_tensor (with perhaps very minor extra CPU overhead).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants