Skip to content

Commit

Permalink
Un-deprecate sampler interface
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Nov 28, 2024
1 parent b2a4c93 commit 7d4f50b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/LogitSamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export Top_pk, Top_p, Top_k
export Min_p
export Top_nσ

include("deprecated.jl")
include("samplers.jl")
export argmax_sampler, top_pk_sampler, min_p_sampler, top_nσ_sampler

end
11 changes: 0 additions & 11 deletions src/deprecated.jl

This file was deleted.

11 changes: 11 additions & 0 deletions src/samplers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
argmax_sampler(logits; device=identity) = logits |> device |> Top_k(1) |> logitsample
argmax_sampler(; kwargs...) = logits -> argmax_sampler(logits; kwargs...)

top_pk_sampler(logits; p = 0.5f0, k = 5, device = identity) = logits |> device |> Top_pk(p, k) |> logitsample
top_pk_sampler(; kwargs...) = logits -> top_pk_sampler(logits; kwargs...)

min_p_sampler(logits; pbase = 0.5f0, device = identity) = logits |> device |> Min_p(pbase) |> logitsample
min_p_sampler(; kwargs...) = logits -> min_p_sampler(logits; kwargs...)

top_nσ_sampler(logits; temperature = 1.0f0, n = 1.0f0, device = identity) = logits |> device |> Temperature(temperature) |> Top_nσ(n) |> logitsample
top_nσ_sampler(; kwargs...) = logits -> top_nσ_sampler(logits; kwargs...)

2 comments on commit 7d4f50b

@AntonOresten
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

  • Add logitsample function for GPU-friendly weighted sampling in the log domain.
  • Add abstract LogitTransform type.
    • Add Temperate type.
    • Add Top_pk type with additional Top_p and Top_k constructors.
    • Add Min_p type.
    • Add Top_nσ type.
  • Add argmax_sampler, top_pk_sampler, min_p_sampler, and top_nσ_sampler functions.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/120312

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" 7d4f50b11a02ace348f2c02bea97fa329cd4bdb4
git push origin v0.1.0

Please sign in to comment.