diff --git a/src/main.cpp b/src/main.cpp index 39bab77f..5142d6a9 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -922,27 +922,32 @@ void autotune_threads(RaxmlInstance& instance) max_workers = std::min(max_workers, (size_t) opts.num_workers_max); if (opts.num_workers == 0) { - auto rank_threads = opts.num_threads > 0 ? opts.num_threads : opts.num_threads_max; - auto opt_workers = rank_threads / res.num_threads_throughput; - opt_workers *= num_ranks; - opts.num_workers = std::min(opt_workers, max_workers); + if (max_workers > 1) + { + auto rank_threads = opts.num_threads > 0 ? opts.num_threads : opts.num_threads_max; + auto opt_workers = std::max(rank_threads / res.num_threads_throughput, 1ul); + opt_workers *= num_ranks; + opts.num_workers = std::min(opt_workers, max_workers); - while (num_ranks*opts.num_threads % opts.num_workers != 0) - opts.num_workers--; + while (num_ranks*opts.num_threads % opts.num_workers != 0) + opts.num_workers--; - /* make sure we have integer number of workers per rank */ - opts.num_workers -= opts.num_workers % num_ranks; + /* make sure we have integer number of workers per rank */ + opts.num_workers -= opts.num_workers % num_ranks; - opts.num_workers = std::max(opts.num_workers, 1u); + opts.num_workers = std::max(opts.num_workers, 1u); - /* workers spanning multiple MPI ranks are not supported atm -> check for this */ - if (opts.num_workers > 1 && opts.num_workers < num_ranks) - { - if (num_ranks <= max_workers) - opts.num_workers = num_ranks; - else - opts.num_workers = 1; + /* workers spanning multiple MPI ranks are not supported atm -> check for this */ + if (opts.num_workers > 1 && opts.num_workers < num_ranks) + { + if (num_ranks <= max_workers) + opts.num_workers = num_ranks; + else + opts.num_workers = 1; + } } + else + opts.num_workers = 1; } assert(opts.num_workers > 0);