Skip to content

Commit

Permalink
Fix unsupported dt for Metal Cast
Browse files Browse the repository at this point in the history
  • Loading branch information
hubertdelajonquieresonos authored and kali committed Oct 1, 2024
1 parent 9a53232 commit 8441bf3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
8 changes: 6 additions & 2 deletions metal/src/ops/cast.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::kernels;
use crate::tensor::MetalTensorExt;
use crate::{kernels, MetalTensor};
use tract_core::internal::*;

#[derive(Debug, Clone, Hash, PartialEq, Eq)]
Expand All @@ -8,8 +8,12 @@ pub struct MetalCast {
}

impl MetalCast {
pub fn is_supported_dt(dt: DatumType) -> bool {
kernels::array::Cast::is_supported_dt(dt)
}

pub fn new(to: DatumType) -> Option<Self> {
MetalTensor::is_supported_dt(to).then(|| Self { to })
Self::is_supported_dt(to).then_some(Self { to })
}
}

Expand Down
5 changes: 4 additions & 1 deletion metal/src/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for Met
} else if let Some(op) = node.op_as::<Const>() {
ops::MetalConst::new(op.0.clone())?.map(|o| -> Box<dyn TypedOp> { Box::new(o) })
} else if let Some(op) = node.op_as::<Cast>() {
ops::MetalCast::new(op.to).map(|o| -> Box<dyn TypedOp> { Box::new(o) })
check_in_dts_are_supported(source, node.id, ops::MetalCast::is_supported_dt)?
.then(|| ops::MetalCast::new(op.to))
.flatten()
.map(|o| -> Box<dyn TypedOp> { Box::new(o) })
} else if let Some(op) = node.op_as::<AxisOp>() {
ops::MetalAxisOp::from_tract_core(op.clone())
.map(|o| -> Box<dyn TypedOp> { Box::new(o) })
Expand Down

0 comments on commit 8441bf3

Please sign in to comment.