From 58ff8c098632ff3f2a6fbfc22a27c6e9cfddcefd Mon Sep 17 00:00:00 2001 From: "Yubing Dong (Tom)" Date: Sun, 5 Dec 2021 21:51:36 -0800 Subject: [PATCH] Limit # of label candidates per leaf for prediction Also uses quick-select during beam search, instead of sorting the entire list. --- Cargo.lock | 7 +++++++ Cargo.toml | 3 ++- c-api/Cargo.lock | 7 +++++++ src/model/mod.rs | 16 +++++++++++++--- 4 files changed, 29 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fddd5ad..fbdf2ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -325,6 +325,7 @@ dependencies = [ "order-stat", "ordered-float", "pbr", + "pdqselect", "rand", "rayon", "serde", @@ -367,6 +368,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "pdqselect" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7778906d9321dd56cde1d1ffa69a73e59dcf5fda6d366f62727adf2bd4193aee" + [[package]] name = "ppv-lite86" version = "0.2.15" diff --git a/Cargo.toml b/Cargo.toml index 0f53dcb..ad004fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ serde_cbor = "0.11.*" serde_json = "1.0.*" simple_logger = { version = "1.15.*", features = ["stderr"], optional = true } sprs = { version = "0.9.*", features = ["serde"] } +pdqselect = "0.1.*" [dev-dependencies] assert_approx_eq = "1.1.*" @@ -47,4 +48,4 @@ cli = ["simple_logger", "clap"] [profile.release] lto = true -codegen-units = 1 \ No newline at end of file +codegen-units = 1 diff --git a/c-api/Cargo.lock b/c-api/Cargo.lock index e7f6e32..b2c08a8 100644 --- a/c-api/Cargo.lock +++ b/c-api/Cargo.lock @@ -355,6 +355,7 @@ dependencies = [ "order-stat", "ordered-float", "pbr", + "pdqselect", "rand", "rayon", "serde", @@ -407,6 +408,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "pdqselect" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7778906d9321dd56cde1d1ffa69a73e59dcf5fda6d366f62727adf2bd4193aee" + [[package]] name = "ppv-lite86" version = "0.2.15" diff --git a/src/model/mod.rs b/src/model/mod.rs index 889a45a..158ac84 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -324,7 +324,9 @@ impl TreeNode { swap(&mut curr_level, &mut next_level); if curr_level.len() > beam_size { - curr_level.sort_unstable_by_key(|&(_, score)| Reverse(NotNan::new(score).unwrap())); + pdqselect::select_by_key(curr_level.as_mut_slice(), beam_size, |&(_, score)| { + Reverse(NotNan::new(score).unwrap()) + }); curr_level.truncate(beam_size); } } @@ -336,11 +338,19 @@ impl TreeNode { let mut label_scores = liblinear::predict(weights, classifier_loss_type, feature_vec); label_scores.mapv_inplace(|v| (v + leaf_score).exp()); - labels + + let mut label_score_pairs = labels .iter() .cloned() .zip_eq(label_scores.into_iter().cloned()) - .collect_vec() + .collect_vec(); + pdqselect::select_by_key( + label_score_pairs.as_mut_slice(), + beam_size, + |&(_, score)| Reverse(NotNan::new(score).unwrap()), + ); + label_score_pairs.truncate(beam_size); + label_score_pairs } _ => unreachable!(), })