Skip to content

Latest commit

 

History

History
154 lines (106 loc) · 3.49 KB

README.md

File metadata and controls

154 lines (106 loc) · 3.49 KB

bragi

Generating metered verse with LLaMA

This repo provides methods for using Meta's 6 billion parameter LLaMA model to generate song lyrics with a specific metric structure. If you've been yearning to rewrite the happy birthday song so that it's just about dogs, bragi can help :).

The core functionality of bragi is provided via the MetricGenerator class. If you're wondering, Bragi is the Norse god of poetry!

The library also provides wrappers around various methods for extracting metric information, such as syllable counts and rhyme schemes.

MetricGenerator controls the metric structure of generated output by constraining the model's probability distribution over tokens. Specifically, tokens that would violate the target metric structure are masked at each inference step. This is implemented via a custom logits warper.

Dev setup

With cog

  1. Install cog
sudo curl -o /usr/local/bin/cog -L https://github.com/replicate/cog/releases/latest/download/cog_`uname -s`_`uname -m`
sudo chmod +x /usr/local/bin/cog
  1. Pull this cog repo:
git clone https://github.com/replicate/cog-llama.git
  1. Build the image
cd cog-llama
cog build
  1. Install espeak in cog
cog run apt-get update -y
cog run apt-get install espeak -y
  1. exec into the container cog build
cog run bash
  1. Clone this repo into the cog container
git clone https://github.com/joehoover/bragi.git
  1. Install requirements
cd bragi
pip install -r requirements.txt
  1. Install jupyterlab
pip install jupyterlab
  1. Launch jupyterlab
jupyter lab --allow-root
  1. Click the last link generated by the jupyter lab process.

General dev setup

  1. You need to install espeak.

On mac:

brew install espeak

On linux

apt-get update -y
apt-get install espeak -y
  1. You also need to make sure torch is installed. I don't like installing torch with poetry, so it's not specified in the pyproject.toml. If your environment doesn't already have torch, run:
pip install torch

Quick Start

See this notebook. But, in general:

from bragi.metric_generator import MetricGenerator
from transformers import LLaMAForCausalLM, LLaMATokenizer
import torch 

CACHE_DIR = 'weights'
SEP = "<sep>"
MODEL_PATH  = "/src/weights"
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'


# Load model and tokenizer
model = LLaMAForCausalLM.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR, local_files_only=True).to(device)
tokenizer = LLaMATokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR, local_files_only=True)

# Initialize `MetricGenerator`
generator = MetricGenerator(model=model, tokenizer=tokenizer, device=device)

# Generate
torch.manual_seed(2)
output = generator(
    prompt = prompt,
    text_init = text_init,
    free_tokens=['||', '?', '.', ','],
    # syllable_budget = torch.Tensor([6., 6.]),
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    remove_invalid_values=True,
    do_sample=True,
    top_k=25,
    temperature=.7,
    max_length = 100,
    new_line_token='||',
    bad_words_ids=[[8876]],
)

print('---text_init----')
print(text_init)
print('\n')

print('----output-----')
print(output)
print('\n')

print('----Syllables-----')
print(f"Syllables per line in output: {generator.calculate_syllable_budget(output)}")
print(f"Syllables per line in `text_init`: {generator.calculate_syllable_budget(text_init)}")