-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Two SimpleChains overwrite each other's predictions? #138
Comments
I can reproduce this. |
Note julia> y1 === y2
true
julia> sc1 === sc2
true The two A SimpleChain overwrites its previous answer to avoid allocating: julia> @benchmark $sc2($X, $p2)
BenchmarkTools.Trial: 10000 samples with 7 evaluations.
Range (min … max): 4.615 μs … 8.601 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 4.643 μs ┊ GC (median): 0.00%
Time (mean ± σ): 4.649 μs ± 61.947 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▁▄▇██▄
▂▂▃▄▆███████▆▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂ ▃
4.62 μs Histogram: frequency by time 4.8 μs <
Memory estimate: 0 bytes, allocs estimate: 0. So evaluating the same But there should at least be a Another hacky workaround: julia> @inline mytanh(x) = Base.FastMath.tanh_fast(x)
mytanh (generic function with 1 method)
julia> sc3 = SimpleChain(
static(1),
TurboDense{true}(tanh, 128),
TurboDense{true}(mytanh, 128),
TurboDense{false}(identity, 1),
)
SimpleChain with the following layers:
TurboDense static(128) with bias.
Activation layer applying: tanh
TurboDense static(128) with bias.
Activation layer applying: mytanh
TurboDense static(1) without bias.
julia> sc1 === sc3 # not the same chain
false
julia> y3 = sc3(X, p2)
1×10 StrideArray{Float32, 2, (1, 2), Tuple{StaticInt{1}, Int64}, Tuple{Nothing, Nothing}, Tuple{StaticInt{1}, StaticInt{1}}, Vector{UInt8}} with indices static(1):static(1)×Base.OneTo(10):
0.337447 0.488095 -0.00920447 -0.2206 -0.242737 0.150495 -0.502708 0.291047 -0.895863 -0.337516
julia> y1 = sc1(X, p1)
1×10 StrideArray{Float32, 2, (1, 2), Tuple{StaticInt{1}, Int64}, Tuple{Nothing, Nothing}, Tuple{StaticInt{1}, StaticInt{1}}, Vector{UInt8}} with indices static(1):static(1)×Base.OneTo(10):
-0.0441075 -0.0641206 0.00119711 0.0287531 0.0316527 … 0.0660765 -0.037995 0.118903 0.0441165
julia> y1 === y3 # different chains, different answers
false
julia> y3
1×10 StrideArray{Float32, 2, (1, 2), Tuple{StaticInt{1}, Int64}, Tuple{Nothing, Nothing}, Tuple{StaticInt{1}, StaticInt{1}}, Vector{UInt8}} with indices static(1):static(1)×Base.OneTo(10):
0.337447 0.488095 -0.00920447 -0.2206 -0.242737 0.150495 -0.502708 0.291047 -0.895863 -0.337516 |
Thanks for the reply! Ok I didn't realize the Yes, something like a wrapper or |
The predictions of two SCs which should have nothing to do with each other seem to be overwriting each other? This seems like a bug, but let me know if I am doing something obviously wrong?
Here is a ~MWE of some code I have been running:
The text was updated successfully, but these errors were encountered: