Skip to content

Commit

Permalink
BUG: avoid unnecessary conversions numpy<->torch
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierverdier committed Sep 27, 2024
1 parent 4c0f00d commit 8ed55ce
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
14 changes: 7 additions & 7 deletions diffeopt/cometric/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
def get_fourier_cometric(group, s):
shape = group.shape
idx, idy = group.get_raw_identity()
lap = 4. - 2.*(np.cos(2.*np.pi*idx/shape[0]) + np.cos(2.*np.pi*idy/shape[1]))
lap = 4. - 2.*(torch.cos(2.*torch.pi*idx/shape[0]) + torch.cos(2.*torch.pi*idy/shape[1]))
lap[0,0] = 1.
lapinv = (1./lap)**s
lap[0,0] = 0.
lapinv[0,0] = 1.
return lapinv

def get_laplace_cometric(group, s=1):
lapinv = get_fourier_cometric(group, s).numpy()
def cometric(momentum):
fx = np.fft.fftn(momentum[0])
fy = np.fft.fftn(momentum[1])
lapinv = get_fourier_cometric(group, s)
fx = torch.fft.fftn(momentum[0])
fy = torch.fft.fftn(momentum[1])
fx *= lapinv
fy *= lapinv
vx = np.real(np.fft.ifftn(fx))
vy = np.real(np.fft.ifftn(fy))
return torch.from_numpy(np.array([vx,vy]))
vx = torch.real(torch.fft.ifftn(fx))
vy = torch.real(torch.fft.ifftn(fy))
return torch.stack([vx,vy])
return cometric
2 changes: 1 addition & 1 deletion diffeopt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def normalize(I):

def get_random_diffeo(group, nb_steps=10, scale=1.):
cometric = laplace.get_laplace_cometric(group, s=2)
rm = np.random.randn(*group.zero().shape)
rm = torch.randn(*group.zero().shape)
rv = cometric(rm)
vmx = rv.abs().max()
shape_scale = (group.shape[0] + group.shape[1])/2
Expand Down

0 comments on commit 8ed55ce

Please sign in to comment.