Skip to content

Latest commit

 

History

History
29 lines (20 loc) · 1.05 KB

README.md

File metadata and controls

29 lines (20 loc) · 1.05 KB

Monkfish: Distributed latent video model training on TPUs (and other stuff maybe)

This is the training code for a 2 stage autoregressive video model. Code works for training latents. 2nd stage is a WIP.

Running on a single TPU

 python -m monkfish.main.main config.json local [args...]

Running distributed from a head node:

 ray start --head --num-cpus=1 --port=6379
 PROJECT_SOURCE=path/to/monkfish python -m monkfish.main.main config.json distributed [args...]
 ray stop

References For Developers

Parameter scaling:

Jax sharding:

Data loader Design: