A trivial MNIST example with RedCoast.
+A trivial MNIST example with RedCoast. Runnable by
+
XLA_FLAGS="--xla_force_host_platform_device_count=8" python main.py --n_model_shards 2
+XLA_FLAGS="--xla_force_host_platform_device_count=8" python main.py
Source Code
from functools import partial
import fire
-import jax
import numpy as np
-import jax.numpy as jnp
from flax import linen as nn
import optax
from torchvision.datasets import MNIST
from redco import Deployer, Trainer, Predictor
+# A simple CNN model
+# Copied from https://github.com/google/flax/blob/main/examples/mnist/train.py
class CNN(nn.Module):
@nn.compact
- def __call__(self, x, training):
+ def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
@@ -575,131 +616,75 @@ Source Code