Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Use TensorLayout in Tensor #15028

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from
Draft

Conversation

sminakov-tt
Copy link
Contributor

Ticket

Link to Github Issue

Problem description

Provide context for the problem.

What's changed

Describe the approach used to solve the problem.
Summarize the changes made and its impact.

Checklist

  • Post commit CI passes
  • Blackhole Post commit (if applicable)
  • Model regression CI testing passes (if applicable)
  • Device performance regression CI testing passes (if applicable)
  • New/Existing tests provide coverage for changes

Comment on lines +205 to +208
inline ttnn::Shape shape() const { return this->tensor_attributes->tensor_spec.compute_shape(); };
inline DataType dtype() const { return this->tensor_attributes->tensor_spec.tensor_layout().get_data_type(); };
inline Layout layout() const { return this->tensor_attributes->tensor_spec.tensor_layout().get_layout(); };
inline Tile tile() const { return this->tensor_attributes->tensor_spec.tensor_layout().get_page_config().get_tile(); };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets discuss this. ideally we should not calculate these all the time, they are not changing dynamically

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only calculated one out of theses getters is shape(). I'm planning to look into how often each compute_* method is called and cache some of them.

Comment on lines +416 to +417
for (const auto& output_spec : output_specs) {
output_tensors.emplace_back(create_device_tensor(output_spec, device));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unexpectedly nice

}));

});
weight_tensor_ = ttnn::reshape(weight_tensor_, target_shape);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mywoodstock please take a look

@@ -99,8 +100,7 @@ ttnn::Tensor RepeatOperation::invoke(

auto padded_to_tiled_shape = ttnn::Shape(sliced_logical_shape.view(),
tiled_output.get_padded_shape().view());
tiled_output.set_shape(padded_to_tiled_shape);
return tiled_output;
return ttnn::reshape(tiled_output, padded_to_tiled_shape);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ntarafdar , please take a look

@@ -920,8 +920,8 @@ Matmul create_matmul_struct(
bool broadcast_batch =
parameters.bcast_batch.value_or(get_broadcast_batch(input_tensor_a, input_tensor_b, parameters.program_config));
TT_FATAL(!(has_user_grid && has_program_config), "Cannot use both user core grid/coordinates and a program config");
const auto& in0_tile = input_tensor_a.get_tile();
const auto& in1_tile = input_tensor_b.get_tile();
auto in0_tile = input_tensor_a.get_tile();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about this change?

output_tensor->tensor_attributes->dtype = local_tensor->tensor_attributes->dtype;
output_tensor->tensor_attributes->layout = local_tensor->tensor_attributes->layout;
output_tensor->tensor_attributes->tile = local_tensor->tensor_attributes->tile;
output_tensor->set_tensor_spec(local_tensor->tensor_spec());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a good public method, would like to discuss this.

{
if(std::holds_alternative<TilePageConfig>(config_)) {
return std::get<TilePageConfig>(config_).get_tile();
}

return std::nullopt;
return Tile();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would like to discuss why

@@ -18,7 +18,7 @@ size_t element_size_bytes(DataType dtype) {
case DataType::UINT8: return sizeof(uint8_t);
case DataType::BFLOAT8_B:
case DataType::BFLOAT4_B:
TT_THROW("element_size_bytes() should not be used for BFLOAT8_B and BFLOAT4_B types becaues of how they are packed");
return sizeof(float);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would like to discuss why

@@ -36,6 +36,12 @@ Alignment legacyShapeToAlignment(const ttnn::Shape& shape) {
values[i] = legacy_padded_shape[i] * values[i + 1];
}

for (auto& value : values) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will this conflict with @TT-BrianLiu 's , change?

@@ -51,6 +51,15 @@ class TensorLayout {
// H is all dimensions except W multiplied and aligned to tile and shard height
Size compute_physical_shape(const ttnn::SimpleShape& shape) const;

void set_memory_config(MemoryConfig memory_config) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't want it to be mutable like this. lets discuss


namespace tt::tt_metal {

class TensorSpec final {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbd, discuss next steps with caching some of the values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants