diff --git a/dacapo/experiments/trainers/gp_augments/elastic_config.py b/dacapo/experiments/trainers/gp_augments/elastic_config.py index 40f40e800..67b84383c 100644 --- a/dacapo/experiments/trainers/gp_augments/elastic_config.py +++ b/dacapo/experiments/trainers/gp_augments/elastic_config.py @@ -62,6 +62,12 @@ class ElasticAugmentConfig(AugmentConfig): "3D rotations." }, ) + augmentation_probability: float = attr.ib( + default=1., + metadata={ + "help_text": "Probability of applying the augmentations." + }, + ) def node(self, _raw_key=None, _gt_key=None, _mask_key=None): """ @@ -87,4 +93,5 @@ def node(self, _raw_key=None, _gt_key=None, _mask_key=None): rotation_interval=self.rotation_interval, subsample=self.subsample, uniform_3d_rotation=self.uniform_3d_rotation, + augmentation_probability=self.augmentation_probability, ) diff --git a/dacapo/experiments/trainers/gp_augments/intensity_config.py b/dacapo/experiments/trainers/gp_augments/intensity_config.py index fef1b26df..b6865eaa6 100644 --- a/dacapo/experiments/trainers/gp_augments/intensity_config.py +++ b/dacapo/experiments/trainers/gp_augments/intensity_config.py @@ -35,6 +35,12 @@ class IntensityAugmentConfig(AugmentConfig): "help_text": "Set to False if modified values should not be clipped to [0, 1]" }, ) + augmentation_probability: float = attr.ib( + default=1., + metadata={ + "help_text": "Probability of applying the augmentation." + }, + ) def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None): """ @@ -58,4 +64,5 @@ def node(self, raw_key: gp.ArrayKey, _gt_key=None, _mask_key=None): shift_min=self.shift[0], shift_max=self.shift[1], clip=self.clip, + p=self.augmentation_probability, ) diff --git a/dacapo/experiments/trainers/gp_augments/simple_config.py b/dacapo/experiments/trainers/gp_augments/simple_config.py index 77c8e6e5a..9e8bd4160 100644 --- a/dacapo/experiments/trainers/gp_augments/simple_config.py +++ b/dacapo/experiments/trainers/gp_augments/simple_config.py @@ -20,6 +20,13 @@ class SimpleAugmentConfig(AugmentConfig): This class is a subclass of AugmentConfig. """ + augmentation_probability: float = attr.ib( + default=1., + metadata={ + "help_text": "Probability of applying the augmentations." + }, + ) + def node(self, _raw_key=None, _gt_key=None, _mask_key=None): """ Get a gp.SimpleAugment node. @@ -36,4 +43,4 @@ def node(self, _raw_key=None, _gt_key=None, _mask_key=None): >>> node = simple_augment_config.node() """ - return gp.SimpleAugment() + return gp.SimpleAugment(p=self.augmentation_probability)