Skip to content

Commit

Permalink
print tensor helper
Browse files Browse the repository at this point in the history
  • Loading branch information
jslhcl committed Feb 17, 2023
1 parent 85ac6f5 commit c414d7e
Showing 1 changed file with 143 additions and 3 deletions.
146 changes: 143 additions & 3 deletions onnxruntime/core/framework/sequential_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <thread>
#include <vector>
#include <sstream>
#include <fstream>
#include <iostream>
#include "core/common/common.h"
#include "core/common/logging/logging.h"
#include "core/framework/allocation_planner.h"
Expand All @@ -15,6 +17,7 @@
#include "core/framework/session_state.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/utils.h"
#include "core/framework/print_tensor_utils.h"

#if defined DEBUG_NODE_INPUTS_OUTPUTS
#include "core/framework/debug_node_inputs_outputs_utils.h"
Expand Down Expand Up @@ -184,9 +187,13 @@ class SessionScope {
VLOGS(logger, 1) << "Size of execution plan vector: " << exec_plan_vec.size();

// Enable TRACE_EXECUTION compile flag to dump execution plan
#if defined(TRACE_EXECUTION)
std::cout << std::make_pair(&seq_exec_plan, &session_state) << std::endl;
#endif
//#if defined(TRACE_EXECUTION)
std::ifstream fs("dump.txt");
if (fs.is_open()) {
std::cout << std::make_pair(&seq_exec_plan, &session_state) << std::endl;
fs.close();
}
//#endif
#if defined(ORT_MINIMAL_BUILD) || !defined(ORT_MEMORY_PROFILE)
ORT_UNUSED_PARAMETER(frame);
#endif
Expand Down Expand Up @@ -412,6 +419,113 @@ class KernelScope {
#endif
};

void PrintTensorHelper(const Tensor& t, StreamExecutionContext& ctx) {
MLDataType dt = t.DataType();
if (t.Location().device.Type() == OrtDevice::GPU) {
auto cpu_allocator = ctx.GetSessionState().GetExecutionProviders().Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault);
const auto& data_transfer_mgr = ctx.GetSessionState().GetDataTransferMgr();
Tensor cpu_tensor{dt, t.Shape(), cpu_allocator};
Status status = data_transfer_mgr.CopyTensor(t, cpu_tensor);
if (status != onnxruntime::Status::OK()) {
std::cout << "cannot copy tensor to cpu";
return;
}
switch (dt->AsPrimitiveDataType()->GetDataType()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
onnxruntime::utils::PrintCpuTensor<float>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
onnxruntime::utils::PrintCpuTensor<bool>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
onnxruntime::utils::PrintCpuTensor<double>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_STRING:
onnxruntime::utils::PrintCpuTensor<std::string>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
onnxruntime::utils::PrintCpuTensor<int8_t>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
onnxruntime::utils::PrintCpuTensor<uint8_t>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT16:
onnxruntime::utils::PrintCpuTensor<int16_t>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT16:
onnxruntime::utils::PrintCpuTensor<uint16_t>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
onnxruntime::utils::PrintCpuTensor<int32_t>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
onnxruntime::utils::PrintCpuTensor<uint32_t>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
onnxruntime::utils::PrintCpuTensor<int64_t>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
onnxruntime::utils::PrintCpuTensor<uint64_t>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
onnxruntime::utils::PrintCpuTensor<MLFloat16>(cpu_tensor);
break;
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
onnxruntime::utils::PrintCpuTensor<BFloat16>(cpu_tensor);
break;
default:
ORT_ENFORCE(false, "Unknown tensor type of ", dt->AsPrimitiveDataType()->GetDataType());
}
} else {
switch (dt->AsPrimitiveDataType()->GetDataType()) {
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
onnxruntime::utils::PrintCpuTensor<float>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
onnxruntime::utils::PrintCpuTensor<bool>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE:
onnxruntime::utils::PrintCpuTensor<double>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_STRING:
onnxruntime::utils::PrintCpuTensor<std::string>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
onnxruntime::utils::PrintCpuTensor<int8_t>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
onnxruntime::utils::PrintCpuTensor<uint8_t>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT16:
onnxruntime::utils::PrintCpuTensor<int16_t>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT16:
onnxruntime::utils::PrintCpuTensor<uint16_t>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT32:
onnxruntime::utils::PrintCpuTensor<int32_t>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT32:
onnxruntime::utils::PrintCpuTensor<uint32_t>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
onnxruntime::utils::PrintCpuTensor<int64_t>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_UINT64:
onnxruntime::utils::PrintCpuTensor<uint64_t>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
onnxruntime::utils::PrintCpuTensor<MLFloat16>(t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
onnxruntime::utils::PrintCpuTensor<BFloat16>(t);
break;
default:
ORT_ENFORCE(false, "Unknown tensor type of ", dt->AsPrimitiveDataType()->GetDataType());
}
}
}

onnxruntime::Status ExecuteKernel(StreamExecutionContext& ctx,
NodeIndex idx,
size_t stream_idx,
Expand Down Expand Up @@ -440,6 +554,19 @@ onnxruntime::Status ExecuteKernel(StreamExecutionContext& ctx,
ORT_THROW("Async Kernel Support is not implemented yet.");
} else {
KernelScope kernel_scope(session_scope, kernel_ctx, *p_kernel);
std::ifstream fstream("dump.txt");
if (fstream.is_open()) {
std::cout << "execute node:" << idx << "\n Inputs (len:" << kernel_ctx.InputCount() << ") :\n";
for (int i = 0; i < kernel_ctx.InputCount(); i++) {
const OrtValue* v = kernel_ctx.GetInputMLValue(i);
if (v->IsTensor()) {
const Tensor& t = v->Get<Tensor>();
std::cout << "i:" << i << " size:" << t.SizeInBytes() << " shape:"<<t.Shape().ToString()<<" element size:"<<t.DataType()->Size()<< "\n";
PrintTensorHelper(t, ctx);
}
}
fstream.close();
}
ORT_TRY {
#ifdef ENABLE_TRAINING
// AllocateInputsContiguously - is only required for NCCL kernels
Expand Down Expand Up @@ -479,6 +606,19 @@ onnxruntime::Status ExecuteKernel(StreamExecutionContext& ctx,
status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what());
});
}
std::ifstream fstream2("dump.txt");
if (fstream2.is_open()) {
std::cout << " Outputs (len:" << kernel_ctx.OutputCount() << ") :\n";
for (int i = 0; i < kernel_ctx.OutputCount(); i++) {
const OrtValue* v = kernel_ctx.GetOutputMLValue(i);
if (v->IsTensor()) {
const Tensor& t = v->Get<Tensor>();
std::cout << "i:" << i << " size:" << t.SizeInBytes() << " shape:"<<t.Shape().ToString()<<" element size:"<<t.DataType()->Size()<< "\n";
PrintTensorHelper(t, ctx);
}
}
fstream2.close();
}
}
if (!status.IsOK()) {
std::ostringstream ss;
Expand Down

0 comments on commit c414d7e

Please sign in to comment.