Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correlation - Python Implementation #121

Open
ndimitriou opened this issue Jun 22, 2021 · 2 comments
Open

Correlation - Python Implementation #121

ndimitriou opened this issue Jun 22, 2021 · 2 comments

Comments

@ndimitriou
Copy link

ndimitriou commented Jun 22, 2021

Why is there a external dependency for computing the cross-correlation between feature pyramids? If we assume that f1 and f2 are the features for image1 and image2, I believe more or less a code like below would do the job and is much simpler,

cost_vol_lev = torch.empty((B, 81, H, W), device=self.device) # the cost volume for a single level
k = 0
for i in range(-4, 5): #assuming a 9x9 window
for j in range(-4, 5):
f2_rolled = torch.roll(f2, shifts=(i, j), dims=(2, 3)) # shifting the second tensor
product = f1 * f2_rolled
f1_norm = torch.sqrt(torch.sum(f1 ** 2, 1) + 1e-10) # adding small constant to avoid division by zero
f2_rolled_norm = torch.sqrt(torch.sum(f2_rolled ** 2, 1) + 1e-10)
corr = torch.mean(product, 1)
norm_fac = f1_norm * f2_rolled_norm
corr = corr / norm_fac # normalizing
cost_vol_lev[:, k, :, :] = corr
k = k + 1

Am I missing something (on the backpropagation step perhaps)?

@xwjabc
Copy link

xwjabc commented Aug 11, 2021

I think with the current PyTorch library, it would be even simpler to use torch.nn.functional.unfold to implement the correlation function.

@liyuxuan89
Copy link

def corr(f1, f2, md=4):
b, c, h, w = f1.shape
# 1.normalize feature
f1 = f1 / torch.norm(f1, dim=1, keepdim=True)
f2 = f2 / torch.norm(f2, dim=1, keepdim=True)
# 2.compute correlation matrix
f1 = F.unfold(f1, kernel_size=(md2+1, md2+1), padding=(md, md), stride=(1, 1))
f1 = f1.view([b, c, -1, h, w])
f2 = f2.view([b, c, 1, h, w])
w = torch.sum(f1 * f2, dim=1)
return w

is it possible to implement this operation like this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants