Skip to content

Commit

Permalink
More work
Browse files Browse the repository at this point in the history
  • Loading branch information
sagarwalTT committed Nov 14, 2024
1 parent c320040 commit 384da63
Show file tree
Hide file tree
Showing 25 changed files with 113 additions and 282 deletions.
2 changes: 1 addition & 1 deletion tests/scripts/run_cpp_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fi

kernel_path="/tmp/kernels"
mkdir -p $kernel_path
TT_METAL_KERNEL_PATH=$kernel_path ./build/test/tt_metal/test_kernel_path_env_var
TT_METAL_KERNEL_PATH=$kernel_path ./build/test/tt_metal/ --gtest_filter=CompileProgramWithKernelPathEnvVarFixture.*
rm -rf $kernel_path

if [[ ! -z "$TT_METAL_SLOW_DISPATCH_MODE" ]]; then
Expand Down
2 changes: 0 additions & 2 deletions tests/tt_metal/tt_metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ set(TT_METAL_TESTS_SRCS
test_core_range_set.cpp
test_compile_sets_kernel_binaries.cpp
test_compile_program.cpp
test_kernel_path_env_var.cpp
test_clean_init.cpp
test_create_kernel_from_string.cpp
)

foreach(TEST_SRC ${TT_METAL_TESTS_SRCS})
Expand Down
1 change: 1 addition & 0 deletions tests/tt_metal/tt_metal/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ target_include_directories(
${UMD_HOME}
${PROJECT_SOURCE_DIR}
${PROJECT_SOURCE_DIR}/tt_metal
${PROJECT_SOURCE_DIR}/tt_metal/common
${PROJECT_SOURCE_DIR}/tests
${PROJECT_SOURCE_DIR}/tests/tt_metal/tt_metal/common
${CMAKE_CURRENT_SOURCE_DIR}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,12 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>

#include <exception>
#include <filesystem>
#pragma once

#include "assert.hpp"
#include "core_coord.hpp"
#include "detail/tt_metal.hpp"
#include <gtest/gtest.h>
#include "host_api.hpp"
#include "impl/kernels/data_types.hpp"
#include "impl/program/program.hpp"
#include "llrt/rtoptions.hpp"
#include "tt_cluster_descriptor_types.h"

using namespace tt;
using namespace tt::tt_metal;
using namespace tt::llrt;

class CompileProgramWithKernelPathEnvVarFixture : public ::testing::Test {
protected:
Expand All @@ -42,18 +31,18 @@ class CompileProgramWithKernelPathEnvVarFixture : public ::testing::Test {
}

void setup_kernel_dir(const string &orig_kernel_file, const string &new_kernel_file) {
const string &kernel_dir = OptionsG.get_kernel_dir();
const string &kernel_dir = llrt::OptionsG.get_kernel_dir();
const std::filesystem::path &kernel_file_path_under_kernel_dir(kernel_dir + new_kernel_file);
const std::filesystem::path &dirs_under_kernel_dir = kernel_file_path_under_kernel_dir.parent_path();
std::filesystem::create_directories(dirs_under_kernel_dir);

const string &metal_root = OptionsG.get_root_dir();
const string &metal_root = llrt::OptionsG.get_root_dir();
const std::filesystem::path &kernel_file_path_under_metal_root(metal_root + orig_kernel_file);
std::filesystem::copy(kernel_file_path_under_metal_root, kernel_file_path_under_kernel_dir);
}

void cleanup_kernel_dir() {
const string &kernel_dir = OptionsG.get_kernel_dir();
const string &kernel_dir = llrt::OptionsG.get_kernel_dir();
for (const std::filesystem::directory_entry &entry : std::filesystem::directory_iterator(kernel_dir)) {
std::filesystem::remove_all(entry);
}
Expand All @@ -69,10 +58,10 @@ class CompileProgramWithKernelPathEnvVarFixture : public ::testing::Test {
}

void validate_env_vars_are_set() {
if (!OptionsG.is_root_dir_specified()) {
if (!llrt::OptionsG.is_root_dir_specified()) {
GTEST_SKIP() << "Skipping test: TT_METAL_HOME must be set";
}
if (!OptionsG.is_kernel_dir_specified()) {
if (!llrt::OptionsG.is_kernel_dir_specified()) {
GTEST_SKIP() << "Skipping test: TT_METAL_KERNEL_PATH must be set";
}
}
Expand Down Expand Up @@ -103,32 +92,3 @@ class CompileProgramWithKernelPathEnvVarFixture : public ::testing::Test {
return std::filesystem::is_empty(file_path);
}
};

TEST_F(CompileProgramWithKernelPathEnvVarFixture, TensixKernelUnderMetalRootDir) {
const string &kernel_file = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_unary_push_4.cpp";
create_kernel(kernel_file);
detail::CompileProgram(this->device_, this->program_);
}

TEST_F(CompileProgramWithKernelPathEnvVarFixture, TensixKernelUnderKernelRootDir) {
const string &orig_kernel_file = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_unary_push_4.cpp";
const string &new_kernel_file = "tests/tt_metal/tt_metal/test_kernels/dataflow/new_kernel.cpp";
this->setup_kernel_dir(orig_kernel_file, new_kernel_file);
this->create_kernel(new_kernel_file);
detail::CompileProgram(this->device_, this->program_);
this->cleanup_kernel_dir();
}

TEST_F(CompileProgramWithKernelPathEnvVarFixture, TensixKernelUnderMetalRootDirAndKernelRootDir) {
const string &kernel_file = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_unary_push_4.cpp";
this->setup_kernel_dir(kernel_file, kernel_file);
this->create_kernel(kernel_file);
detail::CompileProgram(this->device_, this->program_);
this->cleanup_kernel_dir();
}

TEST_F(CompileProgramWithKernelPathEnvVarFixture, TensixNonExistentKernel) {
const string &kernel_file = "tests/tt_metal/tt_metal/test_kernels/dataflow/non_existent_kernel.cpp";
this->create_kernel(kernel_file);
EXPECT_THROW(detail::CompileProgram(this->device_, this->program_), std::exception);
}
47 changes: 21 additions & 26 deletions tests/tt_metal/tt_metal/api/test_kernel_creation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,53 @@
//
// SPDX-License-Identifier: Apache-2.0

#include "dispatch_fixture.hpp"
#include "gtest/gtest.h"
#include "tt_metal/host_api.hpp"
#include "common/core_coord.hpp"
#include "dispatch_fixture.hpp"
#include "tt_metal/detail/tt_metal.hpp"

#include "host_api.hpp"
#include "compile_program_with_kernel_path_env_var_fixture.hpp"

using namespace tt;

// Ensures we can successfully create kernels on available compute grid
TEST_F(DispatchFixture, TensixCreateKernelsOnComputeCores) {
for (unsigned int id = 0; id < devices_.size(); id++) {
for (unsigned int id = 0; id < this->devices_.size(); id++) {
tt_metal::Program program = CreateProgram();
CoreCoord compute_grid = devices_.at(id)->compute_with_storage_grid_size();
CoreCoord compute_grid = this->devices_.at(id)->compute_with_storage_grid_size();
EXPECT_NO_THROW(
auto test_kernel = tt_metal::CreateKernel(
program,
"tests/tt_metal/tt_metal/test_kernels/dataflow/dram_copy.cpp",
CoreRange(CoreCoord(0, 0), CoreCoord(compute_grid.x, compute_grid.y)),
{.processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::RISCV_0_default}
);
);
{.processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::RISCV_0_default}););
}
}

// Ensure we cannot create kernels on storage cores
TEST_F(DispatchFixture, TensixCreateKernelsOnStorageCores) {
for (unsigned int id=0; id < devices_.size(); id++) {
if (devices_.at(id)->storage_only_cores().empty()) {
for (unsigned int id = 0; id < this->devices_.size(); id++) {
if (this->devices_.at(id)->storage_only_cores().empty()) {
GTEST_SKIP() << "This test only runs on devices with storage only cores";
}
CoreRangeSet storage_core_range_set = CoreRangeSet(devices_.at(id)->storage_only_cores());
tt_metal::Program program = CreateProgram();
CoreRangeSet storage_core_range_set = CoreRangeSet(this->devices_.at(id)->storage_only_cores());
EXPECT_ANY_THROW(
auto test_kernel = tt_metal::CreateKernel(
program,
"tests/tt_metal/tt_metal/test_kernels/dataflow/dram_copy.cpp",
storage_core_range_set,
{.processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::RISCV_0_default}
);
);
{.processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::RISCV_0_default}););
}
}

TEST_F(DispatchFixture, TensixIdleEthCreateKernelsOnDispatchCores) {
if (getenv("TT_METAL_SLOW_DISPATCH_MODE")) {
GTEST_SKIP() << "This test is only supported in fast dispatch mode";
}
for (unsigned int id=0; id < devices_.size(); id++) {
for (unsigned int id = 0; id < this->devices_.size(); id++) {
tt_metal::Program program = CreateProgram();
Device* device = this->devices_.at(id);
std::vector<CoreCoord> dispatch_cores = tt::get_logical_dispatch_cores(device->id(), device->num_hw_cqs());
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());
std::set<CoreCoord> dispatch_core_range_set(dispatch_cores.begin(), dispatch_cores.end());
Expand All @@ -59,18 +59,13 @@ TEST_F(DispatchFixture, TensixIdleEthCreateKernelsOnDispatchCores) {
program,
"tests/tt_metal/tt_metal/test_kernels/dataflow/dram_copy.cpp",
dispatch_core_range_set,
{.processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::RISCV_0_default}
);
);
{.processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = tt_metal::NOC::RISCV_0_default}););
} else if (dispatch_core_type == CoreType::ETH) {
EXPECT_ANY_THROW(
auto test_kernel = tt_metal::CreateKernel(
program,
"tests/tt_metal/tt_metal/test_kernels/misc/erisc_print.cpp",
dispatch_core_range_set,
{.noc = tt_metal::NOC::NOC_0, .eth_mode = Eth::IDLE}
);
);
EXPECT_ANY_THROW(auto test_kernel = tt_metal::CreateKernel(
program,
"tests/tt_metal/tt_metal/test_kernels/misc/erisc_print.cpp",
dispatch_core_range_set,
{.noc = tt_metal::NOC::NOC_0, .eth_mode = Eth::IDLE}););
}
}
}
Expand Down
9 changes: 0 additions & 9 deletions tests/tt_metal/tt_metal/common/buffer_fixture.hpp

