Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Dec 14, 2023
1 parent aba4430 commit 1d3bd47
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 90 deletions.
25 changes: 13 additions & 12 deletions src/autoseg/gp_filters/random_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,31 @@
import random
from skimage.util import random_noise


class RandomNoiseAugment(gp.BatchFilter):
"""
Random Noise Augmentation for Gunpowder.
This class applies random noise augmentation to the specified array in a Gunpowder batch.
Args:
array (str):
array (str):
The name of the array in the batch to which noise should be applied.
seed (int, optional):
seed (int, optional):
Seed for the random number generator. Default is None.
clip (bool, optional):
clip (bool, optional):
Whether to clip the values after applying noise. Default is True.
**kwargs:
**kwargs:
Additional keyword arguments to be passed to the `random_noise` function.
Attributes:
array (str):
array (str):
The name of the array in the batch to which noise is applied.
seed (int):
seed (int):
Seed for the random number generator.
clip (bool):
clip (bool):
Whether to clip the values after applying noise.
kwargs (dict):
kwargs (dict):
Additional keyword arguments passed to the `random_noise` function.
"""

Expand All @@ -48,11 +49,11 @@ def prepare(self, request):
Prepare the dependencies for processing based on the requested batch.
Args:
request (BatchRequest):
request (BatchRequest):
The requested batch.
Returns:
BatchRequest:
BatchRequest:
The dependencies for processing.
"""
deps = gp.BatchRequest()
Expand All @@ -64,9 +65,9 @@ def process(self, batch, request):
Apply random noise augmentation to the specified array in the batch.
Args:
batch (Batch):
batch (Batch):
The input batch.
request (BatchRequest):
request (BatchRequest):
The requested batch.
"""
raw = batch.arrays[self.array]
Expand Down
13 changes: 7 additions & 6 deletions src/autoseg/gp_filters/smooth_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@
import random
from scipy.ndimage import gaussian_filter


class SmoothArray(gp.BatchFilter):
"""
Smooth Array in a Gunpowder Batch.
This class applies Gaussian smoothing to a 3D array in a Gunpowder batch.
Args:
array (str):
array (str):
The name of the array in the batch to be smoothed.
blur_range (tuple):
blur_range (tuple):
The range of sigma values for the Gaussian filter.
Attributes:
array (str):
array (str):
The name of the array in the batch to be smoothed.
range (tuple):
range (tuple):
The range of sigma values for the Gaussian filter.
"""

Expand All @@ -31,9 +32,9 @@ def process(self, batch, request):
Apply Gaussian smoothing to the specified array in the batch.
Args:
batch (Batch):
batch (Batch):
The input batch.
request (BatchRequest):
request (BatchRequest):
The requested batch.
"""
array = batch[self.array].data
Expand Down
18 changes: 9 additions & 9 deletions src/autoseg/losses/ACLSDLoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class WeightedACLSD_MSELoss(torch.nn.MSELoss):
weighting term for Auto-Context LSD (ACLSD) segmentation.
Parameters:
aff_lambda (float, optional):
aff_lambda (float, optional):
Weighting factor for the affinity loss. Default is 1.0.
"""

