From 31b1a08e174b9f5f48323d803b47c933aa9261ce Mon Sep 17 00:00:00 2001 From: David Nabergoj Date: Fri, 15 Nov 2024 12:16:31 +0100 Subject: [PATCH] Fix RNODE deepcopy (hack) --- test/test_deepcopy.py | 8 ++++- torchflows/bijections/continuous/base.py | 30 +++++++------------ .../bijections/continuous/regularization.py | 16 ---------- torchflows/flows.py | 12 +++++++- 4 files changed, 29 insertions(+), 37 deletions(-) delete mode 100644 torchflows/bijections/continuous/regularization.py diff --git a/test/test_deepcopy.py b/test/test_deepcopy.py index 1745114..9ca0663 100644 --- a/test/test_deepcopy.py +++ b/test/test_deepcopy.py @@ -16,5 +16,11 @@ def test_post_variational_fit(): b = RNODE(event_shape=(10,)) f = Flow(b) f.variational_fit(lambda x: torch.sum(-x ** 2), n_epochs=2) - b.eval() + deepcopy(b) + +def test_post_fit(): + torch.manual_seed(0) + b = RNODE(event_shape=(10,)) + f = Flow(b) + f.fit(x_train=torch.randn(3, *b.event_shape), n_epochs=2) deepcopy(b) diff --git a/torchflows/bijections/continuous/base.py b/torchflows/bijections/continuous/base.py index fea1037..4928a28 100644 --- a/torchflows/bijections/continuous/base.py +++ b/torchflows/bijections/continuous/base.py @@ -1,14 +1,12 @@ import math -from typing import Union, Tuple, List, Optional, Dict +from typing import Union, Tuple, List import torch import torch.nn as nn -from torch.nn import ModuleDict, ModuleList from torchflows.bijections.base import Bijection from torchflows.bijections.continuous.layers import DiffEqLayer, ConcatConv2d, IgnoreConv2d import torchflows.bijections.continuous.layers as diff_eq_layers -from torchflows.bijections.continuous.regularization import GeodesicRegularization, JacobianRegularization from torchflows.utils import flatten_event, flatten_batch, get_batch_shape, unflatten_batch, unflatten_event @@ -251,31 +249,22 @@ def __init__(self, if isinstance(regularization, str): regularization = (regularization,) - reg_modules = nn.ModuleDict() + self.supported_reg_types = ['sq_jac_norm'] for rt in regularization: - if rt == 'geodesic': - reg_modules[rt] = GeodesicRegularization() - elif rt == "sq_jac_norm": - reg_modules[rt] = JacobianRegularization() - else: + if rt not in self.supported_reg_types: raise ValueError + self.used_reg_types = regularization - self.register_module('reg_modules', reg_modules) - - def regularization(self): - total = torch.tensor(0.0) - for key, val in self.reg_modules.items(): - total += val.coef * val.value.mean() - return total + self.reg_jac_coef = 1.0 + self.stored_reg = None def divergence_step(self, dy, y) -> torch.Tensor: batch_size = y.shape[0] - if "sq_jac_norm" in self.reg_modules: + if "sq_jac_norm" in self.used_reg_types and self.training: divergence, sq_jac_norm = divergence_approx_extended(dy, y, e=self.hutch_noise) # Store regularization data - sq_jac_norm = sq_jac_norm.view(batch_size, 1) - self.reg_modules['sq_jac_norm'].value = sq_jac_norm + self.stored_reg = self.reg_jac_coef * sq_jac_norm.mean() else: divergence = divergence_approx_basic(dy, y, e=self.hutch_noise) divergence = divergence.view(batch_size, 1) @@ -284,6 +273,9 @@ def divergence_step(self, dy, y) -> torch.Tensor: return divergence + def regularization(self): + return (self.stored_reg or 0) + super().regularization() + class ContinuousBijection(Bijection): """ diff --git a/torchflows/bijections/continuous/regularization.py b/torchflows/bijections/continuous/regularization.py deleted file mode 100644 index 89b8cba..0000000 --- a/torchflows/bijections/continuous/regularization.py +++ /dev/null @@ -1,16 +0,0 @@ -import torch -import torch.nn as nn - - -class GeodesicRegularization(nn.Module): - def __init__(self, coef: float = 1.0): - super().__init__() - self.register_buffer('coef', torch.tensor(coef)) - self.register_buffer('value', torch.tensor(0.)) - - -class JacobianRegularization(nn.Module): - def __init__(self, coef: float = 1.0): - super().__init__() - self.register_buffer('coef', torch.tensor(coef)) - self.register_buffer('value', torch.tensor(0.)) diff --git a/torchflows/flows.py b/torchflows/flows.py index 4713509..e6be04c 100644 --- a/torchflows/flows.py +++ b/torchflows/flows.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn from tqdm import tqdm + +from torchflows.bijections.continuous.rnode import RNODE from torchflows.bijections.base import Bijection from torchflows.utils import flatten_event, unflatten_event, create_data_loader from torchflows.base_distributions.gaussian import DiagonalGaussian @@ -260,6 +262,10 @@ def compute_batch_loss(batch_, reduction: callable = torch.mean): if x_val is not None and keep_best_weights: self.load_state_dict(best_weights) + # hacky error handling (Jacobian regularization is a non-leaf node within RNODE's autograd graph) + if isinstance(self.bijection, RNODE): + self.bijection.f.stored_reg = None + self.eval() def variational_fit(self, @@ -272,7 +278,7 @@ def variational_fit(self, keep_best_weights: bool = True, show_progress: bool = False, check_for_divergences: bool = False, - time_limit_seconds:Union[float, int] = None): + time_limit_seconds: Union[float, int] = None): """Train the normalizing flow to fit a target log probability. Stochastic variational inference lets us train a distribution using the unnormalized target log density instead of a fixed dataset. @@ -367,6 +373,10 @@ def variational_fit(self, elif keep_best_weights: self.load_state_dict(best_weights) + # hacky error handling (Jacobian regularization is a non-leaf node within RNODE's autograd graph) + if isinstance(self.bijection, RNODE): + self.bijection.f.stored_reg = None + self.eval()