From 466924bc64e5faff020b0ec615c05595b3b2bb2a Mon Sep 17 00:00:00 2001 From: Tobi Lipede <12844474+TLipede@users.noreply.github.com> Date: Fri, 11 Mar 2022 22:55:03 +0000 Subject: [PATCH] Adding `similar`/`setindex!`, updating `argmax` and `argmin` and adding conversion methods --- src/array.jl | 72 ++++++++++++++++++++++++++++++++++++++++++------ test/array.jl | 18 ++++++++++++ test/runtests.jl | 18 ++++++------ 3 files changed, 91 insertions(+), 17 deletions(-) diff --git a/src/array.jl b/src/array.jl index 3ac48bf..5af5d3d 100644 --- a/src/array.jl +++ b/src/array.jl @@ -47,6 +47,24 @@ end Base.getindex(x::OneHotArray, ::Colon) = BitVector(reshape(x, :)) Base.getindex(x::OneHotArray{<:Any, <:Any, N}, ::Colon, ::Vararg{Colon, N}) where N = x +function Base.similar(::OneHotArray{T, L}, ::Type{Bool}, dims::Dims) where {T, L} + if first(dims) == L + indices = ones(T, Base.tail(dims)) + return OneHotArray(indices, first(dims)) + else + return BitArray(undef, dims) + end +end + +function Base.setindex!(x::OneHotLike{<:Any, <:Any, N}, v::Bool, i::Integer, I::Vararg{Integer, N}) where N + @boundscheck checkbounds(x, i, I...) + if v + _indices(x)[I...] = i + else + error("OneHotArray cannot be set with false values") + end +end + function Base.showarg(io::IO, x::OneHotArray, toplevel) print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(") Base.showarg(io, x.indices, false) @@ -69,6 +87,30 @@ Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, _onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N} _onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N} +_onehot_compatible(x::OneHotLike) = _isonehot(x) +_onehot_compatible(x::AbstractVector{Bool}) = count(x) == 1 +_onehot_compatible(x::AbstractArray{Bool}) = all(isone, reduce(+, x; dims=1)) +_onehot_compatible(x::AbstractArray) = _onehot_compatible(BitArray(x)) + +function OneHotArray(x::OneHotLike) + !_onehot_compatible(x) && error("Array is not onehot compatible") + return x +end + +function OneHotArray(x::AbstractVector) + !_onehot_compatible(x) && error("Array is not onehot compatible") + return OneHotArray(findfirst(x), length(x)) +end + +function OneHotArray(x::AbstractArray) + !_onehot_compatible(x) && error("Array is not onehot compatible") + dims = size(x) + dim1, dim2 = dims[1], reduce(*, Base.tail(dims)) + rx = reshape(x, (dim1, dim2)) + indices = UInt32[findfirst(==(true), col) for col in eachcol(rx)] + return OneHotArray(reshape(indices, Base.tail(dims)), dim1) +end + function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L if isone(dims) || any(x -> !_isonehot(x), (x, xs...)) return cat(map(x -> convert(_onehot_bool_type(x), x), (x, xs...))...; dims = dims) @@ -80,11 +122,9 @@ end Base.hcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 2) Base.vcat(x::OneHotLike, xs::OneHotLike...) = cat(x, xs...; dims = 1) -# optimized concatenation for matrices and vectors of same parameters -Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 2}} = - OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) -Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any, 1}} = - OneHotMatrix(reduce(vcat, _indices.(xs); init = _indices(x)), L) +# optimized concatenation for arrays of same parameters +Base.hcat(x::T, xs::T...) where {L, T <: OneHotLike{<:Any, L, <:Any}} = + OneHotArray(reduce(vcat, _indices.(xs); init = _indices(x)), L) MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatrix(_indices.(xs), L) @@ -94,7 +134,21 @@ Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}} Base.map(f, x::OneHotLike) = Base.broadcast(f, x) -Base.argmax(x::OneHotLike; dims = Colon()) = - (_isonehot(x) && dims == 1) ? - reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) : - invoke(argmax, Tuple{AbstractArray}, x; dims = dims) +function Base.argmax(x::OneHotLike; dims = Colon()) + if _isonehot(x) && dims == 1 + cart_inds = CartesianIndex.(_indices(x), CartesianIndices(_indices(x))) + return reshape(cart_inds, (1, size(_indices(x))...)) + else + return argmax(BitArray(x); dims=dims) + end +end + +function Base.argmin(x::OneHotLike; dims = Colon()) + if _isonehot(x) && dims == 1 + labelargs = ifelse.(_indices(x) .== 1, 2, 1) + cart_inds = CartesianIndex.(labelargs, CartesianIndices(_indices(x))) + return reshape(cart_inds, (1, size(_indices(x))...)) + else + return argmin(BitArray(x); dims=dims) + end +end diff --git a/test/array.jl b/test/array.jl index ed27727..3c6510b 100644 --- a/test/array.jl +++ b/test/array.jl @@ -46,6 +46,14 @@ end @test_throws BoundsError oa[:, :] end +@testset "Converting" begin + compat_arr = BitArray(OneHotArray(rand(1:5, (3, 5)), 5)) + + @test_throws Exception OneHotArray([1 0 0; 0 1 0]) + @test OneHotArray(compat_arr) == compat_arr + @test OneHotArray(oa) === oa +end + @testset "Concatenating" begin # vector cat @test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10) @@ -101,6 +109,16 @@ end @test argmax(oa; dims = 3) == argmax(convert(Array{Bool}, oa); dims = 3) end +@testset "Base.argmin" begin + # argmin test + @test argmin(ov) == argmin(convert(Array{Bool}, ov)) + @test argmin(om) == argmin(convert(Array{Bool}, om)) + @test argmin(om; dims = 1) == argmin(convert(Array{Bool}, om); dims = 1) + @test argmin(om; dims = 2) == argmin(convert(Array{Bool}, om); dims = 2) + @test argmin(oa; dims = 1) == argmin(convert(Array{Bool}, oa); dims = 1) + @test argmin(oa; dims = 3) == argmin(convert(Array{Bool}, oa); dims = 3) +end + @testset "Forward map to broadcast" begin @test map(identity, oa) == oa @test map(x -> 2 * x, oa) == 2 .* oa diff --git a/test/runtests.jl b/test/runtests.jl index f23b34b..696807a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,16 @@ using OneHotArrays using Test -@testset "OneHotArray" begin - include("array.jl") -end +@testset verbose=true "OneHotArrays" begin + @testset "Array" begin + include("array.jl") + end -@testset "Constructors" begin - include("onehot.jl") -end + @testset "Constructors" begin + include("onehot.jl") + end -@testset "Linear Algebra" begin - include("linalg.jl") + @testset "Linear Algebra" begin + include("linalg.jl") + end end