Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add axis permutedims #36

Open
wants to merge 13 commits into
base: main
Choose a base branch
from

Conversation

nomadbl
Copy link

@nomadbl nomadbl commented May 25, 2023

This pull request introduces a new feature to OneHotArray: the axis in which the vectors are one-hot can be changed at initialization time.
This is achieved by a new constructor which includes this new axis variable, and returns a OneHotArray wrapped by a PermutedDimsArray.

This way of implementing it does not require any code changes which may be hard to maintain and debug.
The performance degradation I've seen in testing is around 10% when using the wrapped OneHotArray, which I think is reasonable.
The alternative to this method would be to introduce a new axis variable to the struct and change many constructors and functions to get the desired behavior - which I've already done in a separate branch. However it is easier to go the easy route and let it go through field testing than go in the complex route first, which may have various downsides.

See also #35 for further discussion

PR Checklist

  • ☑️ Tests are added
  • ☑️ Documentation, if applicable

@codecov-commenter
Copy link

codecov-commenter commented May 25, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: +0.10 🎉

Comparison is base (469b192) 96.37% compared to head (1f6599c) 96.47%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##             main      #36      +/-   ##
==========================================
+ Coverage   96.37%   96.47%   +0.10%     
==========================================
  Files           3        4       +1     
  Lines         138      142       +4     
==========================================
+ Hits          133      137       +4     
  Misses          5        5              
Impacted Files Coverage Δ
src/array.jl 95.31% <100.00%> (+0.23%) ⬆️

... and 1 file with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

