diff --git a/metal/src/ops/cast.rs b/metal/src/ops/cast.rs index 7f8def6b3e..47a36e2232 100644 --- a/metal/src/ops/cast.rs +++ b/metal/src/ops/cast.rs @@ -1,5 +1,5 @@ +use crate::kernels; use crate::tensor::MetalTensorExt; -use crate::{kernels, MetalTensor}; use tract_core::internal::*; #[derive(Debug, Clone, Hash, PartialEq, Eq)] @@ -8,8 +8,12 @@ pub struct MetalCast { } impl MetalCast { + pub fn is_supported_dt(dt: DatumType) -> bool { + kernels::array::Cast::is_supported_dt(dt) + } + pub fn new(to: DatumType) -> Option { - MetalTensor::is_supported_dt(to).then(|| Self { to }) + Self::is_supported_dt(to).then_some(Self { to }) } } diff --git a/metal/src/transform.rs b/metal/src/transform.rs index 9256f533d6..b5bec6e776 100644 --- a/metal/src/transform.rs +++ b/metal/src/transform.rs @@ -180,7 +180,10 @@ impl Translate, TypedFact, Box> for Met } else if let Some(op) = node.op_as::() { ops::MetalConst::new(op.0.clone())?.map(|o| -> Box { Box::new(o) }) } else if let Some(op) = node.op_as::() { - ops::MetalCast::new(op.to).map(|o| -> Box { Box::new(o) }) + check_in_dts_are_supported(source, node.id, ops::MetalCast::is_supported_dt)? + .then(|| ops::MetalCast::new(op.to)) + .flatten() + .map(|o| -> Box { Box::new(o) }) } else if let Some(op) = node.op_as::() { ops::MetalAxisOp::from_tract_core(op.clone()) .map(|o| -> Box { Box::new(o) })