Skip to content

Commit

Permalink
finish refactoring and testing backward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
dzhang314 committed Aug 27, 2024
1 parent ad404a3 commit 00550ad
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 182 deletions.
284 changes: 102 additions & 182 deletions src/RungeKuttaToolKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -599,162 +599,102 @@ end
############################################ GRADIENT COMPUTATION (REVERSE-MODE)


function pullback_dPhi_from_residual!(
dPhi::AbstractMatrix{T},
b::AbstractVector{T},
residuals::AbstractVector{T},
source_indices::AbstractVector{Int},
) where {T}

# Validate array dimensions.
stage_indices = axes(dPhi, 1)
internal_indices = axes(dPhi, 2)
output_indices = axes(residuals, 1)
@assert axes(dPhi) == (stage_indices, internal_indices)
@assert axes(b) == (stage_indices,)
@assert axes(residuals) == (output_indices,)
@assert axes(source_indices) == (internal_indices,)

# Construct numeric constants.
_zero = zero(T)

@inbounds for (k, s) in Iterators.reverse(pairs(source_indices))
if s == NULL_INDEX
@simd ivdep for j in stage_indices
dPhi[j, k] = _zero
end
else
residual = residuals[s]
twice_residual = residual + residual
@simd ivdep for j in stage_indices
dPhi[j, k] = twice_residual * b[j]
end
end
end

return dPhi
end


