diff --git a/core/src/floats.rs b/core/src/floats.rs index c4228eb351..09ce00a966 100644 --- a/core/src/floats.rs +++ b/core/src/floats.rs @@ -4,6 +4,7 @@ use crate::internal::translator::Translate; use crate::internal::*; use crate::ops::array::{Pad, PadMode}; use crate::ops::einsum::EinSum; +use crate::ops::cast::Cast; use crate::ops::konst::Const; use crate::ops::scan::Scan; use crate::ops::source::TypedSource; @@ -39,6 +40,12 @@ impl Box::new(TypedSource::new(fact_float_precision_conversion::(&source.fact))) } else if let Some(konst) = node.op_as::() { Box::new(Const(tensor_float_precision_conversion::(&konst.0))) + } else if let Some(cast) = node.op_as::() { + if cast.to == T1::datum_type() { + Box::new(Cast { to: T2::datum_type() }) + } else { + node.op.clone() + } } else if let Some(op) = node.op_as::() { let body = FloatPrecisionTranslator::::default().translate_model(&op.body)?; Box::new(Scan { body, ..op.clone() })