Skip to content

Commit

Permalink
Replace some ORT_ENFORCE with ORT_THROW_IF_ERROR (#18812)
Browse files Browse the repository at this point in the history
### Description
Replace some ORT_ENFORCE with ORT_THROW_IF_ERROR to get better error
messages.
  • Loading branch information
snnn authored Dec 14, 2023
1 parent 95193cb commit 7386e21
Show file tree
Hide file tree
Showing 15 changed files with 34 additions and 35 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/image_scaler.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ template <typename T>
class ImageScaler final : public OpKernel {
public:
ImageScaler(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttr<float>("scale", &scale_).IsOK());
ORT_ENFORCE(info.GetAttrs<float>("bias", bias_).IsOK());
ORT_THROW_IF_ERROR(info.GetAttr<float>("scale", &scale_));
ORT_THROW_IF_ERROR(info.GetAttrs<float>("bias", bias_));
}

Status Compute(OpKernelContext* context) const override {
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cuda/collective/sharding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info
// stored on a 1-D mesh with 2 devices and the second input on another 1-D
// mesh with 1 device.
std::vector<std::string> attr_input_device_mesh_shapes;
ORT_ENFORCE(info.GetAttrs<std::string>("input_device_mesh_shapes", attr_input_device_mesh_shapes).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<std::string>("input_device_mesh_shapes", attr_input_device_mesh_shapes));

// input_device_mesh_elements[i] is the flattened device mesh for the i-th input.
// Note that its actual shape is input_device_mesh_shapes[i].
Expand All @@ -255,12 +255,12 @@ DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info
// Then the first input is stored on a 1-D mesh with 2 devices and the second
// input on another 1-D mesh with 1 device.
std::vector<std::string> attr_input_device_mesh_elements;
ORT_ENFORCE(info.GetAttrs<std::string>("input_device_mesh_elements", attr_input_device_mesh_elements).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<std::string>("input_device_mesh_elements", attr_input_device_mesh_elements));

// input_shard_specs[i] is the sharding spec of the i-th input; e.g.,
// "RR" if the i-th input is not sharded.
std::vector<std::string> input_shard_specs;
ORT_ENFORCE(info.GetAttrs<std::string>("input_shard_specs", input_shard_specs).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<std::string>("input_shard_specs", input_shard_specs));

ORT_ENFORCE(attr_input_device_mesh_shapes.size() == attr_input_device_mesh_elements.size());
ORT_ENFORCE(attr_input_device_mesh_shapes.size() == input_shard_specs.size());
Expand All @@ -274,13 +274,13 @@ DistributedKernel::DistributedKernel(const OpKernelInfo& info) : NcclKernel(info
}

std::vector<std::string> attr_output_device_mesh_shapes;
ORT_ENFORCE(info.GetAttrs<std::string>("output_device_mesh_shapes", attr_output_device_mesh_shapes).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<std::string>("output_device_mesh_shapes", attr_output_device_mesh_shapes));

std::vector<std::string> attr_output_device_mesh_elements;
ORT_ENFORCE(info.GetAttrs<std::string>("output_device_mesh_elements", attr_output_device_mesh_elements).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<std::string>("output_device_mesh_elements", attr_output_device_mesh_elements));

std::vector<std::string> output_shard_specs;
ORT_ENFORCE(info.GetAttrs<std::string>("output_shard_specs", output_shard_specs).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<std::string>("output_shard_specs", output_shard_specs));

ORT_ENFORCE(attr_output_device_mesh_shapes.size() == attr_output_device_mesh_elements.size());
ORT_ENFORCE(attr_output_device_mesh_shapes.size() == output_shard_specs.size());
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/tensor/image_scaler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ REGISTER_KERNEL_TYPED(MLFloat16)

template <typename T>
ImageScaler<T>::ImageScaler(const OpKernelInfo& info) : CudaKernel(info) {
ORT_ENFORCE(info.GetAttr<float>("scale", &scale_).IsOK());
ORT_ENFORCE(info.GetAttrs<float>("bias", bias_).IsOK());
ORT_THROW_IF_ERROR(info.GetAttr<float>("scale", &scale_));
ORT_THROW_IF_ERROR(info.GetAttrs<float>("bias", bias_));

b_data_ = GetScratchBuffer<float>(bias_.size(), nullptr);
// the transfer in kernel construction need to be sync on default stream.
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/codegen/passes/op_ir_creator/nn/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ Status GENERIC_OP_IR_CREATOR_CLASS(Conv)::Evaluate(
info.GetAttrOrDefault<int64_t>("group", &group, 1);
info.GetAttrOrDefault<std::string>("auto_pad", &auto_pad, "NOTSET");

ORT_ENFORCE(info.GetAttrs<int64_t>("kernel_shape", kernel_shape).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<int64_t>("kernel_shape", kernel_shape));
ORT_ENFORCE(kernel_shape.size() <= 2, "Only support 1D/2D convolution currently!");
ORT_ENFORCE(info.GetAttrs<int64_t>("strides", strides).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<int64_t>("strides", strides));

dilations = info.GetAttrs<int64_t>("dilations", dilations).IsOK() ? dilations : std::vector<int64_t>(kernel_shape.size(), 1);
ORT_ENFORCE(dilations == std::vector<int64_t>(kernel_shape.size(), 1), "Only support dilation is 1 currently");
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/codegen/passes/op_ir_creator/tensor/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ Status GENERIC_OP_IR_CREATOR_CLASS(Pad)::Evaluate(
std::vector<int64_t> pads;
float value;

ORT_ENFORCE(attrs.GetAttr<std::string>("mode", &mode).IsOK());
ORT_ENFORCE(attrs.GetAttrs<int64_t>("pads", pads).IsOK());
ORT_ENFORCE(attrs.GetAttr<float>("value", &value).IsOK());
ORT_THROW_IF_ERROR(attrs.GetAttr<std::string>("mode", &mode));
ORT_THROW_IF_ERROR(attrs.GetAttrs<int64_t>("pads", pads));
ORT_THROW_IF_ERROR(attrs.GetAttr<float>("value", &value));

if (mode != "constant" && mode != "edge" && mode != "reflect")
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Pad: Unsupported padding mode!");
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/ml/category_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ class CategoryMapper final : public OpKernel {
std::vector<std::string> string_categories;
std::vector<int64_t> int_categories;

ORT_ENFORCE(info.GetAttrs<std::string>("cats_strings", string_categories).IsOK());
ORT_ENFORCE(info.GetAttrs<int64_t>("cats_int64s", int_categories).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<std::string>("cats_strings", string_categories));
ORT_THROW_IF_ERROR(info.GetAttrs<int64_t>("cats_int64s", int_categories));

ORT_ENFORCE(info.GetAttr<std::string>("default_string", &default_string_).IsOK());
ORT_ENFORCE(info.GetAttr<int64_t>("default_int64", &default_int_).IsOK());
ORT_THROW_IF_ERROR(info.GetAttr<std::string>("default_string", &default_string_));
ORT_THROW_IF_ERROR(info.GetAttr<int64_t>("default_int64", &default_int_));

auto num_entries = string_categories.size();

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/cpu/ml/label_encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class LabelEncoder final : public OpKernel {
LabelEncoder(const OpKernelInfo& info) : OpKernel(info) {
std::vector<std::string> string_classes;

ORT_ENFORCE(info.GetAttrs<std::string>("classes_strings", string_classes).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<std::string>("classes_strings", string_classes));

ORT_ENFORCE(info.GetAttr<std::string>("default_string", &default_string_).IsOK());
ORT_ENFORCE(info.GetAttr<int64_t>("default_int64", &default_int_).IsOK());
Expand Down Expand Up @@ -53,8 +53,8 @@ class LabelEncoder_2 final : public OpKernel {
std::vector<TKey> keys;
std::vector<TValue> values;

ORT_ENFORCE(info.GetAttrs<TKey>(_key_field_name, keys).IsOK());
ORT_ENFORCE(info.GetAttrs<TValue>(_value_field_name, values).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<TKey>(_key_field_name, keys));
ORT_THROW_IF_ERROR(info.GetAttrs<TValue>(_value_field_name, values));

auto num_keys = keys.size();
auto num_values = values.size();
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cpu/ml/linearregressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ LinearRegressor::LinearRegressor(const OpKernelInfo& info)
: OpKernel(info),
intercepts_(info.GetAttrsOrDefault<float>("intercepts")),
post_transform_(MakeTransform(info.GetAttrOrDefault<std::string>("post_transform", "NONE"))) {
ORT_ENFORCE(info.GetAttr<int64_t>("targets", &num_targets_).IsOK());
ORT_ENFORCE(info.GetAttrs<float>("coefficients", coefficients_).IsOK());
ORT_THROW_IF_ERROR(info.GetAttr<int64_t>("targets", &num_targets_));
ORT_THROW_IF_ERROR(info.GetAttrs<float>("coefficients", coefficients_));

// use the intercepts_ if they're valid
use_intercepts_ = intercepts_.size() == static_cast<size_t>(num_targets_);
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cpu/ml/svmclassifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ SVMClassifier::SVMClassifier(const OpKernelInfo& info)
probb_(info.GetAttrsOrDefault<float>("prob_b")),
support_vectors_(info.GetAttrsOrDefault<float>("support_vectors")),
post_transform_(MakeTransform(info.GetAttrOrDefault<std::string>("post_transform", "NONE"))) {
ORT_ENFORCE(info.GetAttrs<float>("rho", rho_).IsOK());
ORT_ENFORCE(info.GetAttrs<float>("coefficients", coefficients_).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<float>("rho", rho_));
ORT_THROW_IF_ERROR(info.GetAttrs<float>("coefficients", coefficients_));

// prob_a and prob_b are optional for Z output
ORT_ENFORCE(proba_.size() == probb_.size());
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/ml/svmclassifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class SVMCommon {
SVMCommon(const OpKernelInfo& info)
: kernel_type_(MakeKernel(info.GetAttrOrDefault<std::string>("kernel_type", "LINEAR"))) {
std::vector<float> kernel_params;
ORT_ENFORCE(info.GetAttrs<float>("kernel_params", kernel_params).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<float>("kernel_params", kernel_params));

if (!kernel_params.empty()) {
gamma_ = kernel_params[0];
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/cpu/ml/svmregressor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ SVMRegressor<T>::SVMRegressor(const OpKernelInfo& info)
support_vectors_(info.GetAttrsOrDefault<float>("support_vectors")),
post_transform_(MakeTransform(info.GetAttrOrDefault<std::string>("post_transform", "NONE"))) {
int64_t vector_count = 0;
ORT_ENFORCE(info.GetAttr<int64_t>("n_supports", &vector_count).IsOK());
ORT_THROW_IF_ERROR(info.GetAttr<int64_t>("n_supports", &vector_count));
vector_count_ = narrow<ptrdiff_t>(vector_count);
ORT_ENFORCE(info.GetAttrs<float>("rho", rho_).IsOK());
ORT_ENFORCE(info.GetAttrs<float>("coefficients", coefficients_).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<float>("rho", rho_));
ORT_THROW_IF_ERROR(info.GetAttrs<float>("coefficients", coefficients_));
ORT_ENFORCE(!coefficients_.empty());

auto onec = info.GetAttrOrDefault<int64_t>("one_class", 0);
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/nn/roi_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class RoiPool : public OpKernel {
public:
RoiPool(const OpKernelInfo& info) : OpKernel(info) {
std::vector<int64_t> pooled_shape;
ORT_ENFORCE(info.GetAttrs<int64_t>("pooled_shape", pooled_shape).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<int64_t>("pooled_shape", pooled_shape));
ORT_ENFORCE(pooled_shape.size() == 2);

pooled_height_ = pooled_shape[0];
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/cpu/nn/unpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ namespace onnxruntime {
class MaxUnpool : public OpKernel {
public:
MaxUnpool(const OpKernelInfo& info) : OpKernel(info) {
ORT_ENFORCE(info.GetAttrs<int64_t>("kernel_shape", kernel_shape_).IsOK(),
"No kernel shape is set.");
ORT_THROW_IF_ERROR(info.GetAttrs<int64_t>("kernel_shape", kernel_shape_));

num_inputs_ = OpKernel::Node().InputDefs().size();

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/tensor/upsamplebase.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class UpsampleBase {

auto input_count = info.GetInputCount();
if (input_count == 1) { // opset < 10
ORT_ENFORCE(info.GetAttrs<float>("scales", scales_).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<float>("scales", scales_));
ORT_THROW_IF_ERROR(ScalesValidation(scales_, mode_));
scales_cached_ = true;
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/js/operators/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class ConvBase : public JsKernel {
}
if (is_fused_conv) {
ORT_THROW_IF_ERROR(info.GetAttr<std::string>("activation", &conv_attrs_.activation));
ORT_ENFORCE(info.GetAttrs<float>("activation_params", activation_params).IsOK());
ORT_THROW_IF_ERROR(info.GetAttrs<float>("activation_params", activation_params));
} else {
conv_attrs_.activation = info.GetAttrOrDefault<std::string>("activation", "");
activation_params = info.GetAttrsOrDefault<float>("activation_params", activation_params);
Expand Down

0 comments on commit 7386e21

Please sign in to comment.