Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8112: Require bias for linear #8912

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ttnn/cpp/pybind11/operations/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void py_module(py::module& module) {
"linear",
[](const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<const ttnn::Tensor>& bias,
const ttnn::Tensor& bias,
const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG,
const std::optional<const DataType> dtype = std::nullopt,
const std::optional<const ttnn::MatmulProgramConfig> program_config = std::nullopt,
Expand All @@ -65,7 +65,7 @@ void py_module(py::module& module) {
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("bias") = std::nullopt,
py::arg("bias"),
py::arg("memory_config") = DRAM_MEMORY_CONFIG,
py::arg("dtype") = std::nullopt,
py::arg("program_config") = std::nullopt,
Expand Down
13 changes: 11 additions & 2 deletions ttnn/cpp/ttnn/operations/conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -704,10 +704,19 @@ std::tuple<ttnn::Tensor, uint32_t, uint32_t, ttnn::Tensor, std::optional<ttnn::T
ttnn::operations::core::deallocate(input_tensor_post_tm);
}
}
auto matmul_output = ttnn::operations::matmul::linear(
auto matmul_output = bias_tensor_on_device.has_value() ?
ttnn::operations::matmul::linear(
matmul_input,
weight_tensor_on_device,
bias_tensor_on_device.value(),
matmul_program_config,
conv_out_memory_config,
conv_config.dtype,
conv_config.activation == "" ? std::nullopt : std::optional<std::string>{conv_config.activation},
compute_kernel_config) :
ttnn::operations::matmul::matmul(
matmul_input,
weight_tensor_on_device,
bias_tensor_on_device,
matmul_program_config,
conv_out_memory_config,
conv_config.dtype,
Expand Down
15 changes: 6 additions & 9 deletions ttnn/cpp/ttnn/operations/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ std::optional<UnaryWithParam> get_fused_activation(const std::optional<const std
ttnn::Tensor linear(
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<const ttnn::Tensor>& bias,
const ttnn::Tensor& bias,
const std::optional<const MatmulProgramConfig> program_config,
const ttnn::MemoryConfig& memory_config,
std::optional<const DataType> dtype,
Expand All @@ -129,16 +129,13 @@ ttnn::Tensor linear(
const auto input_tensor_a_4d = ttnn::unsqueeze_to_4D(input_tensor_a);
const auto input_tensor_b_4d = ttnn::unsqueeze_to_4D(input_tensor_b);

std::optional<Tensor> bias_4d = std::nullopt;
const bool has_user_grid = core_grid.has_value();
const bool has_program_config = program_config.has_value();

bool post_process_bias = false;
if (bias.has_value()) {
bias_4d = ttnn::unsqueeze_to_4D(bias.value());
if (!has_program_config && !has_user_grid) {
post_process_bias = true;
}
auto bias_4d = ttnn::unsqueeze_to_4D(bias);
if (!has_program_config && !has_user_grid) {
post_process_bias = true;
}

if (width_a != height_b) {
Expand All @@ -150,11 +147,11 @@ ttnn::Tensor linear(
}

auto output_tensor = tt::operations::primary::matmul(
input_tensor_a_4d, input_tensor_b_4d, post_process_bias ? std::nullopt : bias_4d, program_config, memory_config, dtype, compute_kernel_config, false /*untilize_out*/, user_core_coord, get_fused_activation(activation));
input_tensor_a_4d, input_tensor_b_4d, post_process_bias ? std::nullopt : std::make_optional<const Tensor>(bias_4d), program_config, memory_config, dtype, compute_kernel_config, false /*untilize_out*/, user_core_coord, get_fused_activation(activation));

if (post_process_bias) {
output_tensor = tt::tt_metal::bcast(
output_tensor, bias_4d.value(), tt::tt_metal::BcastOpMath::ADD, tt::tt_metal::BcastOpDim::H, memory_config);
output_tensor, bias_4d, tt::tt_metal::BcastOpMath::ADD, tt::tt_metal::BcastOpDim::H, memory_config);
}

if (activation.has_value() && !has_user_grid) {
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ std::optional<UnaryWithParam> get_fused_activation(const std::optional<const std
ttnn::Tensor linear(
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<const ttnn::Tensor>& bias,
const ttnn::Tensor& bias,
const std::optional<const MatmulProgramConfig> program_config = std::nullopt,
const ttnn::MemoryConfig& memory_config = ttnn::DRAM_MEMORY_CONFIG,
std::optional<const DataType> dtype = std::nullopt,
Expand Down
4 changes: 2 additions & 2 deletions ttnn/ttnn/operations/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def linear(
input_tensor_a: ttnn.Tensor,
input_tensor_b: ttnn.Tensor,
*,
bias: Optional[ttnn.Tensor] = None,
bias: ttnn.Tensor,
memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG,
dtype: Optional[ttnn.DataType] = None,
core_grid: Optional[ttnn.CoreGrid] = None,
Expand All @@ -190,7 +190,7 @@ def linear(
* :attr:`input_tensor_b` (ttnn.Tensor): the second tensor to be multiplied

Keyword Arguments:
* :attr:`bias` (Optional[ttnn.Tensor]): the bias tensor to be added. Defaults to None
* :attr:`bias` ([ttnn.Tensor]): the bias tensor to be added. Required.
* :attr:`memory_config` (ttnn.MemoryConfig): the memory configuration of the output tensor. Defaults to ttnn.DRAM_MEMORY_CONFIG
* :attr:`dtype` (Optional[ttnn.DataType]): the data type of the output tensor. Defaults to None
* :attr:`core_grid` (Optional[ttnn.CoreGrid]): the grid on which to distribute the sharded tensor on (writes to the cores L1s). Defaults to None
Expand Down
Loading