From e1c1cc6c97bd050d47318527c9bbc481ce3990e5 Mon Sep 17 00:00:00 2001 From: SimonsXuan <1070321083@qq.com> Date: Sat, 31 Jul 2021 14:38:20 +0800 Subject: [PATCH 1/2] add gpu re-ranking --- GETTING_STARTED.md | 13 +++++ fastreid/config/defaults.py | 6 +++ .../build_adjacency_matrix.cpp | 19 +++++++ .../build_adjacency_matrix_kernel.cu | 31 ++++++++++++ .../extension/adjacency_matrix/setup.py | 34 +++++++++++++ fastreid/evaluation/extension/make.sh | 4 ++ .../extension/propagation/gnn_propagate.cpp | 21 ++++++++ .../propagation/gnn_propagate_kernel.cu | 36 ++++++++++++++ .../evaluation/extension/propagation/setup.py | 34 +++++++++++++ fastreid/evaluation/gpu_rerank.py | 49 +++++++++++++++++++ fastreid/evaluation/reid_evaluation.py | 17 +++++++ 11 files changed, 264 insertions(+) create mode 100644 fastreid/evaluation/extension/adjacency_matrix/build_adjacency_matrix.cpp create mode 100644 fastreid/evaluation/extension/adjacency_matrix/build_adjacency_matrix_kernel.cu create mode 100644 fastreid/evaluation/extension/adjacency_matrix/setup.py create mode 100644 fastreid/evaluation/extension/make.sh create mode 100644 fastreid/evaluation/extension/propagation/gnn_propagate.cpp create mode 100644 fastreid/evaluation/extension/propagation/gnn_propagate_kernel.cu create mode 100644 fastreid/evaluation/extension/propagation/setup.py create mode 100644 fastreid/evaluation/gpu_rerank.py diff --git a/GETTING_STARTED.md b/GETTING_STARTED.md index 172006a0..f6644726 100644 --- a/GETTING_STARTED.md +++ b/GETTING_STARTED.md @@ -12,6 +12,11 @@ If you want to use other pre-train models, such as MoCo pre-train, you can downl ```bash cd fastreid/evaluation/rank_cylib; make all ``` +## Compile with gpu_reranking to achieve real-time re-ranking + +```bash +cd fastreid/evaluation/extension; sh make.sh +``` ## Training & Evaluation in Command Line @@ -59,4 +64,12 @@ python3 tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0" ``` + +To evaluate a model's performance, use + +```bash +python3 tools/train_net.py --config-file ./configs/Market1501/bagtricks_R50.yml --eval-only \ +TEST.GPU_RERANK.ENABLED True MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0" +``` + For more options, see `python3 tools/train_net.py -h`. diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py index 2eefca22..01f0f9ea 100644 --- a/fastreid/config/defaults.py +++ b/fastreid/config/defaults.py @@ -312,6 +312,12 @@ _C.TEST.RERANK.K2 = 6 _C.TEST.RERANK.LAMBDA = 0.3 +# Real-time Re-rank +_C.TEST.GPU_RERANK = CN({"ENABLED": False}) +_C.TEST.GPU_RERANK.K1 = 26 +_C.TEST.GPU_RERANK.K2 = 7 +_C.TEST.GPU_RERANK.LAMBDA = 0.1 + # Precise batchnorm _C.TEST.PRECISE_BN = CN({"ENABLED": False}) _C.TEST.PRECISE_BN.DATASET = 'Market1501' diff --git a/fastreid/evaluation/extension/adjacency_matrix/build_adjacency_matrix.cpp b/fastreid/evaluation/extension/adjacency_matrix/build_adjacency_matrix.cpp new file mode 100644 index 00000000..4c496041 --- /dev/null +++ b/fastreid/evaluation/extension/adjacency_matrix/build_adjacency_matrix.cpp @@ -0,0 +1,19 @@ +#include +#include +#include + +at::Tensor build_adjacency_matrix_forward(torch::Tensor initial_rank); + + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +at::Tensor build_adjacency_matrix(at::Tensor initial_rank) { + CHECK_INPUT(initial_rank); + return build_adjacency_matrix_forward(initial_rank); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &build_adjacency_matrix, "build_adjacency_matrix (CUDA)"); +} diff --git a/fastreid/evaluation/extension/adjacency_matrix/build_adjacency_matrix_kernel.cu b/fastreid/evaluation/extension/adjacency_matrix/build_adjacency_matrix_kernel.cu new file mode 100644 index 00000000..4973ddef --- /dev/null +++ b/fastreid/evaluation/extension/adjacency_matrix/build_adjacency_matrix_kernel.cu @@ -0,0 +1,31 @@ +#include + +#include +#include +#include + +#define CUDA_1D_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) + + +__global__ void build_adjacency_matrix_kernel(float* initial_rank, float* A, const int total_num, const int topk, const int nthreads, const int all_num) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < all_num; i += stride) { + int ii = i / topk; + A[ii * total_num + int(initial_rank[i])] = float(1.0); + } +} + +at::Tensor build_adjacency_matrix_forward(at::Tensor initial_rank) { + const auto total_num = initial_rank.size(0); + const auto topk = initial_rank.size(1); + const auto all_num = total_num * topk; + auto A = torch::zeros({total_num, total_num}, at::device(initial_rank.device()).dtype(at::ScalarType::Float)); + + const int threads = 1024; + const int blocks = (all_num + threads - 1) / threads; + + build_adjacency_matrix_kernel<<>>(initial_rank.data_ptr(), A.data_ptr(), total_num, topk, threads, all_num); + return A; + +} diff --git a/fastreid/evaluation/extension/adjacency_matrix/setup.py b/fastreid/evaluation/extension/adjacency_matrix/setup.py new file mode 100644 index 00000000..e66c9773 --- /dev/null +++ b/fastreid/evaluation/extension/adjacency_matrix/setup.py @@ -0,0 +1,34 @@ +""" + Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective + + Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang + + Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking + + Paper: https://arxiv.org/abs/2012.07620v2 + + ====================================================================== + + On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms + with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe + that our method achieves comparable or even better retrieval results on the other four + image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, + with limited time cost. +""" + +from setuptools import setup + +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name='build_adjacency_matrix', + ext_modules=[ + CUDAExtension('build_adjacency_matrix', [ + 'build_adjacency_matrix.cpp', + 'build_adjacency_matrix_kernel.cu', + ]), + ], + cmdclass={ + 'build_ext':BuildExtension + }) diff --git a/fastreid/evaluation/extension/make.sh b/fastreid/evaluation/extension/make.sh new file mode 100644 index 00000000..f0197ff9 --- /dev/null +++ b/fastreid/evaluation/extension/make.sh @@ -0,0 +1,4 @@ +cd adjacency_matrix +python setup.py install +cd ../propagation +python setup.py install \ No newline at end of file diff --git a/fastreid/evaluation/extension/propagation/gnn_propagate.cpp b/fastreid/evaluation/extension/propagation/gnn_propagate.cpp new file mode 100644 index 00000000..10a939ff --- /dev/null +++ b/fastreid/evaluation/extension/propagation/gnn_propagate.cpp @@ -0,0 +1,21 @@ +#include +#include +#include + +at::Tensor gnn_propagate_forward(at::Tensor A, at::Tensor initial_rank, at::Tensor S); + + +#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +at::Tensor gnn_propagate(at::Tensor A ,at::Tensor initial_rank, at::Tensor S) { + CHECK_INPUT(A); + CHECK_INPUT(initial_rank); + CHECK_INPUT(S); + return gnn_propagate_forward(A, initial_rank, S); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &gnn_propagate, "gnn propagate (CUDA)"); +} \ No newline at end of file diff --git a/fastreid/evaluation/extension/propagation/gnn_propagate_kernel.cu b/fastreid/evaluation/extension/propagation/gnn_propagate_kernel.cu new file mode 100644 index 00000000..8bdebf16 --- /dev/null +++ b/fastreid/evaluation/extension/propagation/gnn_propagate_kernel.cu @@ -0,0 +1,36 @@ +#include + +#include +#include +#include +#include + +__global__ void gnn_propagate_forward_kernel(float* initial_rank, float* A, float* A_qe, float* S, const int sample_num, const int topk, const int total_num) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < total_num; i += stride) { + int fea = i % sample_num; + int sample_index = i / sample_num; + float sum = 0.0; + for (int j = 0; j < topk ; j++) { + int topk_fea_index = int(initial_rank[sample_index*topk+j]) * sample_num + fea; + sum += A[ topk_fea_index] * S[sample_index*topk+j]; + } + A_qe[i] = sum; + } +} + +at::Tensor gnn_propagate_forward(at::Tensor A, at::Tensor initial_rank, at::Tensor S) { + const auto sample_num = A.size(0); + const auto topk = initial_rank.size(1); + + const auto total_num = sample_num * sample_num ; + auto A_qe = torch::zeros({sample_num, sample_num}, at::device(initial_rank.device()).dtype(at::ScalarType::Float)); + + const int threads = 1024; + const int blocks = (total_num + threads - 1) / threads; + + gnn_propagate_forward_kernel<<>>(initial_rank.data_ptr(), A.data_ptr(), A_qe.data_ptr(), S.data_ptr(), sample_num, topk, total_num); + return A_qe; + +} \ No newline at end of file diff --git a/fastreid/evaluation/extension/propagation/setup.py b/fastreid/evaluation/extension/propagation/setup.py new file mode 100644 index 00000000..af595ec6 --- /dev/null +++ b/fastreid/evaluation/extension/propagation/setup.py @@ -0,0 +1,34 @@ +""" + Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective + + Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang + + Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking + + Paper: https://arxiv.org/abs/2012.07620v2 + + ====================================================================== + + On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms + with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe + that our method achieves comparable or even better retrieval results on the other four + image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, + with limited time cost. +""" + +from setuptools import setup + +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + +setup( + name='gnn_propagate', + ext_modules=[ + CUDAExtension('gnn_propagate', [ + 'gnn_propagate.cpp', + 'gnn_propagate_kernel.cu', + ]), + ], + cmdclass={ + 'build_ext':BuildExtension + }) diff --git a/fastreid/evaluation/gpu_rerank.py b/fastreid/evaluation/gpu_rerank.py new file mode 100644 index 00000000..140305c4 --- /dev/null +++ b/fastreid/evaluation/gpu_rerank.py @@ -0,0 +1,49 @@ +""" + Understanding Image Retrieval Re-Ranking: A Graph Neural Network Perspective + + Xuanmeng Zhang, Minyue Jiang, Zhedong Zheng, Xiao Tan, Errui Ding, Yi Yang + + Project Page : https://github.com/Xuanmeng-Zhang/gnn-re-ranking + + Paper: https://arxiv.org/abs/2012.07620v2 + + ====================================================================== + + On the Market-1501 dataset, we accelerate the re-ranking processing from 89.2s to 9.4ms + with one K40m GPU, facilitating the real-time post-processing. Similarly, we observe + that our method achieves comparable or even better retrieval results on the other four + image retrieval benchmarks, i.e., VeRi-776, Oxford-5k, Paris-6k and University-1652, + with limited time cost. +""" + +import torch + +import build_adjacency_matrix +import gnn_propagate + + +def gpu_reranking(X_q, X_g, k1, k2): + query_num, gallery_num = X_q.shape[0], X_g.shape[0] + + X_u = torch.cat((X_q, X_g), axis = 0) + original_score = torch.mm(X_u, X_u.t()) + del X_u, X_q, X_g + + # initial ranking list + S, initial_rank = original_score.topk(k=k1, dim=-1, largest=True, sorted=True) + + # stage 1 + A = build_adjacency_matrix.forward(initial_rank.float()) + S = S * S + + # stage 2 + if k2 != 1: + for i in range(2): + A = A + A.T + A = gnn_propagate.forward(A, initial_rank[:, :k2].contiguous().float(), S[:, :k2].contiguous().float()) + A_norm = torch.norm(A, p=2, dim=1, keepdim=True) + A = A.div(A_norm.expand_as(A)) + + query_features, gallery_features = A[:query_num,], A[query_num:, ] + return query_features, gallery_features + diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py index 921b39bb..ee61eb0e 100644 --- a/fastreid/evaluation/reid_evaluation.py +++ b/fastreid/evaluation/reid_evaluation.py @@ -19,6 +19,7 @@ from .query_expansion import aqe from .rank import evaluate_rank from .roc import evaluate_roc +from .gpu_rerank import gpu_reranking logger = logging.getLogger(__name__) @@ -102,6 +103,22 @@ def evaluate(self): rerank_dist = build_dist(query_features, gallery_features, metric="jaccard", k1=k1, k2=k2) dist = rerank_dist * (1 - lambda_value) + dist * lambda_value + if self.cfg.TEST.GPU_RERANK.ENABLED: + logger.info("Test with gpu real-time rerank setting") + k1 = self.cfg.TEST.RERANK.K1 + k2 = self.cfg.TEST.RERANK.K2 + lambda_value = self.cfg.TEST.RERANK.LAMBDA + + if self.cfg.TEST.METRIC == "cosine": + query_features = F.normalize(query_features, dim=1) + gallery_features = F.normalize(gallery_features, dim=1) + + query_features = query_features.cuda() + gallery_features = gallery_features.cuda() + query_features, gallery_features = gpu_reranking(query_features, gallery_features, k1, k2) + rerank_dist = build_dist(query_features, gallery_features, self.cfg.TEST.METRIC) + dist = dist * (1 - lambda_value) + rerank_dist * lambda_value + cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids) mAP = np.mean(all_AP) From 2fdddd96408676b6972866803514a0111ca8e349 Mon Sep 17 00:00:00 2001 From: SimonsXuan <1070321083@qq.com> Date: Sat, 31 Jul 2021 15:10:54 +0800 Subject: [PATCH 2/2] add gpu re-ranking --- fastreid/evaluation/reid_evaluation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py index ee61eb0e..67006ecd 100644 --- a/fastreid/evaluation/reid_evaluation.py +++ b/fastreid/evaluation/reid_evaluation.py @@ -105,9 +105,9 @@ def evaluate(self): if self.cfg.TEST.GPU_RERANK.ENABLED: logger.info("Test with gpu real-time rerank setting") - k1 = self.cfg.TEST.RERANK.K1 - k2 = self.cfg.TEST.RERANK.K2 - lambda_value = self.cfg.TEST.RERANK.LAMBDA + k1 = self.cfg.TEST.GPU_RERANK.K1 + k2 = self.cfg.TEST.GPU_RERANK.K2 + lambda_value = self.cfg.TEST.GPU_RERANK.LAMBDA if self.cfg.TEST.METRIC == "cosine": query_features = F.normalize(query_features, dim=1)