-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix #429 -- add branch-free paths for simple cases like CuArray{Float32}
#430
base: main
Are you sure you want to change the base?
Conversation
src/projection.jl
Outdated
S = project_type(element) # new idea -- for any number, S is enough. | ||
# Store .element for now too, although it's redundant? Reconstruct from eltype? | ||
if axes(x) isa NTuple{N,Base.OneTo{Int}} | ||
return ProjectTo{AbstractArray{S,N}}(; element=element, axes=axes(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here S
and p.element
are the same information: for every S<:Number
, the projector can be reconstructed as ProjectTo(zero(S))
. So perhaps tidier not to store it.
We can alter getproperty
to do this. In which case, perhaps the projector for an array of arrays (which has an array p.elements
instead) should also define it, via mapreduce(project_type, promote_type, projectors)
, even if that ends up trivial?
If you can always call p.elements
, then you can insert that into some in-place rules, dx -> dx .+= ρ.(dy .* conj.(stuff...))
.
CuArray{Float32}
# Fastest path: N means they are OneTo, hence reshape can be skipped | ||
return ProjectTo{AbstractArray{S,N}}(; element=element, axes=axes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is slightly weird, but my idea for avoiding the if axes(dx) == projector.axes
branch is to include the N
only when the axes are Base.OneTo
. If ndims(dx)
matches, and its eltype matches, then it can pass through just by dispatch.
This means it won't check for quite as many size mismatches. But it will still reshape
for OffsetArrays, SArrays, etc, those will go the "slow path" as before.
This tries to add some fast paths for the common case, of ordinary arrays of numbers. They look like this:
Might close #429 .
Might also introduce weird dispatch ambiguities. I thought I hit places where I misunderstood how things like
(p::ProjectTo{<:T})(dx::S)
would be treated.