Expand All @@ -23,15 +23,15 @@ def _calc_loss(self, prediction, target, weights):
Calculates the weighted mean squared error loss.
Args:
prediction (torch.Tensor):
prediction (torch.Tensor):
Predicted affinities.
target (torch.Tensor):
target (torch.Tensor):
Ground truth affinities.
weights (torch.Tensor):
weights (torch.Tensor):
Weighting factor for each affinity.
Returns:
torch.Tensor:
torch.Tensor:
Weighted mean squared error loss.
"""
scaled = weights * (prediction - target) ** 2
Expand All @@ -54,15 +54,15 @@ def forward(
Calculates the weighted ACLSD MSE loss.
Args:
pred_affs (torch.Tensor):
pred_affs (torch.Tensor):
Predicted affinities.
gt_affs (torch.Tensor):
gt_affs (torch.Tensor):
Ground truth affinities.
affs_weights (torch.Tensor):
affs_weights (torch.Tensor):
Weighting factor for each affinity.
Returns:
torch.Tensor:
torch.Tensor:
Weighted ACLSD MSE loss.
"""
aff_loss = self.aff_lambda * self._calc_loss(pred_affs, gt_affs, affs_weights)
Expand Down
38 changes: 19 additions & 19 deletions src/autoseg/losses/GMSELoss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ class Weighted_GMSELoss(torch.nn.Module):
GAN (Generative Adversarial Network) loss term for enhanced data.
Parameters:
aff_lambda (float, optional):
aff_lambda (float, optional):
Weighting factor for the affinity loss. Default is 1.0.
gan_lambda (float, optional):
gan_lambda (float, optional):
Weighting factor for the GAN loss. Default is 1.0.
discrim (torch.nn.Module, optional):
discrim (torch.nn.Module, optional):
Discriminator network for GAN loss.
"""

Expand All @@ -22,11 +22,11 @@ def __init__(self, aff_lambda=1.0, gan_lambda=1.0, discrim=None):
Initializes the Weighted_MSELoss.
Args:
aff_lambda (float, optional):
aff_lambda (float, optional):
Weighting factor for the affinity loss. Default is 1.0.
gan_lambda (float, optional):
gan_lambda (float, optional):
Weighting factor for the GAN loss. Default is 1.0.
discrim (torch.nn.Module, optional):
discrim (torch.nn.Module, optional):
Discriminator network for GAN loss.
"""
super(Weighted_GMSELoss, self).__init__()
Expand All @@ -40,15 +40,15 @@ def _calc_loss(self, prediction, target, weights=None):
Calculates the weighted mean squared error loss.
Args:
prediction (torch.Tensor):
prediction (torch.Tensor):
Predicted values.
target (torch.Tensor):
target (torch.Tensor):
Ground truth values.
weights (torch.Tensor, optional):
weights (torch.Tensor, optional):
Weighting factor for each value.
Returns:
torch.Tensor:
torch.Tensor:
Weighted mean squared error loss.
"""
if type(weights) != torch.Tensor:
Expand Down Expand Up @@ -79,25 +79,25 @@ def forward(
Calculates the weighted MSE loss with GAN loss.
Args:
pred_lsds (torch.Tensor):
pred_lsds (torch.Tensor):
Predicted LSD values.
gt_lsds (torch.Tensor):
gt_lsds (torch.Tensor):
Ground truth LSD values.
lsds_weights (torch.Tensor, optional):
lsds_weights (torch.Tensor, optional):
Weighting factor for each LSD value.
pred_affs (torch.Tensor):
pred_affs (torch.Tensor):
Predicted affinity values.
gt_affs (torch.Tensor):
gt_affs (torch.Tensor):
Ground truth affinity values.
affs_weights (torch.Tensor, optional):
affs_weights (torch.Tensor, optional):
Weighting factor for each affinity value.
pred_enhanced (torch.Tensor):
pred_enhanced (torch.Tensor):
Predicted enhanced data.
gt_enhanced (torch.Tensor):
gt_enhanced (torch.Tensor):
Ground truth enhanced data.
Returns:
torch.Tensor:
torch.Tensor:
Combined weighted MSE loss with GAN loss.
"""
# calculate MSE loss for LSD and Affs
Expand Down
14 changes: 7 additions & 7 deletions src/autoseg/postprocess/segment_mws.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ def get_validation_segmentation(
Get validation segmentation using the specified segmentation style.
Parameters:
segmentation_style (str):
segmentation_style (str):
Style of segmentation ("mws" or "mergetree").
iteration (str):
iteration (str):
Iteration or checkpoint to use (default: "latest").
raw_file (str):
raw_file (str):
Path to the input Zarr dataset containing raw data.
raw_dataset (str):
raw_dataset (str):
Name of the raw dataset in the input Zarr file.
out_file (str):
out_file (str):
Path to the output Zarr file for storing predictions.
pred_affs (bool):
pred_affs (bool):
Flag to indicate whether to predict affinities.
Returns:
bool:
bool:
True if segmentation is successful, False otherwise.
"""
out_datasets = [
Expand Down
18 changes: 9 additions & 9 deletions src/autoseg/postprocess/segment_skel_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,25 @@ def get_skel_correct_segmentation(
Generate segmentation with skeleton-based correction using RUSTY_MWS.
Parameters:
predict_affs (bool):
predict_affs (bool):
Flag to indicate whether to predict affinities.
raw_file (str):
raw_file (str):
Path to the input Zarr dataset containing raw data.
raw_dataset (str):
raw_dataset (str):
Name of the raw dataset in the input Zarr file.
out_file (str):
out_file (str):
Path to the output Zarr file for storing predictions.
out_datasets (list):
out_datasets (list):
List of tuples specifying output dataset names and channel counts.
iteration (str):
iteration (str):
Iteration or checkpoint to use (default: "latest").
model_path (str):
model_path (str):
Path to the directory containing the trained model checkpoints.
voxel_size (int):
voxel_size (int):
Voxel size in all three dimensions.
Returns:
None:
None:
No return value. Segmentation with skeleton-based correction is stored in the specified Zarr file.
"""
if predict_affs:
Expand Down
26 changes: 13 additions & 13 deletions src/autoseg/predict/network_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def predict_task(
model:torch.nn.Module,
model: torch.nn.Module,
model_type: str,
iteration: int,
raw_file: str,
Expand All @@ -33,31 +33,31 @@ def predict_task(
Predict affinities using a trained deep learning model.
Parameters:
model:
model:
Trained deep learning model.
model_type (str):
model_type (str):
Type of the model ("MTLSD", "ACLSD", "STELARR").
iteration (int):
iteration (int):
Iteration or checkpoint of the trained model to use.
raw_file (str):
raw_file (str):
Path to the input Zarr or N5 dataset or TIFF file containing raw data.
raw_dataset (str):
raw_dataset (str):
Name of the raw dataset in the input Zarr or N5 file.
out_file (str):
out_file (str):
Path to the output Zarr file for storing predictions.
out_datasets (list):
out_datasets (list):
List of tuples specifying output dataset names and channel counts.
num_workers (int):
num_workers (int):
Number of parallel workers for blockwise processing.
n_gpu (int):
n_gpu (int):
Number of GPUs available for prediction.
model_path (str):
model_path (str):
Path to the directory containing the trained model checkpoints.
voxel_size (int):
voxel_size (int):
Voxel size in all three dimensions.
Returns:
None:
None:
No return value. Predictions are stored in the specified Zarr file.
"""
if type(iteration) == str and "latest" in iteration:
Expand Down
Loading

0 comments on commit 1d3bd47

Please sign in to comment.