-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathplotting.py
34 lines (27 loc) · 928 Bytes
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
'''
Plotting utilities.
'''
__author__ = "Johannes Bjerva, and Malvina Nissim (modified by Mike Zhang)"
__credits__ = ["Johannes Bjerva", "Malvina Nissim"]
__license__ = "GPL v3"
__version__ = "0.2"
__maintainer__ = "Mike Zhang"
__email__ = "[email protected]"
__status__ = "early alpha"
from typing import List
import numpy as np
import matplotlib.pyplot as plt
def plot_confusion_matrix(cm: np.ndarray, test_y: List[str], title='Confusion matrix', cmap=plt.cm.Blues):
plt.figure()
plt.imshow(np.vstack((cm, np.zeros(cm.shape[0]))), interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(cm.shape[0])
plt.xticks(tick_marks, sorted(list(set(test_y))), rotation=45)
plt.yticks(tick_marks, sorted(list(set(test_y))))
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()