Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Nov 4, 2023
1 parent 4e40cd3 commit ecc2566
Showing 1 changed file with 0 additions and 40 deletions.
40 changes: 0 additions & 40 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,46 +365,6 @@ std::unique_lock<OrtMutex> TensorrtExecutionProvider::GetApiLock() const {
return std::unique_lock<OrtMutex>(singleton);
}

Status GetShapeOfShapeTensor(Ort::ConstValue& input_tensor,
std::vector<int32_t>& shape_values,
nvinfer1::ICudaEngine* trt_engine,
int binding_index,
cudaStream_t stream) {
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
const auto tensor_shapes = tensor_info.GetShape();
const auto tensor_type = tensor_info.GetElementType();
nvinfer1::Dims dims = trt_engine->getBindingDimensions(static_cast<int>(binding_index));
int nb_dims = dims.nbDims;
int shape_size = nb_dims == 0 ? 1 : static_cast<int>(tensor_shapes[0]); // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension
shape_values.resize(shape_size, 1);

switch (tensor_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
auto input = std::make_unique<int32_t[]>(shape_size);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData<int32_t>(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
for (int j = 0; j < shape_size; ++j) {
shape_values[j] = input[j];
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
auto input = std::make_unique<int64_t[]>(shape_size);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData<int64_t>(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
for (int j = 0; j < shape_size; ++j) {
shape_values[j] = static_cast<int32_t>(input[j]);
}
break;
}
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported.");
}
}
return Status::OK();
}

/*
* Get the shape of "shape tensor" input
*/
Expand Down

0 comments on commit ecc2566

Please sign in to comment.