This contains the code computing the statistics and doing inference for Realistic Evaluation of Model Merging Methods.
We use git-theta
to compute the merges, which allows for merging each parameter block indepedently, which makes it possible to merge large models on smaller.
- Create a virtual environment and activate it.
python3.10 -m venv env
source env/bin/activate
- Install dependencies
python -m pip install -r requirements.txt -f https://download.pytorch.org/whl/cu113/torch_stable.html
A modified version of Promptsource must be installed from source.
cd promptsource
python -m pip install -e .
Download multilingual ROUGE scorer from https://github.com/csebuetnlp/xl-sum/tree/master/multilingual_rouge_scoring
git clone https://github.com/csebuetnlp/xl-sum.git
cd multilingual_rouge_scoring
mv io.py io_original.py
python -m pip install -r requirements.txt
python -m pip install .
mv io_original.py io.py
A modified version of open_clip must be installed from source.
cd open_clip
python -m pip install -e .
- Set environment variables (This step has to be done every session)
source bin/setup_{machine}.sh
An example of how save the statistics for TIES, RegMean, Fisher Merging, and MaTS are shown below. Note that though TIES has not statistics, we treat the trimmed model as a statistic so that the merge can be computed for each parameter block indepedently.
Trimmed model
python src/save_trimmed_model.py --pretrained_model_name {pretrained_model_name} --checkpoint_filepath {checkpoint_filepath} --save_path_for_trimmed_model {path_to_save_model}
RegMean and MaTS
python src/save_gram_matrices.py -c configs/evaluation_run/language.json configs/evaluation_dataset/p3.json configs/model/mt5_xl_lm_adapt.json configs/model/full_model.json -er eval_batch_size=2 -ed dataset=squad split=train -m filepath_to_load_model="exp_out/p3/squad/models-google-mt5-xl-lm-adapt/full_model/2024-02-06-21-15-12/checkpoints/best_checkpoint_199.pt" --output_path gram_squad.pt
Fisher Merging
python src/save_fisher.py -c configs/evaluation_run/language.json configs/evaluation_dataset/p3.json configs/model/mt5_xl_lm_adapt.json configs/model/full_model.json -ed dataset=squad split=train -m filepath_to_load_model="exp_out/p3/squad/models-google-mt5-xl-lm-adapt/full_model/2024-02-06-21-15-12/checkpoints/best_checkpoint_199.pt" --output_path fisher_squad.pt
Run the training
script with
-c
the list of configs for the model
-td
any training dataset config parameters to update.
-ed
any evaluation dataset config parameters to update.
-tr
any training run config parameters to update.
-er
any evaluation run config parameters to update.
-m
any model config parameters to update
Cross Lingual:
python src/train/training.py -c configs/model/mt5_xl_lm_adapt.json configs/model/full_model.json configs/training_run/cross_lingual.json configs/training_dataset/p3.json configs/evaluation_dataset/p3.json configs/evaluation_run/language.json -tr micro_train_batch_size=2 train_task_mixture=multitask_multilingual -er eval_batch_size=4 task_mixture=multitask_multilingual
python src/train/training.py -c configs/model/mt5_xl_lm_adapt.json configs/model/full_model.json configs/training_run/cross_lingual.json configs/training_dataset/p3.json configs/evaluation_dataset/p3.json configs/evaluation_run/language.json -td dataset=squad -ed dataset=squad -tr micro_train_batch_size=2 -er eval_batch_size=4
python src/train/training.py -c configs/model/mt5_xl_lm_adapt.json configs/model/full_model.json configs/training_run/cross_lingual.json configs/training_dataset/p3.json configs/evaluation_dataset/p3.json configs/evaluation_run/language.json -td dataset=xnli language_code=ar -ed dataset=xnli language_code=ar -tr micro_train_batch_size=2 -er eval_batch_size=32
python src/train/training.py -c configs/model/mt5_xl_lm_adapt.json configs/model/full_model.json configs/training_run/cross_lingual.json configs/training_dataset/p3.json configs/evaluation_dataset/p3.json configs/evaluation_run/language.json -td dataset=wiki_lingua language=thai -ed dataset=wiki_lingua language=thai -tr micro_train_batch_size=1 -er eval_batch_size=4
python src/train/training.py -c configs/model/mt5_xl_lm_adapt.json configs/model/full_model.json configs/training_run/cross_lingual.json configs/training_dataset/p3.json configs/evaluation_dataset/p3.json configs/evaluation_run/language.json -td dataset=tydiqa language=korean -ed dataset=tydiqa language=korean -tr micro_train_batch_size=2 -er eval_batch_size=4
python src/train/training.py -c configs/model/mt5_xl_lm_adapt.json configs/model/full_model.json configs/training_run/cross_lingual.json configs/training_dataset/p3.json configs/evaluation_dataset/p3.json configs/evaluation_run/language.json -td dataset=xlwic language_code=de -ed dataset=xlwic language_code=de -tr micro_train_batch_size=4 -er eval_batch_size=16
DomainNet
One Domain
python src/train/training.py -c configs/model/clip.json configs/training_run/domainnet.json configs/training_dataset/domainnet.json configs/evaluation_dataset/domainnet_training.json configs/evaluation_run/vision.json -td domain=clipart task=mammal -ed domain=clipart task=mammal -tr micro_train_batch_size=128 -er eval_batch_size=256
All Domains
python src/train/training.py -c configs/model/clip.json configs/training_run/domainnet.json configs/training_dataset/domainnet_all.json configs/evaluation_dataset/domainnet_all.json configs/evaluation_run/vision.json -tr micro_train_batch_size=128 train_task_mixture=domainnet num_batches=10000 should_eval_before_training=False -er eval_batch_size=256 task_mixture=domainnet
Some methods (TIES, RegMean, Fisher Merging, and MaTS) require saving some statistics first. Since the merge is computed indedently for each parameter block, the trimmed model is the statistic for TIES.
python src/merging/save_statistic.py -c configs/evaluation_run/vision.json configs/evaluation_dataset/domainnet.json configs/model/clip.json configs/model/full_model.json configs/merging/domainnet.json configs/merging/{method}.json -er eval_batch_size=32
python src/merging/save_statistic.py -c configs/evaluation_run/language.json configs/evaluation_dataset/p3.json configs/model/mt5_xl_lm_adapt.json configs/model/full_model.json configs/merging/multitask_multilingual.json configs/merging/{method}.json -er eval_batch_size=32
We use git-theta
to compute the merge and recommend and creating a separate repo for tracking the models to not tangle the code and model .git
To do so, first clone the repo. To not tangle the .git
, we recommend cloning git-theta
is a different directory not under this one.
git clone https://github.com/blester125/git-theta
git checkout feat/merge-cli
python -m pip install -e .
We also recommend creating a new git repo for tracking models not under this one to not tangle the .git
mkdir merged-models
mv ../exp_out .
git init
git theta track
Follow the instructions at https://github.com/blester125/git-theta/tree/feat/merge-cli/plugins/merge_cli to start using git-theta
.
Run the inference
script with
-e
path to experiment dir of model
--merged_model
path to merged model
Inference on best checkpoint from experiment. The correct dataset and model config are noted from the experiment path.
python src/eval/inference.py -e {exp_dir}
Inference on merged model. The correct evaluation config, evaluation dataset config, and model configs must be passed in.
python src/eval/inference.py -c configs/model/mt5_xl_lm_adapt.json configs/evaluation_dataset/p3.json configs/evaluation_run/language.json --merged_model average.pt --output_dir average
The domainnet and cross lingual checkpoints can be found here: https://console.cloud.google.com/storage/browser/realistic_evaluation_model_merging_compositional_generalization We also include a Pytorch version of mT5-xl-lm-adapt already converted from the default Jax format.
If you find this repo helpful, feel free to cite our work:
@article{tam2024remm,
title={Realistic Evaluation of Model Merging for Compositional Generalization},
author={Tam, Derek and Kant, Yash and Lester, Brian and Gilitschenski, Igor and Raffel, Colin},
journal={arXiv preprint arXiv:2409.18314},
year={2024}
}