Skip to content

Commit

Permalink
Merge pull request #20 from janelia-cellmap/zouinkhim_fixes
Browse files Browse the repository at this point in the history
bug fixes and better logs
  • Loading branch information
rhoadesScholar authored Feb 7, 2024
2 parents e73afa2 + 3c5f2da commit 34e8253
Show file tree
Hide file tree
Showing 8 changed files with 19 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np

from typing import Dict, Any
import logging

logger = logging.getLogger(__file__)


class ConcatArray(Array):
Expand Down Expand Up @@ -116,5 +119,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
raise Exception(f"{concatenated.shape}, shapes")
logger.info(
f"Concatenated array has only one channel: {self.name} {concatenated.shape}"
)
return concatenated
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def attrs(self):

@property
def axes(self):
return ["t", "z", "y", "x"][-self.dims :]
return ["c", "z", "y", "x"][-self.dims :]

@property
def dims(self) -> int:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def from_gp_array(cls, array: gp.Array):
((["b", "c"] if len(array.data.shape) == instance.dims + 2 else []))
+ (["c"] if len(array.data.shape) == instance.dims + 1 else [])
+ [
"t",
"c",
"z",
"y",
"x",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def axes(self):
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}",
)
return ["t", "z", "y", "x"][-self.dims : :]
return ["c", "z", "y", "x"][-self.dims : :]

@property
def dims(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self,
architecture: Architecture,
prediction_head: torch.nn.Module,
eval_activation: torch.nn.Module = None,
eval_activation: torch.nn.Module | None = None,
):
super().__init__()

Expand Down
4 changes: 3 additions & 1 deletion dacapo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def train(run_name: str, compute_context: ComputeContext = LocalTorch()):
"""Train a run"""

if compute_context.train(run_name):
logger.error("Run %s is already being trained", run_name)
# if compute context runs train in some other process
# we are done here.
return
Expand Down Expand Up @@ -96,7 +97,8 @@ def train_run(
weights_store.retrieve_weights(run, iteration=trained_until)

elif latest_weights_iteration > trained_until:
raise RuntimeError(
weights_store.retrieve_weights(run, iteration=latest_weights_iteration)
logger.error(
f"Found weights for iteration {latest_weights_iteration}, but "
f"run {run.name} was only trained until {trained_until}."
)
Expand Down
2 changes: 2 additions & 0 deletions dacapo/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,15 @@ def validate_run(
prediction_array_identifier = array_store.validation_prediction_array(
run.name, iteration, validation_dataset
)
logger.info("Predicting on dataset %s", validation_dataset.name)
predict(
run.model,
validation_dataset.raw,
prediction_array_identifier,
compute_context=compute_context,
output_roi=validation_dataset.gt.roi,
)
logger.info("Predicted on dataset %s", validation_dataset.name)

post_processor.set_prediction(prediction_array_identifier)

Expand Down
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,9 @@
"funlib.evaluate @ git+https://github.com/pattonw/funlib.evaluate",
"gunpowder>=1.3",
"lsds>=0.1.3",
"xarray",
"cattrs",
"numpy-indexed",
"click",
],
)

0 comments on commit 34e8253

Please sign in to comment.