-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathalm.py
67 lines (50 loc) · 1.54 KB
/
alm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
solve the following problem with Augmented Lagrange Multiplier method
min f(x) = -3x[0] - 5x[1]
s.t. x[0] + x[2] = 4
2x[1] + x[3] = 12
3x[0] + 2x[1] + x[4] = 18
x[0], x[1], x[2], x[3], x[4] >= 0
"""
import torch
def lagrangian_function(x, lambda_):
return f(x) + lambda_ @ (A @ x - b) + alpha / 2 * ((A @ x - b)**2).sum()
def f(x):
return c @ x
def update_x(x, lambda_):
""" update x with gradient descent """
lagrangian_function(x, lambda_).backward()
new_x = x - eta * x.grad
x.data = new_x.clamp(min=0)
x.grad.zero_()
def update_lambda(lambda_):
new_lambda = lambda_ + alpha * (A @ x - b)
lambda_.data = new_lambda
def pprint(i, x, lambda_):
print(
f'\n{i+1}th iter, L:{lagrangian_function(x, lambda_):.2f}, f: {f(x):.2f}'
)
print(f'x: {x}')
print(f'lambda: {lambda_}')
print("constraints violation: ")
print(A @ x - b)
def solve(x, lambda_):
for i in range(500):
pprint(i, x, lambda_)
update_x(x, lambda_)
update_lambda(lambda_)
if __name__ == '__main__':
eta = 0.03
alpha = 1
"""
min f(x) = c^T x
s.t. Ax = b
x >= 0
"""
c = torch.tensor([-3, -5, 0, 0, 0], dtype=torch.float32)
A = torch.tensor([[1, 0, 1, 0, 0], [0, 2, 0, 1, 0], [3, 2, 0, 0, 1]],
dtype=torch.float32)
b = torch.tensor([4, 12, 18], dtype=torch.float32)
lambda_ = torch.tensor([0, 0, 0], dtype=torch.float32)
x = torch.tensor([2, 0, 2, 0, 0], dtype=torch.float32, requires_grad=True)
solve(x, lambda_)