Skip to content

OpenBMB/Locret

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Locret: Accelerating Long-Context LLM Inference with Retaining Heads

Blog | Paper (ArXiV)

A Light-weight Training-based KV Cache Compression Algorithm

Locret is a light-weight training-based KV cache eviction method, which achieves 20x and 8x KV cache compression ratio for Phi-3-mini-128K and Llama-3.1-8B-instruct, enabling 128K+ long-context inference on a single Nvidia 4090 GPU.

This project was made possible thanks to a collaboration with

Design

Overall Framework

Locret Inference

Usage

Note that we only support Phi-3-mini-128K and Llama-3.1-8B-instruct currently. More models will be supported in the future.

Environment Setup

Run the following command, and the environment will be set up automatically.

cd locret
conda env create -f locret_env.yaml
conda activate locret_env
pip install -e .

Official checkpoints of Locret can be downloaded from hyx21/Locret-phi-3-mini-128K and hyx21/Locret-llama-3.1-8B-instruct.

Now you are all set!

Training

First, enter the working directory by

cd locret/train

Then, generate the training dataset. Run the following command.

python data_gen.py --model_dir <model_dir>

Finally, start training the model.

python train.py --model_dir <model_dir>

All the hyperparameters are set automatically according to our paper. Please indicate the model you use in model_dir, i.e. if there is a phi in the path, the training script can detect that.

After the training process, you will be getting the trained checkpoint at locret/train/checkpoints/<model_name>. You can directly load the complete model after save_pretrained or only load the retaining heads. We provide a script, locret/train/convert.py, to convert checkpoints saved in safetensors format to pytorch format.

Inference

Here, we provide an example of one entry in R.PassKey of InfiniteBench in example.py. To run the example, you can execute

python example.py --model_dir <model_dir> # for saved full checkpoint, or
python example.py --model_dir <model_dir> --retaining_head_path <*.bin> # original model + saved retaining heads

For other experiments in our paper, please run the codes in benchmark/infinite_bench and benchmark/LEval-main. Each script correspond to an experimental setting, which can be recognized through the script name.

Develop Roadmap

  • Add support to Llama-3.2-1B and Llama-3.2-3B
  • Add support to MiniCPM-2.4B and MiniCPM-1.2B
  • Add support to Qwen-2.5-1.5B and Qwen-2.5-1.5B

Citation

Please cite our paper if you find our work valuable.

@article{huang2024locret,
  title={Locret: Accelerating Long-Context LLM Inference with Retaining Heads},
  author={Yuxiang Huang, Binhang Yuan, Xu Han, Chaojun Xiao, Zhiyuan Liu},
  journal={arXiv preprint arXiv:2410.01805},
  year={2024}
}

Acknowledgement

We acknowledge the discussion with Ruisi Cai@UT Austin, the author of LoCoCo, for the implementation of LoCoCo. We deeply thank Xinrong Zhang@Tsinghua for details in InfiniteBench, Weilin Zhao@Tsinghua for system implementation advices, Chenyang Song@Tsinghua for model sparsity, Shuo Wang@Tsinghua for long-context training recipe and Yuan Yao@NUS for future works on multimodal LLMs.