Skip to content

Commit

Permalink
added topaz embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jun 24, 2024
1 parent 83d90b9 commit 1dbc494
Show file tree
Hide file tree
Showing 2 changed files with 424 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torchvision.transforms as transforms

from cryo_sbi.utils.image_utils import LowPassFilter, Mask
import cryo_sbi.inference.models.topaz_embeddings as topaz_embeddings


EMBEDDING_NETS = {}
Expand All @@ -27,6 +28,36 @@ def add(class_):
return add


@add_embedding("TOPAZ_RESNET8")
class TopazResNet8_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(TopazResNet8_Encoder, self).__init__()
self.topaz_resnet8 = topaz_embeddings.ResNet8(
units=[32, 64, output_dimension],
activation=nn.SiLU,
)

def forward(self, x):
x = x.unsqueeze(1)
x = self.topaz_resnet8(x)
return x


@add_embedding("TOPAZ_RESNET16")
class TopazResNet16_Encoder(nn.Module):
def __init__(self, output_dimension: int):
super(TopazResNet16_Encoder, self).__init__()
self.topaz_resnet16 = topaz_embeddings.ResNet16(
units=[32, 64, output_dimension],
activation=nn.SiLU,
)

def forward(self, x):
x = x.unsqueeze(1)
x = self.topaz_resnet16(x)
return x


@add_embedding("RESNET18")
class ResNet18_Encoder(nn.Module):
def __init__(self, output_dimension: int):
Expand Down
Loading

0 comments on commit 1dbc494

Please sign in to comment.