Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add features from cellmap fork #29

Open
wants to merge 49 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
eb1309c
fix distance mask predictor
mzouink Aug 30, 2023
5fe699a
Update black.yaml - force formatting
rhoadesScholar Jan 16, 2024
0098fa2
Update black.yaml - automatic black
rhoadesScholar Jan 16, 2024
13ae1f0
Update black.yaml - remove auto black
rhoadesScholar Jan 16, 2024
d741626
Update black.yaml
mzouink Feb 7, 2024
9351c16
remove black
mzouink Feb 7, 2024
bb8cab5
add black format check
mzouink Feb 7, 2024
c25cb79
black format on pull request
mzouink Feb 7, 2024
08f134d
:art: Format Python code with psf/black
mzouink Feb 7, 2024
e73afa2
Merge pull request #19 from janelia-cellmap/actions/black
mzouink Feb 7, 2024
ff61f7c
bug fixes and better logs
mzouink Feb 7, 2024
1490440
Update train.py
mzouink Feb 7, 2024
3c5f2da
Update train.py
mzouink Feb 7, 2024
34e8253
Merge pull request #20 from janelia-cellmap/zouinkhim_fixes
rhoadesScholar Feb 7, 2024
33bbc8a
feat: 🚧 Incorporate simple change from rhoadesj/dev
rhoadesScholar Feb 8, 2024
fe23b5d
feat: 🚧 Incorporate simple change from rhoadesj/dev
rhoadesScholar Feb 8, 2024
5f50f9b
Merge pull request #25 from janelia-cellmap/rhoadesj_simple_changes
mzouink Feb 8, 2024
812acc1
feat: ⚡️ Incorporate start related changes from rhoadesj/dev
rhoadesScholar Feb 8, 2024
ce5d272
docs: 📝 Add authors and versioning.
rhoadesScholar Feb 9, 2024
4f1dfed
starter partial weight load
mzouink Feb 9, 2024
910b2e0
Merge pull request #35 from janelia-cellmap/fix_start_loader
mzouink Feb 9, 2024
906dfd6
:art: Format Python code with psf/black
mzouink Feb 9, 2024
8cce986
Merge pull request #37 from janelia-cellmap/actions/black
mzouink Feb 9, 2024
f5e584a
publish to pypi
mzouink Feb 9, 2024
281a768
logo
rhoadesScholar Feb 9, 2024
65482f4
Update README.md
rhoadesScholar Feb 9, 2024
52409f9
Update README.md
rhoadesScholar Feb 9, 2024
b8e18b4
Update publish.yaml
mzouink Feb 9, 2024
b4b2780
fix: 🐛 Fix broken dependencies for MacOS.
rhoadesScholar Feb 9, 2024
55a3892
include and use more biases during watershed post processing of affin…
davidackerman Feb 9, 2024
58c7abe
include weighting argument for affinities+lsd loss
davidackerman Feb 9, 2024
ce71fb5
make predictor node optional
davidackerman Feb 9, 2024
353b8cb
:art: Format Python code with psf/black
davidackerman Feb 9, 2024
4a54e31
Merge pull request #42 from janelia-cellmap/actions/black
mzouink Feb 9, 2024
f243c7c
styles fixes for mypy
mzouink Feb 9, 2024
cebc737
update git action, fix doc and no more publish
mzouink Feb 9, 2024
7feab6a
remove unfinished cli and apply from main
mzouink Feb 9, 2024
5d77af0
fix test action, pytest 8.0.0 working
mzouink Feb 9, 2024
e46acf0
:art: Format Python code with psf/black
mzouink Feb 9, 2024
cea9c03
Merge pull request #45 from janelia-cellmap/actions/black
mzouink Feb 9, 2024
232047c
:art: Format Python code with psf/black
mzouink Feb 9, 2024
8b9d44a
Merge pull request #46 from janelia-cellmap/actions/black
mzouink Feb 9, 2024
62d6278
test only with python 3.10
mzouink Feb 9, 2024
3c2f0fe
bug fix: loading starter weight, layer exist but mismatch shape
mzouink Feb 12, 2024
84436e8
Merge pull request #60 from janelia-cellmap/fix_bug_starter_shape
mzouink Feb 12, 2024
4a4bd94
update size checking
mzouink Feb 12, 2024
5855dd0
:art: Format Python code with psf/black
mzouink Feb 12, 2024
6afcacd
Merge pull request #63 from janelia-cellmap/actions/black
mzouink Feb 12, 2024
d0512f9
Merge pull request #64 from janelia-cellmap/fix_bug_starter_shape
mzouink Feb 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions .github/workflows/black.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
name: Python Black
name: black-action

