Skip to content

Commit

Permalink
Minor Triton Fix (#19589)
Browse files Browse the repository at this point in the history
Including removing a unnecessary assert, and add support of passing
string attribute from ONNX node attribute to python functoin kwargs
(mainly for passing debug info from graph to python for now).
  • Loading branch information
centwang authored Feb 22, 2024
1 parent 5197db1 commit 3d88487
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ void TritonOpExecutor::ExecuteByFuncName(const std::string& func_name, const Inl
PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyLong_FromLongLong(std::stoll(kv.second.first)));
} else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyFloat_FromDouble(std::stod(kv.second.first)));
} else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_STRING) {
PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyUnicode_FromString(kv.second.first.c_str()));
} else {
ORT_THROW("Unsupported kwargs data type: ", kv.second.second);
}
Expand Down
3 changes: 2 additions & 1 deletion orttraining/orttraining/python/training/ort_triton/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,14 @@ def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tupl


def next_power_of_2(n: int) -> int:
assert n <= 2**32, "32-bit only"
"""Return the smallest power of 2 greater than or equal to n"""
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n += 1
return n

Expand Down
5 changes: 4 additions & 1 deletion orttraining/orttraining/training_ops/cpu/triton/triton_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ class TritonOp final : public OpKernel {
attr.first == "onnx_string") {
continue;
}
// Support int64 and float only for now, skip other types.
// Support int64, float and string only for now, skip other types.
if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) {
kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}});
} else if (attr.second.type() ==
ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) {
kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}});
} else if (attr.second.type() ==
ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_STRING) {
kwargs_.insert({attr.first, {attr.second.s(), ONNX_NAMESPACE::TensorProto_DataType_STRING}});
}
}
}
Expand Down

0 comments on commit 3d88487

Please sign in to comment.