-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#5389: Move ttnn.repeat_interleave to c++ #8961
Conversation
struct RepeatInterleave { | ||
static inline const std::array<TensorSchema, 1> input_tensor_schemas() { | ||
return {ttnn::TensorSchema{ | ||
2, // min rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be 4 according to the actual code
c8d2e27
to
ac9da88
Compare
) | ||
); | ||
|
||
// tests/ttnn/unit_tests/operations/test_repeat_interleave.py proves that it should work over dim 1 too |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Who can help to understand why this fails?
I suspect that the way I compare with expected result is wrong due to tiling or interleaving or..
namespace data_movement { | ||
namespace test { | ||
|
||
void run_repeat_interleave_test(tt::tt_metal::Device* device, const uint32_t repeats, const uint32_t dim) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the best way to write op test. @tt-asaigal directed to tests/tt_eager/ops/test_bmm_op.cpp
, but I wanted to explore this path to better understand whats going on under the hood
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @eyonland had an template example of a ttnn
unit test
@@ -26,4 +29,20 @@ class TTNNFixture : public ::testing::Test { | |||
|
|||
void TearDown() override { tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(false); } | |||
}; | |||
|
|||
class TTNNFixtureWithDevice : public TTNNFixture { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this helps to make sure that the device is properly closed even if the test throws an unhandled exception.
module, | ||
ttnn::repeat_interleave, | ||
R"doc( | ||
repeat_interleave(input_tensor: ttnn.Tensor, repeats : int, dim: int = 0) -> ttnn.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Important change to the inteface - repeats
only accept int
.
Before this change, the interface lies like it can support Tensor or Int, but in reality it only supported Int
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html
We should probably flag this in a separate issue to note the difference relative to torch
ac9da88
to
5b8752c
Compare
5b8752c
to
3ea17c1
Compare
What happens
Moving ttnn.repeat_interleave to C++ as a part of #5389
Todo: