๋ชฉํ : ์ฌ๋์ ์ ๋ฉด ์ฌ์ง์ผ๋ก 1. ๋ง์คํฌ ์ฐฉ์ฉ ์ฌ๋ถ 2. ์ฑ๋ณ 3. ๋์ด๋ฅผ ์๋์ ๊ฐ์ด 18๊ฐ class๋ก ์์ธก
#ViT32 #ViT16 #ModelSoups #Relabeling #Oversampling #ContrastiveLearnig #WeightedAverageEnsemble #HardVoting #SoftVoting #Optuna #Wandb
conda env create -f environment.yml
conda activate model_soups
- wandb, albumentations ๋ฑ ์ถ๊ฐ ์ค์น
- Model soups์์ ์ ๊ณตํ ViT-B/32 ๋ชจ๋ธ ๋ค์ด
- ์ด 72๊ฐ, ๋ณธ ํ๋ก์ ํธ์์๋ ์ต๋ 40๊ฐ๋ฅผ ์ฌ์ฉํ์ต๋๋ค.
python main.py --download-models --model-location <where models will be stored>
python finetune.py --name {๋ชจ๋ธ๋ช
} --i {๋ชจ๋ธ number} --random-seed {์๋ ์ค์ }
- Model soups์์ ์ ๊ณตํ pretrained ๋ชจ๋ธ์ 18๊ฐ์ class vector๋ฅผ output์ผ๋ก ํ๋ 1๊ฐ์ linear layer๋ฅผ ์ถ๊ฐํ์ฌ ํ์ตํฉ๋๋ค.
- ViT-B/16 ์ ๊ฒฝ์ฐ Model soups pretrained weight ๊ฐ ์์ผ๋ฏ๋ก clip ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ์ ๊ณตํ๋ ImageNet pretrained weight ์ ์ฌ์ฉํฉ๋๋ค.
--model {ViT-B/32 | ViT-B/16}
: base ๋ชจ๋ธ ์ค์ --name
: ์ ์ฅํ ๋ชจ๋ธ ์ด๋ฆ--i
: pretrained model์ index--random-seed
: random seed--lr
,--batch-size
,--epochs
,--data-location
,--model-location
: learning rate, batch size, epoch, ๋ฐ์ดํฐ ๊ฒฝ๋ก, ์ ์ฅํ ๋ชจ๋ธ ๊ฒฝ๋ก
- Tip : ์ ์คํฌ๋ฆฝํธ๋ฅผ ์ฌ์ฉํ์ฌ ํ์ต ์๋ํํ๊ธฐ. training.sh ํ์ผ ์์ฑ ํ ๋ค์ ๋ช ๋ น์ด ์คํ
bash training.sh
- Age ์์ฑ์ Old class์ ์ ์ train dataset์ผ๋ก ์ ํ๋ ํ์ต ์ฑ๋ฅ์ ๊ฐ์ ํ๊ธฐ ์ํด Old class data๋ง์ ์ถ๊ฐ๋ก Over samplingํ์์ต๋๋ค.
--old-aug True
: Old class 1ํ ์ถ๊ฐ over sampling
- Interclass์ ๊ฑฐ๋ฆฌ๋ฅผ ๋ํ๊ณ , Intraclass์ ๊ฑฐ๋ฆฌ๋ฅผ ์ขํ๋ Contrastive Learning์ ์ฌ์ฉํ์์ต๋๋ค.
--loss-fn
: ContrastiveLoss or CrossEntropyLoss, default๋ CrossEntropyLoss
- Model soups๋ ์ฌ๋ฌ ๊ฐ์ ๋์ผํ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ง pretrained ํ์ต ๋ชจ๋ธ๋ค์ ์กฐํฉํ์ฌ ํ๋์ ํ์ต ๋ชจ๋ธ์ ๋ง๋๋ ์์๋ธ ๊ธฐ๋ฒ์ ๋๋ค.
- ์ํ ๊ณผ์ ์ ์๋์ ๊ฐ์ต๋๋ค.
- ์ฌ๋ฌ๊ฐ์ pretrained model์ Test ํ์ฌ Accuracy๋ฅผ ์ป๋๋ค.
- Accuracy ๊ฐ์ผ๋ก ๋ด๋ฆผ์ฐจ์์ผ๋ก ์ ๋ ฌํ๋ค.
- ์์ฐจ์ ์ผ๋ก ๋ค์ ๋ชจ๋ธ๊ณผ์ weight๊ฐ์ averageํ์ฌ ํ๋์ ๋ชจ๋ธ์ ์์ฑํ๋ค.
- ์์ฑ๋ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ์ธก์ ์ ํ์์ ๋ ํ์ฌ๊น์ง ๊ฐ์ฅ ์ข์ Accuracy๋ณด๋ค ์ฑ๋ฅ์ด ์ข์ผ๋ฉด ์ ์ฅํ๊ณ , 3, 4๋ฒ์ ๋ฐ๋ณตํ๋ค. ๊ทธ๋ ์ง ์์ผ๋ฉด averageํ์ง ์๊ณ 3, 4๋ฒ์ ๋ฐ๋ณตํ๋ค.
- ๊ฐ์ฅ Accuracy๊ฐ ๋์ ๋ชจ๋ธ์ ์ต์ข ๋ชจ๋ธ๋ก ์ ์ ํ๋ค.
- Model soups์์ ์ ๊ณตํ pretrained model์ ViT-B/32 ๋ชจ๋ธ์ ๋๋ค.
- 1๋ฒ๊ณผ ๋์ผํ๊ฒ Fine tuning์ ์งํํฉ๋๋ค.
python main.py --eval-individual-models --name {๋ชจ๋ธ๋ช
} --model-num {๋ชจ๋ธ ๊ฐ์} --random-seed {๋๋ค ์๋}
- finetune์ ํตํด ๋ง๋ ๋ชจ๋ธ๋ค์ accuracy๋ฅผ ์ธก์ ํ์ฌ ๊ธฐ๋กํฉ๋๋ค.
--name
: ์ ์ฅ๋ ๋ชจ๋ธ๋ช--model-num
: Evaludationํ ๋ชจ๋ธ์ ๊ฐ์--random-seed
: ๋๋ค ์๋--val-ratio
,--epoch
,--data-location
,--model-locatoin
: validation dataset ๋น์จ, epoch, ๋ฐ์ดํฐ์ ๊ฒฝ๋ก, ์ ์ฅํ ๋ชจ๋ธ ๊ฒฝ๋ก
- ์คํ์ด ์๋ฃ๋๋ฉด logs ํด๋ ์์ ๊ฐ ๋ชจ๋ธ์ accuracy๊ฐ ์ ํ jsonl ํ์ผ์ด ์์ฑ๋ฉ๋๋ค.
python main.py --greedy-soup --name {๋ชจ๋ธ๋ช
} --model-num {๋ชจ๋ธ ๊ฐ์} --random-seed {๋๋ค ์๋}
- individual Evaluation์์ ์ ์ฅํ ์ฌ๋ฌ ๋ชจ๋ธ์ accuracy์ ๋ณด๋ฅผ ๋ด๋ฆผ์ฐจ์์ผ๋ก ์ ๋ ฌํฉ๋๋ค. ์ ๋ ฌ ๊ธฐ์ค์ผ๋ก ์ข์ ์ฑ๋ฅ์ ๋ด๋ ๋ชจ๋ธ๋ค์ ์์๋๋ก ๋ถ๋ฌ์ greedyํ๊ฒ ์กฐํฉํ์ฌ(averaging) ๋ ์ข์ ์ฑ๋ฅ์ ๋ด๋๋ก ํ๋ ์ต์ข ๋ชจ๋ธ์ ์์ฑํฉ๋๋ค.
--name
: ์ ์ฅ๋ ๋ชจ๋ธ๋ช--model-num
: Evaludationํ ๋ชจ๋ธ์ ๊ฐ์--random-seed
: ๋๋ค ์๋--val-ratio
,--epoch
,--data-location
,--model-locatoin
: validation dataset ๋น์จ, epoch, ๋ฐ์ดํฐ์ ๊ฒฝ๋ก, ์ ์ฅํ ๋ชจ๋ธ ๊ฒฝ๋ก- ์คํ ๊ฒฐ๊ณผ model ํด๋ ์์ ์ต์ข ๋ชจ๋ธ์ด ์ ์ฅ๋ฉ๋๋ค.
- log ํด๋ ์์ ๋ณ์ GREEDY_SOUP_LOG_FILE๊ฐ ์ด๋ฆ์ ๋ก๊ทธ๋ฅผ ์ ์ฅํฉ๋๋ค. ํด๋น ๋ก๊ทธ์๋ averaging๋ ๋ชจ๋ธ ์ ๋ณด๊ฐ ์ ์ฅ๋ฉ๋๋ค.
python validation.py --model-name {๋ชจ๋ธ๋ช
.pt ํ์ผ}
- Validation set์์ class๋ณ๋ก ์๋ชป ์์ธกํ ๋น์จ์ ์ถ๋ ฅํฉ๋๋ค.
- ํด๋น ๋ชจ๋ธ์ ํ์ตํ์ ๋, ์ฌ์ฉํ๋ random seed ๊ฐ์ ๋์ผํ๊ฒ ์ ์งํด ์ฃผ์ด์ผ ์ ํํ ํ๋ฅ ๊ณผ ์์ธก๊ฐ์ด ๋์ต๋๋ค.
--model-name
: evaluationํ ๋ชจ๋ธ๋ช ,--i
: pretrained model์ index--random-seed
: ๋๋ค ์๋
- Age class์ ๋ถ๋ฅ ์ฑ๋ฅ์ ๋์ด๊ณ ์ Age ์์ฑ๋ง์ ๋ถ๋ฅํ๋ ๋ชจ๋ธ์ ํ์ตํ์ฌ, ์ด๋ฅผ ์ ์ฒด class(18๊ฐ) ๋ถ๋ฅ ๋ชจ๋ธ์ ์์ธก๊ฐ๊ณผ weighted sum์ ํ์์ต๋๋ค.
--weighted-ensemble
: Age class๋ฅผ ํ์ตํ ๋ชจ๋ธ๋ช , Default๋ None
python finetune_age.py --name {๋ชจ๋ธ๋ช
} --i {๋ชจ๋ธ number} --random-seed {์๋ ์ค์ }
- finetune_age.py๋ Age class๋ง์ ํ์ตํฉ๋๋ค.
--name
,--i
,--random-seed
๋ finetune.py์ ๋์ผํ๊ฒ ์ค์
- 2๊ฐ์ ํ์ต ๋ชจ๋ธ์ ๊ฐ class์ ํ๋ฅ ๊ฐ์ minmax scaling ํ ๋ํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.
--soft-voting
: soft votingํ ๋ชจ๋ธ๋ช , Default๋ None
- ์ต์ข ์์ธก ๊ฒฐ๊ณผ csv ํ์ผ์ ์ฌ๋ฌ๊ฐ๋ฅผ ์ต์ข ์ ์ผ๋ก Hard voting์ ์ํํ์ฌ Ensemble์ ์งํํ์์ต๋๋ค.
- hard_voting.ipynb ์ ์คํํ์ฌ, ์์๋ธ์ ์ํ๋ csv๋ฅผ ๊ฐ์ง๊ณ hard voting์ ์ํํ ์ ์์ต๋๋ค.
์๋ ๊ทธ๋ฆผ์ ์ถ๋ ฅ ์์์ ๋๋ค.
python inference.py --model-name {๋ชจ๋ธ๋ช
.pt ํ์ผ}
- ์์ฑํ ๋ชจ๋ธ ํ์ผ(.pt)๋ฅผ ์ด์ฉํ์ฌ Test data๋ฅผ ์์ธกํ๋ ๋ถ๋ถ์ ๋๋ค.
--model-name
: inferenceํ ๋ชจ๋ธ๋ช--weighted-ensemble
,--soft-voting
: Weighted average ensemble ์ ๋ชจ๋ธ๋ช , Soft Voting ์ ๋ชจ๋ธ๋ช
- ์ต์ข ๊ฒฐ๊ณผ csv ํ์ผ์ด output ํด๋์ ์ ์ฅ๋ฉ๋๋ค.
- ์๋ชป ๋ผ๋ฒจ๋ง๋ ๋ฐ์ดํฐ id ๋ชฉ๋ก์ ๋ด์ relabel_dict ๋์ ๋๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ Relabeling์ ์งํํ์์ต๋๋ค.
python optuna_script.py
- Optuna๋ฅผ ์ด์ฉํ์ฌ Hyper paramter tuning์ ์งํํฉ๋๋ค.
- optuna_script.py ํ์ผ์์ hyper parameter tuning์ ์ํ ์ค์ ์ ์๋ ์ฌ์ง๊ณผ ๊ฐ์ด ๋ฃ์ด์ฃผ๊ณ ์คํํฉ๋๋ค.
- Private score 3rd / F1 score - 0.7613 / Accuracy - 81.3175
- Public score 6th / F1 score - 0.7653 / Accuracy - 81.3968
์ ํ์ค | ํํ๋ฏผ | ์ ํ์ | ๊น์ง๋ฒ | ์ค์ ๋ฆผ |
---|---|---|---|---|
|
|
|
|
|
Model soups : Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time.
ViT : https://github.com/google-research/vision_transformer
ContrastiveLoss : https://github.com/KevinMusgrave/pytorch-metric-learning
Optuna : https://optuna.org/
albumentations : https://albumentations.ai/
PyTorch : https://pytorch.org/
Wandb : https://wandb.ai/site