From 843573f1c4de6d9f88bef0914c6bead376dd31d5 Mon Sep 17 00:00:00 2001 From: Olivier Verdier Date: Fri, 27 Sep 2024 14:25:33 +0200 Subject: [PATCH] TYP: remove string types --- diffeopt/cometric/laplace.py | 6 +++--- diffeopt/utils.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/diffeopt/cometric/laplace.py b/diffeopt/cometric/laplace.py index 5a48e5c..ee06256 100644 --- a/diffeopt/cometric/laplace.py +++ b/diffeopt/cometric/laplace.py @@ -1,7 +1,7 @@ -import numpy as np import torch +from ..group.base import BaseDiffeoGroup -def get_fourier_cometric(group: "BaseDiffeoGroup", s: int): +def get_fourier_cometric(group: BaseDiffeoGroup, s: int): shape = group.shape idx, idy = group.get_raw_identity() lap = 4. - 2.*(torch.cos(2.*torch.pi*idx/shape[0]) + torch.cos(2.*torch.pi*idy/shape[1])) @@ -11,7 +11,7 @@ def get_fourier_cometric(group: "BaseDiffeoGroup", s: int): lapinv[0,0] = 1. return lapinv -def get_laplace_cometric(group: "BaseDiffeoGroup", s:int=1): +def get_laplace_cometric(group: BaseDiffeoGroup, s:int=1): lapinv = get_fourier_cometric(group, s) def cometric(momentum: torch.Tensor) -> torch.Tensor: fx = torch.fft.fftn(momentum[0]) diff --git a/diffeopt/utils.py b/diffeopt/utils.py index bc50e0c..a82d4db 100644 --- a/diffeopt/utils.py +++ b/diffeopt/utils.py @@ -1,5 +1,6 @@ import numpy as np import torch +from .group.base import BaseDiffeoGroup def get_volume(shape): @@ -18,7 +19,7 @@ def normalize(I): from diffeopt.cometric import laplace -def get_random_diffeo(group: "BaseDiffeoGroup", nb_steps:int=10, scale:float=1.) -> torch.Tensor: +def get_random_diffeo(group: BaseDiffeoGroup, nb_steps:int=10, scale:float=1.) -> torch.Tensor: cometric = laplace.get_laplace_cometric(group, s=2) rm = torch.randn(*group.zero().shape) rv = cometric(rm)