Skip to content

Commit

Permalink
Minor changes to continuous and residual NF parameter specification
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Sep 12, 2024
1 parent 9094cdb commit fc5456b
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 24 deletions.
21 changes: 17 additions & 4 deletions torchflows/bijections/continuous/ddnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,39 @@ class DeepDiffeomorphicBijection(ApproximateContinuousBijection):
Reference: Salman et al. "Deep diffeomorphic normalizing flows" (2018); https://arxiv.org/abs/1810.03256.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_steps: int = 150, solver="euler", **kwargs):
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
n_steps: int = 150,
solver="euler",
nn_kwargs: dict = None,
**kwargs):
"""
Constructor.
:param event_shape: shape of the event tensor.
:param n_steps: parameter T in the paper, i.e. the number of ResNet cells.
"""
diff_eq = RegularizedApproximateODEFunction(create_nn_time_independent(event_shape))
nn_kwargs = nn_kwargs or {}
diff_eq = RegularizedApproximateODEFunction(create_nn_time_independent(event_shape, **nn_kwargs))
self.n_steps = n_steps
super().__init__(event_shape, diff_eq, solver=solver, **kwargs)


class ConvolutionalDeepDiffeomorphicBijection(ApproximateContinuousBijection):
"""Convolutional variant of the DDNF architecture.
Reference: Salman et al. "Deep diffeomorphic normalizing flows" (2018); https://arxiv.org/abs/1810.03256.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], n_steps: int = 150, solver="euler", **kwargs):
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
n_steps: int = 150,
solver="euler",
nn_kwargs: dict = None,
**kwargs):
nn_kwargs = nn_kwargs or {}
if len(event_shape) != 3:
raise ValueError("Event shape must be of length 3 (channels, height, width).")
diff_eq = RegularizedApproximateODEFunction(create_cnn_time_independent(event_shape[0]))
diff_eq = RegularizedApproximateODEFunction(create_cnn_time_independent(event_shape[0], **nn_kwargs))
self.n_steps = n_steps
super().__init__(event_shape, diff_eq, solver=solver, **kwargs)
10 changes: 6 additions & 4 deletions torchflows/bijections/continuous/ffjord.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class FFJORD(ApproximateContinuousBijection):
Gratwohl et al. "FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models" (2018); https://arxiv.org/abs/1810.01367.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
diff_eq = RegularizedApproximateODEFunction(create_nn(event_shape))
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], nn_kwargs: dict = None, **kwargs):
nn_kwargs = nn_kwargs or {}
diff_eq = RegularizedApproximateODEFunction(create_nn(event_shape, **nn_kwargs))
super().__init__(event_shape, diff_eq, **kwargs)


Expand All @@ -29,8 +30,9 @@ class ConvolutionalFFJORD(ApproximateContinuousBijection):
Gratwohl et al. "FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative Models" (2018); https://arxiv.org/abs/1810.01367.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], nn_kwargs: dict = None, **kwargs):
nn_kwargs = nn_kwargs or {}
if len(event_shape) != 3:
raise ValueError("Event shape must be of length 3 (channels, height, width).")
diff_eq = RegularizedApproximateODEFunction(create_cnn(event_shape[0]))
diff_eq = RegularizedApproximateODEFunction(create_cnn(event_shape[0], **nn_kwargs))
super().__init__(event_shape, diff_eq, **kwargs)
15 changes: 11 additions & 4 deletions torchflows/bijections/continuous/otflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ def hessian_trace(self,


class OTPotential(TimeDerivative):
def __init__(self, event_size: int, hidden_size: int = None, **kwargs):
def __init__(self, event_size: int, hidden_size: int = 50, resnet_kwargs: dict = None):
super().__init__()
resnet_kwargs = resnet_kwargs or {}

# hidden_size = m
if hidden_size is None:
Expand All @@ -163,7 +164,7 @@ def __init__(self, event_size: int, hidden_size: int = None, **kwargs):
self.w = nn.Parameter(1 + delta_w)
self.A = nn.Parameter(torch.eye(r, event_size + 1) + delta_A)
self.b = nn.Parameter(0 + delta_b)
self.resnet = OTResNet(event_size + 1, hidden_size, **kwargs) # (x, t) has d+1 elements
self.resnet = OTResNet(event_size + 1, hidden_size, **resnet_kwargs) # (x, t) has d+1 elements

def forward(self, t, x):
return self.gradient(concatenate_x_t(x, t))
Expand Down Expand Up @@ -208,7 +209,13 @@ class OTFlow(ExactContinuousBijection):
Reference: Onken et al. "OT-Flow: Fast and Accurate Continuous Normalizing Flows via Optimal Transport" (2021); https://arxiv.org/abs/2006.00104.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], solver='dopri8', **kwargs):
def __init__(self,
event_shape: Union[torch.Size, Tuple[int, ...]],
ot_flow_kwargs: dict = None,
solver='dopri8',
**kwargs):
ot_flow_kwargs = ot_flow_kwargs or {}

n_dim = int(torch.prod(torch.as_tensor(event_shape)))
diff_eq = OTFlowODEFunction(n_dim, hidden_size=50)
diff_eq = OTFlowODEFunction(n_dim, **ot_flow_kwargs)
super().__init__(event_shape, diff_eq, solver=solver, **kwargs)
21 changes: 16 additions & 5 deletions torchflows/bijections/continuous/rnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ class RNODE(ApproximateContinuousBijection):
Reference: Finlay et al. "How to train your neural ODE: the world of Jacobian and kinetic regularization" (2020); https://arxiv.org/abs/2002.02798.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
diff_eq = RegularizedApproximateODEFunction(create_nn(event_shape, hidden_size=100, n_hidden_layers=1),
regularization="sq_jac_norm")
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], nn_kwargs: dict = None, **kwargs):
default_nn_kwargs = {'hidden_size': 100, 'n_hidden_layers': 1}
nn_kwargs = nn_kwargs or dict()
default_nn_kwargs.update(nn_kwargs)
diff_eq = RegularizedApproximateODEFunction(
create_nn(event_shape, **default_nn_kwargs),
regularization="sq_jac_norm"
)
super().__init__(event_shape, diff_eq, **kwargs)


Expand All @@ -30,8 +35,14 @@ class ConvolutionalRNODE(ApproximateContinuousBijection):
Reference: Finlay et al. "How to train your neural ODE: the world of Jacobian and kinetic regularization" (2020); https://arxiv.org/abs/2002.02798.
"""

