Skip to content

Commit

Permalink
Merge pull request #36 from tomtung/limit-candidate-per-leaf
Browse files Browse the repository at this point in the history
Limit number of label candidates per leaf for prediction
  • Loading branch information
tomtung authored Dec 6, 2021
2 parents 94ceee9 + 58ff8c0 commit 8ca8e03
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 4 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ serde_cbor = "0.11.*"
serde_json = "1.0.*"
simple_logger = { version = "1.15.*", features = ["stderr"], optional = true }
sprs = { version = "0.9.*", features = ["serde"] }
pdqselect = "0.1.*"

[dev-dependencies]
assert_approx_eq = "1.1.*"
Expand All @@ -47,4 +48,4 @@ cli = ["simple_logger", "clap"]

[profile.release]
lto = true
codegen-units = 1
codegen-units = 1
7 changes: 7 additions & 0 deletions c-api/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 13 additions & 3 deletions src/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ impl TreeNode {

swap(&mut curr_level, &mut next_level);
if curr_level.len() > beam_size {
curr_level.sort_unstable_by_key(|&(_, score)| Reverse(NotNan::new(score).unwrap()));
pdqselect::select_by_key(curr_level.as_mut_slice(), beam_size, |&(_, score)| {
Reverse(NotNan::new(score).unwrap())
});
curr_level.truncate(beam_size);
}
}
Expand All @@ -336,11 +338,19 @@ impl TreeNode {
let mut label_scores =
liblinear::predict(weights, classifier_loss_type, feature_vec);
label_scores.mapv_inplace(|v| (v + leaf_score).exp());
labels

let mut label_score_pairs = labels
.iter()
.cloned()
.zip_eq(label_scores.into_iter().cloned())
.collect_vec()
.collect_vec();
pdqselect::select_by_key(
label_score_pairs.as_mut_slice(),
beam_size,
|&(_, score)| Reverse(NotNan::new(score).unwrap()),
);
label_score_pairs.truncate(beam_size);
label_score_pairs
}
_ => unreachable!(),
})
Expand Down

0 comments on commit 8ca8e03

Please sign in to comment.