diff --git a/ttnn/cpp/pybind11/operations/matmul.hpp b/ttnn/cpp/pybind11/operations/matmul.hpp index 2fde4b10d8f..56bb82534d1 100644 --- a/ttnn/cpp/pybind11/operations/matmul.hpp +++ b/ttnn/cpp/pybind11/operations/matmul.hpp @@ -28,7 +28,7 @@ void py_module(py::module& module) { const std::optional compute_kernel_config = std::nullopt, const std::optional core_grid = std::nullopt) -> ttnn::Tensor { return ttnn::operations::matmul::matmul( - input_tensor_a, input_tensor_b, /*bias=*/std::nullopt, program_config, memory_config, dtype, activation, compute_kernel_config, core_grid); + input_tensor_a, input_tensor_b, /*bias=*/std::nullopt, program_config, memory_config, dtype, activation, compute_kernel_config, core_grid, /*propagate_is_b_batched=*/true); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), diff --git a/ttnn/cpp/ttnn/operations/matmul.cpp b/ttnn/cpp/ttnn/operations/matmul.cpp index 584ed000386..6566f8f316a 100644 --- a/ttnn/cpp/ttnn/operations/matmul.cpp +++ b/ttnn/cpp/ttnn/operations/matmul.cpp @@ -58,7 +58,8 @@ ttnn::Tensor matmul( std::optional dtype, const std::optional& activation, const std::optional compute_kernel_config, - const std::optional core_grid) { + const std::optional core_grid, + const bool propagate_is_b_batched) { ttnn::validate_input_tensor("ttnn.linear", input_tensor_a, input_tensor_schemas()[0]); ttnn::validate_input_tensor("ttnn.linear", input_tensor_b, input_tensor_schemas()[1]); ttnn::validate_input_tensor("ttnn.linear", bias, input_tensor_schemas()[2]); @@ -92,7 +93,7 @@ ttnn::Tensor matmul( } auto output_tensor = tt::operations::primary::matmul( - input_tensor_a, input_tensor_b, post_process_bias ? std::nullopt : bias, program_config, memory_config, dtype, compute_kernel_config, false /*untilize_out*/, user_core_coord, get_fused_activation(activation)); + input_tensor_a, input_tensor_b, post_process_bias ? std::nullopt : bias, program_config, memory_config, dtype, compute_kernel_config, false /*untilize_out*/, user_core_coord, get_fused_activation(activation), propagate_is_b_batched && input_b_is_batched); if (post_process_bias) { output_tensor = tt::operations::primary::bcast( diff --git a/ttnn/cpp/ttnn/operations/matmul.hpp b/ttnn/cpp/ttnn/operations/matmul.hpp index a54058ad789..1b89ee82412 100644 --- a/ttnn/cpp/ttnn/operations/matmul.hpp +++ b/ttnn/cpp/ttnn/operations/matmul.hpp @@ -41,7 +41,8 @@ ttnn::Tensor matmul( std::optional dtype = std::nullopt, const std::optional& activation = std::nullopt, const std::optional compute_kernel_config = std::nullopt, - const std::optional core_grid = std::nullopt); + const std::optional core_grid = std::nullopt, + const bool propagate_is_b_batched = false); } // namespace matmul } // namespace operations