diff --git a/diffeopt/cometric/laplace.py b/diffeopt/cometric/laplace.py index 6d04974..5a48e5c 100644 --- a/diffeopt/cometric/laplace.py +++ b/diffeopt/cometric/laplace.py @@ -1,7 +1,7 @@ import numpy as np import torch -def get_fourier_cometric(group, s): +def get_fourier_cometric(group: "BaseDiffeoGroup", s: int): shape = group.shape idx, idy = group.get_raw_identity() lap = 4. - 2.*(torch.cos(2.*torch.pi*idx/shape[0]) + torch.cos(2.*torch.pi*idy/shape[1])) @@ -11,9 +11,9 @@ def get_fourier_cometric(group, s): lapinv[0,0] = 1. return lapinv -def get_laplace_cometric(group, s=1): - def cometric(momentum): +def get_laplace_cometric(group: "BaseDiffeoGroup", s:int=1): lapinv = get_fourier_cometric(group, s) + def cometric(momentum: torch.Tensor) -> torch.Tensor: fx = torch.fft.fftn(momentum[0]) fy = torch.fft.fftn(momentum[1]) fx *= lapinv diff --git a/diffeopt/utils.py b/diffeopt/utils.py index 6b1f61e..bc50e0c 100644 --- a/diffeopt/utils.py +++ b/diffeopt/utils.py @@ -18,7 +18,7 @@ def normalize(I): from diffeopt.cometric import laplace -def get_random_diffeo(group, nb_steps=10, scale=1.): +def get_random_diffeo(group: "BaseDiffeoGroup", nb_steps:int=10, scale:float=1.) -> torch.Tensor: cometric = laplace.get_laplace_cometric(group, s=2) rm = torch.randn(*group.zero().shape) rv = cometric(rm)