Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I try to reproduce the SOSR based on the Harnet, but some questions are induced. Can you give me some advice? #8

Open
Ahandsomenaive opened this issue Jul 18, 2020 · 5 comments

Comments

@Ahandsomenaive
Copy link

Ahandsomenaive commented Jul 18, 2020

I try to reproduce the SOSR based on the Harnet, but some questions are induced. Can you give me some advice?

the code about loss is here:

`import torch
import torch.nn as nn
import sys

def distance_matrix_vector(anchor, positive):
"""Given batch of anchor descriptors and positive descriptors calculate distance matrix"""

d1_sq = torch.sum(anchor * anchor, dim=1).unsqueeze(-1)
d2_sq = torch.sum(positive * positive, dim=1).unsqueeze(-1)

eps = 1e-6
return torch.sqrt((d1_sq.repeat(1, positive.size(0)) + torch.t(d2_sq.repeat(1, anchor.size(0)))
                  - 2.0 * torch.bmm(anchor.unsqueeze(0), torch.t(positive).unsqueeze(0)).squeeze(0))+eps) 

def inner_dot_matrix(anchor, postive):

inner = torch.mm(anchor, torch.t(postive))
mask = torch.eye(inner.size(1)).cuda() 
inner = inner - 1e-6*mask
dist_m = torch.sqrt( 2.0*(1.0-inner) + 1e-8)
return dist_m

def loss_SosNet(anchor, positive, anchor_swap = False, anchor_ave = False,
margin = 1.0, batch_reduce = 'min', loss_type = "triplet_margin",k = 8):
"""HardNet margin loss - calculates loss based on distance matrix based on positive distance and closest negative distance.
"""

assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal."
assert anchor.dim() == 2, "Inputd must be a 2D matrix."
eps = 1e-8
dist_matrix = distance_matrix_vector(anchor, positive) +eps
eye = torch.autograd.Variable(torch.eye(dist_matrix.size(1))).cuda()

# steps to filter out same patches that occur in distance matrix as negatives
pos1 = torch.diag(dist_matrix)
dist_without_min_on_diag = dist_matrix+eye*10
mask = (dist_without_min_on_diag.ge(0.008).float()-1.0)*(-1)
mask = mask.type_as(dist_without_min_on_diag)*10
dist_without_min_on_diag = dist_without_min_on_diag+mask

if batch_reduce == 'min':
    min_neg = torch.min(dist_without_min_on_diag,1)[0]
    if anchor_swap:
        min_neg2 = torch.min(dist_without_min_on_diag,0)[0]
        min_neg = torch.min(min_neg,min_neg2)
    pos = pos1
else: 
    print ('Unknown batch reduce mode. Try min, average or random')
    sys.exit(1)

if loss_type == "triplet_margin":

    loss = torch.clamp(margin + pos - min_neg, min=0.0)**2

else: 
    print ('Unknown loss type. Try triplet_margin, softmax or contrastive')
    sys.exit(1)

dist_matrix_a = inner_dot_matrix(anchor, anchor)+ eps
dist_without_min_on_diag_a = dist_matrix_a+eye*10
mask_a = (dist_without_min_on_diag_a.ge(0.008).float()-1.0)*(-1)
mask_a = mask_a.type_as(dist_without_min_on_diag_a)*10
dist_without_min_on_diag_a = dist_without_min_on_diag_a+mask_a

cur_a = torch.topk(dist_without_min_on_diag_a, k=k, dim=1, largest=False).indices  
#print(cur_a)      
mask_aa = torch.zeros_like(dist_without_min_on_diag_a)
#print(mask_aa)
mask_aa = mask_aa.scatter(1,cur_a,1)

dist_matrix_p = inner_dot_matrix(positive, positive)+ eps
dist_without_min_on_diag_p = dist_matrix_p+eye*10
mask_p = (dist_without_min_on_diag_p.ge(0.008).float()-1.0)*(-1)
mask_p = mask_p.type_as(dist_without_min_on_diag_p)*10
dist_without_min_on_diag_p = dist_without_min_on_diag_p+mask_p

cur_p = torch.topk(dist_without_min_on_diag_p, k=k, dim=1, largest=False).indices        
mask_pp = torch.zeros_like(dist_without_min_on_diag_p)
mask_pp = mask_pp.scatter(1,cur_p,1)

sosr = torch.sqrt( 1e-8 + torch.sum( (dist_without_min_on_diag_a * mask_aa - dist_without_min_on_diag_p * mask_pp)**2, dim = 0 ))
sosr = torch.mean(sosr)

loss = torch.mean(loss)
return loss + sosr , loss, sosr`
@Ahandsomenaive
Copy link
Author

loss is not converge! FPR95 is very high!

@Ahandsomenaive
Copy link
Author

now, i find the coding bug~
But, the best score is 1.15, and have not got 1.03 described in your paper!
would you like to give me some advice about training setting?

1 similar comment
@Ahandsomenaive
Copy link
Author

now, i find the coding bug~
But, the best score is 1.15, and have not got 1.03 described in your paper!
would you like to give me some advice about training setting?

@ACuOoOoO
Copy link

@Ahandsomenaive, hi. I have embeded the SOSR into HardNet (https://github.com/DagnyT/hardnet). And changed the optimizer as the paper advised and the model published here. In this way, the best score I reproduced reaches ~1.10 but it is still far to catch 1.03. However, it's interestingly to find that my reproduction trained on Liberty outperform the paper declared in Hpatch by a little.

@Ahandsomenaive
Copy link
Author

@Ahandsomenaive, hi. I have embeded the SOSR into HardNet (https://github.com/DagnyT/hardnet). And changed the optimizer as the paper advised and the model published here. In this way, the best score I reproduced reaches ~1.10 but it is still far to catch 1.03. However, it's interestingly to find that my reproduction trained on Liberty outperform the paper declared in Hpatch by a little.

Hi, Thank you~
a little question, i can't find your code according to the link you provided.
Would you like to provide detailed information about your code?
Thank you~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants