Skip to content

Commit

Permalink
convert librett unit tests to use default stream provided by TA runti…
Browse files Browse the repository at this point in the history
…me, not CUDA runtime
  • Loading branch information
evaleev committed Sep 27, 2023
1 parent 5ea446b commit fdaf8bc
Showing 1 changed file with 72 additions and 48 deletions.
120 changes: 72 additions & 48 deletions tests/librett.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,18 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem) {
iter++;
}
}

auto q = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(q.device));

int* a_device;
TiledArray::device::malloc(&a_device, A * A * sizeof(int));
int* b_device;
TiledArray::device::malloc(&b_device, A * A * sizeof(int));

TiledArray::device::memcpy(a_device, a_host, A * A * sizeof(int),
TiledArray::device::MemcpyHostToDevice);
TiledArray::device::memcpyAsync(a_device, a_host, A * A * sizeof(int),
TiledArray::device::MemcpyHostToDevice,
q.stream);

std::vector<int> extent({A, A});
TiledArray::extent_to_col_major(extent);
Expand All @@ -70,22 +75,23 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem) {
TiledArray::permutation_to_col_major(perm);

librettHandle plan;
auto stream = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(stream.device));
librettResult status;

status = librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int),
stream.stream);
status =
librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), q.stream);

BOOST_CHECK(status == LIBRETT_SUCCESS);

status = librettExecute(plan, a_device, b_device);

BOOST_CHECK(status == LIBRETT_SUCCESS);
librettDestroy(plan);

TiledArray::device::memcpy(b_host, b_device, A * A * sizeof(int),
TiledArray::device::MemcpyDeviceToHost);
TiledArray::device::memcpyAsync(b_host, b_device, A * A * sizeof(int),
TiledArray::device::MemcpyDeviceToHost,
q.stream);
TiledArray::device::streamSynchronize(q.stream);

librettDestroy(plan);

iter = 0;
for (std::size_t i = 0; i < A; i++) {
Expand Down Expand Up @@ -113,17 +119,19 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym) {
}
}

auto q = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(q.device));

int* a_device;
TiledArray::device::malloc(&a_device, A * B * sizeof(int));
int* b_device;
TiledArray::device::malloc(&b_device, A * B * sizeof(int));

TiledArray::device::memcpy(a_device, a_host, A * B * sizeof(int),
TiledArray::device::MemcpyHostToDevice);
TiledArray::device::memcpyAsync(a_device, a_host, A * B * sizeof(int),
TiledArray::device::MemcpyHostToDevice,
q.stream);

librettHandle plan;
auto stream = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(stream.device));
librettResult status;

std::vector<int> extent({B, A});
Expand All @@ -132,18 +140,21 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym) {
std::vector<int> perm({1, 0});
TiledArray::permutation_to_col_major(perm);

status = librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int),
stream.stream);
status =
librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), q.stream);

BOOST_CHECK(status == LIBRETT_SUCCESS);

status = librettExecute(plan, a_device, b_device);

BOOST_CHECK(status == LIBRETT_SUCCESS);
librettDestroy(plan);

TiledArray::device::memcpy(b_host, b_device, A * B * sizeof(int),
TiledArray::device::MemcpyDeviceToHost);
TiledArray::device::memcpyAsync(b_host, b_device, A * B * sizeof(int),
TiledArray::device::MemcpyDeviceToHost,
q.stream);
TiledArray::device::streamSynchronize(q.stream);

librettDestroy(plan);

iter = 0;
for (std::size_t i = 0; i < B; i++) {
Expand Down Expand Up @@ -173,19 +184,21 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym_rank_three_column_major) {
}
}

auto q = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(q.device));

int* a_device;
TiledArray::device::malloc(&a_device, A * B * C * sizeof(int));
int* b_device;
TiledArray::device::malloc(&b_device, A * B * C * sizeof(int));

TiledArray::device::memcpy(a_device, a_host, A * B * C * sizeof(int),
TiledArray::device::MemcpyHostToDevice);
TiledArray::device::memcpyAsync(a_device, a_host, A * B * C * sizeof(int),
TiledArray::device::MemcpyHostToDevice,
q.stream);

// b(j,i,k) = a(i,j,k)

librettHandle plan;
auto stream = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(stream.device));
librettResult status;

std::vector<int> extent3{int(A), int(B), int(C)};
Expand All @@ -194,16 +207,18 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym_rank_three_column_major) {
// std::vector<int> perm3{0, 2, 1};

status = librettPlanMeasure(&plan, 3, extent3.data(), perm3.data(),
sizeof(int), stream.stream, a_device, b_device);
sizeof(int), q.stream, a_device, b_device);

BOOST_CHECK(status == LIBRETT_SUCCESS);

status = librettExecute(plan, a_device, b_device);

BOOST_CHECK(status == LIBRETT_SUCCESS);

