Skip to content

Commit

Permalink
Fix Triton C API mode missing infer requested output datatype bug
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewkotila committed Aug 7, 2024
1 parent c7b1642 commit f787972
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 23 deletions.
6 changes: 3 additions & 3 deletions src/client_backend/triton_c_api/triton_c_api_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ TritonCApiInferRequestedOutput::Create(
const size_t class_count, const std::string& datatype)
{
TritonCApiInferRequestedOutput* local_infer_output =
new TritonCApiInferRequestedOutput(name);
new TritonCApiInferRequestedOutput(name, datatype);

tc::InferRequestedOutput* triton_infer_output;
RETURN_IF_TRITON_ERROR(tc::InferRequestedOutput::Create(
Expand All @@ -427,8 +427,8 @@ TritonCApiInferRequestedOutput::SetSharedMemory(
}

TritonCApiInferRequestedOutput::TritonCApiInferRequestedOutput(
const std::string& name)
: InferRequestedOutput(BackendKind::TRITON_C_API, name)
const std::string& name, const std::string& datatype)
: InferRequestedOutput(BackendKind::TRITON_C_API, name, datatype)
{
}

Expand Down
3 changes: 2 additions & 1 deletion src/client_backend/triton_c_api/triton_c_api_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ class TritonCApiInferRequestedOutput : public InferRequestedOutput {
const std::string& name, size_t byte_size, size_t offset = 0) override;

private:
explicit TritonCApiInferRequestedOutput(const std::string& name);
explicit TritonCApiInferRequestedOutput(
const std::string& name, const std::string& datatype);

std::unique_ptr<tc::InferRequestedOutput> output_;
};
Expand Down
29 changes: 10 additions & 19 deletions src/client_backend/triton_c_api/triton_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,25 +338,16 @@ TritonLoader::StartTriton()
// Create the allocator that will be used to allocate buffers for
// the result tensors.
RETURN_IF_TRITONSERVER_ERROR(
GetSingleton()->response_allocator_new_fn_(
&allocator_,
reinterpret_cast<
TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator* allocator,
const char* tensor_name, size_t byte_size,
TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id, void* userp,
void** buffer, void** buffer_userp,
TRITONSERVER_MemoryType*
actual_memory_type,
int64_t* actual_memory_type_id)>(
ResponseAlloc),
reinterpret_cast<
TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator* allocator,
void* buffer, void* buffer_userp,
size_t byte_size,
TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)>(ResponseRelease),
nullptr /* start_fn */),
GetSingleton()
->response_allocator_new_fn_(
&allocator_,
reinterpret_cast<
TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator * allocator, const char* tensor_name, size_t byte_size, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, void* userp, void** buffer, void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type, int64_t* actual_memory_type_id)>(
ResponseAlloc),
reinterpret_cast<
TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator * allocator, void* buffer, void* buffer_userp, size_t byte_size, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id)>(
ResponseRelease),
nullptr /* start_fn */),
"creating response allocator");

return Error::Success;
Expand Down

0 comments on commit f787972

Please sign in to comment.