on: [push, pull_request]

jobs:
lint:
name: Python Lint
linter_name:
name: runner / black
runs-on: ubuntu-latest
steps:
- name: Setup Python
uses: actions/setup-python@v1
- name: Setup checkout
uses: actions/checkout@master
- name: Lint with Black
run: |
pip install black
black -v --check dacapo tests
- uses: actions/checkout@v2
- name: Check files using the black formatter
uses: rickstaa/action-black@v1
id: action_black
with:
black_args: "."
- name: Create Pull Request
if: steps.action_black.outputs.is_formatted == 'true'
uses: peter-evans/create-pull-request@v3
with:
token: ${{ secrets.GITHUB_TOKEN }}
title: "Format Python code with psf/black push"
commit-message: ":art: Format Python code with psf/black"
body: |
There appear to be some python formatting errors in ${{ github.sha }}. This pull request
uses the [psf/black](https://github.com/psf/black) formatter to fix these issues.
base: ${{ github.head_ref }} # Creates pull request onto pull request or commit branch
branch: actions/black
9 changes: 4 additions & 5 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
name: Pages
on:
push:
branches:
- master
name: Generate Pages

on: [push, pull_request]

jobs:
docs:
runs-on: ubuntu-latest
Expand Down
34 changes: 0 additions & 34 deletions .github/workflows/publish.yaml

This file was deleted.

7 changes: 3 additions & 4 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
name: Test

on:
push:
on: [push, pull_request]

jobs:
test:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10"]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v2
Expand All @@ -23,4 +22,4 @@ jobs:
pip install -r requirements-dev.txt
- name: Test with pytest
run: |
pytest tests
pytest tests
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
![DaCapo](docs/source/_static/dacapo.svg)
# DaCapo ![DaCapo](docs/source/_static/icon_dacapo.png)

[![tests](https://github.com/funkelab/dacapo/actions/workflows/tests.yaml/badge.svg)](https://github.com/funkelab/dacapo/actions/workflows/tests.yaml)
[![black](https://github.com/funkelab/dacapo/actions/workflows/black.yaml/badge.svg)](https://github.com/funkelab/dacapo/actions/workflows/black.yaml)
Expand Down
1 change: 1 addition & 0 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ def apply(run_name: str, iteration: int, dataset_name: str):
iteration,
dataset_name,
)
raise NotImplementedError("This function is not yet implemented.")
2 changes: 1 addition & 1 deletion dacapo/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run", required=True, type=str, help="The name of the run to use."
"-r", "--run-name", required=True, type=str, help="The name of the run to use."
)
@click.option(
"-i",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np

from typing import Dict, Any
import logging

logger = logging.getLogger(__file__)


class ConcatArray(Array):
Expand Down Expand Up @@ -116,5 +119,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
raise Exception(f"{concatenated.shape}, shapes")
logger.info(
f"Concatenated array has only one channel: {self.name} {concatenated.shape}"
)
return concatenated
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def attrs(self):

@property
def axes(self):
return ["t", "z", "y", "x"][-self.dims :]
return ["c", "z", "y", "x"][-self.dims :]

@property
def dims(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def from_gp_array(cls, array: gp.Array):
((["b", "c"] if len(array.data.shape) == instance.dims + 2 else []))
+ (["c"] if len(array.data.shape) == instance.dims + 1 else [])
+ [
"t",
"c",
"z",
"y",
"x",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def voxel_size(self) -> Coordinate:

@lazy_property.LazyProperty
def roi(self) -> Roi:
return Roi(self._offset * self.shape)
return Roi(self._offset, self.shape)

@property
def writable(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions dacapo/experiments/datasplits/datasets/arrays/zarr_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def axes(self):
logger.debug(
"DaCapo expects Zarr datasets to have an 'axes' attribute!\n"
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}",
f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}",
)
return ["t", "z", "y", "x"][-self.dims : :]
return ["c", "z", "y", "x"][-self.dims : :]

@property
def dims(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions dacapo/experiments/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self,
architecture: Architecture,
prediction_head: torch.nn.Module,
eval_activation: torch.nn.Module = None,
eval_activation: torch.nn.Module | None = None,
):
super().__init__()

Expand All @@ -46,7 +46,7 @@ def forward(self, x):
result = self.eval_activation(result)
return result

def compute_output_shape(self, input_shape: Coordinate) -> Coordinate:
def compute_output_shape(self, input_shape: Coordinate) -> Tuple[int, Coordinate]:
"""Compute the spatial shape (i.e., not accounting for channels and
batch dimensions) of this model, when fed a tensor of the given spatial
shape as input."""
Expand Down
14 changes: 13 additions & 1 deletion dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,21 @@ def initialize_weights(self, model):
weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)
logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}")

# load the model weights (taken from torch load_state_dict source)
try:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)
# if the model is not the same, we can try to load the weights
# of the common layers
model_dict = model.state_dict()
pretrained_dict = {
k: v
for k, v in weights.model.items()
if k in model_dict and v.size() == model_dict[k].size()
}
model_dict.update(
pretrained_dict
) # update only the existing and matching layers
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")
4 changes: 3 additions & 1 deletion dacapo/experiments/tasks/affinities_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def __init__(self, task_config):
self.predictor = AffinitiesPredictor(
neighborhood=task_config.neighborhood, lsds=task_config.lsds
)
self.loss = AffinitiesLoss(len(task_config.neighborhood))
self.loss = AffinitiesLoss(
len(task_config.neighborhood), task_config.lsds_to_affs_weight_ratio
)
self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood)
self.evaluator = InstanceEvaluator()
6 changes: 6 additions & 0 deletions dacapo/experiments/tasks/affinities_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,9 @@ class AffinitiesTaskConfig(TaskConfig):
"It has been shown that lsds as an auxiliary task can help affinity predictions."
},
)
lsds_to_affs_weight_ratio: float = attr.ib(
default=1,
metadata={
"help_text": "If training with lsds, set how much they should be weighted compared to affs."
},
)
5 changes: 3 additions & 2 deletions dacapo/experiments/tasks/losses/affinities_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@


