Skip to content

Commit

Permalink
Fix RNODE deepcopy (hack)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Nov 15, 2024
1 parent 1c9f3c6 commit 31b1a08
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 37 deletions.
8 changes: 7 additions & 1 deletion test/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,11 @@ def test_post_variational_fit():
b = RNODE(event_shape=(10,))
f = Flow(b)
f.variational_fit(lambda x: torch.sum(-x ** 2), n_epochs=2)
b.eval()
deepcopy(b)

def test_post_fit():
torch.manual_seed(0)
b = RNODE(event_shape=(10,))
f = Flow(b)
f.fit(x_train=torch.randn(3, *b.event_shape), n_epochs=2)
deepcopy(b)
30 changes: 11 additions & 19 deletions torchflows/bijections/continuous/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import math
from typing import Union, Tuple, List, Optional, Dict
from typing import Union, Tuple, List

import torch
import torch.nn as nn
from torch.nn import ModuleDict, ModuleList

from torchflows.bijections.base import Bijection
from torchflows.bijections.continuous.layers import DiffEqLayer, ConcatConv2d, IgnoreConv2d
import torchflows.bijections.continuous.layers as diff_eq_layers
from torchflows.bijections.continuous.regularization import GeodesicRegularization, JacobianRegularization
from torchflows.utils import flatten_event, flatten_batch, get_batch_shape, unflatten_batch, unflatten_event


Expand Down Expand Up @@ -251,31 +249,22 @@ def __init__(self,
if isinstance(regularization, str):
regularization = (regularization,)

reg_modules = nn.ModuleDict()
self.supported_reg_types = ['sq_jac_norm']
for rt in regularization:
if rt == 'geodesic':
reg_modules[rt] = GeodesicRegularization()
elif rt == "sq_jac_norm":
reg_modules[rt] = JacobianRegularization()
else:
if rt not in self.supported_reg_types:
raise ValueError
self.used_reg_types = regularization

self.register_module('reg_modules', reg_modules)

def regularization(self):
total = torch.tensor(0.0)
for key, val in self.reg_modules.items():
total += val.coef * val.value.mean()
return total
self.reg_jac_coef = 1.0
self.stored_reg = None

def divergence_step(self, dy, y) -> torch.Tensor:
batch_size = y.shape[0]

if "sq_jac_norm" in self.reg_modules:
if "sq_jac_norm" in self.used_reg_types and self.training:
divergence, sq_jac_norm = divergence_approx_extended(dy, y, e=self.hutch_noise)
# Store regularization data
sq_jac_norm = sq_jac_norm.view(batch_size, 1)
self.reg_modules['sq_jac_norm'].value = sq_jac_norm
self.stored_reg = self.reg_jac_coef * sq_jac_norm.mean()
else:
divergence = divergence_approx_basic(dy, y, e=self.hutch_noise)
divergence = divergence.view(batch_size, 1)
Expand All @@ -284,6 +273,9 @@ def divergence_step(self, dy, y) -> torch.Tensor:

return divergence

def regularization(self):
return (self.stored_reg or 0) + super().regularization()


class ContinuousBijection(Bijection):
"""
Expand Down
16 changes: 0 additions & 16 deletions torchflows/bijections/continuous/regularization.py

This file was deleted.

12 changes: 11 additions & 1 deletion torchflows/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch
import torch.nn as nn
from tqdm import tqdm

from torchflows.bijections.continuous.rnode import RNODE
from torchflows.bijections.base import Bijection
from torchflows.utils import flatten_event, unflatten_event, create_data_loader
from torchflows.base_distributions.gaussian import DiagonalGaussian
Expand Down Expand Up @@ -260,6 +262,10 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean):
if x_val is not None and keep_best_weights:
self.load_state_dict(best_weights)

# hacky error handling (Jacobian regularization is a non-leaf node within RNODE's autograd graph)
if isinstance(self.bijection, RNODE):
self.bijection.f.stored_reg = None

self.eval()

def variational_fit(self,
Expand All @@ -272,7 +278,7 @@ def variational_fit(self,
keep_best_weights: bool = True,
show_progress: bool = False,
check_for_divergences: bool = False,
time_limit_seconds:Union[float, int] = None):
time_limit_seconds: Union[float, int] = None):
"""Train the normalizing flow to fit a target log probability.
Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset.
Expand Down Expand Up @@ -367,6 +373,10 @@ def variational_fit(self,
elif keep_best_weights:
self.load_state_dict(best_weights)

# hacky error handling (Jacobian regularization is a non-leaf node within RNODE's autograd graph)
if isinstance(self.bijection, RNODE):
self.bijection.f.stored_reg = None

self.eval()


Expand Down

0 comments on commit 31b1a08

Please sign in to comment.