diff --git a/src/cryo_sbi/inference/models/embedding_nets.py b/src/cryo_sbi/inference/models/embedding_nets.py index daa99ed..e502ff2 100644 --- a/src/cryo_sbi/inference/models/embedding_nets.py +++ b/src/cryo_sbi/inference/models/embedding_nets.py @@ -77,22 +77,27 @@ def forward(self, x): @add_embedding("RESNET18_FFT") -class ResNet18_Encoder(nn.Module): +class ResNet18_FFT(nn.Module): def __init__(self, output_dimension: int): - super(ResNet18_Encoder, self).__init__() - print("Using FFT ResNet18") - self.resnet = models.resnet18() - self.resnet.conv1 = nn.Conv2d( - 2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False - ) - self.resnet.fc = nn.Linear( - in_features=512, out_features=output_dimension, bias=True + super(ResNet18_FFT, self).__init__() + print("Using FFT1") + + self.drgn_encoder = nn.Sequential( + nn.Linear(12892, 512), + nn.GELU(), + nn.Linear(512, 256), + nn.GELU(), + nn.Linear(256, output_dimension), + nn.GELU(), ) + self.mask = Mask(128, 64, inside=True).mask.flatten() def forward(self, x): - x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1))) - x = torch.stack([x.real, x.imag], dim=1) - x = self.resnet(x) + if x.dim == 2: + x = x.unsqueeze(0) + x = torch.fft.fftshift(torch.fft.fft2(x, dim=(-2, -1))).real + x = x.flatten(start_dim=1)[:, self.mask] + x = self.drgn_encoder(x) return x