forked from wang-xinyu/tensorrtx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalibrator.hpp
116 lines (104 loc) · 3.93 KB
/
calibrator.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#pragma once
#include "NvInfer.h"
#include <string>
#include <vector>
#include <iostream>
#include <iterator>
#include <fstream>
#include <algorithm>
#include "common.hpp"
//! \class Int8EntropyCalibrator2
//!
//! \brief Implements Entropy calibrator 2.
//! CalibrationAlgoType is kENTROPY_CALIBRATION_2.
//!
class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 {
public:
Int8EntropyCalibrator2(int batchsize, int input_w, int input_h,
const char* img_dir, const char* calib_table_name,
const char* input_blob_name, bool read_cache = true);
virtual ~Int8EntropyCalibrator2();
int getBatchSize() const override;
bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
const void* readCalibrationCache(size_t& length) override;
void writeCalibrationCache(const void* cache, size_t length) override;
private:
int batchsize_;
int input_w_;
int input_h_;
int img_idx_;
std::string img_dir_;
std::vector<std::string> img_files_;
size_t input_count_;
std::string calib_table_name_;
const char* input_blob_name_;
bool read_cache_;
void* device_input_;
std::vector<char> calib_cache_;
};
Int8EntropyCalibrator2::Int8EntropyCalibrator2(int batchsize,
int input_w, int input_h, const char* img_dir,
const char* calib_table_name, const char* input_blob_name,
bool read_cache)
: batchsize_(batchsize)
, input_w_(input_w)
, input_h_(input_h)
, img_idx_(0)
, img_dir_(img_dir)
, calib_table_name_(calib_table_name)
, input_blob_name_(input_blob_name)
, read_cache_(read_cache) {
input_count_ = 3 * input_w * input_h * batchsize;
CUDA_CHECK(cudaMalloc(&device_input_, input_count_ * sizeof(float)));
read_files_in_dir(img_dir, img_files_);
}
Int8EntropyCalibrator2::~Int8EntropyCalibrator2() {
CUDA_CHECK(cudaFree(device_input_));
}
int Int8EntropyCalibrator2::getBatchSize() const {
return batchsize_;
}
bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings) {
if (img_idx_ + batchsize_ > static_cast<int>(img_files_.size())) {
return false;
}
std::vector<float> input_imgs_(input_count_, 0);
for (int i = img_idx_; i < img_idx_ + batchsize_; i++) {
std::cout << img_files_[i] << " " << i << std::endl;
cv::Mat temp = cv::imread(img_dir_ + img_files_[i]);
if (temp.empty()) {
std::cerr << "Fatal error: image cannot open!" << std::endl;
return false;
}
preprocessImg(temp, input_w_, input_h_);
for (int c = 0; c < 3; c++) {
for (int h = 0; h < input_h_; h++) {
for (int w = 0; w < input_w_; w++) {
input_imgs_[(i-img_idx_)*input_w_*input_h_*3 +
c * input_h_ * input_w_ + h * input_w_ + w] = temp.at<cv::Vec3f>(h, w)[c];
}
}
}
}
img_idx_ += batchsize_;
CUDA_CHECK(cudaMemcpy(device_input_, input_imgs_.data(), input_count_ * sizeof(float), cudaMemcpyHostToDevice));
assert(!strcmp(names[0], input_blob_name_));
bindings[0] = device_input_;
return true;
}
const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) {
std::cout << "reading calib cache: " << calib_table_name_ << std::endl;
calib_cache_.clear();
std::ifstream input(calib_table_name_, std::ios::binary);
input >> std::noskipws;
if (read_cache_ && input.good()) {
std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(), std::back_inserter(calib_cache_));
}
length = calib_cache_.size();
return length ? calib_cache_.data() : nullptr;
}
void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length) {
std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl;
std::ofstream output(calib_table_name_, std::ios::binary);
output.write(reinterpret_cast<const char*>(cache), length);
}