Skip to content

Commit

Permalink
Merge branch 'smalton/DOR-760-tx-cpu' into 'master'
Browse files Browse the repository at this point in the history
[DOR-760] Fix CPU calling for TX model

Closes DOR-760

See merge request machine-learning/dorado!1076
  • Loading branch information
iiSeymour committed Jun 17, 2024
2 parents eec6f34 + bb45084 commit b8df2ce
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions dorado/basecall/nn/TxModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ GatedMLPImpl::GatedMLPImpl(int in_features_, int hidden_features_)
at::Tensor GatedMLPImpl::forward(const at::Tensor &x) {
at::Tensor t;
#if DORADO_CUDA_BUILD && !defined(DORADO_TX2)
if (utils::get_dev_opt<bool>("use_koi_swiglu", true) && koi_can_use_cutlass()) {
auto use_koi_swiglu = x.is_cuda() && utils::get_dev_opt<bool>("use_koi_swiglu", true) &&
koi_can_use_cutlass();
if (use_koi_swiglu) {
utils::ScopedProfileRange spr("FC1+SILU", 3);
auto N = x.size(0);
auto T = x.size(1);
Expand Down Expand Up @@ -262,7 +264,8 @@ at::Tensor MultiHeadAttentionImpl::forward(at::Tensor x) {
const int64_t C = x.size(2);

#if DORADO_CUDA_BUILD
bool use_koi_rote = utils::get_dev_opt<bool>("use_koi_rote", true) && d_model <= 512;
bool use_koi_rote =
x.is_cuda() && utils::get_dev_opt<bool>("use_koi_rote", true) && d_model <= 512;
if (use_koi_rote) {
if (!wqkv_transposed) {
auto w = wqkv->weight;
Expand Down Expand Up @@ -311,7 +314,9 @@ at::Tensor MultiHeadAttentionImpl::forward(at::Tensor x) {
attn_output_ntc = at::empty({N, T, C}, x.options());
#if DORADO_CUDA_BUILD && !defined(DORADO_TX2)
int res = KOI_NOT_SUPPORTED;
if (utils::get_dev_opt<bool>("use_koi_attention", true) && koi_can_use_cutlass()) {
bool use_koi_attention = x.is_cuda() && utils::get_dev_opt<bool>("use_koi_attention", true) &&
koi_can_use_cutlass();
if (use_koi_attention) {
utils::ScopedProfileRange spr("KOI_MEA", 3);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto [win_upper, win_lower] = attn_window;
Expand Down Expand Up @@ -377,19 +382,23 @@ at::Tensor TxEncoderImpl::forward(at::Tensor x) {
const int N = static_cast<int>(x.size(0));
const int T = static_cast<int>(x.size(1));
const int C = static_cast<int>(x.size(2));
auto stream = at::cuda::getCurrentCUDAStream().stream();
x = x.contiguous(); // If using koi, make sure x is NTC order in memory
const int num_rows = N * T;
#endif

auto run_norm = [&](RMSNorm norm, const at::Tensor &in) {
#if DORADO_CUDA_BUILD
int res = host_fused_residual_rmsnorm_f16(stream, C, num_rows, in.data_ptr(), x.data_ptr(),
int res = KOI_NOT_SUPPORTED;
if (x.is_cuda()) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
res = host_fused_residual_rmsnorm_f16(stream, C, num_rows, in.data_ptr(), x.data_ptr(),
deepnorm_alpha.data_ptr(),
norm->weight.data_ptr(), x.data_ptr());
if (res != KOI_SUCCESS && res != KOI_NOT_SUPPORTED) {
throw std::runtime_error("Koi error during layer norm");
} else if (res == KOI_NOT_SUPPORTED)
if (res != KOI_SUCCESS && res != KOI_NOT_SUPPORTED) {
throw std::runtime_error("Koi error during layer norm");
}
}
if (res == KOI_NOT_SUPPORTED)
#endif
{
x = norm(in + (x * deepnorm_alpha));
Expand Down

0 comments on commit b8df2ce

Please sign in to comment.