Skip to content

Code for "Preference Tuning For Toxicity Mitigation Generalizes Across Languages." Paper accepted at Findings of EMNLP 2024

License

Notifications You must be signed in to change notification settings

BatsResearch/cross-lingual-detox

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

16 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Preference Tuning For Toxicity Mitigation Generalizes Across Languages

๐Ÿ”ฅ Cross-lingual safety generalization: This is the first work to demonstrate preference tuning for toxicity mitigation can generalize cross-lingually in a zero-shot manner. We evaluated on 17 different languages and different LLMs (such as BLOOM, Llama3, and Aya-23), all of which shows cross-lingual detoxification after English DPO preference tuning.

๐Ÿ” Mechanistic findings: We show that the dual multilinguality of toxic vectors (in MLP layers) explains the cross-lingual generalization. We find that the toxic vectors in MLPs encode multilingual toxic concepts, and we can control the output toxicity level by controlling the activation levels of those vectors. We then show that English DPO reduces activation levels of toxic vectors across languages.


Table of Contents


Setup

  1. Create a conda environment with python version 3.11
conda create --name xgdetox python=3.11
conda activate xgdetox
  1. Install poetry and other dependencies with poetry. (Make sure you are at project's root directory, where pyproject.toml locates.)
pip install poetry 
poetry install 

DPO Preference Tuning

0. Download Training and Evaluation Data

  • Training (Toxicity Pairwise Data): Download the toxicity_pairwise.zip data from here (Source: Mechanistically Understanding DPO: Toxicity).

  • Evaluation (RTP-LX): Follow instructions from Microsoft to download the dataset of RTP-LX input prompts. It will contain files of RTP-LX/RTP_LX_{language}.json. Our repo and experiments use the dataset released in Apr'24 (May'24 works too).

1. Training

To perform DPO preference tuning (with or without LoRA), simply follow the following code example:

python3 xg/training/dpo.py \
    --data_dir /path/to/toxicity_pairwise/ \
    --model_name_or_path meta-llama/Llama-2-7b-hf \
    --output_dir /path/to/save/model_ckpt \
    --per_device_train_batch_size 4 \
    --wandb_run_name your_wandb_runname \
    --use_lora  # remove this line if you want to do full model finetuning 

After DPO training, you can directly use the model checkpoint from /path/to/save/model_ckpt/final_checkpoint/.

However, because parameter-efficient training with LoRA adapters save the adapters, use the following code to merge the LoRA adapters and save the model weight. This helps with vLLM library for generation stage (at the time we design the code, there are bugs with loading LoRA weights so it is more straightforward to pass the merged model instead of base model + lora weights).

python3 xg/training/merge_peft.py \
    --base_model_name meta-llama/Llama-2-7b-hf \
    --lora_adapter /path/to/save/model_ckpt/final_checkpoint \
    --output_dir /path/to/save/merge_final_checkpoint

We have uploaded our trained models to HuggingFace Hub:

2. Generation

We use the vLLM library to obtain model continuations. We recommend user follow their installation instruction before running the following generation code. Our code saves the vLLM generations as /path/to/save/outputs/{MODEL_NAME}/output-rtp_lx_{LANG}.json

PROMPT_FILE=/path/to/RTP-LX/RTP_LX_ZH-Hans.json # you can change the language to other languages than ZH-Hans
python3 xg/generate/vllm_script_sample.py \
    --prompt_file $PROMPT_FILE \
    --model /path/to/save/merge_final_checkpoint \  # or /path/to/save/model_ckpt/final_checkpoint (if you do full finetuning)
    --output_dir /path/to/save/outputs

3. Evaluation

  • Toxicity: First run the xg/eval/perspective_api_eval.py to save the toxicity scores from Perspective API. Then run xg/eval/metric_toxicity.py to aggregate the scores.

  • Fluency: Run the xg/eval/metric_perplexity.py script to compute median conditional perplexity with the mT5-xl model. It will also save the array of all perplexity scores.

  • Diversity: Run the xg/eval/metric_diversity.py script.

