Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add axis permutedims #36

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
64 changes: 48 additions & 16 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
_findval(val, labels::Tuple{}, i::Integer) = nothing

"""
onehotbatch(xs, labels, [default])
onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1})
onehotbatch(xs, labels, [default]; dims = Val(1))


Returns a [`OneHotMatrix`](@ref) where `k`th column of the matrix is [`onehot(xs[k], labels)`](@ref onehot).
This is a sparse matrix, which stores just a `Vector{UInt32}` containing the indices of the
Expand All @@ -64,6 +64,8 @@ i.e. `result[:, k...] == onehot(xs[k...], labels)`.
Note that `xs` can be any iterable, such as a string. And that using a tuple
for `labels` will often speed up construction, certainly for less than 32 classes.

If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one.
If dims keyword is given, the onehot vectors lie on the `dims` dimension rather than the first one. `dims` should be provided as a `Val` to guarantee type stability (but an plain integer is valid as well).


# Examples
```jldoctest
julia> oh = onehotbatch("abracadabra", 'a':'e', 'e')
Expand All @@ -74,30 +76,40 @@ julia> oh = onehotbatch("abracadabra", 'a':'e', 'e')
⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅

julia> oh = onehotbatch("abracadabra", 'a':'e', 'e'; dims=2)
nomadbl marked this conversation as resolved.
Show resolved Hide resolved
5×11 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
1 ⋅ ⋅ ⋅ ⋅
⋅ 1 ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ 1
1 ⋅ ⋅ ⋅ ⋅
⋅ ⋅ 1 ⋅ ⋅
1 ⋅ ⋅ ⋅ ⋅
⋅ ⋅ ⋅ 1 ⋅
1 ⋅ ⋅ ⋅ ⋅
⋅ 1 ⋅ ⋅ ⋅
⋅ ⋅ ⋅ ⋅ 1
1 ⋅ ⋅ ⋅ ⋅

julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently
3×11 Matrix{Int64}:
1 4 13 1 7 1 10 1 4 13 1
2 5 14 2 8 2 11 2 5 14 2
3 6 15 3 9 3 12 3 6 15 3
```
"""
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion for how to write this would be this. Add the keyword dims but leave the basic path as close to untouched as you can, like so:

onehotbatch(data, labels, default...; dims=Val(1)) = _onehotbatch(dims, data, length(labels) < 32 ? Tuple(labels) : labels, default...)

function _onehotbatch(::Val{1}, data, labels)
  # as before
   return OneHotArray(indices, length(labels))
end
function _onehotbatch(::Val{1}, data, labels, default)
  # as before
  return OneHotArray(indices, length(labels))
end

In particular, this does not call collect(data), as this shouldn't be necessary, we can just iterate things.

Readers uninterested in permutations can stop there. But to handle them, make it obvious that we call the same path, and then permute it.

_onehotbatch(dims::Integer, data, labels, default...) = _onehotbatch(Val(dims), data, labels, default...)
_onehotbatch(dims::Val, data, labels, default...) = _permute(dims, _onehotbatch(Val(1), data, labels, default...))

_permute(::Val{2}, array::OneHotArray{<:Any, 1, 2}) = transpose(array)
function _permute(::Val{d}, array::OneHotArray{<:Any, N,M}) where {d, N, M}
  # this is where you compute perm, can use N or M, I forget...
  PermutedDimsArray(array, perm)
end

I made a special case for transpose, as I think that's always preferable to PermutedDimsArray.


function _onehotbatch(data, labels)
indices = UInt32[something(_findval(i, labels), 0) for i in data]
if 0 in indices
for x in data
isnothing(_findval(x, labels)) && error("Value $x not found in labels")
end
onehotbatch(data::String, labels, default...; dims::Val{D} = Val(1)) where D = onehotbatch(collect(data), labels, default...; dims=dims)
onehotbatch(data::AbstractRange, labels, default...; dims::Val{D} = Val(1)) where D = onehotbatch(collect(data), labels, default...; dims=dims)
function onehotbatch(data::AbstractArray{<:Any, N}, labels, default...; dims::Val{D}= Val(1)) where {N,D}
out = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)
if D==1
out
else
perm = Tuple(ntuple(d -> d==D ? 1 : (d==1 ? D : d), N+1))
# need to use obtuse PermutedDimsArray constructor in order to stabilise permuation types
iperm = invperm(perm)
PermutedDimsArray{eltype(out),N+1,(perm...,),(iperm...,),typeof(out)}(out)
end
return OneHotArray(indices, length(labels))
end

function _onehotbatch(data, labels, default)
default_index = _findval(default, labels)
isnothing(default_index) && error("Default value $default is not in labels")
indices = UInt32[something(_findval(i, labels), default_index) for i in data]
return OneHotArray(indices, length(labels))
end

function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
Expand All @@ -108,6 +120,8 @@ function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<
indices = UInt32.(data .+ offset)
return OneHotArray(indices, length(labels))
end
onehotbatch(data::AbstractRange{<:Integer}, labels::AbstractUnitRange{<:Integer}) = onehotbatch(collect(data), labels)

# That bounds check with extrema synchronises on GPU, much slower than rest of the function,
# hence add a special method, with a less helpful error message:
function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRange{<:Integer})
Expand All @@ -120,6 +134,24 @@ function onehotbatch(data::AbstractGPUArray{<:Integer}, labels::AbstractUnitRang
return OneHotArray(indices, length(labels))
end


function _onehotbatch(data, labels)
indices = UInt32[something(_findval(i, labels), 0) for i in data]
if 0 in indices
for x in data
isnothing(_findval(x, labels)) && error("Value $x not found in labels")
end
end
return OneHotArray(indices, length(labels))
end

function _onehotbatch(data, labels, default)
default_index = _findval(default, labels)
isnothing(default_index) && error("Default value $default is not in labels")
indices = UInt32[something(_findval(i, labels), default_index) for i in data]
return OneHotArray(indices, length(labels))
end

"""
onecold(y::AbstractArray, labels = 1:size(y,1))

Expand Down
14 changes: 14 additions & 0 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,17 @@ end
@test y[:,1] isa OneHotVector
@test y[:,:] isa OneHotMatrix
end

@testset "onehotbatch dims" begin
# basic tests
@test onehotbatch([20, 10], 10:10:30; dims=Val(2)) == Bool[0 1 0; 1 0 0]
@test onehotbatch([10, 20], [30, 40, 50], 30; dims=Val(2)) == Bool[1 0 0; 1 0 0]
# higher dimensions
@test size(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=Val(2))) == (3, 12, 4) # test shape
@test sum(onehotbatch(reshape(collect(1:12), 3, 4), 1:12; dims=Val(2)), dims=2)[:] == ones(12) # test onehot on the second dim
# works with strings
@test onehotbatch("ba", 'a':'c'; dims=Val(2)) == Bool[0 1 0; 1 0 0]

@test @inferred(onehotbatch([20, 10], 10:10:30; dims=Val(2))) == Bool[0 1 0; 1 0 0]
@test @inferred(onehotbatch([40, 10], (10,20,30), 20; dims=Val(2))) == Bool[0 1 0; 1 0 0]
end