Skip to content

Commit

Permalink
#4232: added test for test_raw_memory_pointer. Added memcpy function
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Jan 25, 2024
1 parent 0acac56 commit 6d9ea6a
Show file tree
Hide file tree
Showing 13 changed files with 303 additions and 27 deletions.
1 change: 1 addition & 0 deletions tests/tt_eager/module.mk
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ TT_EAGER_TESTS += \
tests/tt_eager/ops/test_sfpu \
tests/tt_eager/tensors/test_copy_and_move \
tests/tt_eager/tensors/test_host_device_loopback \
tests/tt_eager/tensors/test_raw_host_memory_pointer \
tests/tt_eager/tensors/test_sharded_loopback \
tests/tt_eager/integration_tests/test_bert \

Expand Down
188 changes: 188 additions & 0 deletions tests/tt_eager/tensors/test_raw_host_memory_pointer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <algorithm>
#include <functional>
#include <random>

#include "common/bfloat16.hpp"
#include "common/constants.hpp"
#include "tensor/owned_buffer.hpp"
#include "tensor/owned_buffer_functions.hpp"
#include "tensor/tensor.hpp"
#include "tensor/tensor_impl.hpp"
#include "tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp"
#include "tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_numpy/functions.hpp"

/*
01 import torch
02: import numpy as np
03:
04: device = torch.device("cuda:0")
05:
06: // define tensors on the CPU
07: a_cpu = np.array([[1,2,3,4],[5,6,7,8]], dtype=np.bfloat16)
08:
09: // define tensors on the device with CPU tensors
10: a_dev = torch.from_numpy(a_cpu).to(device)
11:
12: c_dev = torch.sqrt(a_dev)
13:
14: print(c_dev[1][0])
15:
16: d_cpu = np.array([[11,12,13,14],[15,16,17,18]])
17: d_dev = d_cpu.to(device)
18:
19: e_dev = c_dev + d_dev
20: print(e_dev)
*/

namespace numpy {

template <typename DataType>
struct ndarray {
Shape shape;
void* data;

ndarray(Shape shape) : shape(shape), data(malloc(tt::tt_metal::compute_volume(shape) * sizeof(DataType))) {}
~ndarray() { free(data); }

std::size_t size() const { return tt::tt_metal::compute_volume(shape); }
};

} // namespace numpy

