-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] Use TensorLayout in Tensor #15028
base: main
Are you sure you want to change the base?
Conversation
inline ttnn::Shape shape() const { return this->tensor_attributes->tensor_spec.compute_shape(); }; | ||
inline DataType dtype() const { return this->tensor_attributes->tensor_spec.tensor_layout().get_data_type(); }; | ||
inline Layout layout() const { return this->tensor_attributes->tensor_spec.tensor_layout().get_layout(); }; | ||
inline Tile tile() const { return this->tensor_attributes->tensor_spec.tensor_layout().get_page_config().get_tile(); }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets discuss this. ideally we should not calculate these all the time, they are not changing dynamically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only calculated one out of theses getters is shape()
. I'm planning to look into how often each compute_*
method is called and cache some of them.
for (const auto& output_spec : output_specs) { | ||
output_tensors.emplace_back(create_device_tensor(output_spec, device)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unexpectedly nice
})); | ||
|
||
}); | ||
weight_tensor_ = ttnn::reshape(weight_tensor_, target_shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mywoodstock please take a look
@@ -99,8 +100,7 @@ ttnn::Tensor RepeatOperation::invoke( | |||
|
|||
auto padded_to_tiled_shape = ttnn::Shape(sliced_logical_shape.view(), | |||
tiled_output.get_padded_shape().view()); | |||
tiled_output.set_shape(padded_to_tiled_shape); | |||
return tiled_output; | |||
return ttnn::reshape(tiled_output, padded_to_tiled_shape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ntarafdar , please take a look
@@ -920,8 +920,8 @@ Matmul create_matmul_struct( | |||
bool broadcast_batch = | |||
parameters.bcast_batch.value_or(get_broadcast_batch(input_tensor_a, input_tensor_b, parameters.program_config)); | |||
TT_FATAL(!(has_user_grid && has_program_config), "Cannot use both user core grid/coordinates and a program config"); | |||
const auto& in0_tile = input_tensor_a.get_tile(); | |||
const auto& in1_tile = input_tensor_b.get_tile(); | |||
auto in0_tile = input_tensor_a.get_tile(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you think about this change?
output_tensor->tensor_attributes->dtype = local_tensor->tensor_attributes->dtype; | ||
output_tensor->tensor_attributes->layout = local_tensor->tensor_attributes->layout; | ||
output_tensor->tensor_attributes->tile = local_tensor->tensor_attributes->tile; | ||
output_tensor->set_tensor_spec(local_tensor->tensor_spec()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a good public method, would like to discuss this.
{ | ||
if(std::holds_alternative<TilePageConfig>(config_)) { | ||
return std::get<TilePageConfig>(config_).get_tile(); | ||
} | ||
|
||
return std::nullopt; | ||
return Tile(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would like to discuss why
@@ -18,7 +18,7 @@ size_t element_size_bytes(DataType dtype) { | |||
case DataType::UINT8: return sizeof(uint8_t); | |||
case DataType::BFLOAT8_B: | |||
case DataType::BFLOAT4_B: | |||
TT_THROW("element_size_bytes() should not be used for BFLOAT8_B and BFLOAT4_B types becaues of how they are packed"); | |||
return sizeof(float); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would like to discuss why
@@ -36,6 +36,12 @@ Alignment legacyShapeToAlignment(const ttnn::Shape& shape) { | |||
values[i] = legacy_padded_shape[i] * values[i + 1]; | |||
} | |||
|
|||
for (auto& value : values) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will this conflict with @TT-BrianLiu 's , change?
@@ -51,6 +51,15 @@ class TensorLayout { | |||
// H is all dimensions except W multiplied and aligned to tile and shard height | |||
Size compute_physical_shape(const ttnn::SimpleShape& shape) const; | |||
|
|||
void set_memory_config(MemoryConfig memory_config) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't want it to be mutable like this. lets discuss
|
||
namespace tt::tt_metal { | ||
|
||
class TensorSpec final { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbd, discuss next steps with caching some of the values
Ticket
Link to Github Issue
Problem description
Provide context for the problem.
What's changed
Describe the approach used to solve the problem.
Summarize the changes made and its impact.
Checklist