MODEL_OUTPUTS_FOLDER=... # vllm generations folder (/path/to/save/outputs/{MODEL_NAME})

############### toxicity ###############
# call Perspective API
LANGS=( ar cs de en es fr hi id it ja ko nl pl pt ru sv zh-hans )
for LANG in "${LANGS[@]}"
do
    echo "Processing $LANG"
    python3 xg/eval/perspective_api_eval.py \
        --api_key ... \  # YOUR API KEY
        --datapath "${MODEL_OUTPUTS_FOLDER}/output-rtp_lx_${LANG}.json" \
        --output_folder "${MODEL_OUTPUTS_FOLDER}/perspective_api_eval/" \
        --language $LANG
done

# aggregate toxicity scores
PERSPECTIVE_OUTPUTS_FOLDER=${MODEL_OUTPUTS_FOLDER}/perspective_api_eval
python3 xg/eval/metric_toxicity.py \
    --perspective_outputs_folder $PERSPECTIVE_OUTPUTS_FOLDER

############### fluency ###############
python3 xg/eval/metric_perplexity.py \
    --model_outputs_folder $MODEL_OUTPUTS_FOLDER

############### diversity ###############
python3 xg/eval/metric_diversity.py \
    --model_outputs_folder $MODEL_OUTPUTS_FOLDER

Interpretability Experiments

0. Download Jigsaw Toxic Comments Dataset

Download the Jigsaw dataset from Kaggle.

1. Probe training

To train a linear probe for binary toxic classification, follow these steps:

  • Replace the train_fp variable with the path to the train split of the Jigsaw dataset.
  • Run the provided script.

All hyperparameters are pre-configured in the script file.

python scripts/run_train_probe.py

2. Analyze the value vectors of model.

We first identify the potential sources of toxicity by selecting the top 100 value vectors based on their cosine similarities with the probe vector. Then, we collect the corresponding neuron activations averaged across the next 20 tokens generated from the English RTP-LX prompt. The value vectors are retained if their corresponding neuron activations are positive during the forward pass. We found 36 value vectors meeting these criteria, and they are stored here. We then project them onto the vocabulary space to interpret the tokens they promote when activated. More details can be found in this notebook.

3. Causal Intervention

To better understand these sub-updates, we directly intervene in their corresponding and inspect the changes they induce. We provide a minimal experiment demonstrating how such interventions are conducted in this notebook. The same code can be used to quantitatively understand the effect of the changes we exert on the neuron activations across all prompts from differnt langauges.

4. Analyze neuron activation before and after DPO

This script can be used to collect neuron activations before and after preference tuning across different languages. We also provide the precomputed results here. The reproduce Figure 3 in the paper, see this notebook.


Bilingual Sentence Retrieval

Data: Since that RTP-LX prompts are not aligned (see Issue), we translate 200 prompts with Google Translate API so we have multiway parallel RTP-LX prompts. This is stored at assets/translated_pairwise_data.

We first use xg/retrieval/retrieval_acc_save.py to save the per-layer representations for parallel sentence pairs in English and lang2 language. Then, we use xg/retrieval/retrieval_acc_load.py to load and calculate the bilingual sentence retrieval accuracy between English and lang2.

LANG2="ar"
for i in "0 50" "50 100" "100 150" "150 200" # process in batches to avoid OOM
do
    set -- $i # Convert the "tuple" into the param args $1 $2...
    python3 xg/retrieval/retrieval_acc_save.py \
        --lang2 $LANG2 \
        --begin $1 \
        --end $2 \
        --model_name "ai-forever/mGPT"
done

python3 xg/retrieval/retrieval_acc_load.py \
  --lang2 $LANG2

Bibtex

@article{li2024preference,
  title={Preference Tuning For Toxicity Mitigation Generalizes Across Languages},
  author={Li, Xiaochen and Yong, Zheng-Xin and Bach, Stephen H},
  journal={arXiv preprint arXiv:2406.16235},
  year={2024}
}