diff --git a/src/autoseg/gp_filters/random_noise.py b/src/autoseg/gp_filters/random_noise.py index ceda405..fb2ae57 100644 --- a/src/autoseg/gp_filters/random_noise.py +++ b/src/autoseg/gp_filters/random_noise.py @@ -3,6 +3,7 @@ import random from skimage.util import random_noise + class RandomNoiseAugment(gp.BatchFilter): """ Random Noise Augmentation for Gunpowder. @@ -10,23 +11,23 @@ class RandomNoiseAugment(gp.BatchFilter): This class applies random noise augmentation to the specified array in a Gunpowder batch. Args: - array (str): + array (str): The name of the array in the batch to which noise should be applied. - seed (int, optional): + seed (int, optional): Seed for the random number generator. Default is None. - clip (bool, optional): + clip (bool, optional): Whether to clip the values after applying noise. Default is True. - **kwargs: + **kwargs: Additional keyword arguments to be passed to the `random_noise` function. Attributes: - array (str): + array (str): The name of the array in the batch to which noise is applied. - seed (int): + seed (int): Seed for the random number generator. - clip (bool): + clip (bool): Whether to clip the values after applying noise. - kwargs (dict): + kwargs (dict): Additional keyword arguments passed to the `random_noise` function. """ @@ -48,11 +49,11 @@ def prepare(self, request): Prepare the dependencies for processing based on the requested batch. Args: - request (BatchRequest): + request (BatchRequest): The requested batch. Returns: - BatchRequest: + BatchRequest: The dependencies for processing. """ deps = gp.BatchRequest() @@ -64,9 +65,9 @@ def process(self, batch, request): Apply random noise augmentation to the specified array in the batch. Args: - batch (Batch): + batch (Batch): The input batch. - request (BatchRequest): + request (BatchRequest): The requested batch. """ raw = batch.arrays[self.array] diff --git a/src/autoseg/gp_filters/smooth_array.py b/src/autoseg/gp_filters/smooth_array.py index e7990ff..a659e77 100644 --- a/src/autoseg/gp_filters/smooth_array.py +++ b/src/autoseg/gp_filters/smooth_array.py @@ -3,6 +3,7 @@ import random from scipy.ndimage import gaussian_filter + class SmoothArray(gp.BatchFilter): """ Smooth Array in a Gunpowder Batch. @@ -10,15 +11,15 @@ class SmoothArray(gp.BatchFilter): This class applies Gaussian smoothing to a 3D array in a Gunpowder batch. Args: - array (str): + array (str): The name of the array in the batch to be smoothed. - blur_range (tuple): + blur_range (tuple): The range of sigma values for the Gaussian filter. Attributes: - array (str): + array (str): The name of the array in the batch to be smoothed. - range (tuple): + range (tuple): The range of sigma values for the Gaussian filter. """ @@ -31,9 +32,9 @@ def process(self, batch, request): Apply Gaussian smoothing to the specified array in the batch. Args: - batch (Batch): + batch (Batch): The input batch. - request (BatchRequest): + request (BatchRequest): The requested batch. """ array = batch[self.array].data diff --git a/src/autoseg/losses/ACLSDLoss.py b/src/autoseg/losses/ACLSDLoss.py index fee77df..14fc63c 100644 --- a/src/autoseg/losses/ACLSDLoss.py +++ b/src/autoseg/losses/ACLSDLoss.py @@ -9,7 +9,7 @@ class WeightedACLSD_MSELoss(torch.nn.MSELoss): weighting term for Auto-Context LSD (ACLSD) segmentation. Parameters: - aff_lambda (float, optional): + aff_lambda (float, optional): Weighting factor for the affinity loss. Default is 1.0. """ @@ -23,15 +23,15 @@ def _calc_loss(self, prediction, target, weights): Calculates the weighted mean squared error loss. Args: - prediction (torch.Tensor): + prediction (torch.Tensor): Predicted affinities. - target (torch.Tensor): + target (torch.Tensor): Ground truth affinities. - weights (torch.Tensor): + weights (torch.Tensor): Weighting factor for each affinity. Returns: - torch.Tensor: + torch.Tensor: Weighted mean squared error loss. """ scaled = weights * (prediction - target) ** 2 @@ -54,15 +54,15 @@ def forward( Calculates the weighted ACLSD MSE loss. Args: - pred_affs (torch.Tensor): + pred_affs (torch.Tensor): Predicted affinities. - gt_affs (torch.Tensor): + gt_affs (torch.Tensor): Ground truth affinities. - affs_weights (torch.Tensor): + affs_weights (torch.Tensor): Weighting factor for each affinity. Returns: - torch.Tensor: + torch.Tensor: Weighted ACLSD MSE loss. """ aff_loss = self.aff_lambda * self._calc_loss(pred_affs, gt_affs, affs_weights) diff --git a/src/autoseg/losses/GMSELoss.py b/src/autoseg/losses/GMSELoss.py index 55f66d7..df2d9bc 100644 --- a/src/autoseg/losses/GMSELoss.py +++ b/src/autoseg/losses/GMSELoss.py @@ -9,11 +9,11 @@ class Weighted_GMSELoss(torch.nn.Module): GAN (Generative Adversarial Network) loss term for enhanced data. Parameters: - aff_lambda (float, optional): + aff_lambda (float, optional): Weighting factor for the affinity loss. Default is 1.0. - gan_lambda (float, optional): + gan_lambda (float, optional): Weighting factor for the GAN loss. Default is 1.0. - discrim (torch.nn.Module, optional): + discrim (torch.nn.Module, optional): Discriminator network for GAN loss. """ @@ -22,11 +22,11 @@ def __init__(self, aff_lambda=1.0, gan_lambda=1.0, discrim=None): Initializes the Weighted_MSELoss. Args: - aff_lambda (float, optional): + aff_lambda (float, optional): Weighting factor for the affinity loss. Default is 1.0. - gan_lambda (float, optional): + gan_lambda (float, optional): Weighting factor for the GAN loss. Default is 1.0. - discrim (torch.nn.Module, optional): + discrim (torch.nn.Module, optional): Discriminator network for GAN loss. """ super(Weighted_GMSELoss, self).__init__() @@ -40,15 +40,15 @@ def _calc_loss(self, prediction, target, weights=None): Calculates the weighted mean squared error loss. Args: - prediction (torch.Tensor): + prediction (torch.Tensor): Predicted values. - target (torch.Tensor): + target (torch.Tensor): Ground truth values. - weights (torch.Tensor, optional): + weights (torch.Tensor, optional): Weighting factor for each value. Returns: - torch.Tensor: + torch.Tensor: Weighted mean squared error loss. """ if type(weights) != torch.Tensor: @@ -79,25 +79,25 @@ def forward( Calculates the weighted MSE loss with GAN loss. Args: - pred_lsds (torch.Tensor): + pred_lsds (torch.Tensor): Predicted LSD values. - gt_lsds (torch.Tensor): + gt_lsds (torch.Tensor): Ground truth LSD values. - lsds_weights (torch.Tensor, optional): + lsds_weights (torch.Tensor, optional): Weighting factor for each LSD value. - pred_affs (torch.Tensor): + pred_affs (torch.Tensor): Predicted affinity values. - gt_affs (torch.Tensor): + gt_affs (torch.Tensor): Ground truth affinity values. - affs_weights (torch.Tensor, optional): + affs_weights (torch.Tensor, optional): Weighting factor for each affinity value. - pred_enhanced (torch.Tensor): + pred_enhanced (torch.Tensor): Predicted enhanced data. - gt_enhanced (torch.Tensor): + gt_enhanced (torch.Tensor): Ground truth enhanced data. Returns: - torch.Tensor: + torch.Tensor: Combined weighted MSE loss with GAN loss. """ # calculate MSE loss for LSD and Affs diff --git a/src/autoseg/postprocess/segment_mws.py b/src/autoseg/postprocess/segment_mws.py index acb16b0..a5f6b9d 100644 --- a/src/autoseg/postprocess/segment_mws.py +++ b/src/autoseg/postprocess/segment_mws.py @@ -19,21 +19,21 @@ def get_validation_segmentation( Get validation segmentation using the specified segmentation style. Parameters: - segmentation_style (str): + segmentation_style (str): Style of segmentation ("mws" or "mergetree"). - iteration (str): + iteration (str): Iteration or checkpoint to use (default: "latest"). - raw_file (str): + raw_file (str): Path to the input Zarr dataset containing raw data. - raw_dataset (str): + raw_dataset (str): Name of the raw dataset in the input Zarr file. - out_file (str): + out_file (str): Path to the output Zarr file for storing predictions. - pred_affs (bool): + pred_affs (bool): Flag to indicate whether to predict affinities. Returns: - bool: + bool: True if segmentation is successful, False otherwise. """ out_datasets = [ diff --git a/src/autoseg/postprocess/segment_skel_correct.py b/src/autoseg/postprocess/segment_skel_correct.py index 710cca3..acb0b70 100644 --- a/src/autoseg/postprocess/segment_skel_correct.py +++ b/src/autoseg/postprocess/segment_skel_correct.py @@ -24,25 +24,25 @@ def get_skel_correct_segmentation( Generate segmentation with skeleton-based correction using RUSTY_MWS. Parameters: - predict_affs (bool): + predict_affs (bool): Flag to indicate whether to predict affinities. - raw_file (str): + raw_file (str): Path to the input Zarr dataset containing raw data. - raw_dataset (str): + raw_dataset (str): Name of the raw dataset in the input Zarr file. - out_file (str): + out_file (str): Path to the output Zarr file for storing predictions. - out_datasets (list): + out_datasets (list): List of tuples specifying output dataset names and channel counts. - iteration (str): + iteration (str): Iteration or checkpoint to use (default: "latest"). - model_path (str): + model_path (str): Path to the directory containing the trained model checkpoints. - voxel_size (int): + voxel_size (int): Voxel size in all three dimensions. Returns: - None: + None: No return value. Segmentation with skeleton-based correction is stored in the specified Zarr file. """ if predict_affs: diff --git a/src/autoseg/predict/network_predictions.py b/src/autoseg/predict/network_predictions.py index ce652f3..1119f63 100644 --- a/src/autoseg/predict/network_predictions.py +++ b/src/autoseg/predict/network_predictions.py @@ -13,7 +13,7 @@ def predict_task( - model:torch.nn.Module, + model: torch.nn.Module, model_type: str, iteration: int, raw_file: str, @@ -33,31 +33,31 @@ def predict_task( Predict affinities using a trained deep learning model. Parameters: - model: + model: Trained deep learning model. - model_type (str): + model_type (str): Type of the model ("MTLSD", "ACLSD", "STELARR"). - iteration (int): + iteration (int): Iteration or checkpoint of the trained model to use. - raw_file (str): + raw_file (str): Path to the input Zarr or N5 dataset or TIFF file containing raw data. - raw_dataset (str): + raw_dataset (str): Name of the raw dataset in the input Zarr or N5 file. - out_file (str): + out_file (str): Path to the output Zarr file for storing predictions. - out_datasets (list): + out_datasets (list): List of tuples specifying output dataset names and channel counts. - num_workers (int): + num_workers (int): Number of parallel workers for blockwise processing. - n_gpu (int): + n_gpu (int): Number of GPUs available for prediction. - model_path (str): + model_path (str): Path to the directory containing the trained model checkpoints. - voxel_size (int): + voxel_size (int): Voxel size in all three dimensions. Returns: - None: + None: No return value. Predictions are stored in the specified Zarr file. """ if type(iteration) == str and "latest" in iteration: diff --git a/src/autoseg/train_job.py b/src/autoseg/train_job.py index 2c71f63..2180caf 100644 --- a/src/autoseg/train_job.py +++ b/src/autoseg/train_job.py @@ -24,37 +24,37 @@ def train_model( Train a deep learning model for segmentation. Parameters: - model_type (str): + model_type (str): Type of the model to train ("MTLSD", "ACLSD", or "STELARR"). - iterations (int): + iterations (int): Number of training iterations. - warmup (int): + warmup (int): Number of warm-up iterations for ACLSD and STELARR models. - raw_file (str): + raw_file (str): Path to the input Zarr or N5 dataset or TIFF file containing raw data. - rewrite_file (str): + rewrite_file (str): Path to the output Zarr file for TIFF conversion. - rewrite_ds (str): + rewrite_ds (str): Name of the Zarr dataset to store the converted TIFF data. - out_file (str): + out_file (str): Path to the output Zarr file for storing predictions. - get_labels (bool): + get_labels (bool): If True, fetch and convert painting labels to Zarr format. - get_rasters (bool): + get_rasters (bool): If True, fetch and convert skeletons to Zarr format. - generate_masks (bool): + generate_masks (bool): If True, generate masks based on label information. - voxel_size (int): + voxel_size (int): Voxel size in all three dimensions. - save_every (int): + save_every (int): Interval for saving intermediate models during training. - annotation_id (str): + annotation_id (str): WebKnossos annotation ID for fetching labels and skeletons. - wk_token (str): + wk_token (str): WebKnossos API token for authentication. Returns: - None: + None: No return value. Trains the specified model and saves predictions. """ if raw_file.endswith(".tiff") or raw_file.endswith(".tif"):