-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#3900: Add prod support for batch and channels #4933
Conversation
|
||
(tt_input, tt_output, torch_input) = get_tensors(input_shape, output_shape, device) | ||
|
||
torch_output = torch.sum(torch_input, dims, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this correct ? Is'nt it suppose to be torch.prod ? https://pytorch.org/docs/stable/generated/torch.prod.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some early comments
c8f391b
to
d198e05
Compare
d198e05
to
8023308
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
8023308
to
3595bfa
Compare
#include "tt_eager/tt_dnn/op_library/prod/kernels/utils.hpp" | ||
|
||
inline uint32_t get_read_tile_id(uint32_t tile_id, uint32_t dim, uint32_t input_tile_offset, uint32_t HtWt) { | ||
return (dim == 0 ) ? (tile_id) : (tile_id / HtWt * input_tile_offset) + (tile_id % HtWt); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, we want to avoid multiplications (especially for GS) and especially avoid divisions since they're not performant on the riscs. You should try and refactor your code to avoid this, especially since you're calling this every loop. Could be refactored to use addition and if check.
if (input_is_dram) { | ||
noc_async_read_tile(read_tile_id, dram_input_addrg, l1_write_addr_in0); | ||
} else { | ||
noc_async_read_tile(read_tile_id, l1_input_addrg, l1_write_addr_in0); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use a compile time arg for the dram flag so you create a single addr gen with the correct template, instead of paying for this cost at runtime in a loop
void fill_cb_with_value(uint32_t cb_id, uint32_t value, int32_t num_of_elems = 1024) { | ||
cb_reserve_back(cb_id, 1); | ||
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id)); | ||
for (int j = 0; j < num_of_elems; j++) { | ||
ptr[j] = uint16_t(value >> 16); | ||
} | ||
cb_push_back(cb_id, 1); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why we use int for num_of_elems/loop?
For pointers to l1 memory you should create them as
volatile tt_l1_ptr std::uint16_t* ptr = (volatile tt_l1_ptr uint16_t*)(get_write_ptr(cb_id));
This is to avoid a potential hang due to HW bug.
Can also store/use the u16 value like how you do it in later code
const auto u16_value = uint16_t(value >> 16);
|
||
// mask_h | ||
// first tile ptr | ||
auto mask_h_ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_mask_h_w)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment on l1 ptr as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate_mask_h_w
function is not used.Hence, I'm removing this function
uint32_t h = 0; | ||
for (; h < mask_h_0; h++) { | ||
mask_h_ptr[h * 16 + w] = u16_one; | ||
} | ||
for (; h < 16; h++) { | ||
mask_h_ptr[h * 16 + w] = u16_zero; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid using multiplies
Tensor prod( | ||
const Tensor &input, | ||
const Tensor &output, | ||
std::vector<int64_t> &dims, | ||
const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function should not take a mem_config since it already takes output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mem_config requires for create_output_tensor function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this is used for intermediate mem config then right? I think if you want to keep this arg you should change the arg name in the binding, since it is currently labelled as output_mem_config, but your output tensor already has a mem_config that is used for the final output.
const auto& input = inputs.at(0); | ||
const auto& output = inputs.at(1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should assert for rank 4 tensors here since you are hardcoding indices like 2, 3 elsewhere.
Otherwise you should update it to support tensors other than rank 4 if possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the rank 4 assert in the code accordingly.
d1fff01
to
9006951
Compare
Hi @tt-aho , |
9006951
to
3da2e49
Compare
ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } | ||
ALWI void REL() { release_dst(tt::DstMode::Half); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Switch to new APIs:
- tile_regs_commit, tile_regs_release
- tile_regs_acquire, tile_regs_wait
Docs: tt_metal/include/compute_kernel_api/reg_api.h
@pytest.mark.parametrize( | ||
"dims", | ||
( | ||
[ | ||
0, | ||
], | ||
[ | ||
1, | ||
], | ||
[ | ||
2, | ||
], | ||
[ | ||
3, | ||
], | ||
), | ||
ids=["0", "1", "2", "3"], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix this
void mask_tile_in_reader(uint32_t l1_addr, uint32_t mask_w = 32, uint32_t mask_h = 32) { | ||
union { | ||
float f; | ||
uint32_t u; | ||
} zero; | ||
zero.f = 0.0f; | ||
auto ptr = reinterpret_cast<uint16_t *>(l1_addr); | ||
for (uint32_t h = 0; h < 16; h++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have to do this in the kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
File removed as HW support has been given using NC support
if (scaler != 0) { | ||
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id_in2)); | ||
for (int j = 0; j < 1024; j++) ptr[j] = uint16_t(0); | ||
|
||
for (int k = 0; k < 4; k++) | ||
for (int j = 0; j < 16; j++) ptr[k * 256 + j] = uint16_t(scaler >> 16); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have to do this in kernel?
Filling CBs is done in the kernel by others including MOREH in recent code.
I don't see anything objectionable @brian Liu ***@***.***>
…On Fri, Feb 9, 2024 at 9:31 AM TT-BrianLiu ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In tt_eager/tt_dnn/op_library/prod/kernels/prod_hw.cpp
<#4933 (comment)>
:
> +ALWI void ACQ() { acquire_dst(tt::DstMode::Half); }
+ALWI void REL() { release_dst(tt::DstMode::Half); }
Switch to new APIs:
- tile_regs_commit, tile_regs_release
- tile_regs_acquire, tile_regs_wait
Docs: tt_metal/include/compute_kernel_api/reg_api.h
------------------------------
In tests/tt_eager/python_api_testing/unit_testing/test_prod.py
<#4933 (comment)>
:
> ***@***.***(
+ "dims",
+ (
+ [
+ 0,
+ ],
+ [
+ 1,
+ ],
+ [
+ 2,
+ ],
+ [
+ 3,
+ ],
+ ),
+ ids=["0", "1", "2", "3"],
+)
fix this
------------------------------
In tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp
<#4933 (comment)>
:
> +void mask_tile_in_reader(uint32_t l1_addr, uint32_t mask_w = 32, uint32_t mask_h = 32) {
+ union {
+ float f;
+ uint32_t u;
+ } zero;
+ zero.f = 0.0f;
+ auto ptr = reinterpret_cast<uint16_t *>(l1_addr);
+ for (uint32_t h = 0; h < 16; h++) {
do we have to do this in the kernel?
------------------------------
In tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp
<#4933 (comment)>
:
> + if (scaler != 0) {
+ auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id_in2));
+ for (int j = 0; j < 1024; j++) ptr[j] = uint16_t(0);
+
+ for (int k = 0; k < 4; k++)
+ for (int j = 0; j < 16; j++) ptr[k * 256 + j] = uint16_t(scaler >> 16);
+ }
Do we have to do this in kernel?
—
Reply to this email directly, view it on GitHub
<#4933 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/BAGOCNFAYHJAL4UNJHOHISTYSZMOFAVCNFSM6AAAAABCKYHGGGVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMYTQNZSHA3TQMJRG4>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
We should look into better ways of lowering constants like this. Convs do this with small sharded tensors. If you want to leave it like this, please document it better and clean it up (for example what are those hardcoded numbers and why are loop indices ints) |
Added the batch and channel support for Prod op.
Working on H & W.