-
Notifications
You must be signed in to change notification settings - Fork 31
/
example_prediction.py
47 lines (38 loc) · 1.87 KB
/
example_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from PIL import Image
import torch
from modeling.BaseModel import BaseModel
from modeling import build_model
from utilities.distributed import init_distributed
from utilities.arguments import load_opt_from_config_files
from utilities.constants import BIOMED_CLASSES
import numpy as np
from inference_utils.inference import interactive_infer_image
from inference_utils.output_processing import check_mask_stats
opt = load_opt_from_config_files(["configs/biomedparse_inference.yaml"])
opt = init_distributed(opt)
# Load model from pretrained weights
pretrained_pth = 'pretrained/biomed_parse.pt'
pretrained_pth = 'hf_hub:microsoft/BiomedParse'
model = BaseModel(opt, build_model(opt)).from_pretrained(pretrained_pth).eval().cuda()
with torch.no_grad():
model.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(BIOMED_CLASSES + ["background"], is_eval=True)
# Load image and run inference
# RGB image input of shape (H, W, 3). Currently only batch size 1 is supported.
image = Image.open('examples/Part_1_516_pathology_breast.png', formats=['png'])
image = image.convert('RGB')
# text prompts querying objects in the image. Multiple ones can be provided.
prompts = ['neoplastic cells', 'inflammatory cells']
# load ground truth mask
gt_masks = []
for prompt in prompts:
gt_mask = Image.open(f"examples/Part_1_516_pathology_breast_{prompt.replace(' ', '+')}.png", formats=['png'])
gt_mask = 1*(np.array(gt_mask.convert('RGB'))[:,:,0] > 0)
gt_masks.append(gt_mask)
pred_mask = interactive_infer_image(model, image, prompts)
# prediction with ground truth mask
for i, pred in enumerate(pred_mask):
gt = gt_masks[i]
dice = (1*(pred>0.5) & gt).sum() * 2.0 / (1*(pred>0.5).sum() + gt.sum())
print(f'Dice score for {prompts[i]}: {dice:.4f}')
p_value = check_mask_stats(np.array(image), pred*255, 'Pathology', prompts[i])
print(f'p-value for {prompts[i]}: {p_value:.4f}')