Alpa automatically parallelizes tensor computation graphs and runs them on a distributed cluster.
Use Alpa's single line API @parallelize
to scale your single-node training code to distributed clusters, even though
your model is much bigger than a single device memory.
import alpa
@alpa.parallelize
def train_step(model_state, batch):
def loss_func(params):
out = model_state.forward(params, batch["x"])
return jnp.mean((out - batch["y"]) ** 2)
grads = grad(loss_func)(state.params)
new_model_state = model_state.apply_gradient(grads)
return new_model_state
# The training loop now automatically runs on your designated cluster.
model_state = create_train_state()
for batch in data_loader:
model_state = train_step(model_state, batch)
Check out the Alpa Documentation site for installation instructions, tutorials, examples, and more.
- Alpa paper (OSDI'22)
- Blog
Please read the contributor guide if you are interested in contributing to Alpa. Please connect to Alpa contributors via the Alpa slack.
Alpa is licensed under the Apache-2.0 license.