Skip to content

MurrellGroup/LogitSamplers.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

LogitSamplers

Stable Dev Build Status Coverage

A Julia package for GPU-friendly sampling from logit distributions with various transformation methods commonly used in language models.

Usage

The package provides a set of logit transforms to modify the distributions in the log domain.

using LogitSamplers

# Create a temperature transform
temperature = Temperature(1.5)

# Create a top-p transform
top_p = Top_p(0.5)

# Compose a function that first applies temperature, then top-p
transform = top_p  temperature

# Create a token index sampler function from the transform
sampler = logitsample  transform

# or equivalently:
sampler = logits -> logitsample(top_p(temperature(logits)))

logits = randn(100)

# Get token probabilities with the transformed logits
probs = softmax(transform(logits))

# Sample a logit index from the sampler
index = sampler(logits)