Skip to content

Commit

Permalink
todo
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Oct 3, 2023
1 parent 8db6472 commit 1418f85
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ print(preds)
import torch
from mega_vit.main import MegaVit

v = ViT(
v = MegaVit(
image_size = 224,
patch_size = 14,
num_classes = 1000,
Expand All @@ -63,6 +63,12 @@ preds = v(img) # (1, 1000)
print(preds)
```

# Train
- A basic training script is provided with more or less the same hyperparams as in paper, note it it will
take optimization to find reliable hyperparameters: `python3 train.py`



# Model Architecture
- Regular vit with new parallel layers, QK(Query/Key)Normalization, and omitted biases.

Expand Down Expand Up @@ -105,8 +111,10 @@ Eprint = {arXiv:2302.05442},
```

# Todo
- [ ] Add flash attention, with layernorm before attn, and then layernom for qk values,
- [ ] Basic training script on CIFAR,
- [x] Add flash attention, with layernorm before attn, and then layernom for qk values,
- [x] Basic training script on CIFAR,

- [ ] Add `FSDP` to training script,
- [ ] When using ViT-22B, similar to any large scale model, it is difficult to understand how the model arrived at a specific decision, which could lead to lack of
trust and accountability. Add in a mechanism to backtrack
- [ ] create logic to train the decoder for 300k steps with a batch size of 64 using Adam (Kingma and Ba, 2015) and clip the gradients to a global norm value of 0.05 to stabilize training. We linearly increase the learning rate for 2500 steps to 0.0002 (starting from 0) and then decay the learning rate with a cosine schedule (Loshchilov and Hutter, 2017) back to 0.

0 comments on commit 1418f85

Please sign in to comment.