diff --git a/src/autoseg/models/ACLSDModel.py b/src/autoseg/models/ACLSDModel.py new file mode 100644 index 0000000..ec814ab --- /dev/null +++ b/src/autoseg/models/ACLSDModel.py @@ -0,0 +1,17 @@ +import torch + +from ..networks.UNet import UNet, ConvPass +from ..utils import neighborhood + + +class ACLSDModel(torch.nn.Module): + def __init__(self, unet: UNet, num_fmaps: int): + super(ACLSDModel, self).__init__() + + self.unet: UNet = unet + self.aff_head: ConvPass = ConvPass(num_fmaps, len(neighborhood), [[1, 1, 1]], activation="Sigmoid") + + def forward(self, input): + x = self.unet(input) + affs = self.aff_head(x) + return affs \ No newline at end of file diff --git a/src/autoseg/models/MTLSDModel.py b/src/autoseg/models/MTLSDModel.py new file mode 100644 index 0000000..749011e --- /dev/null +++ b/src/autoseg/models/MTLSDModel.py @@ -0,0 +1,19 @@ +import torch + +from ..networks.UNet import UNet, ConvPass +from ..utils import neighborhood + + +class MTLSDModel(torch.nn.Module): + def __init__(self, unet: UNet, num_fmaps: int): + super(MTLSDModel, self).__init__() + + self.unet: UNet = unet + self.lsd_head: ConvPass = ConvPass(num_fmaps, 10, [[1, 1, 1]], activation="Sigmoid") + self.aff_head: ConvPass = ConvPass(num_fmaps, len(neighborhood), [[1, 1, 1]], activation="Sigmoid") + + def forward(self, input): + x = self.unet(input) + lsds = self.lsd_head(x[0]) + affs = self.aff_head(x[1]) + return lsds, affs diff --git a/src/autoseg/models/STELARRModel.py b/src/autoseg/models/STELARRModel.py index 831fd3c..58726f1 100644 --- a/src/autoseg/models/STELARRModel.py +++ b/src/autoseg/models/STELARRModel.py @@ -1,5 +1,6 @@ import torch -from funlib.learn.torch.models import UNet, ConvPass + +from ..networks.UNet import UNet, ConvPass from ..utils import neighborhood