class AffinitiesLoss(Loss):
def __init__(self, num_affinities: int):
def __init__(self, num_affinities: int, lsds_to_affs_weight_ratio: float):
self.num_affinities = num_affinities
self.lsds_to_affs_weight_ratio = lsds_to_affs_weight_ratio

def compute(self, prediction, target, weight):
affs, affs_target, affs_weight = (
Expand All @@ -21,7 +22,7 @@ def compute(self, prediction, target, weight):
return (
torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target)
* affs_weight
).mean() + (
).mean() + self.lsds_to_affs_weight_ratio * (
torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target)
* aux_weight
).mean()
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def enumerate_parameters(self):
"""Enumerate all possible parameters of this post-processor. Should
return instances of ``PostProcessorParameters``."""

for i, bias in enumerate([0.1, 0.5, 0.9]):
for i, bias in enumerate([0.1, 0.25, 0.5, 0.75, 0.9]):
yield WatershedPostProcessorParameters(id=i, bias=bias)

def set_prediction(self, prediction_array_identifier):
Expand All @@ -44,9 +44,9 @@ def process(self, parameters, output_array_identifier):
# if a previous segmentation is provided, it must have a "grid graph"
# in its metadata.
pred_data = self.prediction_array[self.prediction_array.roi]
affs = pred_data[: len(self.offsets)]
affs = pred_data[: len(self.offsets)].astype(np.float64)
segmentation = mws.agglom(
affs - 0.5,
affs - parameters.bias,
self.offsets,
)
# filter fragments
Expand All @@ -59,12 +59,17 @@ def process(self, parameters, output_array_identifier):
for fragment, mean in zip(
fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids)
):
if mean < 0.5:
if mean < parameters.bias:
filtered_fragments.append(fragment)

filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype)
replace = np.zeros_like(filtered_fragments)
segmentation = npi.remap(segmentation, filtered_fragments, replace)

# DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input
if filtered_fragments.size > 0:
segmentation = npi.remap(
segmentation.flatten(), filtered_fragments, replace
).reshape(segmentation.shape)

output_array[self.prediction_array.roi] = segmentation

Expand Down
33 changes: 24 additions & 9 deletions dacapo/experiments/tasks/predictors/affinities_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@


class AffinitiesPredictor(Predictor):
def __init__(self, neighborhood: List[Coordinate], lsds: bool = True):
def __init__(
self,
neighborhood: List[Coordinate],
lsds: bool = True,
num_voxels: int = 20,
downsample_lsds: int = 1,
grow_boundary_iterations: int = 0,
):
self.neighborhood = neighborhood
self.lsds = lsds
self.num_voxels = num_voxels
if lsds:
self._extractor = None
if self.dims == 2:
Expand All @@ -30,12 +38,16 @@ def __init__(self, neighborhood: List[Coordinate], lsds: bool = True):
raise ValueError(
f"Cannot compute lsds on volumes with {self.dims} dimensions"
)
self.downsample_lsds = downsample_lsds
else:
self.num_lsds = 0
self.grow_boundary_iterations = grow_boundary_iterations

def extractor(self, voxel_size):
if self._extractor is None:
self._extractor = LsdExtractor(self.sigma(voxel_size))
self._extractor = LsdExtractor(
self.sigma(voxel_size), downsample=self.downsample_lsds
)

return self._extractor

Expand All @@ -45,8 +57,7 @@ def dims(self):

def sigma(self, voxel_size):
voxel_dist = max(voxel_size) # arbitrarily chosen
num_voxels = 10 # arbitrarily chosen
sigma = voxel_dist * num_voxels
sigma = voxel_dist * self.num_voxels # arbitrarily chosen
return Coordinate((sigma,) * self.dims)

def lsd_pad(self, voxel_size):
Expand Down Expand Up @@ -118,7 +129,9 @@ def _grow_boundaries(self, mask, slab):
slice(start[d], start[d] + slab[d]) for d in range(len(slab))
)
mask_slab = mask[slices]
dilated_mask_slab = ndimage.binary_dilation(mask_slab, iterations=1)
dilated_mask_slab = ndimage.binary_dilation(
mask_slab, iterations=self.grow_boundary_iterations
)
foreground[slices] = dilated_mask_slab

# label new background
Expand All @@ -130,10 +143,12 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
(moving_class_counts, moving_lsd_class_counts) = (
moving_class_counts if moving_class_counts is not None else (None, None)
)
# mask_data = self._grow_boundaries(
# mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes)
# )
mask_data = mask[target.roi]
if self.grow_boundary_iterations > 0:
mask_data = self._grow_boundaries(
mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes)
)
else:
mask_data = mask[target.roi]
aff_weights, moving_class_counts = balance_weights(
target[target.roi][: self.num_channels - self.num_lsds].astype(np.uint8),
2,
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class DistancePredictor(Predictor):
in the channels argument.
"""

def __init__(self, channels: List[str], scale_factor: float, mask_distances=bool):
def __init__(self, channels: List[str], scale_factor: float, mask_distances: bool):
self.channels = channels
self.norm = "tanh"
self.dt_scale_factor = scale_factor
Expand Down
Loading