Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] JAX.JIT Switch and Sharding #822

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open

Conversation

rka97
Copy link

@rka97 rka97 commented Dec 9, 2024

Purpose

The goal of this PR is to allow model parameter and optimizer state sharding, and also to migrate the JAX code from using jax.pmap to using jax.jit.

TODOs:

  • Migrate reference optimizers to use jax.jit
    • Nesterov
    • AdamW
    • Others
  • Migrate workloads to use jax.jit
    • (Test workload) MNIST
    • (Test workload) CIFAR
    • WMT
    • Criteo1TB
    • FastMRI
    • Librispeech
    • OGBG
    • ImageNet

Changelog

  • Added some sharding utilities to handle data distributed
  • Replaced pmap code for CIFAR/MNIST with jit
  • Modified AdamW and Nesterov accordingly
  • Updated checkpoint and data_utils to support the new approach (mostly removing explicit jax_utils.replicate calls).

Issues

  • Prefetching functionality in CIFAR is temporarily disabled (marked with FIXME), not sure how to best support it here.
  • I haven't edited any of the PyTorch code, we will need to make sure they still do comparably..

@rka97 rka97 requested a review from a team as a code owner December 9, 2024 21:21
Copy link

github-actions bot commented Dec 9, 2024

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@rka97
Copy link
Author

rka97 commented Dec 9, 2024

recheck

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant