Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not Seeing much memory savings with Fp8 optimizer suddenly #1499

Open
asahni04 opened this issue Jan 6, 2025 · 9 comments
Open

Not Seeing much memory savings with Fp8 optimizer suddenly #1499

asahni04 opened this issue Jan 6, 2025 · 9 comments

Comments

@asahni04
Copy link

asahni04 commented Jan 6, 2025

Not Seeing much memory savings with Fp8 optimizer suddenly tried it on Torchtitan Llama 13B

@gau-nernst
Copy link
Collaborator

Do you have a snippet to reproduce the issue? Also, what is your PyTorch and torchao version.

@asahni04
Copy link
Author

asahni04 commented Jan 7, 2025

hi @gau-nernst i tried out with torchtitan repo, i just launched the training with Llama 13B and 8B FP8 adamw block_size 128 on H100. i see no memory savings at all: https://github.com/pytorch/torchtitan/blob/main/train_configs/llama3_8b.toml using TP + DP on single node. TP rank 8 DP rank =1.

@gau-nernst
Copy link
Collaborator

I will try to reproduce. Btw, if you switch to AdamW8bit or AdamW4bit, do you observe memory saving?

@gau-nernst
Copy link
Collaborator

I can't reproduce the issue. On an 2xH100 machine from vast.ai, using

data_parallel_replicate_degree = 1
data_parallel_shard_degree = -1
tensor_parallel_degree = 2

set NGPU=2

Changes for AdamWFp8
            # if name == "Adam":
            #     # TODO: make the optimizer options configurable by toml/cmd args
            #     optimizer = torch.optim.Adam(model.parameters(), **optimizer_kwargs)
            # elif name == "AdamW":
            #     optimizer = torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
            # else:
            #     raise NotImplementedError(f"Optimizer {name} not added.")

            from torchao.prototype.low_bit_optim import AdamWFp8

            optimizer_kwargs.pop("fused", None)
            optimizer_kwargs.pop("foreach", None)
            optimizer = AdamWFp8(model.parameters(), **optimizer_kwargs)

            self.optimizers.append(optimizer)

torch==2.7.0.dev20250107+cu126, torchtitan commit 90567fc98

image

Without TP, I observed that due to selective activation checkpointing policy, memory consumption might be similar but AdamWFp8 has faster end2end, since there are fewer recomputed activation. You might want to set activation checkpointing to "full" to make sure the comparison is fair.

@asahni04
Copy link
Author

asahni04 commented Jan 9, 2025

@gau-nernst hi what version of ao do you use?

@gau-nernst
Copy link
Collaborator

Latest stable 0.7

@asahni04
Copy link
Author

asahni04 commented Jan 9, 2025

also which config did you use? did you use bf16 training?

@gau-nernst
Copy link
Collaborator

Default Llama 8B config. All the changes I have mentioned in my previous reply.

@asahni04
Copy link
Author

asahni04 commented Jan 9, 2025

very weird, but i do notice savings on torchtitan but not on my own modified model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants