Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for histopathology models #834

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions doc/bioimageio/histopathology_v1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Segment Anything for Histopathology

This is a [Segment Anything]https://segment-anything.com/) model that was specialized for histopathology with [micro_sam](https://github.com/computational-cell-analytics/micro-sam).
This model uses a %s vision transformer as image encoder.

Segment Anything is a model for interactive and automatic instance segmentation.
We improve it for histopathology by finetuning on a large and diverse microscopy dataset.
It should perform well for nucleus segmentation in histopathology datasets.

See [the dataset overview](https://github.com/computational-cell-analytics/micro-sam/blob/master/doc/datasets/histopathology_v%i.md) for further informations on the training data and the [micro_sam documentation](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html) for details on how to use the model for interactive and automatic segmentation.

## Validation

The easiest way to validate the model is to visually check the segmentation quality for your data.
If you have annotations you can use for validation you can also quantitative validation, see [here for details](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#9-how-can-i-evaluate-a-model-i-have-finetuned).
Please note that the required quality for segmentation always depends on the analysis task you want to solve.
22 changes: 13 additions & 9 deletions micro_sam/bioimageio/model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
"tags": ["segment-anything", "instance-segmentation"],
}

# Reference: https://github.com/bioimage-io/spec-bioimage-io/commit/39d343681d427ec93cf69eef7597d9eb9678deb1#diff-0bbdaa8196fa31f945afabcf04a4295ff098f1f24400ef9e59b0f684d411905eL269 # noqa
# We had this parameter in bioimageio.spec. This has been removed. We just make a copy of the same parameter.
ARBITRARY_SIZE = spec.ParameterizedSize(min=1, step=1)


def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path, tmp_dir):

Expand Down Expand Up @@ -204,7 +208,7 @@ def _check_model(model_description, input_paths, result_paths):
image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx"))
embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx"))
box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=tuple("bic"))
point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("biic"))
point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("bhwc"))
point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=tuple("bic"))
mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=tuple("bicyx"))

Expand Down Expand Up @@ -292,8 +296,8 @@ def export_sam_model(
# NOTE: to support 1 and 3 channels we can add another preprocessing.
# Best solution: Have a pre-processing for this! (1C -> RGB)
spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
spec.SpaceInputAxis(id=spec.AxisId("y"), size=spec.ARBITRARY_SIZE),
spec.SpaceInputAxis(id=spec.AxisId("x"), size=spec.ARBITRARY_SIZE),
spec.SpaceInputAxis(id=spec.AxisId("y"), size=ARBITRARY_SIZE),
spec.SpaceInputAxis(id=spec.AxisId("x"), size=ARBITRARY_SIZE),
],
test_tensor=spec.FileDescr(source=input_paths["image"]),
data=spec.IntervalOrRatioDataDescr(type="uint8")
Expand All @@ -307,7 +311,7 @@ def export_sam_model(
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]),
],
Expand All @@ -323,11 +327,11 @@ def export_sam_model(
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
spec.IndexInputAxis(
id=spec.AxisId("point"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]),
],
Expand All @@ -343,11 +347,11 @@ def export_sam_model(
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
spec.IndexInputAxis(
id=spec.AxisId("point"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
],
test_tensor=spec.FileDescr(source=input_paths["point_labels"]),
Expand All @@ -362,7 +366,7 @@ def export_sam_model(
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
size=ARBITRARY_SIZE
),
spec.ChannelAxis(channel_names=["channel"]),
spec.SpaceInputAxis(id=spec.AxisId("y"), size=256),
Expand Down
14 changes: 14 additions & 0 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def models():
"vit_l_em_organelles": "xxh128:096c9695966803ca6fde24f4c1e3c3fb",
"vit_b_em_organelles": "xxh128:f6f6593aeecd0e15a07bdac86360b6cc",
"vit_t_em_organelles": "xxh128:253474720c497cce605e57c9b1d18fd9",
# Histopathology models:
"vit_b_histopathology": "xxh128:ffd1a2cd84570458b257bd95fdd8f974",
"vit_l_histopathology": "xxh128:b591833c89754271023e901281dee3f2",
"vit_h_histopathology": "xxh128:bd1856dafc156a43fb3aa705f1a6e92e",
}
# Additional decoders for instance segmentation.
decoder_registry = {
Expand All @@ -123,6 +127,10 @@ def models():
"vit_l_em_organelles_decoder": "xxh128:d60fd96bd6060856f6430f29e42568fb",
"vit_b_em_organelles_decoder": "xxh128:b2d4dcffb99f76d83497d39ee500088f",
"vit_t_em_organelles_decoder": "xxh128:8f897c7bb93174a4d1638827c4dd6f44",
# Histopathology models:
"vit_b_histopathology_decoder": "xxh128:6a66194dcb6e36199cbee2214ecf7213",
"vit_l_histopathology_decoder": "xxh128:46aab7765d4400e039772d5a50b55c04",
"vit_h_histopathology_decoder": "xxh128:3ed9f87e46ad5e16935bd8d722c8dc47",
}
registry = {**encoder_registry, **decoder_registry}

Expand All @@ -137,6 +145,9 @@ def models():
"vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", # noqa
"vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt",
"vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa
"vit_b_histopathology": "https://owncloud.gwdg.de/index.php/s/sBB4H8CTmIoBZsQ/download",
"vit_l_histopathology": "https://owncloud.gwdg.de/index.php/s/IZgnn1cpBq2PHod/download",
"vit_h_histopathology": "https://owncloud.gwdg.de/index.php/s/L7AcvVz7DoWJ2RZ/download",
}

