Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add axis permutedims #36

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
12 changes: 8 additions & 4 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
OneHotArray{T, N, M, I} <: AbstractArray{Bool, M}
OneHotArray(indices, L)
OneHotArray(indices, L, [axis=1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do add this, it should probably be a dims::Integer keyword on onehotbatch. IMO it's weird if a type constructor does not return the stated type.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the sentiment, but it seems awkward to have to maintain a set of functions and increase their complexity to just have the same functionality..
If this is a complete no go then either the alternative implementation (which might have even more problems) or maybe add this functionality as a separate utility function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it more complex to alter the lower-case function than the upper-case type constructor?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought some more about this point and I agree changing onehot/onehotbatch/etc is the better approach.
I'll take a closer look at the functions (still not very familiar with them all).
Which would you say would be appropriate?


A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, 1) == L` and `sum(A, dims=1) == 1`)
A one-hot `M`-dimensional array with `L` labels (i.e. `size(A, axis) == L` and `sum(A, dims=axis) == 1`)
stored as a compact `N == M-1`-dimensional array of indices.

Typically constructed by [`onehot`](@ref) and [`onehotbatch`](@ref).
Expand All @@ -15,6 +15,10 @@ end
OneHotArray{T, N, I}(indices, L::Int) where {T, N, I} = OneHotArray{T, N, N+1, I}(indices, L)
OneHotArray(indices::T, L::Int) where {T<:Integer} = OneHotArray{T, 0, 1, T}(indices, L)
OneHotArray(indices::I, L::Int) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, N, N+1, I}(indices, L)
function OneHotArray(indices, L, axis::Int)
a = collect(1:length(size(indices))+1)
PermutedDimsArray(OneHotArray(indices, L), insert!(a, 1, popat!(a, axis)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The permutation can be computed without mutating an array, something like this:

julia> let dims=2
       ntuple(d -> d==dims ? 1 : d<dims ? d+1 : d, 4)
       end
(2, 1, 3, 4)

julia> let dims=3
       ntuple(d -> d==dims ? 1 : d<dims ? d+1 : d, 4)
       end
(2, 3, 1, 4)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice your second example gives (2,3,1,4) which is not what I meant.
But the general suggestion is on point

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> dims = 3
julia> ntuple(d -> (d==dims ? 1 : (d==1 ? dims : d)), 4)
(3, 2, 1, 4)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I may be wrong by an invperm, but did not check carefully.

end

_indices(x::OneHotArray) = x.indices
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) =
Expand Down Expand Up @@ -69,7 +73,7 @@ end
# the method above is faster on the CPU but will scalar index on the GPU
# so we define the method below to pass the extra indices directly to GPU array
function Base.getindex(x::OneHotArray{<:Any, N, <:Any, <:AbstractGPUArray},
i::Int,
i::Int,
I::Vararg{Any, N}) where N
@boundscheck (1 <= i <= x.nlabels) || throw(BoundsError(x, (i, I...)))
return x.indices[I...] .== i
Expand Down Expand Up @@ -154,5 +158,5 @@ Base.map(f, x::OneHotLike) = Base.broadcast(f, x)

Base.argmax(x::OneHotLike; dims = Colon()) =
(_isonehot(x) && dims == 1) ?
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)
33 changes: 33 additions & 0 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ ov2 = OneHotVector(rand(1:11), 11)
om = OneHotMatrix(rand(1:10, 5), 10)
om2 = OneHotMatrix(rand(1:11, 5), 11)
oa = OneHotArray(rand(1:10, 5, 5), 10)
oa2 = OneHotArray(rand(1:10, 5, 5), 10, 2)

# sizes
@testset "Base.size" begin
@test size(ov) == (10,)
@test size(om) == (10, 5)
@test size(oa) == (10, 5, 5)
@test size(oa2) == (5, 10, 5)
end

@testset "Indexing" begin
Expand All @@ -32,18 +34,30 @@ end
@test oa[:, :, :] == oa
@test oa[:] == reshape(oa, :)

@test oa2[3, 3, 3] == (oa2.parent.indices[3, 3] == 3)
@test oa2[3, :, 3] == OneHotVector(oa2.parent.indices[3, 3], 10)
@test oa2[:, 3, 3] == (oa2.parent.indices[:, 3] .== 3)
@test oa2[:, 3, :] == (oa2.parent.indices .== 3)
@test oa2[3, :, :] == OneHotMatrix(oa2.parent.indices[3, :], 10)
@test oa2[:, :, :] == oa2
@test oa2[:] == reshape(oa2, :)

# cartesian indexing
@test oa[CartesianIndex(3, 3, 3)] == oa[3, 3, 3]
@test oa2[CartesianIndex(3, 3, 3)] == oa2[3, 3, 3]

# linear indexing
@test om[11] == om[1, 2]
@test oa[52] == oa[2, 1, 2]
@test oa2[55] == oa2[1, 2, 2]

# bounds checks
@test_throws BoundsError ov[0]
@test_throws BoundsError om[2, -1]
@test_throws BoundsError oa[11, 5, 5]
@test_throws BoundsError oa[:, :]
@test_throws BoundsError oa2[5, 11, 5]
@test_throws BoundsError oa2[:, :]
end

@testset "Concatenating" begin
Expand All @@ -64,6 +78,9 @@ end
@test cat(oa, oa; dims = 3) isa OneHotArray
@test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1)

@test cat(oa2, oa2; dims = 3) == OneHotArray(cat(oa2.parent.indices, oa2.parent.indices; dims = 2), 10, 2)
@test cat(oa2, oa2; dims = 2) == cat(collect(oa2), collect(oa2); dims = 2)

# stack
@test stack([ov, ov]) == hcat(ov, ov)
@test stack([ov, ov, ov]) isa OneHotMatrix
Expand Down Expand Up @@ -96,6 +113,18 @@ end
@test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10))
@test OneHotArrays._fast_argmax(r) == collect(reshape(oa.indices, :))
end

@testset "w/ cat" begin
r = reshape(oa2, 10, :)
@test vcat(r, r) isa Array{Bool}
end

@testset "w/ argmax" begin
oa2p = PermutedDimsArray(oa2, [2,1,3])
r = reshape(oa2p, 10, :)
@test argmax(r) == argmax(OneHotMatrix(reshape(oa2p.parent.parent.indices, :), 10))
@test stack(collect(Tuple.(OneHotArrays._fast_argmax(r))))[1,:] == collect(reshape(oa2p.parent.parent.indices, :))
end
end

@testset "Base.argmax" begin
Expand All @@ -106,9 +135,13 @@ end
@test argmax(om; dims = 2) == argmax(convert(Array{Bool}, om); dims = 2)
@test argmax(oa; dims = 1) == argmax(convert(Array{Bool}, oa); dims = 1)
@test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3)
@test argmax(oa2; dims = 2) == argmax(convert(Array{Bool}, oa2); dims = 2)
@test argmax(oa2; dims = 3) == argmax(convert(Array{Bool}, oa2); dims = 3)
end

@testset "Forward map to broadcast" begin
@test map(identity, oa) == oa
@test map(x -> 2 * x, oa) == 2 .* oa
@test map(identity, oa2) == oa2
@test map(x -> 2 * x, oa2) == 2 .* oa2
end