Skip to content

Commit

Permalink
TYP: some type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierverdier committed Sep 27, 2024
1 parent 8ed55ce commit 3805203
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions diffeopt/cometric/laplace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch

def get_fourier_cometric(group, s):
def get_fourier_cometric(group: "BaseDiffeoGroup", s: int):
shape = group.shape
idx, idy = group.get_raw_identity()
lap = 4. - 2.*(torch.cos(2.*torch.pi*idx/shape[0]) + torch.cos(2.*torch.pi*idy/shape[1]))
Expand All @@ -11,9 +11,9 @@ def get_fourier_cometric(group, s):
lapinv[0,0] = 1.
return lapinv

def get_laplace_cometric(group, s=1):
def cometric(momentum):
def get_laplace_cometric(group: "BaseDiffeoGroup", s:int=1):
lapinv = get_fourier_cometric(group, s)
def cometric(momentum: torch.Tensor) -> torch.Tensor:
fx = torch.fft.fftn(momentum[0])
fy = torch.fft.fftn(momentum[1])
fx *= lapinv
Expand Down
2 changes: 1 addition & 1 deletion diffeopt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def normalize(I):

from diffeopt.cometric import laplace

def get_random_diffeo(group, nb_steps=10, scale=1.):
def get_random_diffeo(group: "BaseDiffeoGroup", nb_steps:int=10, scale:float=1.) -> torch.Tensor:
cometric = laplace.get_laplace_cometric(group, s=2)
rm = torch.randn(*group.zero().shape)
rv = cometric(rm)
Expand Down

0 comments on commit 3805203

Please sign in to comment.