Skip to content

Commit

Permalink
TYP: remove string types
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierverdier committed Sep 27, 2024
1 parent 3805203 commit 843573f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 3 additions & 3 deletions diffeopt/cometric/laplace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import torch
from ..group.base import BaseDiffeoGroup

def get_fourier_cometric(group: "BaseDiffeoGroup", s: int):
def get_fourier_cometric(group: BaseDiffeoGroup, s: int):
shape = group.shape
idx, idy = group.get_raw_identity()
lap = 4. - 2.*(torch.cos(2.*torch.pi*idx/shape[0]) + torch.cos(2.*torch.pi*idy/shape[1]))
Expand All @@ -11,7 +11,7 @@ def get_fourier_cometric(group: "BaseDiffeoGroup", s: int):
lapinv[0,0] = 1.
return lapinv

def get_laplace_cometric(group: "BaseDiffeoGroup", s:int=1):
def get_laplace_cometric(group: BaseDiffeoGroup, s:int=1):
lapinv = get_fourier_cometric(group, s)
def cometric(momentum: torch.Tensor) -> torch.Tensor:
fx = torch.fft.fftn(momentum[0])
Expand Down
3 changes: 2 additions & 1 deletion diffeopt/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
from .group.base import BaseDiffeoGroup


def get_volume(shape):
Expand All @@ -18,7 +19,7 @@ def normalize(I):

from diffeopt.cometric import laplace

def get_random_diffeo(group: "BaseDiffeoGroup", nb_steps:int=10, scale:float=1.) -> torch.Tensor:
def get_random_diffeo(group: BaseDiffeoGroup, nb_steps:int=10, scale:float=1.) -> torch.Tensor:
cometric = laplace.get_laplace_cometric(group, s=2)
rm = torch.randn(*group.zero().shape)
rv = cometric(rm)
Expand Down

0 comments on commit 843573f

Please sign in to comment.