Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#3900: Add prod support for batch and channels #4933

Closed
wants to merge 2 commits into from

Conversation

ruthreshx
Copy link
Contributor

Added the batch and channel support for Prod op.

Working on H & W.


(tt_input, tt_output, torch_input) = get_tensors(input_shape, output_shape, device)

torch_output = torch.sum(torch_input, dims, True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this correct ? Is'nt it suppose to be torch.prod ? https://pytorch.org/docs/stable/generated/torch.prod.html

Copy link
Contributor

@muthutt muthutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some early comments

Copy link
Contributor

@muthutt muthutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

#include "tt_eager/tt_dnn/op_library/prod/kernels/utils.hpp"

inline uint32_t get_read_tile_id(uint32_t tile_id, uint32_t dim, uint32_t input_tile_offset, uint32_t HtWt) {
return (dim == 0 ) ? (tile_id) : (tile_id / HtWt * input_tile_offset) + (tile_id % HtWt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, we want to avoid multiplications (especially for GS) and especially avoid divisions since they're not performant on the riscs. You should try and refactor your code to avoid this, especially since you're calling this every loop. Could be refactored to use addition and if check.

Comment on lines +49 to +53
if (input_is_dram) {
noc_async_read_tile(read_tile_id, dram_input_addrg, l1_write_addr_in0);
} else {
noc_async_read_tile(read_tile_id, l1_input_addrg, l1_write_addr_in0);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would use a compile time arg for the dram flag so you create a single addr gen with the correct template, instead of paying for this cost at runtime in a loop

Comment on lines +9 to +16
void fill_cb_with_value(uint32_t cb_id, uint32_t value, int32_t num_of_elems = 1024) {
cb_reserve_back(cb_id, 1);
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id));
for (int j = 0; j < num_of_elems; j++) {
ptr[j] = uint16_t(value >> 16);
}
cb_push_back(cb_id, 1);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we use int for num_of_elems/loop?

For pointers to l1 memory you should create them as
volatile tt_l1_ptr std::uint16_t* ptr = (volatile tt_l1_ptr uint16_t*)(get_write_ptr(cb_id));
This is to avoid a potential hang due to HW bug.

Can also store/use the u16 value like how you do it in later code
const auto u16_value = uint16_t(value >> 16);


// mask_h
// first tile ptr
auto mask_h_ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_mask_h_w));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment on l1 ptr as above

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate_mask_h_w function is not used.Hence, I'm removing this function

Comment on lines +45 to +51
uint32_t h = 0;
for (; h < mask_h_0; h++) {
mask_h_ptr[h * 16 + w] = u16_one;
}
for (; h < 16; h++) {
mask_h_ptr[h * 16 + w] = u16_zero;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using multiplies

tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp Outdated Show resolved Hide resolved
Comment on lines +43 to +49
Tensor prod(
const Tensor &input,
const Tensor &output,
std::vector<int64_t> &dims,
const MemoryConfig &mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should not take a mem_config since it already takes output

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mem_config requires for create_output_tensor function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is used for intermediate mem config then right? I think if you want to keep this arg you should change the arg name in the binding, since it is currently labelled as output_mem_config, but your output tensor already has a mem_config that is used for the final output.

tt_eager/tt_dnn/op_library/prod/prod_nc_op.cpp Outdated Show resolved Hide resolved
Comment on lines +21 to +22
const auto& input = inputs.at(0);
const auto& output = inputs.at(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should assert for rank 4 tensors here since you are hardcoding indices like 2, 3 elsewhere.
Otherwise you should update it to support tensors other than rank 4 if possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the rank 4 assert in the code accordingly.

@ruthreshx ruthreshx force-pushed the ruthresh/prod_op branch 2 times, most recently from d1fff01 to 9006951 Compare January 31, 2024 08:37
@ruthreshx
Copy link
Contributor Author

Hi @tt-aho ,
Addressed the comments.
Reg the L1 Ptr and avoid using the multiples,
I took reference from moreh sum, to added the support for N & C to implement the prod.
I just replicate the same thing for the prod as well.

Comment on lines +11 to +12
ALWI void ACQ() { acquire_dst(tt::DstMode::Half); }
ALWI void REL() { release_dst(tt::DstMode::Half); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switch to new APIs:

  • tile_regs_commit, tile_regs_release
  • tile_regs_acquire, tile_regs_wait

Docs: tt_metal/include/compute_kernel_api/reg_api.h

Comment on lines +47 to +64
@pytest.mark.parametrize(
"dims",
(
[
0,
],
[
1,
],
[
2,
],
[
3,
],
),
ids=["0", "1", "2", "3"],
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix this

Comment on lines +9 to +16
void mask_tile_in_reader(uint32_t l1_addr, uint32_t mask_w = 32, uint32_t mask_h = 32) {
union {
float f;
uint32_t u;
} zero;
zero.f = 0.0f;
auto ptr = reinterpret_cast<uint16_t *>(l1_addr);
for (uint32_t h = 0; h < 16; h++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to do this in the kernel?

Copy link
Contributor

@VirdhatchaniKN VirdhatchaniKN Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File removed as HW support has been given using NC support

Comment on lines +73 to +79
if (scaler != 0) {
auto ptr = reinterpret_cast<uint16_t *>(get_write_ptr(cb_id_in2));
for (int j = 0; j < 1024; j++) ptr[j] = uint16_t(0);

for (int k = 0; k < 4; k++)
for (int j = 0; j < 16; j++) ptr[k * 256 + j] = uint16_t(scaler >> 16);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to do this in kernel?

@muthutt
Copy link
Contributor

muthutt commented Feb 9, 2024 via email

@TT-BrianLiu
Copy link
Contributor

Filling CBs is done in the kernel by others including MOREH in recent code. I don't see anything objectionable @brian Liu @.>

On Fri, Feb 9, 2024 at 9:31 AM TT-BrianLiu @.
> wrote: @.**** commented on this pull request. ------------------------------ In tt_eager/tt_dnn/op_library/prod/kernels/prod_hw.cpp <#4933 (comment)> : > +ALWI void ACQ() { acquire_dst(tt::DstMode::Half); } +ALWI void REL() { release_dst(tt::DstMode::Half); } Switch to new APIs: - tile_regs_commit, tile_regs_release - tile_regs_acquire, tile_regs_wait Docs: tt_metal/include/compute_kernel_api/reg_api.h ------------------------------ In tests/tt_eager/python_api_testing/unit_testing/test_prod.py <#4933 (comment)> : > @.**( + "dims", + ( + [ + 0, + ], + [ + 1, + ], + [ + 2, + ], + [ + 3, + ], + ), + ids=["0", "1", "2", "3"], +) fix this ------------------------------ In tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp <#4933 (comment)> : > +void mask_tile_in_reader(uint32_t l1_addr, uint32_t mask_w = 32, uint32_t mask_h = 32) { + union { + float f; + uint32_t u; + } zero; + zero.f = 0.0f; + auto ptr = reinterpret_cast<uint16_t >(l1_addr); + for (uint32_t h = 0; h < 16; h++) { do we have to do this in the kernel? ------------------------------ In tt_eager/tt_dnn/op_library/prod/kernels/reader_prod_hw.cpp <#4933 (comment)> : > + if (scaler != 0) { + auto ptr = reinterpret_cast<uint16_t >(get_write_ptr(cb_id_in2)); + for (int j = 0; j < 1024; j++) ptr[j] = uint16_t(0); + + for (int k = 0; k < 4; k++) + for (int j = 0; j < 16; j++) ptr[k * 256 + j] = uint16_t(scaler >> 16); + } Do we have to do this in kernel? — Reply to this email directly, view it on GitHub <#4933 (review)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/BAGOCNFAYHJAL4UNJHOHISTYSZMOFAVCNFSM6AAAAABCKYHGGGVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMYTQNZSHA3TQMJRG4 . You are receiving this because you commented.Message ID: @.>

We should look into better ways of lowering constants like this. Convs do this with small sharded tensors. If you want to leave it like this, please document it better and clean it up (for example what are those hardcoded numbers and why are loop indices ints)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants