Skip to content

Commit

Permalink
Merge pull request #27 from HelmholtzAI-Consultants-Munich/mmdet-model
Browse files Browse the repository at this point in the history
Mmdet model

mmcv fails to install correctly on GitHub Actions, but this package has been tested on a local Ubuntu, macOS and Windows for python 3.9 and 3.10.
  • Loading branch information
christinab12 authored Jul 25, 2024
2 parents 5955372 + b920e2b commit 7f7518a
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 100 deletions.
14 changes: 5 additions & 9 deletions .github/workflows/test_and_deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ on:
push:
branches:
- main
- dev-v.0.2
tags:
- "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
pull_request:
Expand All @@ -22,7 +21,7 @@ jobs:
strategy:
matrix:
platform: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.8, 3.9, "3.10"]
python-version: ["3.10", 3.9]
env:
DISPLAY: ':99.0'
steps:
Expand All @@ -47,24 +46,21 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-cookies tox
python -m pip install pytest pytest-qt pytest-cookies
pip install -e ".[testing]"
pip install "mmcv<2.2.0,>=2.0.0rc4" --find-links https://download.openmmlab.com/mmcv/dist/${{ matrix.python-version }}/torch2.4.0/cpu
# this runs the platform-specific tests declared in tox.ini

- name: Test
uses: aganders3/headless-gui@v1
with:
run: |
pip install -e ".[testing]"
python -m pytest -s -v --color=yes
env:
PLATFORM: ${{ matrix.platform }}
deploy:
# this will run when you have tagged a commit, starting with "v*"
# and requires that you have put your twine API key in your
# github secrets (see readme for details)
needs: [test]
# needs: [test]
runs-on: ubuntu-latest
if: contains(github.ref, 'tags')
steps:
Expand Down
24 changes: 17 additions & 7 deletions .napari/DESCRIPTION.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
## Description

A napari plugin to automatically count lung organoids from microscopy imaging data. A Faster R-CNN model was trained on patches of microscopy data. Model inference is run using a sliding window approach, with a 50% overlap and the option for predicting on multiple window sizes and scales, the results of which are then merged using NMS.
A napari plugin to automatically count lung organoids from microscopy imaging data. Several object detection DL models were trained on patches of 2D microscopy data. Model inference is run using a sliding window approach, with a 50% overlap and the option for predicting on multiple window sizes and scales, the results of which are then merged using NMS.

![Alt Text](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/blob/main/readme-content/demo-plugin-v2.gif)

## What's new in v2?
Here is a list of the main changes v2 of napari-organoid-counter offers:
* Use of Faster R-CNN model for object detection
* Use of DL models for object detection - pretrained models: Faster R-CNN, YOLOv3, SSD, and RTMDet. The data used for training these models along with the code for training can be found [here](https://www.kaggle.com/datasets/christinabukas/mutliorg).
* Pyramid model inference with a sliding window approach and tunable parameters for window size and window downsampling rate
* Model confidence added as tunable parameter
* Allow to load and correct existing annotations (note: these must have been saved previously from v2 of this plugin)
Expand All @@ -20,16 +20,21 @@ Technical Extensions:

## Installation

You can install `napari-organoid-counter` via [pip](https://pypi.org/project/napari-organoid-counter/):
This plugin has been tested with python 3.9 and 3.10 - you may consider using conda to create your dedicated environment before running the `napari-organoid-counter`.

pip install napari-organoid-counter
1. You can install `napari-organoid-counter` via [pip](https://pypi.org/project/napari-organoid-counter/):

```pip install napari-organoid-counter```

To install latest development version :
To install latest development version :

pip install git+https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter.git
```pip install git+https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter.git```

For installing on a Windows machine via napari, follow the instuctions [here](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/blob/main/readme-content/How%20to%20install%20on%20a%20Windows%20machine.pdf).
2. Additionally, you will then need to install one additional dependency:

``` mim install "mmcv<2.2.0,>=2.0.0rc4" ```

For installing on a Windows machine directly from within napari, follow the instuctions [here](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/blob/main/readme-content/How%20to%20install%20on%20a%20Windows%20machine.pdf). Step 2 additionally needs to be performed here too (mim install "mmcv<2.2.0,>=2.0.0rc4").

## Quickstart

Expand Down Expand Up @@ -69,6 +74,11 @@ This plugin has been developed and tested with 2D CZI microscopy images of lunch

[2] Eva Maxfield Brown, Talley Lambert, Peter Sobolewski, Napari-AICSImageIO Contributors (2021). Napari-AICSImageIO: Image Reading in Napari using AICSImageIO [Computer software]. GitHub. https://github.com/AllenCellModeling/napari-aicsimageio

The latest version also uses models developed with the ```mmdetection``` package <sup>[3]</sup>, see [here](https://github.com/open-mmlab/mmdetection)

[3] Chen, Kai, et al. "MMDetection: Open mmlab detection toolbox and benchmark." arXiv preprint arXiv:1906.07155 (2019).


## How to Cite
If you use this plugin for your work, please cite it using the following:

Expand Down
30 changes: 17 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Napari Organoid Counter - Version 0.2 is out!

[![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-organoid-counter)](https://napari-hub.org/plugins/napari-organoid-counter)
![stability-stable](https://img.shields.io/badge/stability-stable-green.svg)
[![DOI](https://zenodo.org/badge/476715320.svg)](https://zenodo.org/badge/latestdoi/476715320)
[![License](https://img.shields.io/pypi/l/napari-organoid-counter.svg?color=green)](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/raw/main/LICENSE)
[![PyPI](https://img.shields.io/pypi/v/napari-organoid-counter.svg?color=green)](https://pypi.org/project/napari-organoid-counter)
[![Python Version](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10-blue)](https://python.org)
[![Python Version](https://img.shields.io/badge/python-3.9%20%7C%203.10-blue)](https://python.org)
[![tests](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/workflows/tests/badge.svg)](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/actions)
[![codecov](https://codecov.io/gh/HelmholtzAI-Consultants-Munich/napari-organoid-counter/branch/main/graph/badge.svg)](https://codecov.io/gh/HelmholtzAI-Consultants-Munich/napari-organoid-counter)
[![napari hub](https://img.shields.io/endpoint?url=https://api.napari-hub.org/shields/napari-organoid-counter)](https://napari-hub.org/plugins/napari-organoid-counter)


A napari plugin to automatically count lung organoids from microscopy imaging data. Note: this plugin only supports single channel grayscale images.

Expand All @@ -22,21 +23,21 @@ This [napari] plugin was generated with [Cookiecutter] using [@napari]'s [cookie

## Installation

You can install `napari-organoid-counter` via [pip]:
This plugin has been tested with python 3.9 and 3.10 - you may consider using conda to create your dedicated environment before running the `napari-organoid-counter`.

pip install napari-organoid-counter
1. You can install `napari-organoid-counter` via [pip](https://pypi.org/project/napari-organoid-counter/):

``` pip install napari-organoid-counter```

To install latest development version :
To install latest development version :

pip install git+https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter.git


For the dev branch you can clone this repo and install with:
```pip install git+https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter.git```

pip install -e .
2. Additionally, you will then need to install one additional dependency:

For installing on a Windows machine via napari, follow the instuctions [here](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/blob/main/readme-content/How%20to%20install%20on%20a%20Windows%20machine.pdf).
``` mim install "mmcv<2.2.0,>=2.0.0rc4" ```

For installing on a Windows machine directly from within napari, follow the instuctions [here](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/blob/main/readme-content/How%20to%20install%20on%20a%20Windows%20machine.pdf). Step 2 additionally needs to be performed here too (mim install "mmcv<2.2.0,>=2.0.0rc4").

## What's new in v2?
Checkout our *What's New in v2* [here](https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/blob/main/.napari/DESCRIPTION.md#whats-new-in-v2).
Expand All @@ -48,7 +49,7 @@ For more information on this plugin, its' intended audience, as well as Quicksta

## Contributing

Contributions are very welcome. Tests can be run with [tox], please ensure
Contributions are very welcome. Tests can be run with [pytest], please ensure
the coverage at least stays the same before you submit a pull request.

## License
Expand All @@ -65,6 +66,10 @@ Distributed under the terms of the [MIT] license,

[2] Eva Maxfield Brown, Talley Lambert, Peter Sobolewski, Napari-AICSImageIO Contributors (2021). Napari-AICSImageIO: Image Reading in Napari using AICSImageIO [Computer software]. GitHub. https://github.com/AllenCellModeling/napari-aicsimageio

The latest version also uses models developed with the ```mmdetection``` package <sup>[3]</sup>, see [here](https://github.com/open-mmlab/mmdetection)

[3] Chen, Kai, et al. "MMDetection: Open mmlab detection toolbox and benchmark." arXiv preprint arXiv:1906.07155 (2019).

## Issues

If you encounter any problems, please [file an issue] along with a detailed description.
Expand All @@ -83,7 +88,6 @@ If you encounter any problems, please [file an issue] along with a detailed desc
[file an issue]: https://github.com/HelmholtzAI-Consultants-Munich/napari-organoid-counter/issues

[napari]: https://github.com/napari/napari
[tox]: https://tox.readthedocs.io/en/latest/
[pip]: https://pypi.org/project/pip/
[PyPI]: https://pypi.org/

Expand Down
59 changes: 27 additions & 32 deletions napari_organoid_counter/_orgacount.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch
from torchvision.transforms import ToTensor

from urllib.request import urlretrieve
from napari.utils import progress

from napari_organoid_counter._utils import *
from napari_organoid_counter import settings

#update_version_in_mmdet_init_file('mmdet', '2.2.0', '2.3.0')
import torch
import mmdet
from mmdet.apis import DetInferencer

class OrganoiDL():
'''
Expand All @@ -19,8 +20,6 @@ class OrganoiDL():
The confidence threshold of the model
cur_min_diam: float
The minimum diameter of the organoids
transfroms: torchvision.transforms.ToTensor
The transformation for converting numpy image to tensor so it can be given as an input to the model
model: frcnn
The Faster R-CNN model
img_scale: list of floats
Expand All @@ -45,7 +44,6 @@ def __init__(self, handle_progress):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.cur_confidence = 0.05
self.cur_min_diam = 30
self.transfroms = ToTensor()

self.model = None
self.img_scale = [0., 0.]
Expand All @@ -60,24 +58,23 @@ def set_scale(self, img_scale):

def set_model(self, model_name):
''' Initialise model instance and load model checkpoint and send to device. '''
self.model = frcnn(num_classes=2, rpn_score_thresh=0, box_score_thresh = self.cur_confidence)
self.load_model_checkpoint(model_name)
self.model = self.model.to(self.device)

def download_model(self, model='default'):
model_checkpoint = join_paths(str(settings.MODELS_DIR), settings.MODELS[model_name]["filename"])
mmdet_path = os.path.dirname(mmdet.__file__)
config_dst = join_paths(mmdet_path, str(settings.CONFIGS[model_name]["destination"]))
# download the corresponding config if it doesn't exist already
if not os.path.exists(config_dst):
urlretrieve(settings.CONFIGS[model_name]["source"], config_dst, self.handle_progress)
self.model = DetInferencer(config_dst, model_checkpoint, self.device, show_progress=False)

def download_model(self, model_name='yolov3'):
''' Downloads the model from zenodo and stores it in settings.MODELS_DIR '''
# specify the url of the file which is to be downloaded
down_url = settings.MODELS[model]["source"]
# specify the url of the model which is to be downloaded
down_url = settings.MODELS[model_name]["source"]
# specify save location where the file is to be saved
save_loc = join_paths(str(settings.MODELS_DIR), settings.MODELS[model]["filename"])
# Downloading using urllib
urlretrieve(down_url,save_loc, self.handle_progress)

def load_model_checkpoint(self, model_name):
''' Loads the model checkpoint for the model specified in model_name '''
model_checkpoint = join_paths(settings.MODELS_DIR, settings.MODELS[model_name]["filename"])
ckpt = torch.load(model_checkpoint, map_location=self.device)
self.model.load_state_dict(ckpt) #.state_dict())
save_loc = join_paths(str(settings.MODELS_DIR), settings.MODELS[model_name]["filename"])
# downloading using urllib
urlretrieve(down_url, save_loc, self.handle_progress)

def sliding_window(self,
test_img,
Expand Down Expand Up @@ -120,20 +117,20 @@ def sliding_window(self,
for i in progress(range(0, prepadded_height, step)):
for j in progress(range(0, prepadded_width, step)):
# crop
img_crop = test_img[:, :, i:(i+window_size), j:(j+window_size)]
img_crop = test_img[i:(i+window_size), j:(j+window_size)]
# get predictions
output = self.model(img_crop.float())
preds = output[0]['boxes']
if preds.size(0)==0: continue
output = self.model(img_crop)
preds = output['predictions'][0]['bboxes']
if len(preds)==0: continue
else:
for bbox_id in range(preds.size(0)):
y1, x1, y2, x2 = preds[bbox_id].cpu().detach() # predictions from model will be in form x1,y1,x2,y2
for bbox_id in range(len(preds)):
y1, x1, y2, x2 = preds[bbox_id] # predictions from model will be in form x1,y1,x2,y2
x1_real = torch.div(x1+i, rescale_factor, rounding_mode='floor')
x2_real = torch.div(x2+i, rescale_factor, rounding_mode='floor')
y1_real = torch.div(y1+j, rescale_factor, rounding_mode='floor')
y2_real = torch.div(y2+j, rescale_factor, rounding_mode='floor')
pred_bboxes.append(torch.Tensor([x1_real, y1_real, x2_real, y2_real]))
scores_list.append(output[0]['scores'][bbox_id].cpu().detach())
scores_list.append(output['predictions'][0]['scores'][bbox_id])
return pred_bboxes, scores_list

def run(self,
Expand Down Expand Up @@ -170,9 +167,7 @@ def run(self,
ready_img, prepadded_height, prepadded_width = prepare_img(img,
step,
window_size,
rescale_factor,
self.transfroms,
self.device)
rescale_factor)
# and run sliding window over whole image
bboxes, scores = self.sliding_window(ready_img,
step,
Expand All @@ -184,7 +179,7 @@ def run(self,
scores)
# stack results
bboxes = torch.stack(bboxes)
scores = torch.stack(scores)
scores = torch.Tensor(scores)
# apply NMS to remove overlaping boxes
bboxes, pred_scores = apply_nms(bboxes, scores)
self.pred_bboxes[shapes_name] = bboxes
Expand Down
51 changes: 26 additions & 25 deletions napari_organoid_counter/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from contextlib import contextmanager
import os
from pathlib import Path
import pkgutil

import numpy as np
import math
Expand All @@ -10,9 +11,6 @@
from skimage.color import gray2rgb

import torch
import torch.nn as nn
from torchvision.models import detection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.ops import nms

from napari_organoid_counter import settings
Expand Down Expand Up @@ -104,7 +102,7 @@ def squeeze_img(img):
""" Squeeze image - all dims that have size one will be removed """
return np.squeeze(img)

def prepare_img(test_img, step, window_size, rescale_factor, trans, device):
def prepare_img(test_img, step, window_size, rescale_factor):
""" The original image is prepared for running model inference """
# squeeze and resize image
test_img = squeeze_img(test_img)
Expand All @@ -119,10 +117,8 @@ def prepare_img(test_img, step, window_size, rescale_factor, trans, device):
test_img = (255*test_img).astype(np.uint8)
test_img = gray2rgb(test_img) #[H,W,C]

# convert to tensor and send to device
test_img = trans(test_img)
test_img = torch.unsqueeze(test_img, axis=0) #[B, C, H, W]
test_img = test_img.to(device)
# convert from RGB to GBR - expected from DetInferencer
test_img = test_img[..., ::-1]

return test_img, img_height, img_width

Expand Down Expand Up @@ -175,20 +171,25 @@ def apply_normalization(img):
img_norm = (255 * (img - img_min) / (img_max - img_min)).astype(np.uint8)
return img_norm

class frcnn(nn.Module):
def __init__(self, num_classes,rpn_score_thresh=0,box_score_thresh=0.05):
""" An FRCNN module loads the pretrained FasterRCNN model """
super(frcnn, self).__init__()
# define classes and load pretrained model
self.num_classes = num_classes
self.model = detection.fasterrcnn_resnet50_fpn(pretrained=True, rpn_score_thresh = rpn_score_thresh, box_score_thresh = box_score_thresh)
# get number of input features for the classifier
self.in_features = self.model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
self.model.roi_heads.box_predictor = FastRCNNPredictor(self.in_features, self.num_classes)
self.model.eval()

def forward(self, x, return_all=False):
""" A forward pass through the model """
return self.model(x)

def get_package_init_file(package_name):
loader = pkgutil.get_loader(package_name)
if loader is None or not hasattr(loader, 'get_filename'):
raise ImportError(f"Cannot find package {package_name}")
package_path = loader.get_filename(package_name)
# Determine the path to the __init__.py file
if os.path.isdir(package_path):
init_file_path = os.path.join(package_path, '__init__.py')
else:
init_file_path = package_path
if not os.path.isfile(init_file_path):
raise FileNotFoundError(f"__init__.py file not found for package {package_name}")
return init_file_path

def update_version_in_mmdet_init_file(package_name, old_version, new_version):
init_file_path = get_package_init_file(package_name)
with open(init_file_path, 'r') as file:
lines = file.readlines()
with open(init_file_path, 'w') as file:
for line in lines:
if f"mmcv_maximum_version = '{old_version}'" in line:
file.write(line.replace(old_version, new_version))
Loading

0 comments on commit 7f7518a

Please sign in to comment.