Skip to content

awsm-research/Flax_T5_Pre-training

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 

Repository files navigation

T5 Pre-training Using FLAX Framework


T5 Pre-training Using FLAX

Table of Contents
  1. How to set up your environment
  2. How to prepare your pre-training data
  3. How to pre-train a T5 model using Transformers and Flax

How to set up your environment

First, install the "jax", "jaxlib" properly by running the following commands in your own conda environment:

pip install --upgrade jax jaxlib==0.1.57+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Note. you can check all available versions here

Second, install "flax" by running the following commands in your own conda environment:

pip install --user flax

Third, install "tensorflow" by running the following commands in your own conda environment:

pip install --user tensorflow

Forth, install "transformers" by running the following commands in your own conda environment:

pip install --user transformers

How to prepare your pre-training data

Please first process your data into .txt format where each line represents one data point

By default, the training data path is ./data/train.txt

By default, the validation data path is ./data/val.txt

Where ./ is Flax_T5_Pre-training/transformers/examples/flax/language-modeling

How to pre-train a T5 model using Transformers and Flax

Step 1, cd to pre-training dir:

cd Flax_T5_Pre-training/transformers/examples/flax/language-modeling

Step 2, make dir to save your pre-trained model

mkdir pretrained_model

Step 3, run the following commands to start pre-training:

python run_t5_mlm_flax.py --output_dir="./pretrained_model" \
                          --train_file="./data/train.txt" \
                          --validation_file="./data/val.txt" \
                          --model_type="t5" \
                          --model_name_or_path="Salesforce/CodeT5" \
                          --config_name="Salesforce/CodeT5" \
                          --tokenizer_name="Salesforce/CodeT5" \
                          --from_pt \
                          --max_seq_length="512" \
                          --per_device_train_batch_size="8" \
                          --per_device_eval_batch_size="8" \
                          --adafactor \
                          --learning_rate="0.005" \
                          --weight_decay="0.001" \
                          --warmup_steps="2000" \
                          --overwrite_output_dir \
                          --logging_steps="500" \                            
                          --save_steps="10000" \
                          --eval_steps="2500"

Important Note for Pre-training setting

The pre-training setting above is the default setting provided by authors of Transformers library.

Please modify to fit your needs.

--model_name_or_path / --config_name / --tokenizer_name

These parameters are related to the model checkpoint used to initialize your T5 model to be pre-trained, this can either be a local path or the model provided on the API provided by Huggingface Team.

If your checkpoint model is a flax model, please change "--from_pt" to "--from_flax"

By default, the script accepts a checkpoint model in PyTorch format.