Skip to content

Commit

Permalink
Merge branch 'tx-koi-attention' into 'master'
Browse files Browse the repository at this point in the history
TxModel: Enable use of koi masked attention kernel.

See merge request machine-learning/dorado!1016
  • Loading branch information
tijyojwad committed May 21, 2024
2 parents 97363b0 + 461d4ba commit 5e73b55
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 16 deletions.
2 changes: 1 addition & 1 deletion cmake/Koi.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ endfunction()

if(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR WIN32)

set(KOI_VERSION 0.4.7)
set(KOI_VERSION 0.4.8)
if(BUILD_KOI_FROM_SOURCE)
set(KOI_DIR "${DORADO_3RD_PARTY_SOURCE}/koi")
if(NOT EXISTS ${KOI_DIR})
Expand Down
10 changes: 4 additions & 6 deletions dorado/basecall/BasecallerParams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,18 @@ void BasecallerParams::update(const BasecallerParams &other) {
merge(m_batch_size, other.m_batch_size, "batchsize");
}

void BasecallerParams::normalise(size_t divisor) {
const int div = static_cast<int>(divisor);

void BasecallerParams::normalise(size_t chunk_size_divisor, size_t overlap_divisor) {
// Apply normalised value with FORCE
auto normalise_param = [&, div](Value &self, const std::string &name) {
auto normalise_param = [&](Value &self, const std::string &name, int div) {
const int before_val = self.val;
const int new_val = (self.val / div) * div;
if (set_value(self, Value{new_val, Priority::FORCE})) {
spdlog::info("Normalised: {} {} -> {}", name, before_val, new_val);
}
};

normalise_param(m_chunk_size, "chunksize");
normalise_param(m_overlap, "overlap");
normalise_param(m_chunk_size, "chunksize", static_cast<int>(chunk_size_divisor));
normalise_param(m_overlap, "overlap", static_cast<int>(overlap_divisor));
}

} // namespace dorado::basecall
2 changes: 1 addition & 1 deletion dorado/basecall/BasecallerParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class BasecallerParams {
void update(const BasecallerParams& other);

// Normalise the `chunk_size` and `overlap` to the nearest evenly divisible integer of divisor (stride)
void normalise(size_t divisor);
void normalise(size_t chunk_size_divisor, size_t overlap_divisor);

std::string to_string() const {
std::string str = "BasecallerParams {";
Expand Down
4 changes: 3 additions & 1 deletion dorado/basecall/CRFModelConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ struct CRFModelConfig {
int stride_inner() const { return stride * scale_factor(); };

// Normalise the basecaller parameters `chunk_size` and `overlap` to the `strde_inner`
void normalise_basecaller_params() { basecaller.normalise(stride_inner()); }
void normalise_basecaller_params() {
basecaller.normalise(stride_inner() * (is_tx_model() ? 16 : 1), stride_inner());
}
// True if `chunk_size` and `overlap` is evenly divisible by the `strde_inner`
bool has_normalised_basecaller_params() const;

Expand Down
3 changes: 2 additions & 1 deletion dorado/basecall/CudaCaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,8 @@ void CudaCaller::determine_batch_dims(float memory_limit_fraction,
// trade-off between getting more accurate measurements and avoiding excessive startup time.
const int chunk_size = std::min(m_batch_dims.back().T_in, m_config.stride * 300);
// We limit the maximum when doing benchmarking to avoid excessive startup time.
const int max_batch_size_limit = 10240;
// The limit for transformer models should be increased at a later time.
const int max_batch_size_limit = m_config.is_tx_model() ? 512 : 10240;
int max_batch_size = *std::max_element(max_batch_sizes.begin(), max_batch_sizes.end());
max_batch_size = std::min(max_batch_size, max_batch_size_limit);
spdlog::debug("Auto batchsize {}: testing up to {} in steps of {}", m_device, max_batch_size,
Expand Down
20 changes: 18 additions & 2 deletions dorado/basecall/nn/TxModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ at::Tensor MultiHeadAttentionImpl::forward(at::Tensor x) {
stream, static_cast<int>(N), static_cast<int>(T), nhead, head_dim,
rotary_emb->theta, qkv.data_ptr(), out.data_ptr());
if (res != KOI_SUCCESS) {
throw std::runtime_error("Koi windowed attention failed.");
throw std::runtime_error("Koi rotary embedding failed.");
}
qkv = out;
} else
Expand All @@ -308,10 +308,26 @@ at::Tensor MultiHeadAttentionImpl::forward(at::Tensor x) {
qkv = rotary_emb(qkv);
}
}
attn_output_ntc = at::empty({N, T, C}, x.options());
#if DORADO_CUDA_BUILD && !defined(DORADO_TX2)
int res = KOI_NOT_SUPPORTED;
if (utils::get_dev_opt<bool>("use_koi_attention", true) && koi_can_use_cutlass()) {
utils::ScopedProfileRange spr("KOI_MEA", 3);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const auto [win_upper, win_lower] = attn_window;
res = host_masked_attention_f16(stream, static_cast<int>(N), static_cast<int>(T), nhead,
head_dim, win_upper, win_lower, qkv[0].data_ptr(),
qkv[1].data_ptr(), qkv[2].data_ptr(),
attn_output_ntc.data_ptr());
if (res != KOI_SUCCESS && res != KOI_NOT_SUPPORTED) {
throw std::runtime_error("Koi windowed attention failed.");
}
}
if (res == KOI_NOT_SUPPORTED)
#endif
{
utils::ScopedProfileRange spr("MEA", 3);
auto attn_window_mask = get_attn_window_mask(T);
attn_output_ntc = at::empty({N, T, C}, x.options());
auto attn_output = attn_output_ntc.view({N, T, nhead, head_dim}).transpose(1, 2);
const auto [win_upper, win_lower] = attn_window;
for (int i = 0; i < num_splits; ++i) {
Expand Down
8 changes: 4 additions & 4 deletions tests/CRFModelConfigTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ TEST_CASE(CUT_TAG ": test normalise BasecallerParams", CUT_TAG) {
fs::path(get_data_dir("model_configs/[email protected]"));
CRFModelConfig config = load_crf_model_config(path);

// Set chunksize to (12 * 10) + 1 to ensure it's not mod12
config.basecaller.set_chunk_size(121);
// Set chunksize to (12 * 16 * 10) + 1 to ensure it's not mod192
config.basecaller.set_chunk_size(1921);
CHECK_FALSE(config.has_normalised_basecaller_params());

config.normalise_basecaller_params();
CHECK(config.has_normalised_basecaller_params());
CHECK(config.basecaller.chunk_size() % config.stride_inner() == 0);
// Expected (121 / 12) * 12
CHECK(config.basecaller.chunk_size() == 120);
// Expected (1921 / 192) * 192
CHECK(config.basecaller.chunk_size() == 1920);
}
}

Expand Down

0 comments on commit 5e73b55

Please sign in to comment.