-
Notifications
You must be signed in to change notification settings - Fork 2
/
log_compare.py
67 lines (49 loc) · 1.64 KB
/
log_compare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/env python3.7
from typing import Callable
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from functools import partial
def ext_log_barrier_quadra_2(t, z):
if z <= - 1 / t**2:
return - np.log(-z) / t
else:
return t * z**2 + - np.log(-1 / t**2) / t - 1 / t**3
def ext_log_barrier_quadra(t, z):
if z <= - 1 / t:
return - np.log(-z) / t
else:
return t * z**2 - np.log(1/t) / t - 1 / t
def ext_log_barrier(t, z):
if z <= - 1 / t**2:
return - np.log(-z) / t
else:
return t * z + -np.log(1 / (t**2)) / t + 1 / t
def log_barrier(t, z):
if z <= - 1 / t**2:
return - np.log(-z) / t
else:
return t * z + -np.log(1 / (t**2)) / t + 1 / t
def quadratic(z):
return np.maximum(z, 0) ** 2
def relu(z):
return np.maximum(z, 0)
if __name__ == "__main__":
v_penalty: Callable = np.vectorize(quadratic)
v_relu: Callable = np.vectorize(relu)
matplotlib.rc('font', **{'size': 12})
matplotlib.rc('text', usetex=True)
xs: np.ndarray = np.linspace(-2, 2, 2000, dtype=np.float32)
plt.plot(xs, v_penalty(xs), label='$f$ = Quadratic')
plt.plot(xs, v_relu(xs), label='$f$ = ReLU')
for t in []:
v_barrier = np.vectorize(partial(ext_log_barrier, t))
plt.plot(xs, v_barrier(xs), label=f'Extended Log-barrier (t={t})')
plt.plot([0, 0], [0, 100], linewidth=1, color='gray', linestyle='--')
plt.plot([-100, 0], [0, 0], linewidth=1, color='gray', linestyle='--')
plt.xlabel('$z$')
plt.ylabel('$f(z)$', rotation=0)
plt.ylim(-1, 5)
plt.xlim(-2, 2)
plt.legend(loc='upper left')
plt.show()