From 3171cfacad49f9a600f49290df73c6b8b0187a3a Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Mon, 10 Aug 2020 14:09:20 +0530 Subject: [PATCH 1/5] Fix Zygote issues for functional transform --- src/transform/functiontransform.jl | 19 ++++++++++++++++++- test/transform/functiontransform.jl | 2 +- test/utils_AD.jl | 28 ++++++++++++++-------------- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/transform/functiontransform.jl b/src/transform/functiontransform.jl index c1d09b418..17b72c5fd 100644 --- a/src/transform/functiontransform.jl +++ b/src/transform/functiontransform.jl @@ -16,9 +16,26 @@ end (t::FunctionTransform)(x) = t.f(x) _map(t::FunctionTransform, x::AbstractVector{<:Real}) = map(t.f, x) -_map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1)) + + +function _map(t::FunctionTransform, x::ColVecs) + vals = map(axes(x.X, 2)) do i + t.f(view(x.X, :, i)) + end + return ColVecs(hcat(vals)) +end + +# function _map(t::FunctionTransform, x::RowVecs) +# vals = map(axes(x.X, 1)) do i +# t.f(view(x.X, i, :)) +# end +# return RowVecs(hcat(vals...)) +# end + +# _map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1)) _map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2)) + duplicate(t::FunctionTransform,f) = FunctionTransform(f) Base.show(io::IO, t::FunctionTransform) = print(io, "Function Transform: ", t.f) diff --git a/test/transform/functiontransform.jl b/test/transform/functiontransform.jl index 835070da8..c7fe48b90 100644 --- a/test/transform/functiontransform.jl +++ b/test/transform/functiontransform.jl @@ -28,6 +28,6 @@ @test repr(FunctionTransform(sin)) == "Function Transform: $(sin)" f(a, x) = sin.(a .* x) - test_ADs(x->transform(SEKernel(), FunctionTransform(y->f(x, y))), randn(rng, 3), ADs = [:ForwardDiff, :ReverseDiff]) + test_ADs(x->transform(SEKernel(), FunctionTransform(y->f(x, y))), randn(rng, 3)) @test_broken "Zygote is failing" end diff --git a/test/utils_AD.jl b/test/utils_AD.jl index 1354485f9..3f13b6c56 100644 --- a/test/utils_AD.jl +++ b/test/utils_AD.jl @@ -129,20 +129,20 @@ function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3]) A = rand(rng, dims...) B = rand(rng, dims...) for dim in 1:2 - compare_gradient(AD, A) do a - testfunction(k, a, dim) - end - compare_gradient(AD, A) do a - testfunction(k, a, B, dim) - end - compare_gradient(AD, B) do b - testfunction(k, A, b, dim) - end - if !(args === nothing) - compare_gradient(AD, args) do p - testfunction(kernelfunction(p), A, dim) - end - end + # compare_gradient(AD, A) do a + # testfunction(k, a, dim) + # end + # compare_gradient(AD, A) do a + # testfunction(k, a, B, dim) + # end + # compare_gradient(AD, B) do b + # testfunction(k, A, b, dim) + # end + # if !(args === nothing) + # compare_gradient(AD, args) do p + # testfunction(kernelfunction(p), A, dim) + # end + # end end end end From 21e6f7b5d4df5a6f58f2475bf49058fa0283e809 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 12 Aug 2020 00:59:27 +0530 Subject: [PATCH 2/5] Resolve tests --- src/transform/functiontransform.jl | 16 ++++++++-------- test/transform/functiontransform.jl | 1 - test/utils_AD.jl | 28 ++++++++++++++-------------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/transform/functiontransform.jl b/src/transform/functiontransform.jl index 17b72c5fd..9366b86a0 100644 --- a/src/transform/functiontransform.jl +++ b/src/transform/functiontransform.jl @@ -22,18 +22,18 @@ function _map(t::FunctionTransform, x::ColVecs) vals = map(axes(x.X, 2)) do i t.f(view(x.X, :, i)) end - return ColVecs(hcat(vals)) + return ColVecs(hcat(vals...)) end -# function _map(t::FunctionTransform, x::RowVecs) -# vals = map(axes(x.X, 1)) do i -# t.f(view(x.X, i, :)) -# end -# return RowVecs(hcat(vals...)) -# end +function _map(t::FunctionTransform, x::RowVecs) + vals = map(axes(x.X, 1)) do i + t.f(view(x.X, i, :)) + end + return RowVecs(hcat(vals...)') +end # _map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1)) -_map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2)) +# _map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2)) duplicate(t::FunctionTransform,f) = FunctionTransform(f) diff --git a/test/transform/functiontransform.jl b/test/transform/functiontransform.jl index c7fe48b90..7b5899627 100644 --- a/test/transform/functiontransform.jl +++ b/test/transform/functiontransform.jl @@ -29,5 +29,4 @@ @test repr(FunctionTransform(sin)) == "Function Transform: $(sin)" f(a, x) = sin.(a .* x) test_ADs(x->transform(SEKernel(), FunctionTransform(y->f(x, y))), randn(rng, 3)) - @test_broken "Zygote is failing" end diff --git a/test/utils_AD.jl b/test/utils_AD.jl index 3f13b6c56..1354485f9 100644 --- a/test/utils_AD.jl +++ b/test/utils_AD.jl @@ -129,20 +129,20 @@ function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3]) A = rand(rng, dims...) B = rand(rng, dims...) for dim in 1:2 - # compare_gradient(AD, A) do a - # testfunction(k, a, dim) - # end - # compare_gradient(AD, A) do a - # testfunction(k, a, B, dim) - # end - # compare_gradient(AD, B) do b - # testfunction(k, A, b, dim) - # end - # if !(args === nothing) - # compare_gradient(AD, args) do p - # testfunction(kernelfunction(p), A, dim) - # end - # end + compare_gradient(AD, A) do a + testfunction(k, a, dim) + end + compare_gradient(AD, A) do a + testfunction(k, a, B, dim) + end + compare_gradient(AD, B) do b + testfunction(k, A, b, dim) + end + if !(args === nothing) + compare_gradient(AD, args) do p + testfunction(kernelfunction(p), A, dim) + end + end end end end From 0747a46bb3cff419e2ceb53b2fce472251090809 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Wed, 12 Aug 2020 01:03:41 +0530 Subject: [PATCH 3/5] Clean up --- src/transform/functiontransform.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transform/functiontransform.jl b/src/transform/functiontransform.jl index 9366b86a0..ac6b0ce3c 100644 --- a/src/transform/functiontransform.jl +++ b/src/transform/functiontransform.jl @@ -32,10 +32,6 @@ function _map(t::FunctionTransform, x::RowVecs) return RowVecs(hcat(vals...)') end -# _map(t::FunctionTransform, x::ColVecs) = ColVecs(mapslices(t.f, x.X; dims=1)) -# _map(t::FunctionTransform, x::RowVecs) = RowVecs(mapslices(t.f, x.X; dims=2)) - - duplicate(t::FunctionTransform,f) = FunctionTransform(f) Base.show(io::IO, t::FunctionTransform) = print(io, "Function Transform: ", t.f) From a8659120f93404f472c2db0887da90e2315ab8a1 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Thu, 13 Aug 2020 15:28:04 +0530 Subject: [PATCH 4/5] Avoid splatting, use reduce instead --- src/transform/functiontransform.jl | 4 ++-- src/zygote_adjoints.jl | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/transform/functiontransform.jl b/src/transform/functiontransform.jl index ac6b0ce3c..67c25a35f 100644 --- a/src/transform/functiontransform.jl +++ b/src/transform/functiontransform.jl @@ -22,14 +22,14 @@ function _map(t::FunctionTransform, x::ColVecs) vals = map(axes(x.X, 2)) do i t.f(view(x.X, :, i)) end - return ColVecs(hcat(vals...)) + return ColVecs(reduce(hcat, vals)) end function _map(t::FunctionTransform, x::RowVecs) vals = map(axes(x.X, 1)) do i t.f(view(x.X, i, :)) end - return RowVecs(hcat(vals...)') + return RowVecs(reduce(hcat, vals)') end duplicate(t::FunctionTransform,f) = FunctionTransform(f) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index f51466fb6..16908389c 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -77,6 +77,24 @@ end return RowVecs(X), back end +@adjoint function reduce(::typeof(hcat), xs) + function back(Δ) + start = 0 + Δs = [begin + d = if ndims(xsi) == 1 + Δ[:, start+1] + else + i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail + Δ[:, start+1:start+size(xsi,2), i...] + end + start += size(xsi, 2) + d + end for xsi in xs] + return (nothing, Δs) + end + return reduce(hcat, xs), back +end + @adjoint function Base.map(t::Transform, X::ColVecs) pullback(_map, t, X) end From b516668c7622f3d808b3e2158c99c5627b0930c1 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Sun, 16 Aug 2020 00:41:30 +0530 Subject: [PATCH 5/5] Remove rrules for reduce(hcat) --- src/zygote_adjoints.jl | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/zygote_adjoints.jl b/src/zygote_adjoints.jl index 16908389c..f51466fb6 100644 --- a/src/zygote_adjoints.jl +++ b/src/zygote_adjoints.jl @@ -77,24 +77,6 @@ end return RowVecs(X), back end -@adjoint function reduce(::typeof(hcat), xs) - function back(Δ) - start = 0 - Δs = [begin - d = if ndims(xsi) == 1 - Δ[:, start+1] - else - i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail - Δ[:, start+1:start+size(xsi,2), i...] - end - start += size(xsi, 2) - d - end for xsi in xs] - return (nothing, Δs) - end - return reduce(hcat, xs), back -end - @adjoint function Base.map(t::Transform, X::ColVecs) pullback(_map, t, X) end