diff --git a/ODLA/platforms/tensorrt/odla_tensorrt.cc b/ODLA/platforms/tensorrt/odla_tensorrt.cc index 68585bfd6..79d049123 100644 --- a/ODLA/platforms/tensorrt/odla_tensorrt.cc +++ b/ODLA/platforms/tensorrt/odla_tensorrt.cc @@ -216,15 +216,15 @@ struct _odla_context { #endif typedef struct { - void* host_ptr; - void* dev_ptr; - size_t len; + void* host_ptr = nullptr; + void* dev_ptr = nullptr; + size_t len = 0; odla_value_type vt; } OutputPtrInfo; typedef struct { - const void* host_ptr; - void* dev_ptr; + const void* host_ptr = nullptr; + void* dev_ptr = nullptr; } InputPtrInfo; std::unordered_map output_ptrs; std::unordered_map input_ptrs; @@ -387,7 +387,7 @@ static nvinfer1::Dims SqueezeNVDims(const nvinfer1::Dims dims, int index) { thread_local odla_computation g_comp; static std::vector> g_comps; -static std::vector g_workspace; +static std::vector> g_workspace; static nvinfer1::DataType GetNVDataType(odla_element_type type) { switch (type) { @@ -433,19 +433,20 @@ static odla_value_type ValidateValueType(const odla_value_type& type) { return type; } -static void* ValidateValuePtr(const odla_value_type& type, void* ptr) { +static std::unique_ptr ConvertData(const odla_value_type& type, + const void* ptr) { if (type.element_type == ODLA_INT64) { - int64_t* src = static_cast(ptr); + const int64_t* src = static_cast(ptr); auto num_elements = GetTotalElements(type.shape); - auto workspace_size = g_workspace.size(); - assert(workspace_size + num_elements < MAX_INT64_CONVERTION_NUM); - int* tmp = g_workspace.data() + workspace_size; + auto buf = std::make_unique(num_elements); + int* tmp = buf.get(); for (int i = 0; i < num_elements; ++i) { - g_workspace.push_back(static_cast(*src++)); + assert(*src < MAX_INT64_CONVERTION_NUM); + tmp[i] = (static_cast(*src++)); } - return tmp; + return buf; } - return ptr; + return nullptr; } template @@ -496,7 +497,6 @@ odla_status odla_CreateComputation(odla_computation* computation) { g_comps.push_back(std::make_unique<_odla_computation>()); g_comp = g_comps.back().get(); *computation = g_comp; - g_workspace.reserve(MAX_INT64_CONVERTION_NUM); return ODLA_SUCCESS; } @@ -622,10 +622,15 @@ odla_status odla_GetArgFromComputationByIdx(const odla_computation computation, odla_value odla_CreateConstant(odla_value_type type, const void* ptr, const odla_value_id id) { - nvinfer1::Weights weight{ - .type = GetNVDataType(type.element_type), - .values = ValidateValuePtr(type, const_cast(ptr)), - .count = GetTotalElements(type.shape)}; + void* host_ptr = const_cast(ptr); + auto buf = ConvertData(type, ptr); + if (buf != nullptr) { + host_ptr = buf.get(); + g_workspace.push_back(std::move(buf)); + } + nvinfer1::Weights weight{.type = GetNVDataType(type.element_type), + .values = host_ptr, + .count = GetTotalElements(type.shape)}; auto c = g_comp->network->addConstant(GetNVDims(type.shape), weight); odla_value v = CreateValue(c->getOutput(0), ValidateValueType(type), id); v->const_layer = c; @@ -661,20 +666,27 @@ odla_status odla_GetOutputFromComputationByIdx( odla_status odla_BindToArgument(odla_value value, const odla_void* data_ptr, odla_context context) { - void* dev_ptr = nullptr; odla_value_shape real_shape = value->type.shape; + bool dynamic_input_size = false; if ((g_comp && g_comp->is_dynamic_batch) || context->run_batch_size) { real_shape.dims[0] = context->run_batch_size; + dynamic_input_size = true; } size_t bytes = GetTotalElements(real_shape) * GetElementSize(value->type.element_type); - CHECK(cudaMalloc(&dev_ptr, bytes)); - void* validated_data_ptr = - ValidateValuePtr(value->type, const_cast(data_ptr)); - CHECK(cudaMemcpy(dev_ptr, validated_data_ptr, bytes, cudaMemcpyHostToDevice)); - + void* validated_data_ptr = const_cast(data_ptr); + auto buf = ConvertData(value->type, data_ptr); + if (buf != nullptr) { + validated_data_ptr = buf.get(); + } + void* dev_ptr = context->input_ptrs[value->name].dev_ptr; + if (dev_ptr == nullptr) { + CHECK(cudaMalloc(&dev_ptr, bytes)); + } context->input_ptrs[value->name] = {.host_ptr = data_ptr, .dev_ptr = dev_ptr}; + CHECK(cudaMemcpy(dev_ptr, validated_data_ptr, bytes, cudaMemcpyHostToDevice)); + return ODLA_SUCCESS; } @@ -688,15 +700,17 @@ odla_status odla_BindToArgumentById(const odla_value_id value_id, odla_status odla_BindToOutput(odla_value value, odla_void* data_ptr, odla_context context) { - void* dst = nullptr; odla_value_shape real_shape = value->type.shape; if ((g_comp && g_comp->is_dynamic_batch) || context->run_batch_size) { real_shape.dims[0] = context->run_batch_size; } size_t bytes = GetTotalElements(real_shape) * GetElementSize(value->type.element_type); - - CHECK(cudaMalloc(&dst, bytes)); + // TODO: convert to int64 for int64 outputs? + void* dst = context->output_ptrs[value->name].dev_ptr; + if (dst == nullptr) { + CHECK(cudaMalloc(&dst, bytes)); + } context->output_ptrs[value->name] = { .host_ptr = data_ptr, .dev_ptr = dst, .len = bytes, .vt = value->type}; @@ -959,7 +973,9 @@ odla_status odla_ExecuteComputation(odla_computation comp, odla_context context, cudaMemcpyDeviceToHost)); } } - + if (!comp->is_dynamic_batch) { + return ODLA_SUCCESS; + } // copy results and free temp buffers. for (auto& ptr : buffers) { CHECK(cudaFree(ptr)); @@ -1948,7 +1964,8 @@ odla_values odla_LSTM(odla_value input, odla_rnn_weight_format weight_format, return g_comp->network->addConstant(GetNVDims(dim), weight)->getOutput(0); }; nvinfer1::ITensor* init_hidden_t = getInitTensor(initial_h); - // LOG_VERBOSE("init_hidden dim:" + gen_str(init_hidden_t->getDimensions())); + // LOG_VERBOSE("init_hidden dim:" + + // gen_str(init_hidden_t->getDimensions())); nvinfer1::ITensor* init_cell_t = getInitTensor(initial_c); rnn_layer->setHiddenState(*init_hidden_t); rnn_layer->setCellState(*init_cell_t); @@ -2090,7 +2107,8 @@ odla_values odla_LSTM(odla_value input, odla_rnn_weight_format weight_format, return g_comp->network->addConstant(GetNVDims(dim), weight)->getOutput(0); }; nvinfer1::ITensor* init_hidden_t = getInitTensor(initial_h); - // LOG_VERBOSE("init_hidden dim:" + gen_str(init_hidden_t->getDimensions())); + // LOG_VERBOSE("init_hidden dim:" + + // gen_str(init_hidden_t->getDimensions())); nvinfer1::ITensor* init_cell_t = getInitTensor(initial_c); // LOG("init_cell dim:" << init_cell_t->getDimensions());