Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into snnn/update_fetch_con…
Browse files Browse the repository at this point in the history
…tent
  • Loading branch information
snnn committed Jul 12, 2024
2 parents 046b29d + 42b7ced commit d4264a4
Showing 1 changed file with 35 additions and 11 deletions.
46 changes: 35 additions & 11 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,7 @@ struct ProviderHostImpl : ProviderHost {

const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; }

static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) {
auto* shape = ctx.getAttribute("shape");
auto* data_type = ctx.getAttribute("data_type");
static int32_t convert_elem_type(const ONNX_NAMESPACE::AttributeProto* data_type) {
int32_t elemType = 0;
if (data_type->s() == "float32") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
Expand Down Expand Up @@ -650,17 +648,43 @@ struct ProviderHostImpl : ProviderHost {
elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT4;
} else if (data_type->s() == "int4") {
elemType = ONNX_NAMESPACE::TensorProto_DataType_INT4;
} else {
return;
}
ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType);
if (shape != nullptr) {
for (auto i = 0; i < shape->ints_size(); ++i) {
ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i));
return elemType;
}

static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) {
auto num_output = ctx.getNumOutputs();
if (num_output == 1) {
auto* shape = ctx.getAttribute("shape");
auto* data_type = ctx.getAttribute("data_type");
if (data_type == nullptr) {
std::cerr << "Custom op is missing `data_type` attr." << std::endl;
return;
}
int32_t elemType = convert_elem_type(data_type);
ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType);
if (shape != nullptr) {
for (auto i = 0; i < shape->ints_size(); ++i) {
ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i));
}
} else {
// set scalar type.
ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim();
}
} else {
// set scalar type.
ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim();
for (auto idx = 0u; idx < num_output; idx++) {
auto* shape = ctx.getAttribute("shape_" + std::to_string(idx));
auto* data_type = ctx.getAttribute("data_type_" + std::to_string(idx));
if (shape == nullptr || data_type == nullptr) {
// this output is optional
} else {
int32_t elemType = convert_elem_type(data_type);
ONNX_NAMESPACE::updateOutputElemType(ctx, idx, elemType);
for (auto i = 0; i < shape->ints_size(); ++i) {
ONNX_NAMESPACE::getOutputShape(ctx, idx, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i));
}
}
}
}
}

Expand Down

0 comments on commit d4264a4

Please sign in to comment.