-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#4904: Add support for 1d width sharded LN
Refactored out code for creating tiles for bcast and reduce into common header files
- Loading branch information
Showing
30 changed files
with
1,873 additions
and
1,619 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
44 changes: 44 additions & 0 deletions
44
tt_eager/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "dataflow_api.h" | ||
|
||
// W-bcast scalar | ||
FORCE_INLINE void generate_bcast_col_scalar(const uint32_t cb_id, const uint32_t scalar) { | ||
const uint16_t scalar_val = scalar>>16; | ||
cb_reserve_back(cb_id, 1); | ||
volatile tt_l1_ptr uint16_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint16_t*>(get_write_ptr(cb_id)); | ||
for (int k = 0; k < 4; k+=2) { | ||
uint32_t idx = k << 8; | ||
for (int j = 0; j < 256; j+=16) { | ||
ptr[idx + j] = scalar_val; | ||
} | ||
} | ||
cb_push_back(cb_id, 1); | ||
} | ||
|
||
// H-bcast scalar | ||
FORCE_INLINE void generate_bcast_row_scalar(const uint32_t cb_id, const uint32_t scalar) { | ||
const uint32_t scalar_val = scalar>>16; | ||
cb_reserve_back(cb_id, 1); | ||
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(get_write_ptr(cb_id)); | ||
for (int k = 0; k < 2; ++k) { | ||
uint32_t idx = k << 7; | ||
for (int j = 0; j < 8; ++j) { | ||
ptr[idx + j] = scalar_val; | ||
} | ||
} | ||
cb_push_back(cb_id, 1); | ||
} | ||
|
||
// HW-bcast scalar | ||
FORCE_INLINE void generate_bcast_unary_scalar(const uint32_t cb_id, const uint32_t scalar) { | ||
const uint32_t scalar_val = scalar>>16; | ||
cb_reserve_back(cb_id, 1); | ||
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(get_write_ptr(cb_id)); | ||
ptr[0] = scalar>>16; | ||
cb_push_back(cb_id, 1); | ||
} |
33 changes: 33 additions & 0 deletions
33
tt_eager/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "dataflow_api.h" | ||
|
||
FORCE_INLINE void generate_reduce_scaler(const uint32_t cb_id, const uint32_t scaler) { | ||
cb_reserve_back(cb_id, 1); | ||
|
||
constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE; | ||
uint64_t zeros_noc_addr = get_noc_addr(MEM_ZEROS_BASE); | ||
uint32_t write_addr = get_write_ptr(cb_id); | ||
volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast<volatile tt_l1_ptr uint32_t*>(write_addr); | ||
|
||
// Fill tile with zeros | ||
for (uint32_t i = 0; i < num_zeros_reads; ++i) { | ||
noc_async_read(zeros_noc_addr, write_addr, MEM_ZEROS_SIZE); | ||
write_addr += MEM_ZEROS_SIZE; | ||
} | ||
noc_async_read_barrier(); | ||
|
||
if (scaler != 0) { | ||
for (int k = 0; k < 4; ++k) { | ||
uint32_t idx = k << 7; | ||
for (int j = 0; j < 8; ++j) { | ||
ptr[idx + j] = scaler; | ||
} | ||
} | ||
} | ||
cb_push_back(cb_id, 1); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.