Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the code for GPU-Reranking #542

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ If you want to use other pre-train models, such as MoCo pre-train, you can downl
```bash
cd fastreid/evaluation/rank_cylib; make all
```
## Compile with gpu_reranking to achieve real-time re-ranking

```bash
cd fastreid/evaluation/extension; sh make.sh
```

## Training & Evaluation in Command Line

Expand Down Expand Up @@ -59,4 +64,12 @@ python3 tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml
MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0"
```


To evaluate a model's performance, use

```bash
python3 tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --eval-only \
TEST.GPU_RERANK.ENABLED True MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0"
```

For more options, see `python3 tools/train_net.py -h`.
6 changes: 6 additions & 0 deletions fastreid/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@
_C.TEST.RERANK.K2 = 6
_C.TEST.RERANK.LAMBDA = 0.3

# Real-time Re-rank
_C.TEST.GPU_RERANK = CN({"ENABLED": False})
_C.TEST.GPU_RERANK.K1 = 26
_C.TEST.GPU_RERANK.K2 = 7
_C.TEST.GPU_RERANK.LAMBDA = 0.1

# Precise batchnorm
_C.TEST.PRECISE_BN = CN({"ENABLED": False})
_C.TEST.PRECISE_BN.DATASET = 'Market1501'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <torch/extension.h>
#include <iostream>
#include <set>

at::Tensor build_adjacency_matrix_forward(torch::Tensor initial_rank);


#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

at::Tensor build_adjacency_matrix(at::Tensor initial_rank) {
CHECK_INPUT(initial_rank);
return build_adjacency_matrix_forward(initial_rank);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &build_adjacency_matrix, "build_adjacency_matrix (CUDA)");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>

#define CUDA_1D_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x)


__global__ void build_adjacency_matrix_kernel(float* initial_rank, float* A, const int total_num, const int topk, const int nthreads, const int all_num) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < all_num; i += stride) {
int ii = i / topk;
A[ii * total_num + int(initial_rank[i])] = float(1.0);
}
}

at::Tensor build_adjacency_matrix_forward(at::Tensor initial_rank) {
const auto total_num = initial_rank.size(0);
const auto topk = initial_rank.size(1);
const auto all_num = total_num * topk;
auto A = torch::zeros({total_num, total_num}, at::device(initial_rank.device()).dtype(at::ScalarType::Float));

const int threads = 1024;
const int blocks = (all_num + threads - 1) / threads;

build_adjacency_matrix_kernel<<<blocks, threads>>>(initial_rank.data_ptr<float>(), A.data_ptr<float>(), total_num, topk, threads, all_num);
return A;

}
34 changes: 34 additions & 0 deletions fastreid/evaluation/extension/adjacency_matrix/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective

Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang

Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking

Paper: https://arxiv.org/abs/2012.07620v2

======================================================================

On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""

from setuptools import setup

from torch.utils.cpp_extension import BuildExtension, CUDAExtension


setup(
name='build_adjacency_matrix',
ext_modules=[
CUDAExtension('build_adjacency_matrix', [
'build_adjacency_matrix.cpp',
'build_adjacency_matrix_kernel.cu',
]),
],
cmdclass={
'build_ext':BuildExtension
})
4 changes: 4 additions & 0 deletions fastreid/evaluation/extension/make.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
cd adjacency_matrix
python setup.py install
cd ../propagation
python setup.py install
21 changes: 21 additions & 0 deletions fastreid/evaluation/extension/propagation/gnn_propagate.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include <torch/extension.h>
#include <iostream>
#include <set>

at::Tensor gnn_propagate_forward(at::Tensor A, at::Tensor initial_rank, at::Tensor S);


#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

at::Tensor gnn_propagate(at::Tensor A ,at::Tensor initial_rank, at::Tensor S) {
CHECK_INPUT(A);
CHECK_INPUT(initial_rank);
CHECK_INPUT(S);
return gnn_propagate_forward(A, initial_rank, S);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &gnn_propagate, "gnn propagate (CUDA)");
}
36 changes: 36 additions & 0 deletions fastreid/evaluation/extension/propagation/gnn_propagate_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>
#include <iostream>

__global__ void gnn_propagate_forward_kernel(float* initial_rank, float* A, float* A_qe, float* S, const int sample_num, const int topk, const int total_num) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < total_num; i += stride) {
int fea = i % sample_num;
int sample_index = i / sample_num;
float sum = 0.0;
for (int j = 0; j < topk ; j++) {
int topk_fea_index = int(initial_rank[sample_index*topk+j]) * sample_num + fea;
sum += A[ topk_fea_index] * S[sample_index*topk+j];
}
A_qe[i] = sum;
}
}

at::Tensor gnn_propagate_forward(at::Tensor A, at::Tensor initial_rank, at::Tensor S) {
const auto sample_num = A.size(0);
const auto topk = initial_rank.size(1);

const auto total_num = sample_num * sample_num ;
auto A_qe = torch::zeros({sample_num, sample_num}, at::device(initial_rank.device()).dtype(at::ScalarType::Float));

const int threads = 1024;
const int blocks = (total_num + threads - 1) / threads;

gnn_propagate_forward_kernel<<<blocks, threads>>>(initial_rank.data_ptr<float>(), A.data_ptr<float>(), A_qe.data_ptr<float>(), S.data_ptr<float>(), sample_num, topk, total_num);
return A_qe;

}
34 changes: 34 additions & 0 deletions fastreid/evaluation/extension/propagation/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective

Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang

Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking

Paper: https://arxiv.org/abs/2012.07620v2

======================================================================

On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""

from setuptools import setup

from torch.utils.cpp_extension import BuildExtension, CUDAExtension


setup(
name='gnn_propagate',
ext_modules=[
CUDAExtension('gnn_propagate', [
'gnn_propagate.cpp',
'gnn_propagate_kernel.cu',
]),
],
cmdclass={
'build_ext':BuildExtension
})
49 changes: 49 additions & 0 deletions fastreid/evaluation/gpu_rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective

Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang

Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking

Paper: https://arxiv.org/abs/2012.07620v2

======================================================================

On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms
with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe
that our method achieves comparable or even better retrieval results on the other four
image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652,
with limited time cost.
"""

import torch

import build_adjacency_matrix
import gnn_propagate


def gpu_reranking(X_q, X_g, k1, k2):
query_num, gallery_num = X_q.shape[0], X_g.shape[0]

X_u = torch.cat((X_q, X_g), axis = 0)
original_score = torch.mm(X_u, X_u.t())
del X_u, X_q, X_g

# initial ranking list
S, initial_rank = original_score.topk(k=k1, dim=-1, largest=True, sorted=True)

# stage 1
A = build_adjacency_matrix.forward(initial_rank.float())
S = S * S

# stage 2
if k2 != 1:
for i in range(2):
A = A + A.T
A = gnn_propagate.forward(A, initial_rank[:, :k2].contiguous().float(), S[:, :k2].contiguous().float())
A_norm = torch.norm(A, p=2, dim=1, keepdim=True)
A = A.div(A_norm.expand_as(A))

query_features, gallery_features = A[:query_num,], A[query_num:, ]
return query_features, gallery_features

17 changes: 17 additions & 0 deletions fastreid/evaluation/reid_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .query_expansion import aqe
from .rank import evaluate_rank
from .roc import evaluate_roc
from .gpu_rerank import gpu_reranking

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,6 +103,22 @@ def evaluate(self):
rerank_dist = build_dist(query_features, gallery_features, metric="jaccard", k1=k1, k2=k2)
dist = rerank_dist * (1 - lambda_value) + dist * lambda_value

if self.cfg.TEST.GPU_RERANK.ENABLED:
logger.info("Test with gpu real-time rerank setting")
k1 = self.cfg.TEST.GPU_RERANK.K1
k2 = self.cfg.TEST.GPU_RERANK.K2
lambda_value = self.cfg.TEST.GPU_RERANK.LAMBDA

if self.cfg.TEST.METRIC == "cosine":
query_features = F.normalize(query_features, dim=1)
gallery_features = F.normalize(gallery_features, dim=1)

query_features = query_features.cuda()
gallery_features = gallery_features.cuda()
query_features, gallery_features = gpu_reranking(query_features, gallery_features, k1, k2)
rerank_dist = build_dist(query_features, gallery_features, self.cfg.TEST.METRIC)
dist = dist * (1 - lambda_value) + rerank_dist * lambda_value

cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)

mAP = np.mean(all_AP)
Expand Down