Skip to content

Commit

Permalink
[ODLA/TRT] Fix bug and reuse device memory
Browse files Browse the repository at this point in the history
The scratchpad for int64->int32 is fixed.

For static staic and batch, we can reuse device memory.
  • Loading branch information
Weiming Zhao authored and weimingzha0 committed Oct 1, 2021
1 parent 54e80f3 commit a0a1213
Showing 1 changed file with 49 additions and 31 deletions.
80 changes: 49 additions & 31 deletions ODLA/platforms/tensorrt/odla_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,15 @@ struct _odla_context {
#endif

typedef struct {
void* host_ptr;
void* dev_ptr;
size_t len;
void* host_ptr = nullptr;
void* dev_ptr = nullptr;
size_t len = 0;
odla_value_type vt;
} OutputPtrInfo;

typedef struct {
const void* host_ptr;
void* dev_ptr;
const void* host_ptr = nullptr;
void* dev_ptr = nullptr;
} InputPtrInfo;
std::unordered_map<std::string, OutputPtrInfo> output_ptrs;
std::unordered_map<std::string, InputPtrInfo> input_ptrs;
Expand Down Expand Up @@ -387,7 +387,7 @@ static nvinfer1::Dims SqueezeNVDims(const nvinfer1::Dims dims, int index) {

thread_local odla_computation g_comp;
static std::vector<std::unique_ptr<_odla_computation>> g_comps;
static std::vector<int> g_workspace;
static std::vector<std::unique_ptr<int[]>> g_workspace;

static nvinfer1::DataType GetNVDataType(odla_element_type type) {
switch (type) {
Expand Down Expand Up @@ -433,19 +433,20 @@ static odla_value_type ValidateValueType(const odla_value_type& type) {
return type;
}

static void* ValidateValuePtr(const odla_value_type& type, void* ptr) {
static std::unique_ptr<int[]> ConvertData(const odla_value_type& type,
const void* ptr) {
if (type.element_type == ODLA_INT64) {
int64_t* src = static_cast<int64_t*>(ptr);
const int64_t* src = static_cast<const int64_t*>(ptr);
auto num_elements = GetTotalElements(type.shape);
auto workspace_size = g_workspace.size();
assert(workspace_size + num_elements < MAX_INT64_CONVERTION_NUM);
int* tmp = g_workspace.data() + workspace_size;
auto buf = std::make_unique<int[]>(num_elements);
int* tmp = buf.get();
for (int i = 0; i < num_elements; ++i) {
g_workspace.push_back(static_cast<int>(*src++));
assert(*src < MAX_INT64_CONVERTION_NUM);
tmp[i] = (static_cast<int>(*src++));
}
return tmp;
return buf;
}
return ptr;
return nullptr;
}

template <typename T>
Expand Down Expand Up @@ -496,7 +497,6 @@ odla_status odla_CreateComputation(odla_computation* computation) {
g_comps.push_back(std::make_unique<_odla_computation>());
g_comp = g_comps.back().get();
*computation = g_comp;
g_workspace.reserve(MAX_INT64_CONVERTION_NUM);
return ODLA_SUCCESS;
}

Expand Down Expand Up @@ -622,10 +622,15 @@ odla_status odla_GetArgFromComputationByIdx(const odla_computation computation,

odla_value odla_CreateConstant(odla_value_type type, const void* ptr,
const odla_value_id id) {
nvinfer1::Weights weight{
.type = GetNVDataType(type.element_type),
.values = ValidateValuePtr(type, const_cast<void*>(ptr)),
.count = GetTotalElements(type.shape)};
void* host_ptr = const_cast<void*>(ptr);
auto buf = ConvertData(type, ptr);
if (buf != nullptr) {
host_ptr = buf.get();
g_workspace.push_back(std::move(buf));
}
nvinfer1::Weights weight{.type = GetNVDataType(type.element_type),
.values = host_ptr,
.count = GetTotalElements(type.shape)};
auto c = g_comp->network->addConstant(GetNVDims(type.shape), weight);
odla_value v = CreateValue(c->getOutput(0), ValidateValueType(type), id);
v->const_layer = c;
Expand Down Expand Up @@ -661,20 +666,27 @@ odla_status odla_GetOutputFromComputationByIdx(

odla_status odla_BindToArgument(odla_value value, const odla_void* data_ptr,
odla_context context) {
void* dev_ptr = nullptr;
odla_value_shape real_shape = value->type.shape;
bool dynamic_input_size = false;
if ((g_comp && g_comp->is_dynamic_batch) || context->run_batch_size) {
real_shape.dims[0] = context->run_batch_size;
dynamic_input_size = true;
}
size_t bytes =
GetTotalElements(real_shape) * GetElementSize(value->type.element_type);
CHECK(cudaMalloc(&dev_ptr, bytes));
void* validated_data_ptr =
ValidateValuePtr(value->type, const_cast<void*>(data_ptr));
CHECK(cudaMemcpy(dev_ptr, validated_data_ptr, bytes, cudaMemcpyHostToDevice));

void* validated_data_ptr = const_cast<void*>(data_ptr);
auto buf = ConvertData(value->type, data_ptr);
if (buf != nullptr) {
validated_data_ptr = buf.get();
}
void* dev_ptr = context->input_ptrs[value->name].dev_ptr;
if (dev_ptr == nullptr) {
CHECK(cudaMalloc(&dev_ptr, bytes));
}
context->input_ptrs[value->name] = {.host_ptr = data_ptr, .dev_ptr = dev_ptr};

CHECK(cudaMemcpy(dev_ptr, validated_data_ptr, bytes, cudaMemcpyHostToDevice));

return ODLA_SUCCESS;
}

Expand All @@ -688,15 +700,17 @@ odla_status odla_BindToArgumentById(const odla_value_id value_id,

odla_status odla_BindToOutput(odla_value value, odla_void* data_ptr,
odla_context context) {
void* dst = nullptr;
odla_value_shape real_shape = value->type.shape;
if ((g_comp && g_comp->is_dynamic_batch) || context->run_batch_size) {
real_shape.dims[0] = context->run_batch_size;
}
size_t bytes =
GetTotalElements(real_shape) * GetElementSize(value->type.element_type);

CHECK(cudaMalloc(&dst, bytes));
// TODO: convert to int64 for int64 outputs?
void* dst = context->output_ptrs[value->name].dev_ptr;
if (dst == nullptr) {
CHECK(cudaMalloc(&dst, bytes));
}

context->output_ptrs[value->name] = {
.host_ptr = data_ptr, .dev_ptr = dst, .len = bytes, .vt = value->type};
Expand Down Expand Up @@ -959,7 +973,9 @@ odla_status odla_ExecuteComputation(odla_computation comp, odla_context context,
cudaMemcpyDeviceToHost));
}
}

if (!comp->is_dynamic_batch) {
return ODLA_SUCCESS;
}
// copy results and free temp buffers.
for (auto& ptr : buffers) {
CHECK(cudaFree(ptr));
Expand Down Expand Up @@ -1948,7 +1964,8 @@ odla_values odla_LSTM(odla_value input, odla_rnn_weight_format weight_format,
return g_comp->network->addConstant(GetNVDims(dim), weight)->getOutput(0);
};
nvinfer1::ITensor* init_hidden_t = getInitTensor(initial_h);
// LOG_VERBOSE("init_hidden dim:" + gen_str(init_hidden_t->getDimensions()));
// LOG_VERBOSE("init_hidden dim:" +
// gen_str(init_hidden_t->getDimensions()));
nvinfer1::ITensor* init_cell_t = getInitTensor(initial_c);
rnn_layer->setHiddenState(*init_hidden_t);
rnn_layer->setCellState(*init_cell_t);
Expand Down Expand Up @@ -2090,7 +2107,8 @@ odla_values odla_LSTM(odla_value input, odla_rnn_weight_format weight_format,
return g_comp->network->addConstant(GetNVDims(dim), weight)->getOutput(0);
};
nvinfer1::ITensor* init_hidden_t = getInitTensor(initial_h);
// LOG_VERBOSE("init_hidden dim:" + gen_str(init_hidden_t->getDimensions()));
// LOG_VERBOSE("init_hidden dim:" +
// gen_str(init_hidden_t->getDimensions()));
nvinfer1::ITensor* init_cell_t = getInitTensor(initial_c);
// LOG("init_cell dim:" << init_cell_t->getDimensions());

Expand Down

0 comments on commit a0a1213

Please sign in to comment.