From 56dd7596ee93716d5eb306f15a6a427ffd807869 Mon Sep 17 00:00:00 2001 From: AlbertDominguez Date: Tue, 12 Nov 2024 12:04:57 +0100 Subject: [PATCH] fix wandb logging --- spotiflow/cli/train.py | 8 ++++++++ spotiflow/model/spotiflow.py | 8 ++++++-- spotiflow/model/trainer.py | 2 +- spotiflow/utils/fitting.py | 10 ++++++---- 4 files changed, 21 insertions(+), 7 deletions(-) diff --git a/spotiflow/cli/train.py b/spotiflow/cli/train.py index 6cf846b..8f62910 100644 --- a/spotiflow/cli/train.py +++ b/spotiflow/cli/train.py @@ -219,6 +219,13 @@ def get_args() -> argparse.Namespace: default="tensorboard", help="Logger to use for monitoring training. Defaults to 'tensorboard'.", ) + train_args.add_argument( + "--smart-crop", + type=str2bool, + required=False, + default=False, + help="Use smart cropping for training. Defaults to False.", + ) args = parser.parse_args() return args @@ -306,6 +313,7 @@ def main(): "pos_weight": args.pos_weight, "num_train_samples":args.train_samples, "finetuned_from": args.finetune_from, + "smart_crop": args.smart_crop, }, ) log.info("Done!") diff --git a/spotiflow/model/spotiflow.py b/spotiflow/model/spotiflow.py index 31dc3ec..9d1bb0d 100644 --- a/spotiflow/model/spotiflow.py +++ b/spotiflow/model/spotiflow.py @@ -6,6 +6,7 @@ from types import SimpleNamespace from typing import Callable, Literal, Optional, Sequence, Tuple, Union +import datetime import dask.array as da import lightning.pytorch as pl import numpy as np @@ -467,9 +468,10 @@ def fit( ] if logger == "tensorboard": - logger = pl.loggers.TensorBoardLogger(save_dir=save_dir) + logger = pl.loggers.TensorBoardLogger(save_dir=save_dir, name=f"spotiflow-{datetime.datetime.now().strftime('%Y%m%d_%H%M')}") elif logger == "wandb": - logger = pl.loggers.WandbLogger(save_dir=save_dir) + Path(save_dir/"wandb").mkdir(parents=True, exist_ok=True) + logger = pl.loggers.WandbLogger(save_dir=save_dir, project="spotiflow", name=f"{datetime.datetime.now().strftime('%Y%m%d_%H%M')}") else: if logger != "none": log.warning(f"Logger {logger} not implemented. Using no logger.") @@ -1067,6 +1069,8 @@ def predict( s_src_corr[:actual_n_dims] ] points.append(p) + del out, img_t, tile, y_tile, p + torch.cuda.empty_cache() if scale is not None and scale != 1: y = zoom(y, (1.0 / scale, 1.0 / scale), order=1) diff --git a/spotiflow/model/trainer.py b/spotiflow/model/trainer.py index e273f68..82b774b 100644 --- a/spotiflow/model/trainer.py +++ b/spotiflow/model/trainer.py @@ -362,7 +362,7 @@ def log_images(self): self.logger.log_image( key="flow", images=[ - 0.5 * (1 + np.squeeze(v, axis=0).transpose(1, 2, 0)) + 0.5 * (1 + v.transpose(1, 2, 0)) for v in self._valid_flows[:n_images_to_log] ], step=self.global_step, diff --git a/spotiflow/utils/fitting.py b/spotiflow/utils/fitting.py index 0ca4dee..d125fb3 100644 --- a/spotiflow/utils/fitting.py +++ b/spotiflow/utils/fitting.py @@ -9,6 +9,7 @@ from tqdm.auto import tqdm from dataclasses import dataclass, fields +from scipy.ndimage import map_coordinates FWHM_CONSTANT = 2 * np.sqrt(2 * np.log(2)) @@ -100,11 +101,12 @@ def _estimate_params_single2( y_range = np.arange(-window, window + 1) y, x = np.meshgrid(y_range, x_range, indexing="ij") - # Crop around the spot - region = image[ + # Crop around the spot with interpolation + y_indices, x_indices = np.mgrid[ center[0] - window : center[0] + window + 1, - center[1] - window : center[1] + window + 1, + center[1] - window : center[1] + window + 1 ] + region = map_coordinates(image, [y_indices, x_indices], order=3, mode='reflect') try: mi, ma = np.min(region), np.max(region) @@ -226,7 +228,7 @@ def estimate_params( peak_range (np.ndarray): peak range of the spots """ img = np.pad(img, window, mode="reflect") - centers = np.asarray(centers).astype(int) + window + centers = np.asarray(centers) + window if max_workers == 1: params = tuple( _estimate_params_single(