Skip to content

Commit

Permalink
feat: ✨ Features from rhoadesj/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Feb 8, 2024
1 parent 5f50f9b commit 6ff9682
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
22 changes: 22 additions & 0 deletions dacapo/experiments/trainers/gp_augments/gaussian_noise_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .augment_config import AugmentConfig

import gunpowder as gp

import attr


@attr.s
class GaussianNoiseAugmentConfig(AugmentConfig):
mean: float = attr.ib(
metadata={"help_text": "The mean of the gaussian noise to apply to your data."},
default=0.0,
)
var: float = attr.ib(
metadata={"help_text": "The variance of the gaussian noise."},
default=0.05,
)

def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None):
return gp.NoiseAugment(
array=raw_key, mode="gaussian", mean=self.mean, var=self.var
)
48 changes: 46 additions & 2 deletions dacapo/experiments/trainers/gp_augments/simple_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List, Optional
from .augment_config import AugmentConfig

import gunpowder as gp
Expand All @@ -7,5 +8,48 @@

@attr.s
class SimpleAugmentConfig(AugmentConfig):
def node(self, _raw_key=None, _gt_key=None, _mask_key=None):
return gp.SimpleAugment()
mirror_only: Optional[List[int]] = attr.ib(
default=None,
metadata={
"help_text": (
"If set, only mirror between the given axes. This is useful to exclude channels that have a set direction, like time."
)
},
)
transpose_only: Optional[List[int]] = attr.ib(
default=None,
metadata={
"help_text": (
"If set, only transpose between the given axes. This is useful to exclude channels that have a set direction, like time."
)
},
)
mirror_probs: Optional[List[float]] = attr.ib(
default=None,
metadata={
"help_text": (
"Probability of mirroring along each axis. Defaults to 0.5 for each axis."
)
},
)
transpose_probs: Optional[List[float]] = attr.ib(
default=None,
metadata={
"help_text": (
"Probability of transposing along each axis. Defaults to 0.5 for each axis."
)
},
)

def node(
self,
_raw_key=None,
_gt_key=None,
_mask_key=None,
):
return gp.SimpleAugment(
self.mirror_only,
self.transpose_only,
self.mirror_probs,
self.transpose_probs,
)
2 changes: 1 addition & 1 deletion dacapo/gp/elastic_augment_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _min_max_mean_std(ndarray, prefix=""):
return ""


class ElasticAugment(BatchFilter):
class ElasticAugment(BatchFilter): # TODO: replace DeformAugment node from gunpowder
"""
Elasticly deform a batch. Requests larger batches upstream to avoid data
loss due to rotation and jitter.
Expand Down

0 comments on commit 6ff9682

Please sign in to comment.