Skip to content

Commit

Permalink
Base STELARR model
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 20, 2023
1 parent bb54f1d commit a328f01
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/autoseg/models/stelarr_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from funlib.learn.torch.models import UNet, ConvPass
from ..utils import neighborhood


class MultitaskModel(torch.nn.Module):
def __init__(self, unet: UNet, num_fmaps: int) -> None:
super(MultitaskModel, self).__init__()

self.unet: UNet = unet
self.lsd_head: ConvPass = ConvPass(num_fmaps, 10, [[1, 1, 1]], activation="Sigmoid")
self.aff_head: ConvPass = ConvPass(num_fmaps, len(neighborhood), [[1, 1, 1]], activation="Sigmoid")
self.enhancement_head: ConvPass = ConvPass(num_fmaps, 1, [[1, 1, 1]], activation="Sigmoid")

def forward(self, input):
x = self.unet(input)
lsds = self.lsd_head(x[0])
affs = self.aff_head(x[1])
fake = self.enhancement_head(x[2])

return lsds, affs, fake

0 comments on commit a328f01

Please sign in to comment.