forked from skypilot-org/skypilot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tpuvm_mnist.yaml
33 lines (28 loc) · 893 Bytes
/
tpuvm_mnist.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
name: tpuvm_mnist
resources:
accelerators: tpu-v2-8
# The setup command. Will be run under the working directory.
setup: |
git clone https://github.com/google/flax.git --branch v0.8.2
conda activate flax
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n flax python=3.10 -y
conda activate flax
# Make sure to install TPU related packages in a conda env to avoid package conflicts.
pip install \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html "jax[tpu]==0.4.25" \
clu \
tensorflow tensorflow-datasets
pip install -e flax
fi
# The command to run. Will be run under the working directory.
run: |
conda activate flax
pip install clu
cd flax/examples/mnist
python3 main.py --workdir=/tmp/mnist \
--config=configs/default.py \
--config.learning_rate=0.05 \
--config.num_epochs=10