function pullback_dPhi_initial!(
dPhi::AbstractMatrix{T},
function pullback_dPhi_from_b!(
ev::RKOCEvaluator{T},
b::AbstractVector{T},
Phi::AbstractMatrix{T},
inv_gamma::AbstractVector{T},
source_indices::AbstractVector{Int},
) where {T}

# Validate array dimensions.
stage_indices = axes(dPhi, 1)
internal_indices = axes(dPhi, 2)
output_indices = axes(inv_gamma, 1)
@assert axes(dPhi) == (stage_indices, internal_indices)
@assert axes(b) == (stage_indices,)
@assert axes(Phi) == (stage_indices, internal_indices)
@assert axes(inv_gamma) == (output_indices,)
@assert axes(source_indices) == (internal_indices,)
stage_axis, _, _ = get_axes(ev)
@assert axes(b) == (stage_axis,)

# Construct numeric constants.
_zero = zero(T)

@inbounds for (k, s) in Iterators.reverse(pairs(source_indices))
if s == NULL_INDEX
@simd ivdep for j in stage_indices
dPhi[j, k] = _zero
@inbounds for (k, i) in Iterators.reverse(pairs(ev.table.source_indices))
if i == NULL_INDEX
@simd ivdep for j in stage_axis
ev.dPhi[j, k] = _zero
end
else
# Compute dot product without SIMD for determinism.
residual = _zero
for j in stage_indices
residual += b[j] * Phi[j, k]
lhs = _zero
for j in stage_axis
lhs += b[j] * ev.Phi[j, k]
end
# Subtract inv_gamma[s] at the end for improved numerical stability.
residual -= inv_gamma[s]
# Subtract inv_gamma at the end for improved numerical stability.
residual = lhs - ev.inv_gamma[i]
# Double residual. (Addition is faster than multiplication by two.)
twice_residual = residual + residual
@simd ivdep for j in stage_indices
dPhi[j, k] = twice_residual * b[j]
@simd ivdep for j in stage_axis
ev.dPhi[j, k] = twice_residual * b[j]
end
end
end

return dPhi
return ev
end


function pullback_dPhi!(
dPhi::AbstractMatrix{T},
function pullback_dPhi_from_A!(
ev::RKOCEvaluator{T},
A::AbstractMatrix{T},
Phi::AbstractMatrix{T},
extension_indices::AbstractVector{Int},
rooted_sum_ranges::AbstractVector{UnitRange{Int}},
rooted_sum_indices::AbstractVector{Pair{Int,Int}},
) where {T}

# Validate array dimensions.
stage_indices = axes(dPhi, 1)
internal_indices = axes(dPhi, 2)
@assert axes(dPhi) == (stage_indices, internal_indices)
@assert axes(A) == (stage_indices, stage_indices)
@assert axes(Phi) == (stage_indices, internal_indices)
@assert axes(extension_indices) == (internal_indices,)
@assert axes(rooted_sum_ranges) == (internal_indices,)

@inbounds for k in Iterators.reverse(internal_indices)
c = extension_indices[k]
stage_axis, internal_axis, _ = get_axes(ev)
@assert axes(A) == (stage_axis, stage_axis)

# Iterate over intermediate trees in reverse order.
@inbounds for k in Iterators.reverse(internal_axis)
c = ev.table.extension_indices[k]
if c != NULL_INDEX
for j in stage_indices
temp = dPhi[j, k]
for i in stage_indices
temp += A[i, j] * dPhi[i, c]
# Perform adjoint matrix-vector multiplication.
for i in stage_axis
dphi = ev.dPhi[i, c]
@simd ivdep for j in stage_axis
ev.dPhi[j, k] += A[i, j] * dphi
end
dPhi[j, k] = temp
end
end
for i in rooted_sum_ranges[k]
(p, q) = rooted_sum_indices[i]
@simd ivdep for j in stage_indices
dPhi[j, k] += Phi[j, p] * dPhi[j, q]
for i in ev.table.rooted_sum_ranges[k]
(p, q) = ev.table.rooted_sum_indices[i]
@simd ivdep for j in stage_axis
ev.dPhi[j, k] += ev.Phi[j, p] * ev.dPhi[j, q]
end
end
end

return dPhi
return ev
end


function pullback_dA!(
dA::AbstractMatrix{T},
Phi::AbstractMatrix{T},
dPhi::AbstractMatrix{T},
extension_indices::AbstractVector{Int},
ev::RKOCEvaluator{T},
) where {T}

# Validate array dimensions.
stage_indices = axes(dA, 1)
internal_indices = axes(dPhi, 2)
@assert axes(dA) == (stage_indices, stage_indices)
@assert axes(Phi) == (stage_indices, internal_indices)
@assert axes(dPhi) == (stage_indices, internal_indices)
@assert axes(extension_indices) == (internal_indices,)
stage_axis, _, _ = get_axes(ev)
@assert axes(dA) == (stage_axis, stage_axis)

# Construct numeric constants.
_zero = zero(T)

# Initialize dA to zero.
@inbounds for j in stage_indices
@simd ivdep for i in stage_indices
@inbounds for j in stage_axis
@simd ivdep for i in stage_axis
dA[i, j] = _zero
end
end

# Iterate over intermediate trees obtained by extension.
@inbounds for (k, c) in pairs(extension_indices)
@inbounds for (k, c) in pairs(ev.table.extension_indices)
if c != NULL_INDEX
for t in stage_indices
f = Phi[t, k]
@simd ivdep for s in stage_indices
dA[s, t] += f * dPhi[s, c]
for t in stage_axis
phi = ev.Phi[t, k]
@simd ivdep for s in stage_axis
dA[s, t] += phi * ev.dPhi[s, c]
end
end
end
Expand All @@ -766,47 +706,40 @@ end

function pullback_db!(
db::AbstractVector{T},
ev::RKOCEvaluator{T},
b::AbstractVector{T},
Phi::AbstractMatrix{T},
inv_gamma::AbstractVector{T},
selected_indices::AbstractVector{Int},
) where {T}

# Validate array dimensions.
stage_indices = axes(db, 1)
internal_indices = axes(Phi, 2)
output_indices = axes(inv_gamma, 1)
@assert axes(db) == (stage_indices,)
@assert axes(b) == (stage_indices,)
@assert axes(Phi) == (stage_indices, internal_indices)
@assert axes(inv_gamma) == (output_indices,)
@assert axes(selected_indices) == (output_indices,)
stage_axis, _, _ = get_axes(ev)
@assert axes(db) == (stage_axis,)
@assert axes(b) == (stage_axis,)

# Construct numeric constants.
_zero = zero(T)

@inbounds begin

# Initialize db to zero.
@simd ivdep for i in stage_indices
@simd ivdep for i in stage_axis
db[i] = _zero
end

for (i, k) in pairs(selected_indices)
for (i, k) in pairs(ev.table.selected_indices)
# Compute dot product without SIMD for determinism.
residual = _zero
for j in stage_indices
residual += b[j] * Phi[j, k]
lhs = _zero
for j in stage_axis
lhs += b[j] * ev.Phi[j, k]
end
# Subtract inv_gamma[i] at the end for improved numerical stability.
residual -= inv_gamma[i]
@simd ivdep for j in stage_indices
db[j] += residual * Phi[j, k]
# Subtract inv_gamma at the end for improved numerical stability.
residual = lhs - ev.inv_gamma[i]
@simd ivdep for j in stage_axis
db[j] += residual * ev.Phi[j, k]
end
end

# Double db. (Addition is faster than multiplication by two.)
@simd ivdep for i in stage_indices
@simd ivdep for i in stage_axis
db[i] += db[i]
end

Expand All @@ -816,6 +749,50 @@ function pullback_db!(
end


"""
(adj::RKOCAdjoint{T})(
dA::AbstractMatrix{T},
db::AbstractVector{T},
A::AbstractMatrix{T},
b::AbstractVector{T},
) -> Tuple{AbstractMatrix{T}, AbstractVector{T}}
Compute the gradient of the sum of squared residuals
of the Runge--Kutta order conditions ``\\nabla_{A, \\mathbf{b}}
\\sum_{t \\in T} (\\mathbf{b} \\cdot \\Phi_t(A) - 1/\\gamma(t))^2``
at a given Butcher tableau ``(A, \\mathbf{b})``
over a set of rooted trees ``T`` encoded by an `RKOCEvaluator`.
# Arguments
- `adj`: `RKOCAdjoint` object obtained by applying the adjoint operator `'`
to an `RKOCEvaluator`. In other words, this function should be called as
`ev'(dA, db, A, b)` where `ev` is an `RKOCEvaluator`.
- `dA`: ``s \\times s`` output matrix containing the gradient of the sum of
squared residuals with respect to ``A``.
- `db`: length ``s`` output vector containing the gradient of the sum of
squared residuals with respect to ``\\mathbf{b}``.
- `A`: ``s \\times s`` input matrix containing the coefficients of a
Runge--Kutta method (i.e., the upper-right block of a Butcher tableau).
- `b`: length ``s`` input vector containing the weights of a Runge--Kutta
method (i.e., the lower-right row of a Butcher tableau).
Here, ``s`` denotes the number of stages specified when constructing `ev`.
"""
function (adj::RKOCAdjoint{T})(
dA::AbstractMatrix{T},
db::AbstractVector{T},
A::AbstractMatrix{T},
b::AbstractVector{T},
) where {T}
compute_Phi!(adj.ev, A)
pullback_dPhi_from_b!(adj.ev, b)
pullback_dPhi_from_A!(adj.ev, A)
pullback_dA!(dA, adj.ev)
pullback_db!(db, adj.ev, b)
return (dA, db)
end


########################################################### LEAST-SQUARES SOLVER


Expand Down Expand Up @@ -1311,63 +1288,6 @@ end
# end


# function (adj::RKOCEvaluatorBEAdjoint{T})(g::Vector{T}, x::Vector{T}) where {T}
# reshape_explicit!(adj.ev.A, adj.ev.b, x)
# compute_Phi!(adj.ev.Phi, adj.ev.A, adj.ev.table.instructions)
# compute_residuals!(adj.ev.residuals,
# adj.ev.b, adj.ev.Phi, adj.ev.inv_gamma, adj.ev.table.selected_indices)
# pullback_dPhi_from_residual!(adj.ev.dPhi,
# adj.ev.b, adj.ev.residuals, adj.ev.table.source_indices)
# pullback_dPhi!(adj.ev.dPhi,
# adj.ev.A, adj.ev.Phi, adj.ev.table.extension_indices,
# adj.ev.table.rooted_sum_ranges, adj.ev.table.rooted_sum_indices)
# pullback_dA!(adj.ev.dA,
# adj.ev.Phi, adj.ev.dPhi, adj.ev.table.extension_indices)
# pullback_db!(adj.ev.db,
# adj.ev.Phi, adj.ev.residuals, adj.ev.table.selected_indices)
# reshape_explicit!(g, adj.ev.dA, adj.ev.db)
# return g
# end


# function (adj::RKOCEvaluatorBDAdjoint{T})(g::Vector{T}, x::Vector{T}) where {T}
# reshape_diagonally_implicit!(adj.ev.A, adj.ev.b, x)
# compute_Phi!(adj.ev.Phi, adj.ev.A, adj.ev.table.instructions)
# compute_residuals!(adj.ev.residuals,
# adj.ev.b, adj.ev.Phi, adj.ev.inv_gamma, adj.ev.table.selected_indices)
# pullback_dPhi_from_residual!(adj.ev.dPhi,
# adj.ev.b, adj.ev.residuals, adj.ev.table.source_indices)
# pullback_dPhi!(adj.ev.dPhi,
# adj.ev.A, adj.ev.Phi, adj.ev.table.extension_indices,
# adj.ev.table.rooted_sum_ranges, adj.ev.table.rooted_sum_indices)
# pullback_dA!(adj.ev.dA,
# adj.ev.Phi, adj.ev.dPhi, adj.ev.table.extension_indices)
# pullback_db!(adj.ev.db,
# adj.ev.Phi, adj.ev.residuals, adj.ev.table.selected_indices)
# reshape_diagonally_implicit!(g, adj.ev.dA, adj.ev.db)
# return g
# end


# function (adj::RKOCEvaluatorBIAdjoint{T})(g::Vector{T}, x::Vector{T}) where {T}
# reshape_implicit!(adj.ev.A, adj.ev.b, x)
# compute_Phi!(adj.ev.Phi, adj.ev.A, adj.ev.table.instructions)
# compute_residuals!(adj.ev.residuals,
# adj.ev.b, adj.ev.Phi, adj.ev.inv_gamma, adj.ev.table.selected_indices)
# pullback_dPhi_from_residual!(adj.ev.dPhi,
# adj.ev.b, adj.ev.residuals, adj.ev.table.source_indices)
# pullback_dPhi!(adj.ev.dPhi,
# adj.ev.A, adj.ev.Phi, adj.ev.table.extension_indices,
# adj.ev.table.rooted_sum_ranges, adj.ev.table.rooted_sum_indices)
# pullback_dA!(adj.ev.dA,
# adj.ev.Phi, adj.ev.dPhi, adj.ev.table.extension_indices)
# pullback_db!(adj.ev.db,
# adj.ev.Phi, adj.ev.residuals, adj.ev.table.selected_indices)
# reshape_implicit!(g, adj.ev.dA, adj.ev.db)
# return g
# end


# function (adj::RKOCResidualEvaluatorAEAdjoint{T})(
# jacobian::Matrix{T}, x::Vector{T}
# ) where {T}
Expand Down
Loading

0 comments on commit 00550ad

Please sign in to comment.