-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7477d8f
commit 7b9aae9
Showing
8 changed files
with
294 additions
and
194 deletions.
There are no files selected for viewing
170 changes: 170 additions & 0 deletions
170
tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include <stdint.h> | ||
|
||
#include "dataflow_api.h" | ||
|
||
void fill_cb_with_value(uint32_t cb_id, uint32_t value) { | ||
cb_reserve_back(cb_id, 1); | ||
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id)); | ||
for (int j = 0; j < 1024; j++) { | ||
ptr[j] = uint16_t(value >> 16); | ||
} | ||
cb_push_back(cb_id, 1); | ||
} | ||
|
||
void generate_mask_h_w(uint32_t cb_mask_h_w, uint32_t mask_h, uint32_t mask_w, uint32_t single_tile_size = 2048) { | ||
union { | ||
float f; | ||
uint32_t u; | ||
} one; | ||
one.f = 1.0f; | ||
union { | ||
float f; | ||
uint32_t u; | ||
} zero; | ||
zero.f = 0.0f; | ||
|
||
const auto u16_one = uint16_t(one.u >> 16); | ||
const auto u16_zero = uint16_t(zero.u >> 16); | ||
|
||
cb_reserve_back(cb_mask_h_w, 2); | ||
|
||
// mask_h | ||
// first tile ptr | ||
auto mask_h_ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_mask_h_w)); | ||
for (uint32_t w = 0; w < 16; w++) { | ||
// sub tile 0 | ||
{ | ||
uint32_t mask_h_0 = mask_h; | ||
if (mask_h_0 >= 16) { | ||
mask_h_0 = 16; | ||
} | ||
uint32_t h = 0; | ||
for (; h < mask_h_0; h++) { | ||
mask_h_ptr[h * 16 + w] = u16_one; | ||
} | ||
for (; h < 16; h++) { | ||
mask_h_ptr[h * 16 + w] = u16_zero; | ||
} | ||
} | ||
|
||
// sub tile 1 | ||
{ | ||
uint32_t mask_h_0 = mask_h; | ||
if (mask_h_0 >= 16) { | ||
mask_h_0 = 16; | ||
} | ||
uint32_t h = 0; | ||
for (; h < mask_h_0; h++) { | ||
mask_h_ptr[h * 16 + w + 256] = u16_one; | ||
} | ||
for (; h < 16; h++) { | ||
mask_h_ptr[h * 16 + w + 256] = u16_zero; | ||
} | ||
} | ||
|
||
// sub tile 2 | ||
{ | ||
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; | ||
uint32_t h = 0; | ||
for (; h < mask_h_1; h++) { | ||
mask_h_ptr[h * 16 + w + 512] = u16_one; | ||
} | ||
for (; h < 16; h++) { | ||
mask_h_ptr[h * 16 + w + 512] = u16_zero; | ||
} | ||
} | ||
|
||
// sub tile 3 | ||
{ | ||
uint32_t mask_h_1 = (mask_h < 16) ? 0 : mask_h - 16; | ||
uint32_t h = 0; | ||
for (; h < mask_h_1; h++) { | ||
mask_h_ptr[h * 16 + w + 768] = u16_one; | ||
} | ||
for (; h < 16; h++) { | ||
mask_h_ptr[h * 16 + w + 768] = u16_zero; | ||
} | ||
} | ||
} | ||
|
||
// mask_w | ||
// second tile ptr | ||
auto mask_w_ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_mask_h_w) + single_tile_size); | ||
for (uint32_t h = 0; h < 16; h++) { | ||
// sub tile 0 | ||
{ | ||
uint32_t mask_w_0 = mask_w; | ||
if (mask_w_0 >= 16) { | ||
mask_w_0 = 16; | ||
} | ||
uint32_t w = 0; | ||
for (; w < mask_w_0; w++) { | ||
mask_w_ptr[h * 16 + w] = u16_one; | ||
} | ||
for (; w < 16; w++) { | ||
mask_w_ptr[h * 16 + w] = u16_zero; | ||
} | ||
} | ||
|
||
// sub tile 1 | ||
{ | ||
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; | ||
uint32_t w = 0; | ||
for (; w < mask_w_1; w++) { | ||
mask_w_ptr[h * 16 + w + 256] = u16_one; | ||
} | ||
for (; w < 16; w++) { | ||
mask_w_ptr[h * 16 + w + 256] = u16_zero; | ||
} | ||
} | ||
|
||
// sub tile 2 | ||
{ | ||
uint32_t mask_w_0 = mask_w; | ||
if (mask_w_0 >= 16) { | ||
mask_w_0 = 16; | ||
} | ||
uint32_t w = 0; | ||
for (; w < mask_w_0; w++) { | ||
mask_w_ptr[h * 16 + w + 512] = u16_one; | ||
} | ||
for (; w < 16; w++) { | ||
mask_w_ptr[h * 16 + w + 512] = u16_zero; | ||
} | ||
} | ||
|
||
// sub tile 3 | ||
{ | ||
uint32_t mask_w_1 = (mask_w < 16) ? 0 : mask_w - 16; | ||
uint32_t w = 0; | ||
for (; w < mask_w_1; w++) { | ||
mask_w_ptr[h * 16 + w + 768] = u16_one; | ||
} | ||
for (; w < 16; w++) { | ||
mask_w_ptr[h * 16 + w + 768] = u16_zero; | ||
} | ||
} | ||
} | ||
|
||
cb_push_back(cb_mask_h_w, 2); | ||
} | ||
|
||
void generate_mask_h_w_if_needed(uint32_t cb_mask_h_w, uint32_t origin_h, uint32_t origin_w) { | ||
constexpr uint32_t TILE_H = 32; | ||
constexpr uint32_t TILE_W = 32; | ||
|
||
const bool do_mask_h = (origin_h % TILE_H) != 0; | ||
const uint32_t mask_h = do_mask_h ? (origin_h % TILE_H) : TILE_H; | ||
|
||
const bool do_mask_w = (origin_w % TILE_W) != 0; | ||
const uint32_t mask_w = do_mask_w ? (origin_w % TILE_W) : TILE_W; | ||
|
||
if (do_mask_h || do_mask_w) { | ||
const uint32_t mask_tile_bytes = get_tile_size(cb_mask_h_w); | ||
generate_mask_h_w(cb_mask_h_w, mask_h, mask_w, mask_tile_bytes); | ||
} | ||
} |
107 changes: 107 additions & 0 deletions
107
tt_eager/tt_dnn/op_library/moreh_clip_grad_norm/kernel_utils/common_ckernels.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
#include <cstdint> | ||
|
||
#include "compute_kernel_api.h" | ||
#include "compute_kernel_api/bcast.h" | ||
#include "compute_kernel_api/eltwise_binary.h" | ||
#include "compute_kernel_api/eltwise_unary/exp.h" | ||
#include "compute_kernel_api/eltwise_unary/recip.h" | ||
#include "compute_kernel_api/mask.h" | ||
#include "compute_kernel_api/reduce.h" | ||
#include "compute_kernel_api/tile_move_copy.h" | ||
|
||
ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } | ||
ALWI void REL() { release_dst(tt::DstMode::Half); } | ||
|
||
namespace ckernel { | ||
|
||
ALWI void power_tile_to_cb( | ||
std::uint8_t cb_x, | ||
std::uint8_t cb_xpow, | ||
std::uint8_t cb_logx, | ||
std::uint8_t cb_decimal, | ||
std::uint8_t cb_exp_lxmd, | ||
std::uint8_t cb_correct_xpow, | ||
uint32_t p, | ||
bool p_is_negative) { | ||
constexpr uint32_t onetile = 1; | ||
constexpr uint32_t dst0 = 0; | ||
|
||
// x^p | ||
ACQ(); | ||
cb_wait_front(cb_x, onetile); | ||
cb_reserve_back(cb_xpow, onetile); | ||
|
||
copy_tile_init(); | ||
copy_tile(cb_x, 0, dst0); | ||
|
||
power_tile_init(); | ||
power_tile(dst0, p); | ||
|
||
if (p_is_negative) { | ||
recip_tile_init(); | ||
recip_tile(dst0); | ||
} | ||
|
||
pack_tile(dst0, cb_xpow); | ||
|
||
cb_push_back(cb_xpow, onetile); | ||
REL(); | ||
// We don't pop cb_x here. | ||
|
||
// log(x) | ||
ACQ(); | ||
cb_reserve_back(cb_logx, onetile); | ||
|
||
copy_tile_init(); | ||
copy_tile(cb_x, 0, dst0); | ||
|
||
log_tile_init(); | ||
log_tile(dst0); | ||
|
||
pack_tile(dst0, cb_logx); | ||
|
||
cb_pop_front(cb_x, onetile); | ||
cb_push_back(cb_logx, onetile); | ||
REL(); | ||
|
||
// exp(log(x) * decimal) | ||
ACQ(); | ||
cb_wait_front(cb_logx, onetile); | ||
cb_reserve_back(cb_exp_lxmd, onetile); | ||
|
||
mul_tiles_init(); | ||
mul_tiles(cb_logx, cb_decimal, 0, 0, dst0); | ||
|
||
exp_tile_init(); | ||
exp_tile(dst0); | ||
|
||
pack_tile(dst0, cb_exp_lxmd); | ||
|
||
cb_pop_front(cb_logx, onetile); | ||
cb_push_back(cb_exp_lxmd, onetile); | ||
REL(); | ||
|
||
// x^p * exp(log(x) * decimal)(==(x + decimal)^p) | ||
ACQ(); | ||
cb_wait_front(cb_xpow, onetile); | ||
cb_wait_front(cb_exp_lxmd, onetile); | ||
cb_reserve_back(cb_correct_xpow, onetile); | ||
|
||
mul_tiles_init(); | ||
mul_tiles(cb_xpow, cb_exp_lxmd, 0, 0, dst0); | ||
|
||
pack_tile(dst0, cb_correct_xpow); | ||
|
||
cb_pop_front(cb_xpow, onetile); | ||
cb_pop_front(cb_exp_lxmd, onetile); | ||
cb_push_back(cb_correct_xpow, onetile); | ||
REL(); | ||
} | ||
|
||
} // namespace ckernel |
Oops, something went wrong.