Skip to content

Commit

Permalink
feat: 🚧 Incorporate LeibNetz.
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Feb 5, 2024
1 parent 75b98a5 commit 3ea65eb
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 5 deletions.
31 changes: 31 additions & 0 deletions dacapo/experiments/architectures/leibnetz_architecture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# TODO
from .architecture import Architecture

from funlib.geometry import Coordinate

import torch


class DummyArchitecture(Architecture):
def __init__(self, architecture_config):
super().__init__()

self.channels_in = architecture_config.num_in_channels
self.channels_out = architecture_config.num_out_channels

self.conv = torch.nn.Conv3d(self.channels_in, self.channels_out, kernel_size=3)

@property
def input_shape(self):
return Coordinate(40, 20, 20)

@property
def num_in_channels(self):
return self.channels_in

@property
def num_out_channels(self):
return self.channels_out

def forward(self, x):
return self.conv(x)
22 changes: 22 additions & 0 deletions dacapo/experiments/architectures/leibnetz_architecture_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# TODO
import attr

from .dummy_architecture import DummyArchitecture
from .architecture_config import ArchitectureConfig

from typing import Tuple


@attr.s
class DummyArchitectureConfig(ArchitectureConfig):
"""This is just a dummy architecture config used for testing. None of the
attributes have any particular meaning."""

architecture_type = DummyArchitecture

num_in_channels: int = attr.ib(metadata={"help_text": "Dummy attribute."})

num_out_channels: int = attr.ib(metadata={"help_text": "Dummy attribute."})

def verify(self) -> Tuple[bool, str]:
return False, "This is a DummyArchitectureConfig and is never valid"
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ class ResampledArrayConfig(ArrayConfig):
downsample: Coordinate = attr.ib(
metadata={"help_text": "The amount by which to downsample!"}
)
interp_order: bool = attr.ib(
interp_order: int = attr.ib(
metadata={"help_text": "The order of the interpolation!"}
)
4 changes: 2 additions & 2 deletions dacapo/experiments/trainers/leibnetz_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, trainer_config):
self.learning_rate = trainer_config.learning_rate
self.batch_size = trainer_config.batch_size
self.num_data_fetchers = trainer_config.num_data_fetchers
self.print_profiling = 100
self.print_profiling = 100 # TODO: remove all hard coded values
self.snapshot_iteration = trainer_config.snapshot_interval
self.min_masked = trainer_config.min_masked
self.reject_probability = trainer_config.reject_probability
Expand All @@ -47,7 +47,7 @@ def __init__(self, trainer_config):

self.scheduler = None

def create_optimizer(self, model):
def create_optimizer(self, model): # TODO: add optimizer to config
optimizer = torch.optim.RAdam(lr=self.learning_rate, params=model.parameters())
self.scheduler = (
torch.optim.lr_scheduler.LinearLR( # TODO: add scheduler to config
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
"funlib.math>=0.1",
"funlib.geometry>=0.2",
"mwatershed>=0.1",
"funlib.persistence>=0.1",
"funlib.persistence @ git+https://github.com/janelia-cellmap/funlib.persistence.git@ome-zarr-reading",
"funlib.evaluate @ git+https://github.com/pattonw/funlib.evaluate",
"gunpowder>=1.3",
"lsds>=0.1.3",
# "lsds>=0.1.3",
"reloading",
],
)

0 comments on commit 3ea65eb

Please sign in to comment.