This file was deleted.

36 changes: 15 additions & 21 deletions tests/tt_metal/tt_metal/common/command_queue_fixture.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
#include "common/env_lib.hpp"
#include "gtest/gtest.h"
#include "dispatch_fixture.hpp"
#include "buffer_fixture.hpp"
#include "program_fixture.hpp"
#include "trace_fixture.hpp"
#include "event_fixture.hpp"
#include "hostdevcommon/common_runtime_address_map.h"
#include "hostdevcommon/common_values.hpp"
#include "impl/buffers/circular_buffer_types.hpp"
Expand All @@ -30,7 +27,7 @@
#include "test_utils.hpp"
#include "tt_soc_descriptor.h"

class CommandQueueFixture : virtual public DispatchFixture {
class CommandQueueFixture : public DispatchFixture {
protected:
tt::tt_metal::Device* device_;
void SetUp() override {
Expand All @@ -53,11 +50,11 @@ class CommandQueueFixture : virtual public DispatchFixture {
}
};

class CommandQueueEventFixture : virtual public CommandQueueFixture, virtual public EventFixture {};
class CommandQueueEventFixture : public CommandQueueFixture {};

class CommandQueueBufferFixture : virtual public CommandQueueFixture, virtual public BufferFixture {};
class CommandQueueBufferFixture : public CommandQueueFixture {};

class CommandQueueProgramFixture : virtual public CommandQueueFixture, virtual public ProgramFixture {};
class CommandQueueProgramFixture : public CommandQueueFixture {};

class CommandQueueSingleCardFixture : virtual public DispatchFixture {
protected:
Expand Down Expand Up @@ -96,10 +93,9 @@ class CommandQueueSingleCardFixture : virtual public DispatchFixture {
std::map<chip_id_t, tt::tt_metal::Device *> reserved_devices_;
};

class CommandQueueSingleCardBufferFixture : virtual public CommandQueueSingleCardFixture,
virtual public BufferFixture {};
class CommandQueueSingleCardBufferFixture : public CommandQueueSingleCardFixture {};

class CommandQueueSingleCardTraceFixture : virtual public CommandQueueSingleCardFixture, virtual public TraceFixture {
class CommandQueueSingleCardTraceFixture : virtual public CommandQueueSingleCardFixture {
protected:
void SetUp() override {
this->validate_dispatch_mode();
Expand All @@ -108,8 +104,7 @@ class CommandQueueSingleCardTraceFixture : virtual public CommandQueueSingleCard
}
};

class CommandQueueSingleCardProgramFixture : virtual public CommandQueueSingleCardFixture,
virtual public ProgramFixture {};
class CommandQueueSingleCardProgramFixture : virtual public CommandQueueSingleCardFixture {};

class CommandQueueMultiDeviceFixture : public DispatchFixture {
protected:
Expand Down Expand Up @@ -144,8 +139,7 @@ class CommandQueueMultiDeviceFixture : public DispatchFixture {
size_t num_devices_;
};

class CommandQueueMultiDeviceProgramFixture : virtual public CommandQueueMultiDeviceFixture,
virtual public ProgramFixture {};
class CommandQueueMultiDeviceProgramFixture : public CommandQueueMultiDeviceFixture {};

class RandomProgramFixture : virtual public CommandQueueSingleCardProgramFixture {
protected:
Expand Down Expand Up @@ -519,7 +513,7 @@ class RandomProgramTraceFixture : virtual public RandomProgramFixture, virtual p
}
};

class MultiCommandQueueSingleDeviceFixture : virtual public DispatchFixture {
class MultiCommandQueueSingleDeviceFixture : public DispatchFixture {
protected:
void SetUp() override {
auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE");
Expand Down Expand Up @@ -551,13 +545,13 @@ class MultiCommandQueueSingleDeviceFixture : virtual public DispatchFixture {
tt::ARCH arch_;
};

class MultiCommandQueueSingleDeviceEventFixture : public MultiCommandQueueSingleDeviceFixture, public EventFixture {};
class MultiCommandQueueSingleDeviceEventFixture : public MultiCommandQueueSingleDeviceFixture {};

class MultiCommandQueueSingleDeviceBufferFixture : public MultiCommandQueueSingleDeviceFixture, public BufferFixture {};
class MultiCommandQueueSingleDeviceBufferFixture : public MultiCommandQueueSingleDeviceFixture {};

class MultiCommandQueueSingleDeviceProgramFixture : public MultiCommandQueueSingleDeviceFixture, public ProgramFixture {};
class MultiCommandQueueSingleDeviceProgramFixture : public MultiCommandQueueSingleDeviceFixture {};

class MultiCommandQueueMultiDeviceFixture : virtual public DispatchFixture {
class MultiCommandQueueMultiDeviceFixture : public DispatchFixture {
protected:
void SetUp() override {
auto slow_dispatch = getenv("TT_METAL_SLOW_DISPATCH_MODE");
Expand Down Expand Up @@ -594,6 +588,6 @@ class MultiCommandQueueMultiDeviceFixture : virtual public DispatchFixture {
std::map<chip_id_t, tt::tt_metal::Device*> reserved_devices_;
};

class MultiCommandQueueMultiDeviceBufferFixture : public MultiCommandQueueMultiDeviceFixture, public BufferFixture {};
class MultiCommandQueueMultiDeviceBufferFixture : public MultiCommandQueueMultiDeviceFixture {};

class MultiCommandQueueMultiDeviceEventFixture : public MultiCommandQueueMultiDeviceFixture, public EventFixture {};
class MultiCommandQueueMultiDeviceEventFixture : public MultiCommandQueueMultiDeviceFixture {};
5 changes: 2 additions & 3 deletions tests/tt_metal/tt_metal/common/device_fixture.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include <gtest/gtest.h>

#include "buffer_fixture.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/test_utils/env_vars.hpp"
Expand Down Expand Up @@ -53,7 +52,7 @@ class DeviceFixture : public ::testing::Test {
size_t num_devices_;
};

class DeviceSingleCardFixture : virtual public ::testing::Test {
class DeviceSingleCardFixture : public ::testing::Test {
protected:
void SetUp() override {
this->validate_dispatch_mode();
Expand Down Expand Up @@ -84,7 +83,7 @@ class DeviceSingleCardFixture : virtual public ::testing::Test {
size_t num_devices_;
};

class DeviceSingleCardBufferFixture : virtual public DeviceSingleCardFixture, virtual public BufferFixture {};
class DeviceSingleCardBufferFixture : public DeviceSingleCardFixture {};

class BlackholeSingleCardFixture : public DeviceSingleCardFixture {
protected:
Expand Down
2 changes: 1 addition & 1 deletion tests/tt_metal/tt_metal/common/dispatch_fixture.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "tt_metal/impl/device/device_pool.hpp"

// A dispatch-agnostic test fixture
class DispatchFixture : virtual public ::testing::Test {
class DispatchFixture : public ::testing::Test {
public:
// A function to run a program, according to which dispatch mode is set.
void RunProgram(tt::tt_metal::Device* device, tt::tt_metal::Program& program) {
Expand Down
9 changes: 0 additions & 9 deletions tests/tt_metal/tt_metal/common/event_fixture.hpp

This file was deleted.

9 changes: 0 additions & 9 deletions tests/tt_metal/tt_metal/common/program_fixture.hpp

This file was deleted.

4 changes: 1 addition & 3 deletions tests/tt_metal/tt_metal/common/trace_fixture.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
#include "tt_metal/test_utils/env_vars.hpp"
#include "impl/kernels/kernel.hpp"

class TraceFixture : virtual public ::testing::Test {};

class SingleDeviceTraceFixture : virtual public TraceFixture {
class SingleDeviceTraceFixture : public ::testing::Test {
protected:
Device* device_;
tt::ARCH arch_;
Expand Down
Loading

0 comments on commit 384da63

Please sign in to comment.