diff --git a/dorado/basecall/CudaCaller.cpp b/dorado/basecall/CudaCaller.cpp index c321cd9b..af637edc 100644 --- a/dorado/basecall/CudaCaller.cpp +++ b/dorado/basecall/CudaCaller.cpp @@ -19,6 +19,7 @@ #include #include #include +#include using namespace std::chrono_literals; @@ -228,12 +229,17 @@ void CudaCaller::determine_batch_dims(const BasecallerCreationParams ¶ms) { c10::cuda::CUDACachingAllocator::emptyCache(); int64_t available = utils::available_memory(m_options.device()); spdlog::debug("{} memory available: {:.2f}GB", m_device, available / GB); - const int scale_factor = m_config.scale_factor(); const int granularity = get_batch_size_granularity(m_config); + const int cs_granularity = static_cast(m_config.chunk_size_granularity()); { // First set of batch dimensions. Adjust chunk size to be a multiple of stride_inner. // Batch size defaults to `granularity` but will be increased further down if memory allows. - int T_out = (requested_chunk_size / m_config.stride_inner()) * scale_factor; + int T_out = (requested_chunk_size / m_config.stride / cs_granularity) * cs_granularity; + if (T_out < 1) { + spdlog::error("Cannot determine batch dims for chunksize of {} - too small", + requested_chunk_size); + throw std::runtime_error("chunksize too small"); + } m_batch_dims.push_back({granularity, T_out * m_config.stride, T_out}); } #ifdef DORADO_TX2 @@ -259,8 +265,9 @@ void CudaCaller::determine_batch_dims(const BasecallerCreationParams ¶ms) { constexpr char SEPARATOR = ';'; std::string env_string(env_extra_chunk_sizes); for (size_t start = 0, end = 0; end != std::string::npos; start = end + 1) { - int T_out = (std::atoi(env_string.c_str() + start) / m_config.stride_inner()) * - scale_factor; + int T_out = + (std::atoi(env_string.c_str() + start) / m_config.stride / cs_granularity) * + cs_granularity; if (T_out > 0) { m_batch_dims.push_back({granularity, T_out * m_config.stride, T_out}); } @@ -271,7 +278,14 @@ void CudaCaller::determine_batch_dims(const BasecallerCreationParams ¶ms) { // TODO: determine the best set of chunk sizes for (float fraction : {0.5f}) { // First chunk is already divided by stride - int T_out = int(m_batch_dims[0].T_out * fraction / scale_factor) * scale_factor; + int T_out = int(m_batch_dims[0].T_out * fraction / cs_granularity) * cs_granularity; + if (T_out < 1) { + spdlog::error( + "Cannot determine batch dims for partial chunksize of {}*{} - too " + "small", + m_batch_dims[0].T_out, fraction); + throw std::runtime_error("partial chunksize too small"); + } m_batch_dims.push_back({granularity, T_out * m_config.stride, T_out}); } }