def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], **kwargs):
def __init__(self, event_shape: Union[torch.Size, Tuple[int, ...]], nn_kwargs: dict = None, **kwargs):
default_nn_kwargs = {'n_layers': 2}
nn_kwargs = nn_kwargs or dict()
default_nn_kwargs.update(nn_kwargs)
if len(event_shape) != 3:
raise ValueError("Event shape must be of length 3 (channels, height, width).")
diff_eq = RegularizedApproximateODEFunction(create_cnn(event_shape[0]), regularization="sq_jac_norm")
diff_eq = RegularizedApproximateODEFunction(
create_cnn(event_shape[0], **default_nn_kwargs),
regularization="sq_jac_norm"
)
super().__init__(event_shape, diff_eq, **kwargs)
14 changes: 7 additions & 7 deletions torchflows/bijections/finite/residual/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,23 @@ def forward(self, x):
class SpectralNeuralNetwork(nn.Module):
def __init__(self,
event_shape: Union[Tuple[int, ...], torch.Size],
n_hidden: int = None,
hidden_size: int = None,
n_hidden_layers: int = 1,
**kwargs):
self.event_shape = event_shape
event_size = int(torch.prod(torch.as_tensor(event_shape)))
if n_hidden is None:
n_hidden = int(3 * max(math.log(event_size), 4))
if hidden_size is None:
hidden_size = int(3 * max(math.log(event_size), 4))

if n_hidden_layers == 0:
layers = [SpectralLinear(event_size, event_size, **kwargs)]
else:
layers = [SpectralLinear(event_size, n_hidden, **kwargs)]
for _ in range(n_hidden):
layers = [SpectralLinear(event_size, hidden_size, **kwargs)]
for _ in range(hidden_size):
layers.append(nn.Tanh())
layers.append(SpectralLinear(n_hidden, n_hidden, **kwargs))
layers.append(SpectralLinear(hidden_size, hidden_size, **kwargs))
layers.pop(-1)
layers.append(SpectralLinear(n_hidden, event_size, **kwargs))
layers.append(SpectralLinear(hidden_size, event_size, **kwargs))
super().__init__()
self.layers = nn.ModuleList(layers)

Expand Down

0 comments on commit fc5456b

Please sign in to comment.