Skip to content

Block Transformer: Global-to-Local Language Modeling for Fast Inference (NeurIPS 2024)

License

Notifications You must be signed in to change notification settings

itsnamgyu/block-transformer

Repository files navigation

Block Transformer: Global-to-Local Language Modeling for Fast Inference (NeurIPS 2024)

Namgyu Ho1,2†*   Sangmin Bae1*   Taehyeon Kim1   Hyunjik Jo2   Yireun Kim2   Tal Schuster3   Adam Fisch3
James Thorne1‡   Se-Young Yun1‡

1KAIST AI   2LG AI Research   3Google DeepMind  
†Work done during an internship at LG AI Research.   *Equal contribution.   ‡Corresponding authors.

  • We propose Block Transformer architecture which adopts hierarchical global-to-local language modeling to autoregressive transformers to mitigate inference bottlenecks of self-attention.
  • Block Transformer models global dependencies through self-attention between coarse blocks at lower layers (in block decoder), and decodes fine-grained tokens within each local block at upper layers (in token decoder).
  • We leverage inference-time benefits of both global and local modules, achieving 10-20x gains in throughput compared to vanilla transformers with equivalent perplexity.

⚡️ Real-World Decoding Speed Comparison

https://www.youtube.com/watch?v=c0D7EvffYnU

🚀 Getting Started

To try out our pretrained Block Transformer models, install requirements and download our pretrained checkpoints (see sections below).

Note, make sure to run the following command before running any code to support absolute imports.

python setup.py develop

Inference with Custom Prompts

Use our demo notebook at ./notebooks/inference_demo.ipynb.

Batch Inference Speed Demo

CUDA_VISIBLE_DEVICES=0 python inference_demo.py --model=block_main_b4_1.2b --batch_size=128

💎 Pretrained Checkpoints

We share all checkpoints of our main models, pretrained on tens of thousands of A100 hours. With ❤️ from LG AI Research.

To use our code as-is, unzip the checkpoints into the ./results directory, as shown below.

block-transformer/
|-- results/
  |-- block_main_b4_1.2b/
    |-- checkpoint-570000/
      |-- model.safetensors
  |-- ...

💻 Requirements

Refer to requirements.txt.

Note, make sure to run the following command before running any code to support absolute imports.

python setup.py develop

Transformers version

Our subclasses of GPTNeoX models for Block Transformer have been tested under

transformers==4.39.3
accelerate==0.33.0

Installing FlashAttention

Requires CUDA>=11.6 and PyTorch>=1.12 with GPU support. See https://github.com/Dao-AILab/flash-attention#installation-and-features.

pip install packaging ninja
ninja --version; echo $?  # make sure that 0 is printed. else, reinstall ninja
pip install flash-attn --no-build-isolation

Building wheels takes a few minutes (we've seen 10 minutes+).

FlashAttention support for GPTNeoX was added in Dec 7, 2023 and released v4.36.0. huggingface/transformers#26463

📚 Pretraining

  • Vanilla (HuggingFace) model training: pretrain_vanilla_transformer.py

    deepspeed --include localhost:0,1,2,3 --no_local_rank --master_port 29540 pretrain_vanilla_transformer.py --config-name vanilla_31 pythia_pile_idxmaps_path=/path/to/pythia_pile_idxmaps
  • Block transformer training: pretrain_block_transformer.py

      deepspeed --include localhost:0,1,2,3 --no_local_rank --master_port 29540 pretrain_block_transformer.py --config-name block_main_b4_5 pythia_pile_idxmaps_path=/path/to/pythia_pile_idxmaps
  • Using the torch.distributed launcher

    OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.run --nproc_per_node=4 --master_port=29540
    • Note that this still uses deepspeed optimization. To run without deepspeed optimization, append --deepspeed=null.

🔬 Evaluation

  • Zero-shot evaluation: eval_zero_shot_task.py

    CUDA_VISIBLE_DEVICES=0 python eval_zero_shot_task.py --config-name=eval_multiple_ckpt configs.hf=["vanilla_31"] batch_size=64
    CUDA_VISIBLE_DEVICES=0 python eval_zero_shot_task.py --config-name=eval_multiple_ckpt configs.block=["block_main_b4_5"] batch_size=64
  • Inference throughput wall-time measurement: measure_generation_time.py

    CUDA_VISIBLE_DEVICES=0 python measure_generation_time.py --config-name=block_main_b4_5 ++benchmark_prefill_length=2048 ++benchmark_decode_length=128
    CUDA_VISIBLE_DEVICES=0 python measure_generation_time.py --config-name=block_main_b4_5 ++benchmark_prefill_length=128 ++benchmark_decode_length=2048
    • Works for both HF and block models.
    • By default, batch size is auto-tuned via binary search to maximize VRAM utilization.To set a specific batch size, use ++batch_size=64.

📑 Pretraining Data Preparation

The Pile (Pythia version)

Refer to https://github.com/EleutherAI/pythia/. The resulting files are a Megatron-LM compatible dataset of The Pile (in memory-mapped Numpy format), pre-shuffled document-wise and pre-tokenized, without any added special tokens. The dataset can be accessed via https://github.com/EleutherAI/pythia/blob/main/utils/mmap_dataset.py.

git clone https://github.com/EleutherAI/pythia/  # about 500MB
cd pythia

GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/EleutherAI/pythia_deduped_pile_idxmaps
cd pythia_deduped_pile_idxmaps
git config lfs.activitytimeout 3600
# sudo apt-get update; sudo apt-get install git-lfs -y
git lfs pull

cd ..

# Optionally, to ensure against corrupt files
python utils/checksum_shards.py

# Unshard data
python utils/unshard_memmap.py --input_file ./pythia_deduped_pile_idxmaps/pile_0.87_deduped_text_document-00000-of-00082.bin --num_shards 83 --output_dir ./pythia_pile_idxmaps/

# Copy over idx data
cp pythia_deduped_pile_idxmaps/pile_0.87_deduped_text_document.idx pythia_pile_idxmaps

# Checksum for final file
echo "Expected checksum: 0cd548efd15974d5cca78f9baddbd59220ca675535dcfc0c350087c79f504693"
sha256sum pythia_pile_idxmaps/pile_0.87_deduped_text_document.bin

🌟 BibTeX

@article{ho2024block,
  title={Block Transformer: Global-to-Local Language Modeling for Fast Inference},
  author={Ho, Namgyu and Bae, Sangmin and Kim, Taehyeon and Jo, Hyunjik and Kim, Yireun and Schuster, Tal and Fisch, Adam and Thorne, James and Yun, Se-Young},
  journal={arXiv preprint arXiv:2406.02657},
  year={2024}
}