-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
75 lines (68 loc) · 3.07 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st
import torchxrayvision as xrv
import cv2
import torch
import torchvision
import PIL
import numpy as np
def load_xray(data_path):
image_loaded = PIL.Image.open(data_path)
image_loaded = image_loaded.resize((350,350))
img = np.array(image_loaded)
show_XRAY(img)
if img.ndim==2:
img = np.expand_dims(img,axis=0)
img = xrv.datasets.normalize(img, 255) # convert 8-bit image to [-1024, 1024] range
img = img.mean(2)[None, ...] # Make single color channel
return img
img = xrv.datasets.normalize(img, 255) # convert 8-bit image to [-1024, 1024] range
img = img.mean(2)[None, ...] # Make single color channel
return img
def findings(results,thresh):
if [(key, value) for key, value in results.items() if value > thresh]==[]:
st.write('This XRAY is normal')
else:
findings_list= [(key, value) for key, value in results.items() if value > thresh]
for finding in findings_list:
st.write(finding[0],'with probability','{:.3f}'.format(finding[1]))
# st.write ([(key, value) for key, value in results.items() if value > thresh])
def show_XRAY(image):
image = np.squeeze(image)
st.image(image)
code = st.text_input("Enter code")
if code=='epivlepsi':
# st.title("Triage of Chest X-RAYs")
with st.container():
bio_image= cv2.imread('epivlepsi.png')
bio_image = cv2.cvtColor(bio_image, cv2.COLOR_BGR2RGB)
st.image(bio_image)
uploaded_file = st.file_uploader("Choose an XRAY image (not DICOM) ",type=['png', 'jpg','jpeg'])
if uploaded_file is not None:
img = load_xray(uploaded_file)
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
img = transform(img)
img = torch.from_numpy(img)
# Load model and process image
model = xrv.models.DenseNet(weights="densenet121-res224-all")
outputs = model(img[None,...]) # or model.features(img[None,...])
# Print results
results = dict(zip(model.pathologies,outputs[0].detach().numpy()))
findings(results,0.6)
elif code=='wings':
st.title("Triage of Chest X-RAYs")
with st.container():
bio_image= cv2.imread('wings.png')
bio_image = cv2.cvtColor(bio_image, cv2.COLOR_BGR2RGB)
st.image(bio_image)
uploaded_file = st.file_uploader("Choose an XRAY image (not DICOM) ",type=['png', 'jpg','jpeg'])
if uploaded_file is not None:
img = load_xray(uploaded_file)
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
img = transform(img)
img = torch.from_numpy(img)
# Load model and process image
model = xrv.models.DenseNet(weights="densenet121-res224-all")
outputs = model(img[None,...]) # or model.features(img[None,...])
# Print results
results = dict(zip(model.pathologies,outputs[0].detach().numpy()))
findings(results,0.6)