-
Notifications
You must be signed in to change notification settings - Fork 5
/
loss.py
55 lines (45 loc) · 1.02 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import numpy as np
import torch.nn.functional as F
# def pair_similarity(x, y):
# '''
# x: n * dx
# y: m * dy
# '''
#
# n = x.size(0)
# m = y.size(0)
# d = x.size(1)
#
# x = x.unsqueeze(1).expand(n, m, d)
# y = y.unsqueeze(0).expand(n, m, d)
# ps = torch.eq(x,y).squeeze(2)
# ps[ps==0] = -1
# return ps
def cdist(x, y):
'''
x: n * dx
y: m * dy
'''
n = x.size(0)
m = y.size(0)
d = x.size(1)
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
dist = torch.pow(x - y, 2).sum(2)
return dist
def pair_similarity(x, y):
'''
x: n * dx
y: m * dy
'''
n = x.size(0)
m = y.size(0)
d = x.size(1)
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
ps = torch.eq(x,y).squeeze(2)
return ps
def relation_loss(relation_score, labelS):
loss = torch.nn.MSELoss(reduction='none')(relation_score, labelS.float()).sum() / np.sqrt(labelS.size(0))
return loss