Skip to content

Commit

Permalink
add test for TimmBackbone
Browse files Browse the repository at this point in the history
1. Split `test_functional` into `test_engine` and `test_full_inference`
2. In `test_full_inference` use `@pytest.mark.parametrize` for `CNNBackbone` and `TimmBackbone` instead of making 2 copies of the function
  • Loading branch information
GeorgeBatch committed Oct 29, 2024
1 parent d30d7d8 commit d08127f
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions tests/models/test_feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from typing import Callable

import numpy as np
import pytest
import torch

from tiatoolbox.models.architecture.vanilla import CNNBackbone
from tiatoolbox.models.architecture.vanilla import CNNBackbone, TimmBackbone
from tiatoolbox.models.engine.semantic_segmentor import (
DeepFeatureExtractor,
IOSegmentorConfig,
Expand All @@ -22,8 +23,8 @@
# -------------------------------------------------------------------------------------


def test_functional(remote_sample: Callable, tmp_path: Path) -> None:
"""Test for feature extraction."""
def test_engine(remote_sample: Callable, tmp_path: Path) -> None:
"""Test feature extraction with DeepFeatureExtractor engine."""
save_dir = tmp_path / "output"
# # convert to pathlib Path to prevent wsireader complaint
mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))
Expand All @@ -41,12 +42,27 @@ def test_functional(remote_sample: Callable, tmp_path: Path) -> None:
wsi_0_root_path = output_list[0][1]
positions = np.load(f"{wsi_0_root_path}.position.npy")
features = np.load(f"{wsi_0_root_path}.features.0.npy")
assert len(positions.shape) == 2
assert len(features.shape) == 4

# * test same output between full infer and engine
# pre-emptive clean up
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test


@pytest.mark.parametrize(
"model", [CNNBackbone("resnet50"), TimmBackbone("efficientnet_b0", pretrained=True)]
)
def test_full_inference(
remote_sample: Callable, tmp_path: Path, model: Callable
) -> None:
"""Test full inference with CNNBackbone and TimmBackbone models."""
save_dir = tmp_path / "output"
# pre-emptive clean up
shutil.rmtree(save_dir, ignore_errors=True) # default output dir test

mini_wsi_svs = Path(remote_sample("wsi4_1k_1k_svs"))

ioconfig = IOSegmentorConfig(
input_resolutions=[
{"units": "mpp", "resolution": 0.25},
Expand All @@ -60,7 +76,6 @@ def test_functional(remote_sample: Callable, tmp_path: Path) -> None:
save_resolution={"units": "mpp", "resolution": 8.0},
)

model = CNNBackbone("resnet50")
extractor = DeepFeatureExtractor(batch_size=4, model=model)
# should still run because we skip exception
output_list = extractor.predict(
Expand Down

0 comments on commit d08127f

Please sign in to comment.