-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add logitsample * Add logit transforms * tweaks * mutable types * Add tests; Add Temperature, Top_pk, Top_p, Top_k; Add deprecated functions for compat * Remove x86 and 1.9 testing * Update README * Add test for show method
- Loading branch information
1 parent
d8a397c
commit 00a13bf
Showing
11 changed files
with
299 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,3 +3,4 @@ | |
*.jl.mem | ||
/docs/Manifest.toml | ||
/docs/build/ | ||
Manifest.toml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,17 @@ | ||
name = "LogitSamplers" | ||
uuid = "1b30fcfc-0ee9-4be2-9cfe-b2289b43e041" | ||
authors = ["murrellb <[email protected]> and contributors"] | ||
version = "1.0.0-DEV" | ||
version = "1.1.0-DEV" | ||
|
||
[deps] | ||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
|
||
[compat] | ||
NNlib = "0.9" | ||
Random = "1.11.0" | ||
StatsBase = "0.34" | ||
Random = "1" | ||
Statistics = "1" | ||
julia = "1.9" | ||
|
||
[extras] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,22 @@ | ||
module LogitSamplers | ||
|
||
using NNlib, StatsBase | ||
using NNlib: softmax | ||
using Random | ||
using Statistics: std | ||
|
||
include("samplers.jl") | ||
include("mask.jl") | ||
|
||
include("sample.jl") | ||
export logitsample | ||
|
||
include("transforms.jl") | ||
export LogitTransform | ||
export Temperature | ||
export Top_pk, Top_p, Top_k | ||
export Min_p | ||
export Top_nσ | ||
|
||
include("deprecated.jl") | ||
export argmax_sampler, top_pk_sampler, min_p_sampler, top_nσ_sampler | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
@deprecate argmax_sampler(logits; device=identity) Top_k(1)(device(logits)) | ||
@deprecate argmax_sampler(; kwargs...) logits -> argmax_sampler(logits; kwargs...) | ||
|
||
@deprecate top_pk_sampler(logits; p = 0.5f0, k = 5, device = identity) Top_pk(p, k)(device(logits)) | ||
@deprecate top_pk_sampler(; kwargs...) logits -> top_pk_sampler(logits; kwargs...) | ||
|
||
@deprecate min_p_sampler(logits; pbase = 0.5f0, device = identity) Min_p(pbase)(device(logits)) | ||
@deprecate min_p_sampler(; kwargs...) logits -> min_p_sampler(logits; kwargs...) | ||
|
||
@deprecate top_nσ_sampler(logits; temperature = 1.0f0, n = 1.0f0, device = identity) Top_nσ(temperature, n)(device(logits)) | ||
@deprecate top_nσ_sampler(; kwargs...) logits -> top_nσ_sampler(logits; kwargs...) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
apply_mask(x::AbstractVector{T}, mask::AbstractVector{Bool}) where T<:AbstractFloat = | ||
T(-Inf) * .!mask + x | ||
|
||
function create_mask(x::AbstractVector, indices::AbstractVector{Int}) | ||
mask = similar(x, Bool) .= false | ||
mask[indices] .= true | ||
return mask | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
logitsample([rng], logits, [buffer=similar(logits)]) -> Int | ||
Sample an index from a logit distribution using the Gumbel argmax trick. | ||
Alternatively pass a buffer to avoid allocating a new array when creating | ||
the random numbers. | ||
""" | ||
function logitsample(rng::AbstractRNG, x::AbstractVector{T}, u::AbstractVector{T}=similar(x)) where T<:AbstractFloat | ||
length(x) == length(u) || throw(DimensionMismatch("Expected buffer of same length as logits")) | ||
rand!(rng, u) | ||
argmax(-log.(-log.(u)) + x) | ||
end | ||
|
||
@inline logitsample(args...) = logitsample(Random.default_rng(), args...) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
abstract type LogitTransform <: Function end | ||
|
||
Base.show(io::IO, ::MIME"text/plain", t::LogitTransform) = show(io, t) | ||
|
||
|
||
""" | ||
Temperature(T) | ||
A logit transform that scales (divides) the logits by a temperature parameter. | ||
""" | ||
mutable struct Temperature{T<:Real} <: LogitTransform | ||
T::T | ||
end | ||
|
||
(t::Temperature)(logits::AbstractVector{T}) where T = logits / T(t.T) | ||
|
||
|
||
""" | ||
Top_pk(p, k) | ||
A logit transform that masks logits to only include tokens in the top `k` or the top `p` cumulative probability. | ||
""" | ||
mutable struct Top_pk{P<:Real,K<:Union{Integer,Nothing}} <: LogitTransform | ||
p::P | ||
k::K | ||
end | ||
|
||
function (t::Top_pk)(logits::AbstractVector{T}) where T<:AbstractFloat | ||
0 < t.p <= 1 || throw(DomainError(t.p, "p must be in the interval (0, 1]")) | ||
probs = softmax(logits) | ||
sorted_probs = sort(probs, rev=true) | ||
cutoff_p = maximum(sorted_probs[cumsum(sorted_probs) .>= t.p]; init=zero(T)) | ||
cutoff_k = t.k isa Integer ? maximum(sorted_probs[t.k:t.k]) : zero(T) | ||
return apply_mask(logits, probs .>= max(cutoff_p, cutoff_k)) | ||
end | ||
|
||
Top_p(p) = Top_pk(p, nothing) | ||
Top_k(k) = Top_pk(1, k) | ||
|
||
|
||
""" | ||
Min_p(pbase) | ||
A logit transform that samples from the most probable tokens using the min-p strategy. | ||
See: https://arxiv.org/pdf/2407.01082 | ||
""" | ||
mutable struct Min_p{T<:Real} <: LogitTransform | ||
pbase::T | ||
end | ||
|
||
function (t::Min_p)(logits::AbstractVector) | ||
p = softmax(logits) | ||
return apply_mask(logits, p .>= t.pbase * maximum(p)) | ||
end | ||
|
||
|
||
""" | ||
Top_nσ(n) | ||
A logit transform that samples within `n` standard deviations of the maximum logit. | ||
Top-nσ is temperature-invariant, i.e. the candidate set does not change with temperature. | ||
See: https://arxiv.org/pdf/2411.07641 | ||
""" | ||
mutable struct Top_nσ{T<:Real} <: LogitTransform | ||
n::T | ||
end | ||
|
||
function (t::Top_nσ)(logits::AbstractVector) | ||
M, σ = maximum(logits), std(logits) | ||
return apply_mask(logits, logits .>= M - t.n * σ) | ||
end |
Oops, something went wrong.