diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 5e1d933154045..ff81fe23530b4 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -231,6 +231,50 @@ static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, cons return Status::OK(); } +enum TypeGroup { + Unknown = -1, + Bool = 0, + Integer = 1, + Unsigned = 2, + Float = 3, +}; + +TypeGroup GetTypeGroup(DataType type) { + if (*type == "tensor(bool)") { + return Bool; + } + + if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)") { + return Integer; + } + + if (*type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") { + return Unsigned; + } + + if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") { + return Float; + } + + return Unknown; +} + +int BitLength(DataType type) { + if (*type == "tensor(bool)") { + return 1; + } else if (*type == "tensor(uint8)" || *type == "tensor(int8)") { + return 8; + } else if (*type == "tensor(int16)" || *type == "tensor(uint16)" || *type == "tensor(bfloat16)" || *type == "tensor(float16)") { + return 16; + } else if (*type == "tensor(int32)" || *type == "tensor(uint32)" || *type == "tensor(float)") { + return 32; + } else if (*type == "tensor(int64)" || *type == "tensor(uint64)" || *type == "tensor(double)") { + return 64; + } else { + return -1; + } +} + /** Transformer to remove duplicate Cast nodes. */ class RemoveDuplicateCastTransformer : public GraphTransformer { public: @@ -238,20 +282,46 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { } private: - static constexpr std::array cast_ordering{ - "tensor(bool)", "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)", "tensor(int8)", "tensor(int16)", - "tensor(int32)", "tensor(int64)", "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)"}; + static bool UnsafeCast(DataType src_type, DataType dst_type, const Node& node) { + // This is not a complete cast optimisation pass, and is more conservative than it could be. + // For instance, certain integral -> floating point casts could be optimised but this is left to an explicit cast optimisation pass. - static bool LossOfPrecision(DataType src_type, DataType dst_type, const Node& node) { // The comparison with "InsertedPrecisionFreeCast_" reflects cast nodes that are inserted by InsertCastTransformer. // Such casts should not be considered as loss of precision - the inserted upcasts (f16 -> f32) and downcasts (f32 -> f16) are inserted to support kernels when on a CPU EP without F16 support. - auto src_pos = std::find(cast_ordering.begin(), cast_ordering.end(), *src_type); - auto dst_pos = std::find(cast_ordering.begin(), cast_ordering.end(), *dst_type); - if (src_pos == cast_ordering.end() || dst_pos == cast_ordering.end()) { + auto src_type_group = GetTypeGroup(src_type); + auto dst_type_group = GetTypeGroup(dst_type); + if (Unknown == src_type_group || Unknown == dst_type_group) { + return true; + } + + // Do not remove any signed -> unsigned cast. + if ((src_type_group != Bool && src_type_group != Unsigned) && Unsigned == dst_type_group) { + return true; + } + + // Do not remove any floating point -> non floating point cast. + if (Float == src_type_group && Float != dst_type_group) { + return true; + } + + auto src_bit_length = BitLength(src_type); + auto dst_bit_length = BitLength(dst_type); + + // unsigned integer -> integer cast may overflow if the destination integer is smaller or equal to the source integer. + if (Unsigned == src_type_group && Integer == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + // integral -> floating cast may overflow if integer cannot be encoded in the mantissa. This check could be more precise. + if ((Integer == src_type_group || Unsigned == src_type_group) && Float == dst_type_group) { + return dst_bit_length <= src_bit_length; + } + + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { return true; } - return std::distance(src_pos, dst_pos) < 0 && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_")); + return src_bit_length > dst_bit_length && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_")); } Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override { @@ -286,7 +356,10 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { auto src_type = node.InputDefs()[0]->Type(); auto dst_type = node.OutputDefs()[0]->Type(); - bool loss_precision_cast = LossOfPrecision(src_type, dst_type, node); + bool loss_precision_cast = UnsafeCast(src_type, dst_type, node); + std::cout << "--------------\n"; + std::cout << "First cast: " << *src_type << " -> " << *dst_type << std::endl; + std::cout << "Loss precision cast: " << loss_precision_cast << std::endl; size_t num_children = node.GetOutputEdgesCount(); @@ -296,8 +369,11 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { if (output_node.OpType() == "Cast") { auto src_type1 = output_node.InputDefs()[0]->Type(); auto dst_type1 = output_node.OutputDefs()[0]->Type(); - if (loss_precision_cast && LossOfPrecision(dst_type1, src_type1, output_node)) { + std::cout << "Second cast: " << *src_type1 << " -> " << *dst_type1 << std::endl; + std::cout << "Loss precision cast (reverse): " << UnsafeCast(dst_type1, src_type1, output_node) << std::endl; + if (loss_precision_cast && UnsafeCast(dst_type1, src_type1, output_node)) { inconsistent_casts = true; + std::cout << "Inconsistent casts" << std::endl; break; }