-
Notifications
You must be signed in to change notification settings - Fork 2
/
sgm.py
123 lines (73 loc) · 2.78 KB
/
sgm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# torch 版
import torch
from scipy.optimize import linear_sum_assignment
def sgraphmatch(A,B,m,iteration):
# m,seed 节点的个数 iteration 迭代的个数
totv = A.shape[0]
n = totv-m
start = torch.ones(n,n).cuda()*(1/n)
if m!= 0:
# 标识未知与已知
A12 = A[:m,m:totv]
A21 = A[m:totv,:m]
B12 = B[:m,m:totv]
B21 = B[m:totv,:m]
if m == 0:
A12 = A21 = B12 = B21 = torch.zeros_like(n,n)
if n==1:
A12 = A12.T
A21 = A21.T
B12 = B12.T
B21 = B21.T
# 标识 未知与未知
A22 = A[m:totv,m:totv]
B22 = B[m:totv,m:totv]
tol = 1
patience = iteration
P = start #start 是初始选择的节点
toggle = 1
iter = 0
x = torch.mm(A21,B21.T)
y = torch.mm(A12.T,B12)
while (toggle == 1 and iter < patience):
iter = iter + 1
z = torch.mm(torch.mm(A22,P),B22.T)
w = torch.mm(torch.mm(A22.T,P),B22)
Grad = x + y + z + w # 目标函数关于P 的一阶导
mm = abs(Grad).max()
obj = Grad+torch.ones([n,n]).cuda()*mm
_,ind = linear_sum_assignment(-obj.cpu())
Tt = torch.eye(n).cuda()
Tt = Tt[ind] # 按照ind 的顺序排列矩阵
wt = torch.mm(torch.mm(A22.T,Tt),B22)
c = torch.sum(torch.diag(torch.mm(w,P.T)))
d = torch.sum(torch.diag(torch.mm(wt,P.T)))+torch.sum(torch.diag(torch.mm(wt,Tt.T)))
e = torch.sum(torch.diag(torch.mm(wt,Tt.T)))
u = torch.sum(torch.diag(torch.mm(P.T,x) + torch.mm(P.T,y)))
v = torch.sum(torch.diag(torch.mm(Tt.T,x)+torch.mm(Tt.T,y)))
if (c - d + e == 0 and d - 2 * e + u - v == 0):
alpha = 0
else:
alpha = -(d - 2 * e + u - v)/(2 * (c - d + e))
f0 = 0
f1 = c - e + u - v
falpha = (c - d + e) * alpha**2 + (d - 2 * e + u - v) * alpha
if (alpha < tol and alpha > 0 and falpha > f0 and falpha > f1):
P = alpha * P + (1 - alpha) * Tt
elif f0 > f1:
P = Tt
else:
P = Tt
toggle = 0
break
D = P
_,corr = linear_sum_assignment(-P.cpu()) # matrix(solve_LSAP(P, maximum = TRUE))# return matrix P
corr = torch.LongTensor(corr).cuda()
P = torch.eye(n).cuda()
ccat = torch.cat([torch.eye(m).cuda(),torch.zeros([m,n]).cuda()],1)
P = torch.index_select(P,0,corr)
rcat = torch.cat([torch.zeros([n,m]).cuda(),P],1)
P = torch.cat((ccat,rcat),0)
# P = np.vstack([np.hstack([torch.eye(m),torch.zeros([m,n])]),np.hstack([np.zeros([n,m]),P[corr]])])
corr = corr
return corr,P