Skip to content

Commit

Permalink
add bfloat16 support for ConcatTraining and SplitTraining ops (micros…
Browse files Browse the repository at this point in the history
…oft#18280)

### Description
<!-- Describe your changes. -->

Updates input/output type constraints on training operators
ConcatTraining and SplitTraining to include bfloat16 which was
introduced in IR version 4.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Enabling `meta-llama/Llama-2-70b` to be finetuned with ONNX Runtime
training.

Co-authored-by: Prathik Rao <[email protected]@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
  • Loading branch information
2 people authored and kleiti committed Mar 22, 2024
1 parent c240c4f commit be30ebb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2193,7 +2193,7 @@ Example 4:
OpSchema::Variadic)
.TypeConstraint(
"T",
OpSchema::all_tensor_types(),
OpSchema::all_tensor_types_ir4(),
"Constrain input and output types to all tensor types.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
for (int i = 0; i < static_cast<int>(ctx.getNumOutputs()); ++i) {
Expand Down Expand Up @@ -2270,7 +2270,7 @@ Example 4:
OpSchema::Optional)
.TypeConstraint(
"T",
OpSchema::all_tensor_types(),
OpSchema::all_tensor_types_ir4(),
"Constrain output types to any tensor type.")
.TypeConstraint(
"Tint",
Expand Down

0 comments on commit be30ebb

Please sign in to comment.