-
-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Add axis permutedims #36
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## main #36 +/- ##
==========================================
+ Coverage 96.37% 96.47% +0.10%
==========================================
Files 3 4 +1
Lines 138 142 +4
==========================================
+ Hits 133 137 +4
Misses 5 5
☔ View full report in Codecov by Sentry. |
src/array.jl
Outdated
@@ -1,8 +1,8 @@ | |||
""" | |||
OneHotArray{T, N, M, I} <: AbstractArray{Bool, M} | |||
OneHotArray(indices, L) | |||
OneHotArray(indices, L, [axis=1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do add this, it should probably be a dims::Integer
keyword on onehotbatch
. IMO it's weird if a type constructor does not return the stated type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the sentiment, but it seems awkward to have to maintain a set of functions and increase their complexity to just have the same functionality..
If this is a complete no go then either the alternative implementation (which might have even more problems) or maybe add this functionality as a separate utility function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it more complex to alter the lower-case function than the upper-case type constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought some more about this point and I agree changing onehot/onehotbatch/etc is the better approach.
I'll take a closer look at the functions (still not very familiar with them all).
Which would you say would be appropriate?
src/array.jl
Outdated
a = collect(1:length(size(indices))+1) | ||
PermutedDimsArray(OneHotArray(indices, L), insert!(a, 1, popat!(a, axis))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The permutation can be computed without mutating an array, something like this:
julia> let dims=2
ntuple(d -> d==dims ? 1 : d<dims ? d+1 : d, 4)
end
(2, 1, 3, 4)
julia> let dims=3
ntuple(d -> d==dims ? 1 : d<dims ? d+1 : d, 4)
end
(2, 3, 1, 4)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notice your second example gives (2,3,1,4)
which is not what I meant.
But the general suggestion is on point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia> dims = 3
julia> ntuple(d -> (d==dims ? 1 : (d==1 ? dims : d)), 4)
(3, 2, 1, 4)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought I may be wrong by an invperm, but did not check carefully.
@mcabbott |
The following is type-stable: function ohaxis(data::AbstractArray{<:Any, N}, labels, dims::Val{D} = Val(1)) where {N, D}
out = onehotbatch(data, labels)
if D == 1
return out
else
perm = (D, ntuple(i -> i, D - 1)..., ntuple(i -> i + D, N - D)...)
return PermutedDimsArray(out, perm)
end
end We can offer the |
This issue comes up in Base too, where |
I'll get around to it soon. Thanks for the suggestion! :) |
Fixed! :) onehotbatch(data::AbstractRange{<:Integer}, labels::AbstractUnitRange{<:Integer}) = onehotbatch(collect(data), labels) Not sure why, but CUDA tests were failing without it. |
@darsnack |
julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently | ||
3×11 Matrix{Int64}: | ||
1 4 13 1 7 1 10 1 4 13 1 | ||
2 5 14 2 8 2 11 2 5 14 2 | ||
3 6 15 3 9 3 12 3 6 15 3 | ||
``` | ||
""" | ||
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My suggestion for how to write this would be this. Add the keyword dims but leave the basic path as close to untouched as you can, like so:
onehotbatch(data, labels, default...; dims=Val(1)) = _onehotbatch(dims, data, length(labels) < 32 ? Tuple(labels) : labels, default...)
function _onehotbatch(::Val{1}, data, labels)
# as before
return OneHotArray(indices, length(labels))
end
function _onehotbatch(::Val{1}, data, labels, default)
# as before
return OneHotArray(indices, length(labels))
end
In particular, this does not call collect(data)
, as this shouldn't be necessary, we can just iterate things.
Readers uninterested in permutations can stop there. But to handle them, make it obvious that we call the same path, and then permute it.
_onehotbatch(dims::Integer, data, labels, default...) = _onehotbatch(Val(dims), data, labels, default...)
_onehotbatch(dims::Val, data, labels, default...) = _permute(dims, _onehotbatch(Val(1), data, labels, default...))
_permute(::Val{2}, array::OneHotArray{<:Any, 1, 2}) = transpose(array)
function _permute(::Val{d}, array::OneHotArray{<:Any, N,M}) where {d, N, M}
# this is where you compute perm, can use N or M, I forget...
PermutedDimsArray(array, perm)
end
I made a special case for transpose
, as I think that's always preferable to PermutedDimsArray.
while working on the suggestions of @mcabbott , I seem to have found errors, so I'm marking this as WIP for now. Do not merge this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My only suggestion is to rewrite the dispatch matching @mcabbott‘s suggestion.
@@ -48,7 +48,7 @@ end | |||
_findval(val, labels::Tuple{}, i::Integer) = nothing | |||
|
|||
""" | |||
onehotbatch(xs, labels, [default]) | |||
onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1}) | |
onehotbatch(xs, labels, [default]; dims = Val(1)) |
@@ -64,6 +64,8 @@ i.e. `result[:, k...] == onehot(xs[k...], labels)`. | |||
Note that `xs` can be any iterable, such as a string. And that using a tuple | |||
for `labels` will often speed up construction, certainly for less than 32 classes. | |||
|
|||
If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one. | |
If dims keyword is given, the onehot vectors lie on the `dims` dimension rather than the first one. `dims` should be provided as a `Val` to guarantee type stability (but an plain integer is valid as well). |
This pull request introduces a new feature to OneHotArray: the axis in which the vectors are one-hot can be changed at initialization time.
This is achieved by a new constructor which includes this new axis variable, and returns a OneHotArray wrapped by a PermutedDimsArray.
This way of implementing it does not require any code changes which may be hard to maintain and debug.
The performance degradation I've seen in testing is around 10% when using the wrapped OneHotArray, which I think is reasonable.
The alternative to this method would be to introduce a new
axis
variable to the struct and change many constructors and functions to get the desired behavior - which I've already done in a separate branch. However it is easier to go the easy route and let it go through field testing than go in the complex route first, which may have various downsides.See also #35 for further discussion
PR Checklist