Skip to content

Commit

Permalink
docstrings (#160)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Mar 13, 2024
2 parents 63c040d + b96f79f commit 2298ee4
Show file tree
Hide file tree
Showing 43 changed files with 1,016 additions and 94 deletions.
15 changes: 14 additions & 1 deletion dacapo/compute_context/local_torch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .compute_context import ComputeContext

import torch
import attr

Expand All @@ -8,6 +7,16 @@

@attr.s
class LocalTorch(ComputeContext):
"""
The LocalTorch class is a subclass of the ComputeContext class.
It is used to specify the context in which computations are to be done.
LocalTorch is used to specify that computations are to be done on the local machine using PyTorch.
Attributes:
_device (Optional[str]): This stores the type of device on which torch computations are to be done. It can
take "cuda" for GPU or "cpu" for CPU. None value results in automatic detection of device type.
"""

_device: Optional[str] = attr.ib(
default=None,
metadata={
Expand All @@ -18,6 +27,10 @@ class LocalTorch(ComputeContext):

@property
def device(self):
"""
A property method that returns the torch device object. It automatically detects and uses "cuda" (GPU) if
available, else it falls back on using "cpu".
"""
if self._device is None:
if torch.cuda.is_available():
return torch.device("cuda")
Expand Down
50 changes: 45 additions & 5 deletions dacapo/experiments/architectures/architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,75 @@


class Architecture(torch.nn.Module, ABC):
"""
An abstract base class for defining the architecture of a neural network model.
It is inherited from PyTorch's Module and built-in class `ABC` (Abstract Base Classes).
Other classes can inherit this class to define their own specific variations of architecture.
It requires to implement several property methods, and also includes additional methods related to the architecture design.
"""

@property
@abstractmethod
def input_shape(self) -> Coordinate:
"""The spatial input shape (i.e., not accounting for channels and batch
dimensions) of this architecture."""
"""
Abstract method to define the spatial input shape for the neural network architecture.
The shape should not account for the channels and batch dimensions.
Returns:
Coordinate: The spatial input shape.
"""
pass

@property
def eval_shape_increase(self) -> Coordinate:
"""
How much to increase the input shape during prediction.
Provides information about how much to increase the input shape during prediction.
Returns:
Coordinate: An instance representing the amount to increase in each dimension of the input shape.
"""
return Coordinate((0,) * self.input_shape.dims)

@property
@abstractmethod
def num_in_channels(self) -> int:
"""Return the number of input channels this architecture expects."""
"""
Abstract method to return number of input channels required by the architecture.
Returns:
int: Required number of input channels.
"""
pass

@property
@abstractmethod
def num_out_channels(self) -> int:
"""Return the number of output channels of this architecture."""
"""
Abstract method to return the number of output channels provided by the architecture.
Returns:
int: Number of output channels.
"""
pass

@property
def dims(self) -> int:
"""
Returns the number of dimensions of the input shape.
Returns:
int: The number of dimensions.
"""
return self.input_shape.dims

def scale(self, input_voxel_size: Coordinate) -> Coordinate:
"""
Method to scale the input voxel size as required by the architecture.
Args:
input_voxel_size (Coordinate): The original size of the input voxel.
Returns:
Coordinate: The scaled voxel size.
"""
return input_voxel_size
26 changes: 21 additions & 5 deletions dacapo/experiments/architectures/architecture_config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
import attr

from typing import Tuple


@attr.s
class ArchitectureConfig:
"""Base class for architecture configurations. Each subclass of an
`Architecture` should have a corresponding config class derived from
`ArchitectureConfig`.
"""
A class to represent the base configurations of any architecture.
Attributes
----------
name : str
a unique name for the architecture.
Methods
-------
verify()
validates the given architecture.
"""

name: str = attr.ib(
Expand All @@ -20,6 +29,13 @@ class ArchitectureConfig:

def verify(self) -> Tuple[bool, str]:
"""
Check whether this is a valid architecture
A method to validate an architecture configuration.
Returns
-------
bool
A flag indicating whether the config is valid or not.
str
A description of the architecture.
"""
return True, "No validation for this Architecture"
46 changes: 44 additions & 2 deletions dacapo/experiments/architectures/dummy_architecture.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
from .architecture import Architecture

from funlib.geometry import Coordinate

import torch


class DummyArchitecture(Architecture):
"""
A class used to represent a dummy architecture layer for a 3D CNN.
Attributes:
channels_in: An integer representing the number of input channels.
channels_out: An integer representing the number of output channels.
conv: A 3D convolution object.
input_shape: A coordinate object representing the shape of the input.
Methods:
forward(x): Performs the forward pass of the network.
"""

def __init__(self, architecture_config):
"""
Args:
architecture_config: An object containing the configuration settings for the architecture.
"""
super().__init__()

self.channels_in = architecture_config.num_in_channels
Expand All @@ -16,15 +31,42 @@ def __init__(self, architecture_config):

@property
def input_shape(self):
"""
Returns the input shape for this architecture.
Returns:
Coordinate: Input shape of the architecture.
"""
return Coordinate(40, 20, 20)

@property
def num_in_channels(self):
"""
Returns the number of input channels for this architecture.
Returns:
int: Number of input channels.
"""
return self.channels_in

@property
def num_out_channels(self):
"""
Returns the number of output channels for this architecture.
Returns:
int: Number of output channels.
"""
return self.channels_out

def forward(self, x):
"""
Perform the forward pass of the network.
Args:
x: Input tensor.
Returns:
Tensor: Output tensor after the forward pass.
"""
return self.conv(x)
24 changes: 22 additions & 2 deletions dacapo/experiments/architectures/dummy_architecture_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,19 @@

@attr.s
class DummyArchitectureConfig(ArchitectureConfig):
"""This is just a dummy architecture config used for testing. None of the
attributes have any particular meaning."""
"""A dummy architecture configuration class used for testing purposes.
It extends the base class "ArchitectureConfig". This class contains dummy attributes and always
returns that the configuration is invalid when verified.
Attributes:
architecture_type (:obj:`DummyArchitecture`): A class attribute assigning
the DummyArchitecture class to this configuration.
num_in_channels (int): The number of input channels. This is a dummy attribute and has no real
functionality or meaning.
num_out_channels (int): The number of output channels. This is also a dummy attribute and
has no real functionality or meaning.
"""

architecture_type = DummyArchitecture

Expand All @@ -18,4 +29,13 @@ class DummyArchitectureConfig(ArchitectureConfig):
num_out_channels: int = attr.ib(metadata={"help_text": "Dummy attribute."})

def verify(self) -> Tuple[bool, str]:
"""Verifies the configuration validity.
Since this is a dummy configuration for testing purposes, this method always returns False
indicating that the configuration is invalid.
Returns:
tuple: A tuple containing a boolean validity flag and a reason message string.
"""

return False, "This is a DummyArchitectureConfig and is never valid"
7 changes: 7 additions & 0 deletions dacapo/experiments/arraytypes/arraytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,11 @@ class ArrayType(ABC):
@property
@abstractmethod
def interpolatable(self) -> bool:
"""
This is an abstract method which should be overridden in each of the subclasses
to determine if an array is interpolatable or not.
Returns:
bool: True if the array is interpolatable, False otherwise.
"""
pass
19 changes: 17 additions & 2 deletions dacapo/experiments/arraytypes/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@
@attr.s
class BinaryArray(ArrayType):
"""
An BinaryArray is a bool or uint8 Array where each
voxel is either 1 or 0.
A subclass of ArrayType representing BinaryArray. The BinaryArray object is created with two attributes; channels.
Each voxel in this array is either 1 or 0.
Attributes:
channels (Dict[int, str]): A dictionary attribute representing channel mapping with its binary classification.
Args:
channels (Dict[int, str]): A dictionary input where keys are channel numbers and values are their corresponding class for binary classification.
Methods:
interpolatable: Returns False as binary array type is not interpolatable.
"""

channels: Dict[int, str] = attr.ib(
Expand All @@ -20,4 +29,10 @@ class BinaryArray(ArrayType):

@property
def interpolatable(self) -> bool:
"""
This function returns the interpolatable property value of the binary array.
Returns:
bool: Always returns False because interpolation is not possible.
"""
return False
17 changes: 17 additions & 0 deletions dacapo/experiments/arraytypes/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@

@attr.s
class Mask(ArrayType):
"""
A class that inherits the ArrayType class. This is a representation of a Mask in the system.
Methods
-------
interpolatable():
It is a method that returns False.
"""

@property
def interpolatable(self) -> bool:
"""
Method to return False.
Returns
------
bool
Returns a boolean value of False representing that the values are not interpolatable.
"""
return False
20 changes: 14 additions & 6 deletions dacapo/experiments/arraytypes/probabilities.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from .arraytype import ArrayType

import attr

from typing import List


@attr.s
class ProbabilityArray(ArrayType):
"""
An array containing probabilities for each voxel. I.e. each voxel has a vector
of length `c` where `c` is the number of classes. The l1 norm of this vector should
always be 1. The class of this voxel can be determined by simply taking the
argmax.
Class to represent an array containing probability distributions for each voxel pointed by its coordinate.
The class defines a ProbabilityArray object with each voxel having a vector of length `c`, where `c` is the
number of classes. The l1 norm of this vector should always be 1. The class of each voxel can be
determined by simply taking the argmax.
Attributes:
classes (List[str]): A mapping from channel to class on which distances were calculated.
"""

classes: List[str] = attr.ib(
Expand All @@ -22,4 +24,10 @@ class ProbabilityArray(ArrayType):

@property
def interpolatable(self) -> bool:
"""
Checks if the array is interpolatable. Returns True for this class.
Returns:
bool: True indicating that the data can be interpolated.
"""
return True
Loading

0 comments on commit 2298ee4

Please sign in to comment.