void test_raw_host_memory_pointer() {
using tt::tt_metal::BorrowedStorage;
using tt::tt_metal::DataType;
using tt::tt_metal::OwnedStorage;
using tt::tt_metal::Shape;
using tt::tt_metal::Tensor;
using namespace tt::tt_metal::borrowed_buffer;
using namespace tt::tt_metal::owned_buffer;

int device_id = 0;
tt::tt_metal::Device* device = tt::tt_metal::CreateDevice(device_id);

Shape shape = {1, 1, tt::constants::TILE_HEIGHT, tt::constants::TILE_WIDTH};

// Host tensor to print the output
Tensor tensor_for_printing = Tensor(
OwnedStorage{owned_buffer::create<bfloat16>(tt::tt_metal::compute_volume(shape))},
shape,
DataType::BFLOAT16,
Layout::TILE);

/* Borrow Data from Numpy Start */
// Create some
auto a_np_array = numpy::ndarray<bfloat16>(shape);
void* a_np_array_data = a_np_array.data;
auto on_creation_callback = [] {};
auto on_destruction_callback = [] {};
Tensor a_cpu = Tensor(
BorrowedStorage{
borrowed_buffer::Buffer(static_cast<bfloat16*>(a_np_array_data), a_np_array.size()),
on_creation_callback,
on_destruction_callback},
shape,
DataType::BFLOAT16,
Layout::ROW_MAJOR);
/* Borrow Data from Numpy End */

/* Sanity Check Start */

// Test that borrowing numpy array's data works correctly

// Set every value of tt Tensor to the same non-zero number
bfloat16 a_value = 4.0f;

for (auto& element : borrowed_buffer::get_as<bfloat16>(a_cpu)) {
element = a_value;
}

// Check that numpy array's data is now set to that value
for (auto index = 0; index < a_np_array.size(); index++) {
auto a_np_array_element = static_cast<bfloat16*>(a_np_array_data)[index];
TT_ASSERT(a_np_array_element == a_value);
}
/* Sanity Check End */

/* Run and Print Start */
Tensor a_dev = a_cpu.to(device);

Tensor c_dev = tt::tt_metal::sqrt(a_dev);

tt::tt_metal::memcpy(tensor_for_printing, c_dev);

// Check that cpu tensor has correct data
bfloat16 output_value = 1.99219f; // Not exactly 2.0f because of rounding errors
for (auto& element : owned_buffer::get_as<bfloat16>(tensor_for_printing)) {
TT_ASSERT(element == output_value);
}

tensor_for_printing.print();
/* Run and Print End */

/* Alternative Way to Print Start */
// Alternatively, we could allocate memory manually and create Tensors with BorrowedStorage on the fly to print the
// data
void* alternative_tensor_for_printing_data = malloc(tt::tt_metal::compute_volume(shape) * sizeof(bfloat16));
Tensor alternative_tensor_for_printing = Tensor(
BorrowedStorage{
borrowed_buffer::Buffer(
static_cast<bfloat16*>(alternative_tensor_for_printing_data), tt::tt_metal::compute_volume(shape)),
on_creation_callback,
on_destruction_callback},
shape,
DataType::BFLOAT16,
Layout::TILE);
tt::tt_metal::memcpy(alternative_tensor_for_printing, c_dev);
alternative_tensor_for_printing.print();

for (auto& element : borrowed_buffer::get_as<bfloat16>(alternative_tensor_for_printing)) {
TT_ASSERT(element == output_value);
}

free(alternative_tensor_for_printing_data);
/* Alternative Way to Print End */

auto d_np_array = numpy::ndarray<bfloat16>(shape);
void* d_np_array_data = d_np_array.data;
Tensor d_cpu = Tensor(
BorrowedStorage{
borrowed_buffer::Buffer(static_cast<bfloat16*>(d_np_array_data), d_np_array.size()),
on_creation_callback,
on_destruction_callback},
shape,
DataType::BFLOAT16,
Layout::ROW_MAJOR);

bfloat16 d_value = 8.0f;
for (auto& element : borrowed_buffer::get_as<bfloat16>(d_cpu)) {
element = d_value;
}

Tensor d_dev = a_dev;
memcpy(d_dev, d_cpu);

Tensor e_dev = tt::tt_metal::add(c_dev, d_dev);

tt::tt_metal::memcpy(tensor_for_printing, e_dev);
tensor_for_printing.print();

for (auto& element : owned_buffer::get_as<bfloat16>(tensor_for_printing)) {
TT_ASSERT(element == bfloat16(10.0f));
}

TT_FATAL(tt::tt_metal::CloseDevice(device));
}

int main() {
test_raw_host_memory_pointer();
return 0;
}
5 changes: 4 additions & 1 deletion tt_eager/tensor/borrowed_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ struct Buffer {
inline const T* begin() const noexcept { return this->data_ptr_; }
inline const T* end() const noexcept { return this->data_ptr_ + this->size(); }

private:
inline void* data() noexcept { return static_cast<void*>(this->data_ptr_); }
inline const void* data() const noexcept { return static_cast<void*>(this->data_ptr_); }

private:
T* data_ptr_;
std::size_t size_;
};
Expand Down
14 changes: 6 additions & 8 deletions tt_eager/tensor/borrowed_buffer_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,30 @@ template<typename T>
Buffer<T> get_as(Tensor& tensor) {
validate_datatype<T>(tensor);
return std::visit(
[&] (auto&& storage) {
[](auto&& storage) -> Buffer<T> {
using StorageType = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<StorageType, BorrowedStorage>) {
return get_as<T>(storage.buffer);
} else {
TT_THROW("Must be a BorrowedStorage");
TT_THROW("Tensor must have BorrowedStorage");
}
},
tensor.storage()
);
tensor.storage());
}

