A Julia package for GPU-friendly sampling from logit distributions with various transformation methods commonly used in language models.
The package provides a set of logit transforms to modify the distributions in the log domain.
using LogitSamplers
# Create a temperature transform
temperature = Temperature(1.5)
# Create a top-p transform
top_p = Top_p(0.5)
# Compose a function that first applies temperature, then top-p
transform = top_p ∘ temperature
# Create a token index sampler function from the transform
sampler = logitsample ∘ transform
# or equivalently:
sampler = logits -> logitsample(top_p(temperature(logits)))
logits = randn(100)
# Get token probabilities with the transformed logits
probs = softmax(transform(logits))
# Sample a logit index from the sampler
index = sampler(logits)