diff --git a/Project.toml b/Project.toml index 362f43c..63f4015 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "1.0.0-DEV" [deps] NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] julia = "1.9" diff --git a/src/LogitSamplers.jl b/src/LogitSamplers.jl index 8ae6da2..2524dc8 100644 --- a/src/LogitSamplers.jl +++ b/src/LogitSamplers.jl @@ -1,5 +1,7 @@ module LogitSamplers +using NNlib, StatsBase + include("samplers.jl") export argmax_sampler, top_pk_sampler, min_p_sampler, top_nσ_sampler diff --git a/src/samplers.jl b/src/samplers.jl index 2f14809..6dcd6ea 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -1,3 +1,6 @@ +#To do: refactor into a combination of modified_softmax and sample. This way we can viz the result of the modified logits without having to sample. +#This won't be visible to the user. Any method that doesn't fit this interface can be implemented directly. + function argmax_sampler(logits::AbstractVector; device = identity) return argmax(device(logits)) end