Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiability of Spatial Gradients #598

Open
maximilian-gelbrecht opened this issue Oct 29, 2024 · 2 comments
Open

Differentiability of Spatial Gradients #598

maximilian-gelbrecht opened this issue Oct 29, 2024 · 2 comments
Assignees
Labels
differentiability 🤖 Making the model differentiable via AD

Comments

@maximilian-gelbrecht
Copy link
Member

maximilian-gelbrecht commented Oct 29, 2024

After/with #589, I continued my hunt for differentiability with Enzyme a bit. With the transforms (more or less) differentiable, I found the first error from Enzyme at the spatial gradients, more precisely _divergence!.

As it is, Enzyme seems to have a problem with the way we write these methods with function kernels/closures that are conditional. Because as the following shows, Enzyme only yields an error when we use a kernel with a conditional closure as we commonly do in our code for these functions:

import Pkg
Pkg.activate(".")

using SpeedyWeather
using Enzyme

grid_types = [FullGaussianGrid] #, OctahedralGaussianGrid] # one full and one reduced grid, both Gaussian to have exact transforms 
grid_dealiasing = [2] #, 3]

grid_type = grid_types[1]
i_grid = 1 

spectral_grid = SpectralGrid(Grid=grid_type, dealiasing=grid_dealiasing[i_grid])

# forwards 
S = SpectralTransform(spectral_grid)
dS = deepcopy(S)

u_grid = rand(spectral_grid.Grid{spectral_grid.NF}, spectral_grid.nlat_half, spectral_grid.nlayers)
v_grid = rand(spectral_grid.Grid{spectral_grid.NF}, spectral_grid.nlat_half, spectral_grid.nlayers)

u = transform(u_grid, S)
v = transform(v_grid, S)
du = zero(u)
dv = zero(v)

div = zero(u)
ddiv = zero(u)
fill!(ddiv, 1+1im)

func(o, a, b, c) = a - b + c

# no error 
autodiff(Reverse, SpeedyWeather.SpeedyTransforms._divergence!, Const, Const(func), Duplicated(div, ddiv), Duplicated(u, du), Duplicated(v, dv), Duplicated(S, dS)) # this doesn't give an error directly

mul = 2
kernel2(o, a, b, c) = mul .* (a-b+c)  # just some arbitrary closure over a function

# no error 
autodiff(Reverse, SpeedyWeather.SpeedyTransforms._divergence!, Const, Const(kernel2), Duplicated(div, ddiv), Duplicated(u, du), Duplicated(v, dv), Duplicated(S, dS)) # this also doesn't give an error directly

add = false
kernel(o, a, b, c) = add ? o-(a-b+c) : -(a-b+c) # this is (similar to) what we actually use in divergence!

# directly an error
autodiff(Reverse, SpeedyWeather.SpeedyTransforms._divergence!, Const, Const(kernel), Duplicated(div, ddiv), Duplicated(u, du), Duplicated(v, dv), Duplicated(S, dS)) # this will yield an error from the compiler 

Not sure why that's the case. I post it here, to document it, let's see if I still have some time before my vacation to condense this down to a MWE for the Enzyme devs. So far my attempts to do a quick MWE have not been successful, because very simple examples do actually work, like the following causes no issue:

using Enzyme 

function apply_func!(kernel, a, b, c)
    
    for i in eachindex(b)
        a[i] = kernel(b[i], c[i])
    end 

    return nothing
end 

add_or_mul = false

kernel(b, c) = add_or_mul ? b + c : b * c

a = zeros(10)
da = ones(10)

b = rand(10)
c = rand(10)

db = zeros(10)
dc = zeros(10)

# no error, gradients are correct
autodiff(Reverse, apply_func!, Const, Const(kernel), Duplicated(a, da), Duplicated(b, db), Duplicated(c, dc))

All of that being said, maybe we also have to do a slightly different way of computing those divergences with a GPU/KernelAbstractions version anyway. So I won't invest too much effort into this now.

@maximilian-gelbrecht maximilian-gelbrecht added the differentiability 🤖 Making the model differentiable via AD label Oct 29, 2024
@maximilian-gelbrecht maximilian-gelbrecht self-assigned this Oct 29, 2024
@milankl
Copy link
Member

milankl commented Oct 30, 2024

@vchuravy we are running into some problems differentiating through the divergence function defined here

function _divergence!(
kernel,
div::LowerTriangularArray,
u::LowerTriangularArray,
v::LowerTriangularArray,
S::SpectralTransform;
radius = DEFAULT_RADIUS,
)
(; grad_y_vordiv1, grad_y_vordiv2 ) = S
@boundscheck ismatching(S, div) || throw(DimensionMismatch(S, div))
lmax, mmax = size(div, OneBased, as=Matrix)
for k in eachmatrix(div, u, v) # also checks size compatibility
lm = 0
@inbounds for m in 1:mmax # 1-based l, m
# DIAGONAL (separate to avoid access to v[l-1, m])
lm += 1
∂u∂λ = ((m-1)*im)*u[lm, k]
∂v∂θ1 = 0 # always above the diagonal
∂v∂θ2 = grad_y_vordiv2[lm] * v[lm+1, k]
div[lm, k] = kernel(div[lm, k], ∂u∂λ, ∂v∂θ1, ∂v∂θ2)
# BELOW DIAGONAL (but skip last row)
for l in m+1:lmax-1
lm += 1
∂u∂λ = ((m-1)*im)*u[lm, k]
∂v∂θ1 = grad_y_vordiv1[lm] * v[lm-1, k]
∂v∂θ2 = grad_y_vordiv2[lm] * v[lm+1, k] # this pulls in data from the last row though
div[lm, k] = kernel(div[lm, k], ∂u∂λ, ∂v∂θ1, ∂v∂θ2)
end
# Last row, only vectors make use of the lmax+1 row, set to zero for scalars div, curl
lm += 1
div[lm, k] = 0
end
end
# /radius scaling if not unit sphere
if radius != 1
div .*= inv(radius)
end
return div
end

This is the CPU version of the code (with scalar indexing and a running index too) but before I rewrite this towards KernelAbstractions is there anything that pops up why Enzyme would struggle to differentiate this?

@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Oct 30, 2024

It's not the _divergence! function only, it's calling it with specific kernels.

Doing it with anonymous functions instead of function closures at least doesn't get an error. I didn't check correctness yet, though.

So doing this works

kernel = flipsign ? (add ? (o, a, b, c) -> o-(a-b+c) : (o, a, b, c) -> -(a-b+c)) :
                                    (add ? (o, a, b, c) ->  o+(a-b+c) : (o, a, b, c) -> a-b+c)    

but this (our version in the main code) doesn't:

kernel(o, a, b, c) = flipsign ? (add ? o-(a-b+c) : -(a-b+c)) :
                                    (add ? o+(a-b+c) :   a-b+c )     

But yeah, I couldn't condense this down to a MWE yet. Because in simpler cases it's works to hand over a function like the one below.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
differentiability 🤖 Making the model differentiable via AD
Projects
None yet
Development

No branches or pull requests

2 participants