Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dks integration #67

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions libs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,23 @@ if (USE_LIBPLL_CMAKE)
${PLLMOD_SRC}/tree/pll_tree.h
${PLLMOD_SRC}/util/pllmod_util.h
${LIBPLL_SRC}/pll.h
${PLLMOD_SRC}/dks/benchmark.h
${PLLMOD_SRC}/dks/dks.h
${PLLMOD_SRC}/dks/model.h
${PLLMOD_SRC}/dks/partition.h
${PLLMOD_SRC}/dks/test_case.h
${PLLMOD_SRC}/dks/tree.h
DESTINATION ${RAXML_LOCALDEPS_DIR}/include/libpll)

# file(COPY ${PLLMOD_SRC}/dks/benchmark.h
# ${PLLMOD_SRC}/dks/dks.h
# ${PLLMOD_SRC}/dks/model.h
# ${PLLMOD_SRC}/dks/msa.h
# ${PLLMOD_SRC}/dks/partition.h
# ${PLLMOD_SRC}/dks/test_case.h
# ${PLLMOD_SRC}/dks/tree.h
# DESTINATION ${RAXML_LOCALDEPS_DIR}/include/libpll/dks)

#target_include_directories(raxml_module PRIVATE ${RAXML_LIBPLL_HEADERS})

else()
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ else()
${RAXML_LOCALDEPS_DIR}/lib/libpll_msa.a
${RAXML_LOCALDEPS_DIR}/lib/libpll_tree.a
${RAXML_LOCALDEPS_DIR}/lib/libpll_util.a
${RAXML_LOCALDEPS_DIR}/lib/libpll_dks.a
${RAXML_LOCALDEPS_DIR}/lib/libpll.a
)
endif()
Expand Down
7 changes: 7 additions & 0 deletions src/CommandLineParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ static struct option long_options[] =
{"rf", optional_argument, 0, 0 }, /* 51 */
{"consense", optional_argument, 0, 0 }, /* 52 */
{"ancestral", optional_argument, 0, 0 }, /* 53 */
{"nodks", no_argument, 0, 0 }, /* 54 */

{ 0, 0, 0, 0 }
};
Expand Down Expand Up @@ -463,6 +464,7 @@ void CommandLineParser::parse_options(int argc, char** argv, Options &opts)
}
break;
case 20: /* SIMD instruction set */
opts.simd_set = true;
if (strcasecmp(optarg, "none") == 0 || strcasecmp(optarg, "scalar") == 0)
{
opts.simd_arch = PLL_ATTRIB_ARCH_CPU;
Expand Down Expand Up @@ -834,6 +836,11 @@ void CommandLineParser::parse_options(int argc, char** argv, Options &opts)
num_commands++;
break;

case 54:
opts.dks_off = true;
break;


default:
throw OptionException("Internal error in option parsing");
}
Expand Down
6 changes: 4 additions & 2 deletions src/Options.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class Options
bootstop_interval(RAXML_BOOTSTOP_INTERVAL), bootstop_permutations(RAXML_BOOTSTOP_PERMUTES),
tbe_naive(false), consense_cutoff(ConsenseCutoff::MR),
tree_file(""), constraint_tree_file(""), msa_file(""), model_file(""), outfile_prefix(""),
num_threads(1), num_ranks(1), simd_arch(PLL_ATTRIB_ARCH_CPU), thread_pinning(false),
load_balance_method(LoadBalancing::benoit)
num_threads(1), num_ranks(1), simd_set(false), simd_arch(PLL_ATTRIB_ARCH_CPU), thread_pinning(false),
load_balance_method(LoadBalancing::benoit), dks_off(false)
{};

~Options() = default;
Expand Down Expand Up @@ -111,9 +111,11 @@ class Options
/* parallelization stuff */
unsigned int num_threads; /* number of threads */
unsigned int num_ranks; /* number of MPI ranks */
bool simd_set; /* did the user specify the simd flag*/
unsigned int simd_arch; /* vector instruction set */
bool thread_pinning; /* pin threads to cores */
LoadBalancing load_balance_method;
bool dks_off;

std::string simd_arch_name() const;
std::string consense_type_name() const;
Expand Down
233 changes: 196 additions & 37 deletions src/TreeInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,55 +571,214 @@ void set_partition_tips(const Options& opts, const MSA& msa, const IDVector& tip
pll_set_pattern_weights(partition, comp_weights.data());
}

pll_partition_t* create_pll_partition(const Options& opts, const PartitionInfo& pinfo,
const IDVector& tip_msa_idmap,
const PartitionRange& part_region, const uintVector& weights)
{
const MSA& msa = pinfo.msa();
const Model& model = pinfo.model();
std::vector<std::vector<char>> create_trimmed_msa(const MSA &msa,
const uintVector &weights,
const IDVector& tip_msa_idmap,
size_t start, size_t length,
size_t trimmed_length) {
std::vector<std::vector<char>> trimmed_msa;
trimmed_msa.reserve(msa.size());

for (size_t tip_id = 0; tip_id < msa.size(); ++tip_id) {
trimmed_msa.emplace_back(trimmed_length);
auto seq_id = tip_msa_idmap.empty() ? tip_id : tip_msa_idmap[tip_id];
const char *full_seq = msa.at(seq_id).c_str();
size_t pos = 0;
for (size_t j = start; j < start + length; ++j) {
if (weights[j] > 0)
trimmed_msa[tip_id][pos++] = full_seq[j];
}
}

/* part_length doesn't include columns with zero weight */
const size_t part_length = weights.empty() ? part_region.length :
std::count_if(weights.begin() + part_region.start,
weights.begin() + part_region.start + part_region.length,
[](uintVector::value_type w) -> bool
{ return w > 0; }
);
return trimmed_msa;
}

unsigned int attrs = opts.simd_arch;
std::vector<std::vector<double>>
create_trimmed_prob_msa(const MSA &msa, const IDVector &tip_msa_idmap,
size_t start, size_t length, size_t trimmed_length,
size_t states) {

if (opts.use_rate_scalers && model.num_ratecats() > 1)
{
attrs |= PLL_ATTRIB_RATE_SCALERS;
std::vector<std::vector<double>> trimmed_msa;
trimmed_msa.reserve(msa.size());

auto normalize = !msa.normalized();
auto weights_iter = msa.weights().cbegin() + start;

auto clv_size = trimmed_length * states;
for (size_t tip_id = 0; tip_id < msa.size(); ++tip_id) {
trimmed_msa.emplace_back(clv_size);

auto seq_id = tip_msa_idmap.empty() ? tip_id : tip_msa_idmap[tip_id];
auto probs = msa.probs(seq_id, start);

auto clvp = trimmed_msa[tip_id].begin();
for (size_t i = 0; i < length; ++i) {
if (weights_iter[i] > 0) {
double sum = 0.;
for (size_t j = 0; j < states; ++j)
sum += probs[j];

for (size_t j = 0; j < states; ++j) {
if (sum > 0.)
clvp[j] = normalize ? probs[j] / sum : probs[j];
else
clvp[j] = 1.0;
}

clvp += states;
}

/* NB: clv has to be padded, but msa arrays are not! */
probs += states;
}
}

if (opts.use_repeats)
{
assert(!(opts.use_prob_msa));
attrs |= PLL_ATTRIB_SITE_REPEATS;
return trimmed_msa;
}

size_t calculate_part_length(const uintVector &weights,
const PartitionRange &part_region) {
return weights.empty()
? part_region.length
: std::count_if(
weights.begin() + part_region.start,
weights.begin() + part_region.start + part_region.length,
[](uintVector::value_type w) -> bool { return w > 0; });
}

uintVector create_trimmed_weights(const uintVector &weights,
size_t part_length) {
uintVector trimmed_weights;
trimmed_weights.reserve(part_length);
for (size_t i = 0; i < weights.size(); i++) {
if (weights[i] > 0)
trimmed_weights.push_back(weights[i]);
}
else if (opts.use_tip_inner)
{
assert(!(opts.use_prob_msa));
// 1) SSE3 tip-inner kernels are not implemented so far, so generic version will be faster
// 2) same for state-rich models
if (opts.simd_arch != PLL_ATTRIB_ARCH_SSE && model.num_states() <= 20)
{
// TODO: use proper auto-tuning
const unsigned long min_len_ti = model.num_states() > 4 ? 40 : 100;
if ((unsigned long) part_length > min_len_ti)
attrs |= PLL_ATTRIB_PATTERN_TIP;
return trimmed_weights;
}

unsigned int create_pll_partition_attrs(const Options &opts, const MSA &msa,
const Model &model, size_t part_length,
const IDVector &tip_msa_idmap,
const PartitionRange &part_region,
const uintVector &weights) {
unsigned int attrs = 0;

if (!opts.dks_off) {
LOG_INFO_TS << "Starting DKS" << std::endl;
if (opts.simd_set) {
attrs |= opts.simd_arch;
}
if (opts.use_rate_scalers && model.num_ratecats() > 1) {
attrs |= PLL_ATTRIB_RATE_SCALERS;
}
if (opts.use_repeats) {
assert(!(opts.use_prob_msa));
attrs |= PLL_ATTRIB_SITE_REPEATS;
} else if (opts.use_tip_inner) {
attrs |= PLL_ATTRIB_PATTERN_TIP;
}

dks::attributes_generator_t gen;
gen.enable(attrs);

// disable invalid cpu flags

unsigned int off_flags = 0;

if (!PLL_STAT(sse3_present))
off_flags |= PLL_ATTRIB_ARCH_SSE;
if (!PLL_STAT(avx_present))
off_flags |= PLL_ATTRIB_ARCH_AVX;
if (!PLL_STAT(avx2_present))
off_flags |= PLL_ATTRIB_ARCH_AVX2;

// for now, unconditionally disable avx512

off_flags |= PLL_ATTRIB_ARCH_AVX512;
gen.disable(off_flags);

auto trimmed_weights = create_trimmed_weights(weights, part_length);
if (opts.use_prob_msa && msa.probabilistic()) {
auto trimmed_msa = create_trimmed_prob_msa(
msa, tip_msa_idmap, part_region.start, part_region.length,
part_length, msa.states());
attrs =
dks::select_kernel_auto(trimmed_msa, trimmed_weights,
model.num_states(), model.num_states(), gen);
} else {
auto trimmed_msa =
create_trimmed_msa(msa, weights, tip_msa_idmap, part_region.start,
part_region.length, part_length);
attrs =
dks::select_kernel_auto(trimmed_msa, trimmed_weights, model.charmap(),
model.num_states(), model.num_states(), gen);
}
LOG_INFO_TS << "DKS Finshed" << std::endl;
} else {
if (opts.use_repeats) {
assert(!(opts.use_prob_msa));
attrs |= PLL_ATTRIB_SITE_REPEATS;
} else if (opts.use_tip_inner) {
assert(!(opts.use_prob_msa));
// 1) SSE3 tip-inner kernels are not implemented so far, so generic
// version will be faster 2) same for state-rich models
if (opts.simd_arch != PLL_ATTRIB_ARCH_SSE && model.num_states() <= 20) {
const unsigned long min_len_ti = model.num_states() > 4 ? 40 : 100;
if ((unsigned long)part_length > min_len_ti)
attrs |= PLL_ATTRIB_PATTERN_TIP;
}
}
}

// NOTE: if partition is split among multiple threads, asc. bias correction must be applied only once!
// NOTE: if partition is split among multiple threads, asc. bias correction
// must be applied only once!
if (model.ascbias_type() == AscBiasCorrection::lewis ||
(model.ascbias_type() != AscBiasCorrection::none && part_region.master()))
{
attrs |= PLL_ATTRIB_AB_FLAG;
attrs |= (unsigned int) model.ascbias_type();
(model.ascbias_type() != AscBiasCorrection::none &&
part_region.master())) {
attrs |= PLL_ATTRIB_AB_FLAG;
attrs |= (unsigned int)model.ascbias_type();
}
return attrs;
}

std::vector<unsigned int>
create_partition_attr_list(const Options &opts, const PartitionedMSA &msa,
const IDVector &tip_msa_idmap,
const PartitionAssignment& part_assign,
const std::vector<uintVector> &site_weights) {

std::vector<unsigned int> attr_list;
attr_list.reserve(msa.part_count());

for (size_t i = 0; i < msa.part_count(); i++) {
const PartitionInfo& pinfo = msa.part_info(i);
auto part_region = *(part_assign.find(i));
auto weights = site_weights.empty() ? pinfo.msa().weights() : site_weights.at(i);
attr_list.push_back(create_pll_partition_attrs(
opts,
pinfo.msa(),
pinfo.model(),
calculate_part_length(weights, part_region),
tip_msa_idmap,
part_region,
weights));

}
return attr_list;
}

pll_partition_t* create_pll_partition(const Options& opts, const PartitionInfo& pinfo,
const IDVector& tip_msa_idmap,
const PartitionRange& part_region, const uintVector& weights)
{
const MSA& msa = pinfo.msa();
const Model& model = pinfo.model();

/* part_length doesn't include columns with zero weight */
const size_t part_length = calculate_part_length(weights, part_region);

unsigned int attrs = create_pll_partition_attrs(opts, msa, model, part_length, tip_msa_idmap, part_region, weights);
BasicTree tree(msa.size());
pll_partition_t * partition = pll_partition_create(
tree.num_tips(), /* number of tip sequences */
Expand Down
1 change: 1 addition & 0 deletions src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ extern "C" {
#include <libpll/pllmod_util.h>
#include <libpll/pllmod_algorithm.h>
}
#include <libpll/dks.h>

#include "types.hpp"
#include "constants.hpp"
Expand Down
1 change: 0 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2202,7 +2202,6 @@ void master_main(RaxmlInstance& instance, CheckpointManager& cm)
load_parted_msa(instance);
assert(instance.parted_msa);
auto& parted_msa = *instance.parted_msa;

load_constraint(instance);

check_options(instance);
Expand Down