Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 27, 2023
1 parent f1ac17d commit afe4bb4
Showing 1 changed file with 86 additions and 10 deletions.
96 changes: 86 additions & 10 deletions onnxruntime/core/optimizer/insert_cast_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,27 +231,97 @@ static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, cons
return Status::OK();
}

enum TypeGroup {
Unknown = -1,
Bool = 0,
Integer = 1,
Unsigned = 2,
Float = 3,
};

TypeGroup GetTypeGroup(DataType type) {
if (*type == "tensor(bool)") {
return Bool;
}

if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)") {
return Integer;
}

if (*type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") {
return Unsigned;
}

if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") {
return Float;
}

return Unknown;
}

int BitLength(DataType type) {
if (*type == "tensor(bool)") {
return 1;
} else if (*type == "tensor(uint8)" || *type == "tensor(int8)") {
return 8;
} else if (*type == "tensor(int16)" || *type == "tensor(uint16)" || *type == "tensor(bfloat16)" || *type == "tensor(float16)") {
return 16;
} else if (*type == "tensor(int32)" || *type == "tensor(uint32)" || *type == "tensor(float)") {
return 32;
} else if (*type == "tensor(int64)" || *type == "tensor(uint64)" || *type == "tensor(double)") {
return 64;
} else {
return -1;
}
}

/** Transformer to remove duplicate Cast nodes. */
class RemoveDuplicateCastTransformer : public GraphTransformer {
public:
RemoveDuplicateCastTransformer() : GraphTransformer("RemoveDuplicateCastTransformer") {
}

private:
static constexpr std::array<std::string_view, 13> cast_ordering{
"tensor(bool)", "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)", "tensor(int8)", "tensor(int16)",
"tensor(int32)", "tensor(int64)", "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)"};
static bool UnsafeCast(DataType src_type, DataType dst_type, const Node& node) {
// This is not a complete cast optimisation pass, and is more conservative than it could be.
// For instance, certain integral -> floating point casts could be optimised but this is left to an explicit cast optimisation pass.

static bool LossOfPrecision(DataType src_type, DataType dst_type, const Node& node) {
// The comparison with "InsertedPrecisionFreeCast_" reflects cast nodes that are inserted by InsertCastTransformer.
// Such casts should not be considered as loss of precision - the inserted upcasts (f16 -> f32) and downcasts (f32 -> f16) are inserted to support kernels when on a CPU EP without F16 support.
auto src_pos = std::find(cast_ordering.begin(), cast_ordering.end(), *src_type);
auto dst_pos = std::find(cast_ordering.begin(), cast_ordering.end(), *dst_type);
if (src_pos == cast_ordering.end() || dst_pos == cast_ordering.end()) {
auto src_type_group = GetTypeGroup(src_type);
auto dst_type_group = GetTypeGroup(dst_type);
if (Unknown == src_type_group || Unknown == dst_type_group) {
return true;
}

// Do not remove any signed -> unsigned cast.
if ((src_type_group != Bool && src_type_group != Unsigned) && Unsigned == dst_type_group) {
return true;
}

// Do not remove any floating point -> non floating point cast.
if (Float == src_type_group && Float != dst_type_group) {
return true;
}

auto src_bit_length = BitLength(src_type);
auto dst_bit_length = BitLength(dst_type);

// unsigned integer -> integer cast may overflow if the destination integer is smaller or equal to the source integer.
if (Unsigned == src_type_group && Integer == dst_type_group) {
return dst_bit_length <= src_bit_length;
}

// integral -> floating cast may overflow if integer cannot be encoded in the mantissa. This check could be more precise.
if ((Integer == src_type_group || Unsigned == src_type_group) && Float == dst_type_group) {
return dst_bit_length <= src_bit_length;
}

if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) {
return true;
}

return std::distance(src_pos, dst_pos) < 0 && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_"));
return src_bit_length > dst_bit_length && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_"));
}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override {
Expand Down Expand Up @@ -286,7 +356,10 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
auto src_type = node.InputDefs()[0]->Type();
auto dst_type = node.OutputDefs()[0]->Type();

bool loss_precision_cast = LossOfPrecision(src_type, dst_type, node);
bool loss_precision_cast = UnsafeCast(src_type, dst_type, node);
std::cout << "--------------\n";
std::cout << "First cast: " << *src_type << " -> " << *dst_type << std::endl;
std::cout << "Loss precision cast: " << loss_precision_cast << std::endl;

size_t num_children = node.GetOutputEdgesCount();

Expand All @@ -296,8 +369,11 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
if (output_node.OpType() == "Cast") {
auto src_type1 = output_node.InputDefs()[0]->Type();
auto dst_type1 = output_node.OutputDefs()[0]->Type();
if (loss_precision_cast && LossOfPrecision(dst_type1, src_type1, output_node)) {
std::cout << "Second cast: " << *src_type1 << " -> " << *dst_type1 << std::endl;
std::cout << "Loss precision cast (reverse): " << UnsafeCast(dst_type1, src_type1, output_node) << std::endl;
if (loss_precision_cast && UnsafeCast(dst_type1, src_type1, output_node)) {
inconsistent_casts = true;
std::cout << "Inconsistent casts" << std::endl;
break;
}

Expand Down

0 comments on commit afe4bb4

Please sign in to comment.