You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, it seems that the rrule for mean(f, x) is not vectorized and thus does not place nicely with CUDA:
using Zygote, CUDA, Statistics
julia>gradient(y ->mean(x -> x.^2, y), CUDA.randn(10))
ERROR: Scalar indexing is disallowed.
Invocation of getindex resulted in scalar indexing of a GPU array.
This is typically caused by calling an iterating implementation of a method.
Such implementations *do not* execute on the GPU, but very slowly on the CPU,
and therefore are only permitted from the REPL for prototyping purposes.
If you did intend to index this array, annotate the caller with @allowscalar.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] assertscalar(op::String)
@ GPUArraysCore ~/.julia/packages/GPUArraysCore/uOYfN/src/GPUArraysCore.jl:103
[3] getindex
@ ~/.julia/packages/GPUArrays/5XhED/src/host/indexing.jl:9 [inlined]
[4] iterate
@ ./abstractarray.jl:1220 [inlined]
[5] iterate
@ ./abstractarray.jl:1218 [inlined]
[6] iterate
@ ./generator.jl:44 [inlined]
[7] collect(itr::Base.Generator{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ChainRules.var"#1655#1660"{Zygote.ZygoteRuleConfig{Zygote.Context{false}}, var"#24#26"}})
@ Base ./array.jl:782
[8] rrule(config::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::typeof(sum), f::var"#24#26", xs::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}; dims::Function)
@ ChainRules ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:102
[9] rrule
@ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Base/mapreduce.jl:76 [inlined]
[10] #rrule#1808
@ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Statistics/statistics.jl:28 [inlined]
[11] rrule
@ ~/.julia/packages/ChainRules/9sNmB/src/rulesets/Statistics/statistics.jl:21 [inlined]
[12] chain_rrule
@ ~/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:223 [inlined]
[13] macro expansion
@ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
[14] _pullback
@ ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:101 [inlined]
[15] _pullback
@ ./REPL[14]:1 [inlined]
[16] _pullback(ctx::Zygote.Context{false}, f::var"#23#25", args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface2.jl:0
[17] pullback(f::Function, cx::Zygote.Context{false}, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:44
[18] pullback
@ ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:42 [inlined]
[19] gradient(f::Function, args::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
@ Zygote ~/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:96
[20] top-level scope
@ REPL[14]:1
[21] top-level scope
@ ~/.julia/packages/CUDA/tVtYo/src/initialization.jl:185
The problem seems to be that this line does not use map or broadcasting. But the comment seems to suggest that we can't do that here. Is there anything we can do?
By the way, sum(f, x) for the same f works perfectly. So I'm quite curious why the result is different. Both hit the same rrule right?
This appears to be more complicated. It seems that gradient(y -> sum(x -> x^2, y)/10, CUDA.randn(10)) does not hit the sum(f, x)rrule, while mean(f, x) does. This is super weird. I have no idea which rrule is being hit for sum(f, x).
Hi, it seems that the
rrule
formean(f, x)
is not vectorized and thus does not place nicely with CUDA:The problem seems to be that this line does not use
map
or broadcasting. But the comment seems to suggest that we can't do that here. Is there anything we can do?By the way,
sum(f, x)
for the samef
works perfectly. So I'm quite curious why the result is different. Both hit the samerrule
right?The text was updated successfully, but these errors were encountered: