-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_lib.py
45 lines (39 loc) · 1.51 KB
/
plot_lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from matplotlib import pyplot as plt
import numpy as np
import torch
def set_default():
plt.style.use(['dark_background', 'bmh'])
plt.rc('axes', facecolor='k')
plt.rc('figure', facecolor='k')
plt.rc('figure', figsize=(10, 10))
def plot_data(X, y, d=0, auto=False, zoom=1):
plt.scatter(X.numpy()[:, 0], X.numpy()[:, 1], c=y, s=20, cmap=plt.cm.Spectral)
plt.axis('square')
plt.axis(np.array((-1.1, 1.1, -1.1, 1.1)) * zoom)
if auto is True: plt.axis('equal')
plt.axis('off')
_m, _c = 0, '.15'
plt.axvline(0, ymin=_m, color=_c, lw=1, zorder=0)
plt.axhline(0, xmin=_m, color=_c, lw=1, zorder=0)
def plot_model(X, y, model):
mesh = np.arange(-1.1, 1.1, 0.01)
xx, yy = np.meshgrid(mesh, mesh)
with torch.no_grad():
data = torch.from_numpy(np.vstack((xx.reshape(-1), yy.reshape(-1))).T).float()
Z = model(data).detach()
Z = np.argmax(Z, axis=1).reshape(xx.shape)
plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral, alpha=0.3)
plot_data(X, y)
def show_scatterplot(X, colors, title=''):
colors = colors.numpy()
X = X.numpy()
plt.figure()
plt.axis('equal')
plt.scatter(X[:, 0], X[:, 1], c=colors, s=30)
# plt.grid(True)
plt.title(title)
plt.axis('off')
def plot_bases(bases, width=0.04):
bases[2:] -= bases[:2]
plt.arrow(*bases[0], *bases[2], width=width, color=(1,0,0), zorder=10, alpha=1., length_includes_head=True)
plt.arrow(*bases[1], *bases[3], width=width, color=(0,1,0), zorder=10, alpha=1., length_includes_head=True)