Skip to content

Commit

Permalink
Merge pull request #150 from DiffEqML/avoid-saving-autograd-func
Browse files Browse the repository at this point in the history
Avoid saving autograd func
  • Loading branch information
massastrello authored Jun 22, 2022
2 parents 3b3fd7b + d9ede6c commit c6639c7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 47 deletions.
57 changes: 16 additions & 41 deletions torchdyn/core/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,40 +66,25 @@ def __init__(self, vector_field:Union[Callable, nn.Module], solver:Union[str, nn
self.vf.register_parameter('dummy_parameter', dummy_parameter)
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])

# instantiates an underlying autograd.Function that overrides the backward pass with the intended version
# sensitivity algorithm
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, solver, atol, rtol, interpolator,
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
problem_type='standard').apply
elif self.sensalg == 'interpolated_adjoint':
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, solver, atol, rtol, interpolator,
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
problem_type='standard').apply


def _prep_odeint(self):
def _autograd_func(self):
"create autograd functions for backward pass"
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
problem_type='standard').apply
elif self.sensalg == 'interpolated_adjoint':
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, self.atol, self.rtol, self.interpolator,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
problem_type='standard').apply

problem_type='standard').apply

def odeint(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}):
"Returns Tuple(`t_eval`, `solution`)"
self._prep_odeint()
if self.sensalg == 'autograd':
return odeint(self.vf, x, t_span, self.solver, self.atol, self.rtol, interpolator=self.interpolator,
save_at=save_at, args=args)

else:
return self.autograd_function(self.vf_params, x, t_span, save_at, args)
return self._autograd_func()(self.vf_params, x, t_span, save_at, args)

def forward(self, x:Tensor, t_span:Tensor, save_at:Tensor=(), args={}):
"For safety redirects to intended method `odeint`"
Expand Down Expand Up @@ -128,39 +113,29 @@ def __init__(self, vector_field:Callable, solver:str, sensitivity:str='autograd'
self.parallel_solver = solver
self.fine_steps, self.maxiter = fine_steps, maxiter

def _autograd_func(self):
"create autograd functions for backward pass"
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, solver, 0, 0, None,
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
'multiple_shooting', fine_steps, maxiter).apply
return _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
'multiple_shooting', self.fine_steps, self.maxiter).apply
elif self.sensalg == 'interpolated_adjoint':
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, solver, 0, 0, None,
solver_adjoint, atol_adjoint, rtol_adjoint, integral_loss,
'multiple_shooting', fine_steps, maxiter).apply

return _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
'multiple_shooting', self.fine_steps, self.maxiter).apply
def odeint(self, x:Tensor, t_span:Tensor, B0:Tensor=None):
"Returns Tuple(`t_eval`, `solution`)"
self._prep_odeint()
if self.sensalg == 'autograd':
return odeint_mshooting(self.vf, x, t_span, self.parallel_solver, B0, self.fine_steps, self.maxiter)
else:
return self.autograd_function(self.vf_params, x, t_span, B0)
return self._autograd_func()(self.vf_params, x, t_span, B0)

def forward(self, x:Tensor, t_span:Tensor, B0:Tensor=None):
"For safety redirects to intended method `odeint`"
return self.odeint(x, t_span, B0)

def _prep_odeint(self):
"create autograd functions for backward pass"
self.vf_params = torch.cat([p.contiguous().flatten() for p in self.vf.parameters()])
if self.sensalg == 'adjoint': # alias .apply as direct call to preserve consistency of call signature
self.autograd_function = _gather_odefunc_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
'multiple_shooting', self.fine_steps, self.maxiter).apply
elif self.sensalg == 'interpolated_adjoint':
self.autograd_function = _gather_odefunc_interp_adjoint(self.vf, self.vf_params, self.solver, 0, 0, None,
self.solver_adjoint, self.atol_adjoint, self.rtol_adjoint, self.integral_loss,
'multiple_shooting', self.fine_steps, self.maxiter).apply


class SDEProblem(nn.Module):
def __init__(self):
Expand Down
10 changes: 4 additions & 6 deletions torchdyn/numerics/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ def generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator
def _gather_odefunc_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint,
atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4):
"Prepares definition of autograd.Function for adjoint sensitivity analysis of the above `ODEProblem`"
global _ODEProblemFuncAdjoint
class _ODEProblemFuncAdjoint(Function):
class _ODEProblemFunc(Function):
@staticmethod
def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B,
Expand Down Expand Up @@ -98,15 +97,14 @@ def adjoint_dynamics(t, A):
λ_tspan = torch.stack([dLdt[0], dLdt[-1]])
return (μ, λ, λ_tspan, None, None, None)

return _ODEProblemFuncAdjoint
return _ODEProblemFunc


#TODO: introduce `t_span` grad as above
def _gather_odefunc_interp_adjoint(vf, vf_params, solver, atol, rtol, interpolator, solver_adjoint,
atol_adjoint, rtol_adjoint, integral_loss, problem_type, maxiter=4, fine_steps=4):
"Prepares definition of autograd.Function for interpolated adjoint sensitivity analysis of the above `ODEProblem`"
global _ODEProblemFuncInterpAdjoint
class _ODEProblemFuncInterpAdjoint(Function):
class _ODEProblemFunc(Function):
@staticmethod
def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B,
Expand Down Expand Up @@ -160,4 +158,4 @@ def adjoint_dynamics(t, A):
λ, μ = λ.reshape(λT.shape), μ.reshape(μT.shape)
return (μ, λ, None, None, None)

return _ODEProblemFuncInterpAdjoint
return _ODEProblemFunc

0 comments on commit c6639c7

Please sign in to comment.