Skip to content

Commit

Permalink
Towards unified device & host tasks
Browse files Browse the repository at this point in the history
We now template the `ttg::device::Task` on the `ExecutionSpace` so that
we can determine whether it's a host or device task based on the space.
We can then optimize away the select and kernel-wait suspension points.
We could remove the send suspension point but we use coroutines for
storing the final sends anyway and we don't have access to the
task return type in `ttg::device::send()`.

This allows tasks to be written once for both host and device without
duplicating much of the code. Host tasks that are not coroutines will
continue to be supported.

Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Nov 8, 2024
1 parent 2ec6982 commit dc08fcf
Show file tree
Hide file tree
Showing 13 changed files with 646 additions and 374 deletions.
20 changes: 18 additions & 2 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ if (TARGET tiledarray)
COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2)

add_ttg_executable(testing_dpotrf potrf/testing_dpotrf.cc LINK_LIBRARIES tiledarray lapackpp)
add_ttg_executable(testing_dpotrf_host potrf/testing_dpotrf.cc
LINK_LIBRARIES tiledarray lapackpp
COMPILE_DEFINITIONS TTG_ENABLE_DEV_HOST=1)
add_ttg_executable(testing_dtrtri potrf/testing_dtrtri.cc LINK_LIBRARIES tiledarray lapackpp)
add_ttg_executable(testing_dlauum potrf/testing_dlauum.cc LINK_LIBRARIES tiledarray lapackpp)
add_ttg_executable(testing_dpoinv potrf/testing_dpoinv.cc LINK_LIBRARIES tiledarray lapackpp)
Expand Down Expand Up @@ -50,14 +53,27 @@ if (TARGET tiledarray)
endif()

if (TTG_HAVE_CUDA)
add_ttg_executable(chain-ttg-cuda task-benchmarks/chain-ttg-dev.cc LINK_LIBRARIES tiledarray RUNTIMES "parsec")
add_ttg_executable(chain-ttg-dev-cuda task-benchmarks/chain-ttg-dev.cc
COMPILE_DEFINITIONS CHAIN_CUDA=1
LINK_LIBRARIES tiledarray
RUNTIMES "parsec")
endif(TTG_HAVE_CUDA)

if (TTG_HAVE_HIP)
add_ttg_executable(chain-ttg-hip task-benchmarks/chain-ttg-dev.cc LINK_LIBRARIES tiledarray RUNTIMES "parsec")
add_ttg_executable(chain-ttg-dev-hip task-benchmarks/chain-ttg-dev.cc
COMPILE_DEFINITIONS CHAIN_HIP=1
LINK_LIBRARIES tiledarray
RUNTIMES "parsec")
endif(TTG_HAVE_HIP)
endif()

add_ttg_executable(chain-ttg-host task-benchmarks/chain-ttg.cc)

add_ttg_executable(chain-ttg-dev-host task-benchmarks/chain-ttg-dev.cc
COMPILE_DEFINITIONS CHAIN_HOST=1
LINK_LIBRARIES tiledarray
RUNTIMES "parsec")

if (TARGET MADworld)
add_ttg_executable(madness-1d madness/madness-1d/madness-1d.cc RUNTIMES "mad")
if (TARGET blaspp) #(CBLAS_FOUND AND MKL_FOUND)
Expand Down
107 changes: 62 additions & 45 deletions examples/potrf/potrf.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,19 @@
#include "util.h"
#include "../devblas_helper.h"

#if (defined(TTG_ENABLE_CUDA) || defined(TTG_ENABLE_HIP))
#if (defined(TTG_ENABLE_CUDA) || defined(TTG_ENABLE_HIP) || defined(TTG_ENABLE_DEV_HOST))
#define ENABLE_DEVICE_KERNEL 1
#endif

#if defined(TTG_HAVE_CUDART)
#define ES ttg::ExecutionSpace::CUDA
#define TASKRET -> ttg::device::Task
#include <cusolverDn.h>
#elif defined(TTG_ENABLE_HIP)
#define ES ttg::ExecutionSpace::HIP
#define TASKRET -> ttg::device::Task
#include <hipsolver/hipsolver.h>
#include <hipblas/hipblas.h>
#else
#define ES ttg::ExecutionSpace::Host
#define TASKRET -> void
#endif

namespace potrf {
Expand All @@ -35,21 +32,21 @@ namespace potrf {
#if defined(ENABLE_DEVICE_KERNEL)
static int device_potrf_workspace_size(MatrixTile<double> &A) {
int Lwork;
#if defined(TTG_ENABLE_CUDA)
#if defined(TTG_ENABLE_CUDA)
cusolverDnDpotrf_bufferSize(cusolver_handle(),
CUBLAS_FILL_MODE_LOWER, A.cols(),
nullptr, A.lda(),
&Lwork);
return Lwork;
#elif defined(TTG_ENABLE_HIP)
#elif defined(TTG_ENABLE_HIP)
hipsolverDnDpotrf_bufferSize(hipsolver_handle(),
HIPSOLVER_FILL_MODE_LOWER, A.cols(),
nullptr, A.lda(),
&Lwork);
return Lwork;
#else
#else
return 0;
#endif
#endif
}

static void device_potrf(MatrixTile<double> &A, double *workspace, int Lwork, int *devInfo) {
Expand All @@ -64,13 +61,16 @@ namespace potrf {
A.buffer().current_device_ptr(), A.lda(),
workspace, Lwork,
devInfo);
#elif defined(TTG_ENABLE_HIP)
#elif defined(TTG_ENABLE_HIP)
hipsolverDpotrf(hipsolver_handle(),
HIPSOLVER_FILL_MODE_LOWER, A.cols(),
A.buffer().current_device_ptr(), A.lda(),
workspace, Lwork,
devInfo);
#endif
#else
auto info = lapack::potrf(lapack::Uplo::Lower, A.rows(), A.buffer().current_device_ptr(), A.lda());
assert(info == 0);
#endif
}

static void device_norm(const MatrixTile<double> &A, double *norm) {
Expand All @@ -81,9 +81,11 @@ namespace potrf {
auto handle = cublas_handle();
//double n = 1.0;
cublasDnrm2(handle, size, buffer, 1, norm);
#elif defined(TTG_ENABLE_HIP)
#elif defined(TTG_ENABLE_HIP)
hipblasDnrm2(hipblas_handle(), size, buffer, 1, norm);
#endif
#else
*norm = blas::nrm2(size, buffer, 1);
#endif
}
#endif // ENABLE_DEVICE_KERNEL

Expand All @@ -99,7 +101,8 @@ namespace potrf {
//std::cout << "Creating CUDA POTRF task " << std::endl;
auto f_dev = [=, iallocator = std::move(iallocator)]
(const Key1& key, MatrixTile<T>&& tile_kk,
std::tuple<ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>>& out) TASKRET {
std::tuple<ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>>& out)
-> ttg::device::Task<ES> {
const auto K = key[0];

/* compute successors before submitting the kernel
Expand Down Expand Up @@ -186,7 +189,7 @@ namespace potrf {
ttg::abort();
}
};
return ttg::make_tt<ES>(f_dev, ttg::edges(ttg::fuse(input, input_disp)), ttg::edges(output_result, output_trsm), "POTRF",
return ttg::make_tt(f_dev, ttg::edges(ttg::fuse(input, input_disp)), ttg::edges(output_result, output_trsm), "POTRF",
{"tile_kk/dispatcher"}, {"output_result", "output_trsm"});
#else /* defined(ENABLE_DEVICE_KERNEL) */
auto f = [=](const Key1& key, MatrixTile<T>&& tile_kk,
Expand Down Expand Up @@ -234,7 +237,7 @@ namespace potrf {
#if defined(ENABLE_DEVICE_KERNEL)
auto f = [=](const Key2& key, const MatrixTile<T>& tile_kk, MatrixTile<T>&& tile_mk,
std::tuple<ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key3, MatrixTile<T>>,
ttg::Out<Key3, MatrixTile<T>>>& out) TASKRET {
ttg::Out<Key3, MatrixTile<T>>>& out) -> ttg::device::Task<ES> {
const int M = key[0];
const int K = key[1]; // the column equals the outer most look K (same as PO)

Expand Down Expand Up @@ -302,6 +305,9 @@ namespace potrf {
mb, nb, &alpha,
tile_kk.buffer().current_device_ptr(), tile_kk.lda(),
tile_mk.buffer().current_device_ptr(), tile_mk.lda());
#else
blas::trsm(blas::Layout::ColMajor, blas::Side::Right, lapack::Uplo::Lower, blas::Op::Trans, blas::Diag::NonUnit,
mb, nb, 1.0, tile_kk.data(), tile_kk.lda(), tile_mk.data(), tile_mk.lda());

#endif

Expand All @@ -320,7 +326,7 @@ namespace potrf {
co_await ttg::device::forward(ttg::device::broadcast<0, 1, 2, 3>(std::make_tuple(key, Key2(K, M), keylist_row, keylist_col),
std::move(tile_mk), out));
};
return ttg::make_tt<ES>(f, ttg::edges(input_kk, ttg::fuse(input_mk, input_disp)),
return ttg::make_tt(f, ttg::edges(input_kk, ttg::fuse(input_mk, input_disp)),
ttg::edges(output_result, output_diag, output_row, output_col), "TRSM",
{"tile_kk", "tile_mk/dispatcher"}, {"output_result", "tile_mk", "output_row", "output_col"});
#else // defined(ENABLE_DEVICE_KERNEL)
Expand Down Expand Up @@ -386,8 +392,8 @@ namespace potrf {
ttg::Edge<Key2, MatrixTile<typename MatrixT::element_type>>& output_syrk) {
using T = typename MatrixT::element_type;
#if defined(ENABLE_DEVICE_KERNEL)
auto f = [=](const Key2& key, const MatrixTile<T>& tile_mk, MatrixTile<T>&& tile_kk,
std::tuple<ttg::Out<Key1, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>>& out) TASKRET {
auto f = [=](const Key2& key, const MatrixTile<T>& tile_mk, MatrixTile<T>&& tile_kk)
-> ttg::device::Task<ES> {
const int K = key[0];
const int M = key[1];

Expand Down Expand Up @@ -432,6 +438,9 @@ namespace potrf {
mb, nb, &alpha,
tile_mk.buffer().current_device_ptr(), tile_mk.lda(), &beta,
tile_kk.buffer().current_device_ptr(), tile_kk.lda());
#else
blas::syrk(blas::Layout::ColMajor, lapack::Uplo::Lower, blas::Op::NoTrans, mb, nb, -1.0, tile_mk.data(),
tile_mk.lda(), 1.0, tile_kk.data(), tile_kk.lda());
#endif

#ifdef DEBUG_TILES_VALUES
Expand All @@ -449,18 +458,17 @@ namespace potrf {
if (M == K + 1) {
/* send the tile to potrf */
if (ttg::tracing()) ttg::print("SYRK(", key, "): sending output to POTRF(", Key1{K + 1}, ")");
co_await ttg::device::send<0>(Key1(K + 1), std::move(tile_kk), out);
co_await ttg::device::send<0>(Key1(K + 1), std::move(tile_kk));
} else {
/* send output to next syrk */
if (ttg::tracing()) ttg::print("SYRK(", key, "): sending output to SYRK(", Key2{K + 1, M}, ")");
co_await ttg::device::send<1>(Key2(K + 1, M), std::move(tile_kk), out);
co_await ttg::device::send<1>(Key2(K + 1, M), std::move(tile_kk));
}
};
return ttg::make_tt<ES>(f, ttg::edges(input_mk, ttg::fuse(input_kk, input_disp)), ttg::edges(output_potrf, output_syrk),
return ttg::make_tt(f, ttg::edges(input_mk, ttg::fuse(input_kk, input_disp)), ttg::edges(output_potrf, output_syrk),
"SYRK", {"tile_mk", "tile_kk/dispatcher"}, {"output_potrf", "output_syrk"});
#else // defined(ENABLE_DEVICE_KERNEL)
auto f = [=](const Key2& key, const MatrixTile<T>& tile_mk, MatrixTile<T>&& tile_kk,
std::tuple<ttg::Out<Key1, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>>& out) {
auto f = [=](const Key2& key, const MatrixTile<T>& tile_mk, MatrixTile<T>&& tile_kk) {
const int K = key[0];
const int M = key[1];

Expand All @@ -487,11 +495,11 @@ namespace potrf {
if (M == K + 1) {
/* send the tile to potrf */
if (ttg::tracing()) ttg::print("SYRK(", key, "): sending output to POTRF(", Key1{K + 1}, ")");
ttg::send<0>(Key1(K + 1), std::move(tile_kk), out);
ttg::send<0>(Key1(K + 1), std::move(tile_kk));
} else {
/* send output to next syrk */
if (ttg::tracing()) ttg::print("SYRK(", key, "): sending output to SYRK(", Key2{K + 1, M}, ")");
ttg::send<1>(Key2(K + 1, M), std::move(tile_kk), out);
ttg::send<1>(Key2(K + 1, M), std::move(tile_kk));
}
};
return ttg::make_tt(f, ttg::edges(input_mk, ttg::fuse(input_kk, input_disp)), ttg::edges(output_potrf, output_syrk),
Expand All @@ -509,8 +517,8 @@ namespace potrf {
ttg::Edge<Key3, MatrixTile<typename MatrixT::element_type>>& output_gemm) {
using T = typename MatrixT::element_type;
#if defined(ENABLE_DEVICE_KERNEL)
auto f = [=](const Key3& key, const MatrixTile<T>& tile_mk, const MatrixTile<T>& tile_nk, MatrixTile<T>&& tile_mn,
std::tuple<ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key3, MatrixTile<T>>>& out) TASKRET {
auto f = [=](const Key3& key, const MatrixTile<T>& tile_mk, const MatrixTile<T>& tile_nk, MatrixTile<T>&& tile_mn)
-> ttg::device::Task<ES> {
const int M = key[0];
const int N = key[1];
const int K = key[2];
Expand Down Expand Up @@ -559,6 +567,10 @@ namespace potrf {
tile_mk.buffer().current_device_ptr(), tile_mk.lda(),
tile_nk.buffer().current_device_ptr(), tile_nk.lda(), &beta,
tile_mn.buffer().current_device_ptr(), tile_mn.lda());
#else
blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::Trans, tile_mk.rows(), tile_nk.rows(),
tile_nk.cols(), -1.0, tile_mk.data(), tile_mk.lda(), tile_nk.data(), tile_nk.lda(), 1.0,
tile_mn.data(), tile_mn.lda());
#endif


Expand All @@ -578,19 +590,18 @@ namespace potrf {
if (N == K + 1) {
/* send the tile to trsm */
if (ttg::tracing()) ttg::print("GEMM(", key, "): sending output to TRSM(", Key2{M, N}, ")");
co_await ttg::device::send<0>(Key2(M, N), std::move(tile_mn), out);
co_await ttg::device::send<0>(Key2(M, N), std::move(tile_mn));
} else {
/* send the tile to the next gemm */
if (ttg::tracing()) ttg::print("GEMM(", key, "): sending output to GEMM(", Key3{M, N, K + 1}, ")");
co_await ttg::device::send<1>(Key3(M, N, K + 1), std::move(tile_mn), out);
co_await ttg::device::send<1>(Key3(M, N, K + 1), std::move(tile_mn));
}
};
return ttg::make_tt<ES>(f, ttg::edges(input_mk, input_nk, ttg::fuse(input_disp, input_mn)),
return ttg::make_tt(f, ttg::edges(input_mk, input_nk, ttg::fuse(input_disp, input_mn)),
ttg::edges(output_trsm, output_gemm), "GEMM", {"input_mk", "input_kn", "input_mn/dispatcher"},
{"output_trsm", "outout_gemm"});
#else // defined(ENABLE_DEVICE_KERNEL)
auto f = [=](const Key3& key, const MatrixTile<T>& tile_mk, const MatrixTile<T>& tile_nk, MatrixTile<T>&& tile_mn,
std::tuple<ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key3, MatrixTile<T>>>& out) {
auto f = [=](const Key3& key, const MatrixTile<T>& tile_mk, const MatrixTile<T>& tile_nk, MatrixTile<T>&& tile_mn) {
const int M = key[0];
const int N = key[1];
const int K = key[2];
Expand All @@ -617,11 +628,11 @@ namespace potrf {
if (N == K + 1) {
/* send the tile to trsm */
if (ttg::tracing()) ttg::print("GEMM(", key, "): sending output to TRSM(", Key2{M, N}, ")");
ttg::send<0>(Key2(M, N), std::move(tile_mn), out);
ttg::send<0>(Key2(M, N), std::move(tile_mn));
} else {
/* send the tile to the next gemm */
if (ttg::tracing()) ttg::print("GEMM(", key, "): sending output to GEMM(", Key3{M, N, K + 1}, ")");
ttg::send<1>(Key3(M, N, K + 1), std::move(tile_mn), out);
ttg::send<1>(Key3(M, N, K + 1), std::move(tile_mn));
}
};
return ttg::make_tt(f, ttg::edges(input_mk, input_nk, ttg::fuse(input_disp, input_mn)),
Expand All @@ -634,33 +645,31 @@ namespace potrf {
auto make_dispatcher(ttg::Edge<Key2, MatrixTile<T>>& input, ttg::Edge<Key1, MatrixTile<T>>& to_potrf,
ttg::Edge<Key2, MatrixTile<T>>& to_trsm, ttg::Edge<Key2, MatrixTile<T>>& to_syrk,
ttg::Edge<Key3, MatrixTile<T>>& to_gemm) {
auto f = [=](const Key2& key, const MatrixTile<T>& tile,
std::tuple<ttg::Out<Key1, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>, ttg::Out<Key2, MatrixTile<T>>,
ttg::Out<Key3, MatrixTile<T>>>& out) {
auto f = [=](const Key2& key, const MatrixTile<T>& tile) {
if (ttg::tracing()) ttg::print("POTRF_Dispatch(", key, ")");
if (0 == key[0] && 0 == key[1]) {
// First element goes to POTRF
if (ttg::tracing()) ttg::print("POTRF_Dispatch(", key, ") sending to POTRF(", Key1{key[0]}, ")");
ttg::send<0>(Key1{key[0]}, tile, out);
ttg::send<0>(Key1{key[0]}, tile);
return;
}
if (key[0] == key[1]) {
// Other diagonal elements go to SYRK
if (ttg::tracing()) ttg::print("POTRF_Dispatch(", key, ") sending to SYRK(", Key2{0, key[0]}, ")");
ttg::send<2>(Key2{0, key[0]}, tile, out);
ttg::send<2>(Key2{0, key[0]}, tile);
return;
}
// We only consider the lower triangular
assert(key[0] > key[1]);
if (0 == key[1]) {
// First column goes to TRSM
if (ttg::tracing()) ttg::print("POTRF_Dispatch(", key, ") sending to TRSM(", key, ")");
ttg::send<1>(key, tile, out);
ttg::send<1>(key, tile);
return;
}
// Rest goes to GEMM
if (ttg::tracing()) ttg::print("POTRF_Dispatch(", key, ") sending to GEMM(", Key3{key[0], key[1], 0}, ")");
ttg::send<3>(Key3{key[0], key[1], 0}, tile, out);
ttg::send<3>(Key3{key[0], key[1], 0}, tile);
};

return ttg::make_tt(f, ttg::edges(input), ttg::edges(to_potrf, to_trsm, to_syrk, to_gemm), "POTRF Dispatch",
Expand Down Expand Up @@ -705,28 +714,36 @@ namespace potrf {
tt_potrf->set_keymap(keymap1);
tt_potrf->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_potrf->set_devicemap(devmap1);
if constexpr (ES != ttg::ExecutionSpace::Host) {
tt_potrf->set_devicemap(devmap1);
}
#endif // 0

auto tt_trsm = make_trsm(A, disp_trsm, potrf_trsm, gemm_trsm, trsm_syrk, trsm_gemm_row, trsm_gemm_col, output);
tt_trsm->set_keymap(keymap2a);
tt_trsm->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_trsm->set_devicemap(devmap2a);
if constexpr (ES != ttg::ExecutionSpace::Host) {
tt_trsm->set_devicemap(devmap2a);
}
#endif // 0

auto tt_syrk = make_syrk(A, disp_syrk, trsm_syrk, syrk_syrk, syrk_potrf, syrk_syrk);
tt_syrk->set_keymap(keymap2b);
tt_syrk->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_syrk->set_devicemap(devmap2b);
if constexpr (ES != ttg::ExecutionSpace::Host) {
tt_syrk->set_devicemap(devmap2b);
}
#endif // 0

auto tt_gemm = make_gemm(A, disp_gemm, trsm_gemm_row, trsm_gemm_col, gemm_gemm, gemm_trsm, gemm_gemm);
tt_gemm->set_keymap(keymap3);
tt_gemm->set_defer_writer(defer_write);
#ifdef ENABLE_DEVICE_KERNEL
tt_gemm->set_devicemap(devmap3);
if constexpr (ES != ttg::ExecutionSpace::Host) {
tt_gemm->set_devicemap(devmap3);
}
#endif // 0

/* Priorities taken from DPLASMA */
Expand Down
Loading

0 comments on commit dc08fcf

Please sign in to comment.