TiledArray::device::memcpy(b_host, b_device, A * B * C * sizeof(int),
TiledArray::device::MemcpyDeviceToHost);
TiledArray::device::memcpyAsync(b_host, b_device, A * B * C * sizeof(int),
TiledArray::device::MemcpyDeviceToHost,
q.stream);
TiledArray::device::streamSynchronize(q.stream);

status = librettDestroy(plan);

Expand Down Expand Up @@ -239,19 +254,21 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym_rank_three_row_major) {
}
}

auto q = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(q.device));

int* a_device;
TiledArray::device::malloc(&a_device, A * B * C * sizeof(int));
int* b_device;
TiledArray::device::malloc(&b_device, A * B * C * sizeof(int));

TiledArray::device::memcpy(a_device, a_host, A * B * C * sizeof(int),
TiledArray::device::MemcpyHostToDevice);
TiledArray::device::memcpyAsync(a_device, a_host, A * B * C * sizeof(int),
TiledArray::device::MemcpyHostToDevice,
q.stream);

// b(j,i,k) = a(i,j,k)

librettHandle plan;
auto stream = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(stream.device));
librettResult status;

std::vector<int> extent({A, B, C});
Expand All @@ -261,16 +278,18 @@ BOOST_AUTO_TEST_CASE(librett_gpu_mem_nonsym_rank_three_row_major) {
TiledArray::permutation_to_col_major(perm);

status = librettPlanMeasure(&plan, 3, extent.data(), perm.data(), sizeof(int),
stream.stream, a_device, b_device);
q.stream, a_device, b_device);

BOOST_CHECK(status == LIBRETT_SUCCESS);

status = librettExecute(plan, a_device, b_device);

BOOST_CHECK(status == LIBRETT_SUCCESS);

TiledArray::device::memcpy(b_host, b_device, A * B * C * sizeof(int),
TiledArray::device::MemcpyDeviceToHost);
TiledArray::device::memcpyAsync(b_host, b_device, A * B * C * sizeof(int),
TiledArray::device::MemcpyDeviceToHost,
q.stream);
TiledArray::device::streamSynchronize(q.stream);

status = librettDestroy(plan);

Expand Down Expand Up @@ -308,9 +327,10 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem) {
}
}

auto q = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(q.device));

librettHandle plan;
auto stream = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(stream.device));
librettResult status;

std::vector<int> extent({A, A});
Expand All @@ -319,18 +339,18 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem) {
std::vector<int> perm({1, 0});
TiledArray::permutation_to_col_major(perm);

status = librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int),
stream.stream);
status =
librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), q.stream);

BOOST_CHECK(status == LIBRETT_SUCCESS);

status = librettExecute(plan, a_um, b_um);

BOOST_CHECK(status == LIBRETT_SUCCESS);

librettDestroy(plan);
TiledArray::device::streamSynchronize(q.stream);

TiledArray::device::deviceSynchronize();
librettDestroy(plan);

iter = 0;
for (std::size_t i = 0; i < A; i++) {
Expand Down Expand Up @@ -359,9 +379,10 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem_nonsym) {
}
}

auto q = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(q.device));

librettHandle plan;
auto stream = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(stream.device));
librettResult status;

std::vector<int> extent({B, A});
Expand All @@ -370,17 +391,18 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem_nonsym) {
std::vector<int> perm({1, 0});
TiledArray::permutation_to_col_major(perm);

status = librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int),
stream.stream);
status =
librettPlan(&plan, 2, extent.data(), perm.data(), sizeof(int), q.stream);

BOOST_CHECK(status == LIBRETT_SUCCESS);

status = librettExecute(plan, a_um, b_um);

BOOST_CHECK(status == LIBRETT_SUCCESS);

TiledArray::device::streamSynchronize(q.stream);

librettDestroy(plan);
TiledArray::device::deviceSynchronize();

iter = 0;
for (std::size_t i = 0; i < B; i++) {
Expand Down Expand Up @@ -410,9 +432,10 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem_rank_three) {
}
}

auto q = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(q.device));

librettHandle plan;
auto stream = TiledArray::deviceEnv::instance()->stream(0);
DeviceSafeCall(TiledArray::device::setDevice(stream.device));
librettResult status;

// b(k,i,j) = a(i,j,k)
Expand All @@ -423,17 +446,18 @@ BOOST_AUTO_TEST_CASE(librett_unified_mem_rank_three) {
std::vector<int> perm({2, 0, 1});
TiledArray::permutation_to_col_major(perm);

status = librettPlan(&plan, 3, extent.data(), perm.data(), sizeof(int),
stream.stream);
status =
librettPlan(&plan, 3, extent.data(), perm.data(), sizeof(int), q.stream);

BOOST_CHECK(status == LIBRETT_SUCCESS);

status = librettExecute(plan, a_um, b_um);

BOOST_CHECK(status == LIBRETT_SUCCESS);

TiledArray::device::streamSynchronize(q.stream);

librettDestroy(plan);
TiledArray::device::deviceSynchronize();

iter = 0;
for (std::size_t i = 0; i < A; i++) {
Expand Down

0 comments on commit fdaf8bc

Please sign in to comment.