Table of Contents
First, install the "jax", "jaxlib" properly by running the following commands in your own conda environment:
pip install --upgrade jax jaxlib==0.1.57+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
Note. you can check all available versions here
pip install --user flax
pip install --user tensorflow
pip install --user transformers
Please first process your data into .txt format where each line represents one data point
Where ./ is Flax_T5_Pre-training/transformers/examples/flax/language-modeling
cd Flax_T5_Pre-training/transformers/examples/flax/language-modeling
mkdir pretrained_model
python run_t5_mlm_flax.py --output_dir="./pretrained_model" \
--train_file="./data/train.txt" \
--validation_file="./data/val.txt" \
--model_type="t5" \
--model_name_or_path="Salesforce/CodeT5" \
--config_name="Salesforce/CodeT5" \
--tokenizer_name="Salesforce/CodeT5" \
--from_pt \
--max_seq_length="512" \
--per_device_train_batch_size="8" \
--per_device_eval_batch_size="8" \
--adafactor \
--learning_rate="0.005" \
--weight_decay="0.001" \
--warmup_steps="2000" \
--overwrite_output_dir \
--logging_steps="500" \
--save_steps="10000" \
--eval_steps="2500"
--model_name_or_path / --config_name / --tokenizer_name
These parameters are related to the model checkpoint used to initialize your T5 model to be pre-trained, this can either be a local path or the model provided on the API provided by Huggingface Team.