src/array.jl Outdated
@@ -1,8 +1,8 @@
"""
OneHotArray{T, N, M, I} <: AbstractArray{Bool, M}
OneHotArray(indices, L)
OneHotArray(indices, L, [axis=1])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do add this, it should probably be a dims::Integer keyword on onehotbatch. IMO it's weird if a type constructor does not return the stated type.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the sentiment, but it seems awkward to have to maintain a set of functions and increase their complexity to just have the same functionality..
If this is a complete no go then either the alternative implementation (which might have even more problems) or maybe add this functionality as a separate utility function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it more complex to alter the lower-case function than the upper-case type constructor?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought some more about this point and I agree changing onehot/onehotbatch/etc is the better approach.
I'll take a closer look at the functions (still not very familiar with them all).
Which would you say would be appropriate?

src/array.jl Outdated
Comment on lines 19 to 20
a = collect(1:length(size(indices))+1)
PermutedDimsArray(OneHotArray(indices, L), insert!(a, 1, popat!(a, axis)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The permutation can be computed without mutating an array, something like this:

julia> let dims=2
       ntuple(d -> d==dims ? 1 : d<dims ? d+1 : d, 4)
       end
(2, 1, 3, 4)

julia> let dims=3
       ntuple(d -> d==dims ? 1 : d<dims ? d+1 : d, 4)
       end
(2, 3, 1, 4)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice your second example gives (2,3,1,4) which is not what I meant.
But the general suggestion is on point

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> dims = 3
julia> ntuple(d -> (d==dims ? 1 : (d==1 ? dims : d)), 4)
(3, 2, 1, 4)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought I may be wrong by an invperm, but did not check carefully.

@nomadbl
Copy link
Author

nomadbl commented May 26, 2023

@mcabbott
I've moved the functionality into onehotbatch as you suggested.
I'm struggling with getting the inferred output type of onehotbatch right - It is currently either PermutedDimsArray or OneHotArray.
If you have suggestions I'll implement them.

@nomadbl nomadbl mentioned this pull request Jul 22, 2023
2 tasks
@darsnack
Copy link
Member

The following is type-stable:

function ohaxis(data::AbstractArray{<:Any, N}, labels, dims::Val{D} = Val(1)) where {N, D}
   out = onehotbatch(data, labels)
   if D == 1
       return out
   else
       perm = (D, ntuple(i -> i, D - 1)..., ntuple(i -> i + D, N - D)...)
       return PermutedDimsArray(out, perm)
   end
end

We can offer the Val based method for people who need type-stability, while still offering a more user-friendly keyword interface.

@mcabbott
Copy link
Member

This issue comes up in Base too, where cat([1], [2]; dims=3) isn't type-stable, but it also accepts dims=Val(3) if required.

@nomadbl
Copy link
Author

nomadbl commented Jul 24, 2023

I'll get around to it soon. Thanks for the suggestion! :)

@nomadbl
Copy link
Author

nomadbl commented Jul 26, 2023

Fixed! :)
Thank you for your support, this should do it.
I also had to add an additional constructor:

onehotbatch(data::AbstractRange{<:Integer}, labels::AbstractUnitRange{<:Integer}) = onehotbatch(collect(data), labels)

Not sure why, but CUDA tests were failing without it.

@nomadbl nomadbl requested a review from mcabbott July 26, 2023 08:27
@nomadbl
Copy link
Author

nomadbl commented Jul 26, 2023

@darsnack
Can I ask for your review as well?

julia> reshape(1:15, 3, 5) * oh # this matrix multiplication is done efficiently
3×11 Matrix{Int64}:
1 4 13 1 7 1 10 1 4 13 1
2 5 14 2 8 2 11 2 5 14 2
3 6 15 3 9 3 12 3 6 15 3
```
"""
onehotbatch(data, labels, default...) = _onehotbatch(data, length(labels) < 32 ? Tuple(labels) : labels, default...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My suggestion for how to write this would be this. Add the keyword dims but leave the basic path as close to untouched as you can, like so:

onehotbatch(data, labels, default...; dims=Val(1)) = _onehotbatch(dims, data, length(labels) < 32 ? Tuple(labels) : labels, default...)

function _onehotbatch(::Val{1}, data, labels)
  # as before
   return OneHotArray(indices, length(labels))
end
function _onehotbatch(::Val{1}, data, labels, default)
  # as before
  return OneHotArray(indices, length(labels))
end

In particular, this does not call collect(data), as this shouldn't be necessary, we can just iterate things.

Readers uninterested in permutations can stop there. But to handle them, make it obvious that we call the same path, and then permute it.

_onehotbatch(dims::Integer, data, labels, default...) = _onehotbatch(Val(dims), data, labels, default...)
_onehotbatch(dims::Val, data, labels, default...) = _permute(dims, _onehotbatch(Val(1), data, labels, default...))

_permute(::Val{2}, array::OneHotArray{<:Any, 1, 2}) = transpose(array)
function _permute(::Val{d}, array::OneHotArray{<:Any, N,M}) where {d, N, M}
  # this is where you compute perm, can use N or M, I forget...
  PermutedDimsArray(array, perm)
end

I made a special case for transpose, as I think that's always preferable to PermutedDimsArray.

src/onehot.jl Outdated Show resolved Hide resolved
@nomadbl
Copy link
Author

nomadbl commented Jul 26, 2023

while working on the suggestions of @mcabbott , I seem to have found errors, so I'm marking this as WIP for now. Do not merge this.

@nomadbl nomadbl changed the title Add axis permutedims WIP: Add axis permutedims Jul 26, 2023
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only suggestion is to rewrite the dispatch matching @mcabbott‘s suggestion.

@@ -48,7 +48,7 @@ end
_findval(val, labels::Tuple{}, i::Integer) = nothing

"""
onehotbatch(xs, labels, [default])
onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
onehotbatch(xs, labels, [default]; dims::Val{D}=Val{1})
onehotbatch(xs, labels, [default]; dims = Val(1))

@@ -64,6 +64,8 @@ i.e. `result[:, k...] == onehot(xs[k...], labels)`.
Note that `xs` can be any iterable, such as a string. And that using a tuple
for `labels` will often speed up construction, certainly for less than 32 classes.

If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
If dims keyword is given, the onehot vectors lie on the [dims] dimension rather than the first one.
If dims keyword is given, the onehot vectors lie on the `dims` dimension rather than the first one. `dims` should be provided as a `Val` to guarantee type stability (but an plain integer is valid as well).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants