diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index bc2f561b17a..2ab4686b5f4 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -42,11 +42,12 @@ inline Tensor execute_on_worker_thread( const Tensor& input_tensor, const std::vector& op_chain, const std::optional& memory_config = std::nullopt) { + DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? DataType::UINT32 : input_tensor.get_dtype(); bool fp32_dest_acc_en = input_tensor.get_dtype() == DataType::UINT32 or input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to // DST directly, fp32 is converted to fp16b return operation::run( - EltwiseUnary{op_chain, memory_config.value_or(input_tensor.memory_config()), fp32_dest_acc_en}, + EltwiseUnary{op_chain, memory_config.value_or(input_tensor.memory_config()), fp32_dest_acc_en, output_dtype}, {input_tensor}) .at(0); }