Skip to content

Commit

Permalink
Fix OTFlow forward output sign
Browse files Browse the repository at this point in the history
  • Loading branch information
davidnabergoj committed Dec 27, 2023
1 parent 8ee1e7c commit e82e79a
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions normalizing_flows/bijections/continuous/otflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,21 @@ def __init__(self, event_size: int, hidden_size: int, step_size: float = 0.01):

divisor = max(event_size ** 2, 10)

K0_delta = torch.randn(size=(hidden_size, event_size)) / divisor
b0_delta = torch.randn(size=(hidden_size,)) / divisor
self.K0_delta = nn.Parameter(torch.randn(size=(hidden_size, event_size)) / divisor)
self.b0 = nn.Parameter(torch.randn(size=(hidden_size,)) / divisor)

K1_delta = torch.randn(size=(hidden_size, hidden_size)) / divisor
b1_delta = torch.randn(size=(hidden_size,)) / divisor
self.K1_delta = nn.Parameter(torch.randn(size=(hidden_size, hidden_size)) / divisor)
self.b1 = nn.Parameter(torch.randn(size=(hidden_size,)) / divisor)

self.K0 = nn.Parameter(torch.eye(hidden_size, event_size) + K0_delta)
self.b0 = nn.Parameter(0 + b0_delta)
self.step_size = step_size

self.K1 = nn.Parameter(torch.eye(hidden_size, hidden_size) + K1_delta)
self.b1 = nn.Parameter(0 + b1_delta)
@property
def K0(self):
return torch.eye(*self.K0_delta.shape) + self.K0_delta / 1000

self.step_size = step_size
@property
def K1(self):
return torch.eye(*self.K1_delta.shape) + self.K1_delta / 1000

@staticmethod
def sigma(x):
Expand Down Expand Up @@ -115,7 +117,7 @@ def hessian_trace(self,

t0 = torch.sum(
torch.multiply(
(self.sigma_prime_prime(torch.nn.functional.linear(s, self.K0, self.b0)) * z1),
self.sigma_prime_prime(torch.nn.functional.linear(s, self.K0, self.b0)) * z1,
torch.nn.functional.linear(ones, self.K0[:, :-1] ** 2)
),
dim=1
Expand Down Expand Up @@ -164,7 +166,7 @@ def __init__(self, event_size: int, hidden_size: int = None, **kwargs):
self.resnet = OTResNet(event_size + 1, hidden_size, **kwargs) # (x, t) has d+1 elements

def forward(self, t, x):
return -self.gradient(concatenate_x_t(x, t))
return self.gradient(concatenate_x_t(x, t))

def gradient(self, s):
# Equation 12
Expand Down

0 comments on commit e82e79a

Please sign in to comment.