-
Notifications
You must be signed in to change notification settings - Fork 4
/
metrics.py
18 lines (18 loc) · 866 Bytes
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from tensorflow.keras.metrics import MeanIoU
import numpy as np
class m_iou():
def __init__(self, classes: int) -> None:
self.classes = classes
def mean_iou(self,y_true, y_pred):
y_pred = np.argmax(y_pred, axis = 3)
miou_keras = MeanIoU(num_classes= self.classes)
miou_keras.update_state(y_true, y_pred)
return miou_keras.result().numpy()
def miou_class(self, y_true, y_pred):
y_pred = np.argmax(y_pred, axis = 3)
miou_keras = MeanIoU(num_classes= self.classes)
miou_keras.update_state(y_true, y_pred)
values = np.array(miou_keras.get_weights()).reshape(self.classes, self.classes)
for i in range(self.classes):
class_iou = values[i,i] / (sum(values[i,:]) + sum(values[:,i]) - values[i,i])
print(f'IoU for class{str(i + 1)} is: {class_iou}')