decoder_urls = {
Expand All @@ -146,6 +157,9 @@ def models():
"vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", # noqa
"vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", # noqa
"vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa
"vit_b_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/KO9AWqynI7SFOBj/download",
"vit_l_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/oIs6VSmkOp7XrKF/download",
"vit_h_histopathology_decoder": "https://owncloud.gwdg.de/index.php/s/1qAKxy5H0jgwZvM/download",
}
urls = {**encoder_urls, **decoder_urls}

Expand Down
132 changes: 132 additions & 0 deletions scripts/model_export/export_histopathology_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
import xxhash
import argparse
import warnings
from glob import glob

import h5py

import bioimageio.spec.model.v0_5 as spec

from micro_sam.bioimageio import export_sam_model

from models import get_id_and_emoji


MODEL_TO_NAME = {
"vit_b_histopathology": "SAM Histopathology Generalist (ViT-B)",
"vit_l_histopathology": "SAM Histopathology Generalist (ViT-L)",
"vit_h_histopathology": "SAM Histopathology Generalist (ViT-H)",
}

BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
OUTPUT_FOLDER = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/exported_models/"
PUMA_ROOT = "/mnt/vast-nhr/projects/cidas/cca/experiments/patho_sam/data/puma"


def create_doc(model_type, version):
template_file = os.path.join(
os.path.split(__file__)[0], "../../doc/bioimageio", f"histopathology_v{version}.md"
)
assert os.path.exists(template_file), template_file
with open(template_file, "r") as f:
template = f.read()

doc = template % (model_type, version)
return doc


def get_data():
input_paths = glob(os.path.join(PUMA_ROOT, "test", "preprocessed", "training_set_*.h5"))
# Choose the first input path
input_path = input_paths[0]

with h5py.File(input_path, "r") as f:
image = f["raw"][:]
label_image = f["labels/nuclei"][:]

# Convert to channels first.
image = image.transpose(1, 2, 0)

return image, label_image


def compute_checksum(path):
xxh_checksum = xxhash.xxh128()
with open(path, "rb") as f:
while True:
data = f.read(BUF_SIZE)
if not data:
break
xxh_checksum.update(data)
return xxh_checksum.hexdigest()


def export_model(model_path, model_type, version, email):
output_folder = os.path.join(OUTPUT_FOLDER, "histopathology")
os.makedirs(output_folder, exist_ok=True)

model_name = f"{model_type}_histopathology"

output_path = os.path.join(output_folder, model_name)
if os.path.exists(output_path):
print("The model", model_name, "has already been exported.")
return

image, label_image = get_data()
covers = ["./covers/cover_lm.png"] # HACK: We use existing covers.
doc = create_doc(model_type, version)

model_id, emoji = get_id_and_emoji(model_name)
uploader = spec.Uploader(email=email)

export_name = MODEL_TO_NAME[model_name]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
export_sam_model(
image, label_image,
name=export_name,
model_type=model_type,
checkpoint_path=model_path,
output_path=output_path,
documentation=doc,
covers=covers,
id=model_id,
id_emoji=emoji,
uploader=uploader,
)

# NOTE: I needed to unzip the files myself. Not sure how this worked before. Maybe something changed in spec?
from torch_em.data.datasets.util import unzip
unzip(zip_path=output_path, dst=(output_path + ".unzip"))

print("Exported model", model_id)
encoder_path = os.path.join(output_path + ".unzip", f"{model_type}.pt")
encoder_checksum = compute_checksum(encoder_path)
print("Encoder:")
print(model_name, f"xxh128:{encoder_checksum}")

decoder_path = os.path.join(output_path + ".unzip", f"{model_type}_decoder.pt")
decoder_checksum = compute_checksum(decoder_path)
print("Decoder:")
print(f"{model_name}_decoder", f"xxh128:{decoder_checksum}")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--email", required=True)
parser.add_argument("-v", "--version", default=1, type=int)
parser.add_argument("-c", "--checkpoint", required=True, type=str)
parser.add_argument("-m", "--model_type", required=True, type=str)
args = parser.parse_args()

export_model(
model_path=args.checkpoint,
model_type=args.model_type,
version=1,
email=args.email,
)


if __name__ == "__main__":
main()
5 changes: 2 additions & 3 deletions test/test_bioimageio/test_model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from shutil import rmtree

import bioimageio.spec

import micro_sam.util as util
from micro_sam.sample_data import synthetic_data

spec_minor = int(bioimageio.spec.__version__.split(".")[1])


@unittest.skipIf(spec_minor < 5, "Needs bioimagio.spec >= 0.5")
@unittest.expectedFailure
class TestModelExport(unittest.TestCase):
tmp_folder = "tmp"
model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b"
Expand All @@ -20,9 +20,8 @@ def setUp(self):
os.makedirs(self.tmp_folder, exist_ok=True)

def tearDown(self):
rmtree(self.tmp_folder)
rmtree(self.tmp_folder, ignore_errors=True)

@unittest.expectedFailure
def test_model_export(self):
from micro_sam.bioimageio import export_sam_model
image, labels = synthetic_data(shape=(1024, 1022))
Expand Down
Loading