Skip to content

Commit

Permalink
#14032: set bfp8_pack_precise based on op and dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT committed Nov 13, 2024
1 parent 3a3b034 commit abcbdf0
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 8 deletions.
11 changes: 6 additions & 5 deletions ttnn/cpp/ttnn/operations/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,19 @@ inline Tensor copy_impl(
const std::vector<ttnn::operations::unary::UnaryWithParam>& op_chain,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt) {
DataType output_dtype = (op_chain[0].op_type == ttnn::operations::unary::UnaryOpType::TYPECAST) ? static_cast<DataType>(op_chain[0].params[1]) : input_tensor.get_dtype();
bool preserve_fp32_precision = (op_chain[0].op_type == ttnn::operations::unary::UnaryOpType::TYPECAST) and (input_tensor.get_dtype() == DataType::FLOAT32);
DataType output_dtype = (op_chain[0].op_type == unary::UnaryOpType::TYPECAST) ? static_cast<DataType>(op_chain[0].params[1]) : input_tensor.get_dtype();
auto arch = input_tensor.device()->arch();
bool preserve_fp32_precision = (arch != tt::ARCH::GRAYSKULL) and (input_tensor.get_dtype() == DataType::FLOAT32);
bool fp32_dest_acc_en = preserve_fp32_precision or
output_dtype == DataType::UINT32 or
output_dtype == DataType::INT32 or
output_dtype == DataType::FLOAT32 or
input_tensor.get_dtype() == DataType::UINT32 or
input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to
// DST directly, fp32 is converted to fp16b
input_tensor.get_dtype() == DataType::INT32;
bool bfp8_pack_precise = (op_chain[0].op_type == unary::UnaryOpType::TYPECAST && output_dtype == DataType::BFLOAT8_B);

auto output_memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(input_tensor.memory_config());
return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, optional_output_tensor);
return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, bfp8_pack_precise, optional_output_tensor);
}
} // namespace detail

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ UnaryDeviceOperation::invoke(
const MemoryConfig& output_memory_config,
bool fp32_dest_acc_en,
bool preserve_fp32_precision,
bool bfp8_pack_precise,
const std::optional<Tensor>& preallocated_output) {
return {
operation_attributes_t{
Expand All @@ -200,6 +201,7 @@ UnaryDeviceOperation::invoke(
.output_memory_config = output_memory_config,
.fp32_dest_acc_en = fp32_dest_acc_en,
.preserve_fp32_precision = preserve_fp32_precision,
.bfp8_pack_precise = bfp8_pack_precise,
},
tensor_args_t{.input = input, .preallocated_output = preallocated_output}};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ struct UnaryDeviceOperation {
const MemoryConfig& output_memory_config,
bool fp32_dest_acc_en,
bool preserve_fp32_precision,
bool bfp8_pack_precise,
const std::optional<Tensor>& preallocated_output);
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct operation_attributes_t {
const MemoryConfig output_memory_config;
const bool fp32_dest_acc_en = false;
const bool preserve_fp32_precision = false;
const bool bfp8_pack_precise = false;
};

struct tensor_args_t {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ UnaryProgramFactory::cached_program_t UnaryProgramFactory::create(
.math_fidelity = MathFidelity::HiFi4,
.fp32_dest_acc_en = args.fp32_dest_acc_en,
.unpack_to_dest_mode = unpack_to_dest_mode,
.bfp8_pack_precise = true,
.bfp8_pack_precise = args.bfp8_pack_precise,
.math_approx_mode = math_approx_mode,
.compile_args = compute_kernel_args_group_1,
.defines = unary_defines});
Expand All @@ -120,7 +120,7 @@ UnaryProgramFactory::cached_program_t UnaryProgramFactory::create(
.math_fidelity = MathFidelity::HiFi4,
.fp32_dest_acc_en = args.fp32_dest_acc_en,
.unpack_to_dest_mode = unpack_to_dest_mode,
.bfp8_pack_precise = true,
.bfp8_pack_precise = args.bfp8_pack_precise,
.math_approx_mode = math_approx_mode,
.compile_args = compute_kernel_args_group_2,
.defines = unary_defines});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ UnaryShardedProgramFactory::cached_program_t UnaryShardedProgramFactory::create(
.math_fidelity = MathFidelity::HiFi4,
.fp32_dest_acc_en = args.fp32_dest_acc_en,
.unpack_to_dest_mode = unpack_to_dest_mode,
.bfp8_pack_precise = args.bfp8_pack_precise,
.math_approx_mode = math_approx_mode,
.compile_args = compute_kernel_args_group_1,
.defines = unary_defines});
Expand Down
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ inline Tensor unary_impl(
output_dtype == DataType::FLOAT32 or
input_tensor.get_dtype() == DataType::UINT32 or
input_tensor.get_dtype() == DataType::INT32;
bool bfp8_pack_precise = (op_chain[0].op_type == UnaryOpType::TYPECAST && output_dtype == DataType::BFLOAT8_B);

auto output_memory_config = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(input_tensor.memory_config());
return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, optional_output_tensor);
return prim::unary(queue_id, input_tensor, op_chain, output_dtype, output_memory_config, fp32_dest_acc_en, preserve_fp32_precision, bfp8_pack_precise, optional_output_tensor);
}

} // namespace detail
Expand Down

0 comments on commit abcbdf0

Please sign in to comment.