template<typename T>
const Buffer<T> get_as(const Tensor& tensor) {
validate_datatype<T>(tensor);
return std::visit(
[] (auto&& storage) {
[](auto&& storage) -> Buffer<T> {
using StorageType = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<StorageType, BorrowedStorage>) {
return get_as<T>(storage.buffer);
} else {
TT_THROW("Must be an BorrowedStorage");
TT_THROW("Tensor must have BorrowedStorage");
}
},
tensor.storage()
);
tensor.storage());
}

} // namespace borrowed_buffer
Expand Down
5 changes: 4 additions & 1 deletion tt_eager/tensor/owned_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ struct Buffer {
inline const std::vector<T>& get() const { return *this->shared_vector_; }
inline void reset() { this->shared_vector_.reset(); }

private:
inline void* data() noexcept { return static_cast<void*>(this->pointer_for_faster_access_); }
inline const void* data() const noexcept { return static_cast<void*>(this->pointer_for_faster_access_); }

private:
std::shared_ptr<std::vector<T>> shared_vector_;
T* pointer_for_faster_access_;
std::size_t size_;
Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tensor/owned_buffer_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Buffer<T> get_as(Tensor& tensor) {
if constexpr (std::is_same_v<StorageType, OwnedStorage>) {
return get_as<T>(storage.buffer);
} else {
TT_THROW("Must be an owned storage");
TT_THROW("Tensor must have OwnedStorage");
}
},
tensor.storage()
Expand All @@ -71,7 +71,7 @@ const Buffer<T> get_as(const Tensor& tensor) {
if constexpr (std::is_same_v<StorageType, OwnedStorage>) {
return get_as<T>(storage.buffer);
} else {
TT_THROW("Must be a owned storage");
TT_THROW("Tensor must have OwnedStorage");
}
},
tensor.storage()
Expand Down
30 changes: 28 additions & 2 deletions tt_eager/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,11 @@ void Tensor::print(Layout print_layout, bool pretty_print) const {
std::cout << write_to_string(print_layout, pretty_print);
}

Tensor Tensor::pad(const Shape &output_tensor_shape, const Shape &input_tensor_start, float pad_value) const {
Tensor Tensor::pad(const Shape& output_tensor_shape, const Shape& input_tensor_start, float pad_value) const {
ZoneScoped;
TT_ASSERT(this->storage_type() == StorageType::OWNED or this->storage_type() == StorageType::BORROWED && "Tensor must be on host for padding");
TT_ASSERT(
this->storage_type() == StorageType::OWNED or
this->storage_type() == StorageType::BORROWED && "Tensor must be on host for padding");
TT_ASSERT(this->layout() == Layout::ROW_MAJOR && "Tensor layout must be ROW_MAJOR for padding");

auto input_shape = this->shape();
Expand Down Expand Up @@ -345,6 +347,30 @@ Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layo
return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout);
}

void* get_raw_host_data_ptr(const Tensor& tensor) {
const static std::unordered_map<DataType, std::function<void*(const Tensor&)>> dispatch_map = {
{DataType::BFLOAT16, &tensor_impl::get_raw_host_data_ptr<bfloat16>},
{DataType::FLOAT32, &tensor_impl::get_raw_host_data_ptr<float>},
{DataType::UINT32, &tensor_impl::get_raw_host_data_ptr<uint32_t>},
{DataType::BFLOAT8_B, &tensor_impl::get_raw_host_data_ptr<uint32_t>},
{DataType::UINT16, &tensor_impl::get_raw_host_data_ptr<uint16_t>},
};
return dispatch_map.at(tensor.dtype())(tensor);
}

void memcpy(Tensor& dst, const Tensor& src) {
ZoneScoped;

const static std::unordered_map<DataType, std::function<void(Tensor&, const Tensor&)>> dispatch_map = {
{DataType::BFLOAT16, &tensor_impl::memcpy<bfloat16>},
{DataType::FLOAT32, &tensor_impl::memcpy<float>},
{DataType::UINT32, &tensor_impl::memcpy<uint32_t>},
{DataType::BFLOAT8_B, &tensor_impl::memcpy<uint32_t>},
{DataType::UINT16, &tensor_impl::memcpy<uint16_t>},
};
dispatch_map.at(dst.dtype())(dst, src);
}

} // namespace tt_metal

} // namespace tt
5 changes: 5 additions & 0 deletions tt_eager/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ Tensor create_device_tensor(const Shape& shape, DataType dtype, Layout layout, D

Tensor create_sharded_device_tensor(const Shape& shape, DataType data_type, Layout layout, Device *device, const MemoryConfig& memory_config);

// template<typename Buffer>
// void *get_host_buffer(const Tensor &tensor);
void *get_raw_host_data_ptr(const Tensor &tensor);
void memcpy(Tensor &dst, const Tensor &src);

} // namespace tt_metal

} // namespace tt
Loading

0 comments on commit 6d9ea6a

Please sign in to comment.