Skip to content

Commit

Permalink
Add typing, custom UNet
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 20, 2023
1 parent 0f775c3 commit e0e224a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 1 deletion.
17 changes: 17 additions & 0 deletions src/autoseg/models/ACLSDModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch

from ..networks.UNet import UNet, ConvPass
from ..utils import neighborhood


class ACLSDModel(torch.nn.Module):
def __init__(self, unet: UNet, num_fmaps: int):
super(ACLSDModel, self).__init__()

self.unet: UNet = unet
self.aff_head: ConvPass = ConvPass(num_fmaps, len(neighborhood), [[1, 1, 1]], activation="Sigmoid")

def forward(self, input):
x = self.unet(input)
affs = self.aff_head(x)
return affs
19 changes: 19 additions & 0 deletions src/autoseg/models/MTLSDModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import torch

from ..networks.UNet import UNet, ConvPass
from ..utils import neighborhood


class MTLSDModel(torch.nn.Module):
def __init__(self, unet: UNet, num_fmaps: int):
super(MTLSDModel, self).__init__()

self.unet: UNet = unet
self.lsd_head: ConvPass = ConvPass(num_fmaps, 10, [[1, 1, 1]], activation="Sigmoid")
self.aff_head: ConvPass = ConvPass(num_fmaps, len(neighborhood), [[1, 1, 1]], activation="Sigmoid")

def forward(self, input):
x = self.unet(input)
lsds = self.lsd_head(x[0])
affs = self.aff_head(x[1])
return lsds, affs
3 changes: 2 additions & 1 deletion src/autoseg/models/STELARRModel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from funlib.learn.torch.models import UNet, ConvPass

from ..networks.UNet import UNet, ConvPass
from ..utils import neighborhood


Expand Down

0 comments on commit e0e224a

Please sign in to comment.