Skip to content

Commit

Permalink
#4489: fixed bugs in the program caching of eltiwse unary and eltwise…
Browse files Browse the repository at this point in the history
… binary. Updated bloom to use L1 memory config
  • Loading branch information
arakhmati committed Jan 24, 2024
1 parent d764d4a commit 052e170
Show file tree
Hide file tree
Showing 16 changed files with 138 additions and 210 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def create_query_key_value(hidden_states, weight, bias, num_heads):
query_key_value = ttnn.linear(
hidden_states, weight, bias=bias, core_grid=(9, 12), memory_config=BLOOM_MEMORY_CONFIG, dtype=BLOOM_DTYPE
)
ttnn.deallocate(hidden_states)
query, key, value = split_query_key_value_and_split_heads(query_key_value, num_heads=num_heads)
ttnn.deallocate(query_key_value)

Expand Down Expand Up @@ -161,8 +162,6 @@ def multi_head_attention(
query_layer, key_layer, value_layer = create_query_key_value(
hidden_states, query_key_value_weight, query_key_value_bias, num_heads=num_heads
)
value_layer = ttnn.reallocate(value_layer)

attention_scores = compute_attention_scores(query_layer, key_layer, alibi)
attention_probs = compute_attention_probs(attention_scores, causal_mask)
context_layer = compute_context_layer(attention_probs, value_layer)
Expand All @@ -186,6 +185,8 @@ def mlp(
memory_config=BLOOM_MEMORY_CONFIG,
dtype=BLOOM_DTYPE,
)
ttnn.deallocate(hidden_states)

ff2_output = ttnn.linear(
ff1_output,
dense_4h_to_h_weight,
Expand All @@ -211,12 +212,11 @@ def bloom(
layout=ttnn.TILE_LAYOUT,
)

# TODO(arakhmati): put hidden_states in L1
hidden_states = ttnn.layer_norm(
inputs_embeds,
weight=parameters.transformer.word_embeddings_layernorm.weight,
bias=parameters.transformer.word_embeddings_layernorm.bias,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
memory_config=BLOOM_MEMORY_CONFIG,
)
ttnn.deallocate(inputs_embeds)

Expand All @@ -238,10 +238,8 @@ def bloom(
layer_parameters.self_attention.dense.bias,
num_heads=num_heads,
)
ttnn.deallocate(normalized_hidden_states)

# TODO(arakhmati): put attention_output in L1
attention_output = ttnn.add(attention_output, hidden_states, memory_config=ttnn.DRAM_MEMORY_CONFIG)
attention_output = ttnn.add(attention_output, hidden_states, memory_config=BLOOM_MEMORY_CONFIG)
ttnn.deallocate(hidden_states)

normalized_attention_output = ttnn.layer_norm(
Expand All @@ -258,10 +256,8 @@ def bloom(
layer_parameters.mlp.dense_4h_to_h.weight,
layer_parameters.mlp.dense_4h_to_h.bias,
)
ttnn.deallocate(normalized_attention_output)

# TODO(arakhmati): put mlp_output in L1
mlp_output = ttnn.add(mlp_output, attention_output, memory_config=ttnn.DRAM_MEMORY_CONFIG)
mlp_output = ttnn.add(mlp_output, attention_output, memory_config=BLOOM_MEMORY_CONFIG)
ttnn.deallocate(attention_output)

hidden_states = mlp_output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_transpose_wh_sharded_program_cache(device, use_program_cache):
mem_config = ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, ttl.tensor.BufferType.L1, input_shard_spec
)

# shape change also changes shard_spec as shard_shape is dependent on input_shape (resulting in CACHE MISS)
transpose(
input_shape,
device,
Expand All @@ -195,7 +195,7 @@ def test_transpose_wh_sharded_program_cache(device, use_program_cache):
input_mem_config=mem_config,
output_mem_config=mem_config,
input_dtype=input_dtype,
expected_program_cache_size=1,
expected_program_cache_size=2,
)

# changing shape
Expand All @@ -222,6 +222,7 @@ def test_transpose_wh_sharded_program_cache(device, use_program_cache):
ttl.tensor.TensorMemoryLayout.HEIGHT_SHARDED, ttl.tensor.BufferType.L1, input_shard_spec
)

# shape change also changes shard_spec as shard_shape is dependent on input_shape (resulting in CACHE MISS)
transpose(
input_shape,
device,
Expand All @@ -230,5 +231,5 @@ def test_transpose_wh_sharded_program_cache(device, use_program_cache):
input_mem_config=mem_config,
output_mem_config=mem_config,
input_dtype=input_dtype,
expected_program_cache_size=1,
expected_program_cache_size=3,
)
7 changes: 3 additions & 4 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ namespace tt {

namespace tt_metal {

Tensor::Tensor(const Storage& storage, const Shape& shape, DataType dtype, Layout layout, std::optional<ShardSpec> shard_spec)
: storage_(storage), shape_(shape), dtype_(dtype), layout_(layout), shard_spec_(shard_spec) {
Tensor::Tensor(const Storage& storage, const Shape& shape, DataType dtype, Layout layout) :
storage_(storage), shape_(shape), dtype_(dtype), layout_(layout) {
std::visit(
[&] (auto&& storage) {
using StorageType = std::decay_t<decltype(storage)>;
Expand All @@ -48,7 +48,6 @@ Tensor::Tensor(const Storage& storage, const Shape& shape, DataType dtype, Layou
);
}


Tensor::~Tensor() {
this->deallocate();
}
Expand Down Expand Up @@ -343,7 +342,7 @@ Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layo
data_type, layout, memory_config,
std::make_optional<ShardSpecBuffer>(shard_spec_buffer)
);
return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, shard_spec);
return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout);
}

} // namespace tt_metal
Expand Down
148 changes: 64 additions & 84 deletions tt_eager/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,93 +29,100 @@ class Tensor {
// ======================================================================================
// Hi Level APIs
// ======================================================================================
Tensor(const Storage& storage, const Shape& shape, DataType dtype, Layout layout, std::optional<ShardSpec> shard_spec = std::nullopt);
Tensor(const Storage &storage, const Shape &shape, DataType dtype, Layout layout);

Tensor(const Tensor &other) = default;
Tensor& operator=(const Tensor &other) = default;
Tensor(const Tensor &other) = default;
Tensor &operator=(const Tensor &other) = default;

Tensor(Tensor &&other) = default;
Tensor& operator=(Tensor &&other) = default;
Tensor(Tensor &&other) = default;
Tensor &operator=(Tensor &&other) = default;

~Tensor();
~Tensor();

void deallocate(bool force=false);
void deallocate(bool force = false);

Tensor to(Device *target_device, const MemoryConfig &mem_config={.memory_layout=tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const;
Tensor to(
Device *target_device,
const MemoryConfig &mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const;

Tensor to(Layout target_layout) const;
Tensor to(Layout target_layout) const;

Tensor pad(const Shape &output_tensor_shape, const Shape &input_tensor_start, float pad_value) const;
Tensor pad(const Shape &output_tensor_shape, const Shape &input_tensor_start, float pad_value) const;

Tensor cpu() const;
Tensor cpu() const;

Tensor cpu_sharded() const;
Tensor cpu_sharded() const;

Tensor unpad(const Shape &output_tensor_start, const Shape &output_tensor_end) const;
Tensor unpad(const Shape &output_tensor_start, const Shape &output_tensor_end) const;

Tensor pad_to_tile(float pad_value) const;
Tensor pad_to_tile(float pad_value) const;

Tensor unpad_from_tile(const Shape &output_tensor_shape) const;
Tensor unpad_from_tile(const Shape &output_tensor_shape) const;

const std::string write_to_string(Layout print_layout = Layout::ROW_MAJOR, bool pretty_print = false) const;
void print(Layout print_layout=Layout::ROW_MAJOR, bool pretty_print=false) const;
const std::string write_to_string(Layout print_layout = Layout::ROW_MAJOR, bool pretty_print = false) const;
void print(Layout print_layout = Layout::ROW_MAJOR, bool pretty_print = false) const;

Tensor extract_shard(const CoreCoord & core) const;
Tensor extract_shard(const uint32_t & core_id) const;
Tensor extract_shard(const CoreCoord &core) const;
Tensor extract_shard(const uint32_t &core_id) const;

// ======================================================================================
// Low Level APIs
// ======================================================================================
Tensor reshape(int N, int C, int H, int W) const;
Tensor reshape(const Shape& new_shape) const;
// ======================================================================================
// Low Level APIs
// ======================================================================================
Tensor reshape(int N, int C, int H, int W) const;
Tensor reshape(const Shape &new_shape) const;

// ======================================================================================
// Getters
// ======================================================================================
const Storage& storage() const;
const Shape& shape() const { return this->shape_; }
DataType dtype() const { return this->dtype_; }
Layout layout() const { return this->layout_; }
const std::optional<ShardSpec>& shard_spec() const { return this->shard_spec_; }
// ======================================================================================
// Getters
// ======================================================================================
const Storage &storage() const;
const Shape &shape() const { return this->shape_; }
DataType dtype() const { return this->dtype_; }
Layout layout() const { return this->layout_; }

// ======================================================================================
// Extra Helper Functions
// ======================================================================================
// ======================================================================================
// Extra Helper Functions
// ======================================================================================

StorageType storage_type() const;
const Shape strides() const;
uint32_t volume() const;
StorageType storage_type() const;
const Shape strides() const;
uint32_t volume() const;

bool is_allocated() const;
bool is_allocated() const;

// TODO(arakhmati): clean up the methods below
Buffer* buffer() const { return std::get<DeviceStorage>(this->storage_).buffer.get(); }
Device *device() const { return this->buffer()->device(); }
const MemoryConfig memory_config() const { return std::get<DeviceStorage>(this->storage_).memory_config(); }
// TODO(arakhmati): clean up the methods below
Buffer *buffer() const { return std::get<DeviceStorage>(this->storage_).buffer.get(); }
Device *device() const { return this->buffer()->device(); }
const MemoryConfig memory_config() const {
return std::visit(
[](const auto &storage) -> MemoryConfig {
using T = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<T, DeviceStorage>) {
return storage.memory_config();
} else {
TT_THROW("MemoryConfig can only be obtained for a tensor with DeviceStorage");
}
},
this->storage_);
}
const std::optional<ShardSpec> shard_spec() const { return this->memory_config().shard_spec; }

const bool is_sharded() const { return this->memory_config().is_sharded(); }
const bool is_sharded() const { return this->memory_config().is_sharded(); }

// Size in bytes of a single element held in tensor
uint32_t element_size() const;
// Size in bytes of a single element held in tensor
uint32_t element_size() const;

static constexpr auto attribute_names = std::make_tuple("storage", "shape", "dtype", "layout", "shard_spec");
const auto attribute_values() const {
return std::make_tuple(
std::cref(this->storage_),
std::cref(this->shape_),
std::cref(this->dtype_),
std::cref(this->layout_),
std::cref(this->shard_spec_));
}
static constexpr auto attribute_names = std::make_tuple("storage", "shape", "dtype", "layout");
const auto attribute_values() const {
return std::make_tuple(
std::cref(this->storage_), std::cref(this->shape_), std::cref(this->dtype_), std::cref(this->layout_));
}

std::vector<uint32_t> host_page_ordering();
private:
Storage storage_;
Shape shape_;
DataType dtype_;
Layout layout_;
std::optional<ShardSpec> shard_spec_;

};


Expand All @@ -127,30 +134,3 @@ Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layo
} // namespace tt_metal

} // namespace tt


namespace std {

template <>
struct hash<tt::tt_metal::Tensor> {
uint64_t operator()(const tt::tt_metal::Tensor &tensor) const {
if (std::holds_alternative<tt::tt_metal::DeviceStorage>(tensor.storage())) {
return tt::stl::hash::hash_objects(0,
typeid(tt::tt_metal::Tensor).hash_code(),
tensor.storage_type(),
std::holds_alternative<tt::tt_metal::DeviceStorage>(tensor.storage()),
tensor.shape(),
tensor.layout(),
tensor.dtype());
}
else {
return tt::stl::hash::hash_objects(0,
typeid(tt::tt_metal::Tensor).hash_code(),
tensor.storage_type(),
tensor.shape(),
tensor.layout(),
tensor.dtype());
}
}
};
}
2 changes: 1 addition & 1 deletion tt_eager/tensor/tensor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ inline Tensor to_device(const Tensor &tensor, Device *target_device, const Memor
data_type, layout, memory_config,
shard_spec_buffer_opt
);
return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, memory_config.shard_spec);
return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout);
}


Expand Down
Loading

0 comments on commit 052e170

Please sign in to comment.