forked from jjzha/foml
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathplotting.py
42 lines (35 loc) · 1.11 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plotting utilities.
"""
__author__ = "Johannes Bjerva, and Malvina Nissim (modified by Mike Zhang)"
__credits__ = ["Johannes Bjerva", "Malvina Nissim"]
__license__ = "GPL v3"
__version__ = "0.2"
__maintainer__ = "Mike Zhang"
__email__ = "[email protected]"
__status__ = "early alpha"
import os
from datetime import datetime
from typing import List
import numpy as np
import matplotlib.pyplot as plt
def plot_confusion_matrix(
cm: np.ndarray, test_y: List[str], title="Confusion matrix", cmap=plt.cm.Blues
):
plt.figure()
plt.imshow(
np.vstack((cm, np.zeros(cm.shape[0]))), interpolation="nearest", cmap=cmap
)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(cm.shape[0])
plt.xticks(tick_marks, sorted(list(set(test_y))), rotation=45)
plt.yticks(tick_marks, sorted(list(set(test_y))))
plt.tight_layout()
plt.ylabel("True label")
plt.xlabel("Predicted label")
if not os.path.exists("plot_images"):
os.makedirs("plot_images")
plt.savefig("plot_images/" + datetime.now().isoformat() + "-" + title + "-plot.png")