Skip to content

Commit

Permalink
(ml5717) Untested impl of more robust index sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
juntyr committed Dec 17, 2021
1 parent 98dc48e commit fcd8e45
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 36 deletions.
130 changes: 96 additions & 34 deletions necsim/core/src/cogs/rng.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use core::{
convert::AsMut,
default::Default,
num::{NonZeroU128, NonZeroU32, NonZeroUsize},
num::{NonZeroU128, NonZeroU32, NonZeroU64, NonZeroUsize},
ptr::copy_nonoverlapping,
};

Expand Down Expand Up @@ -41,6 +41,7 @@ pub trait SeedableRng<M: MathsCore>: RngCore<M> {
const INC: u64 = 11_634_580_027_462_260_723_u64;

let mut seed = Self::Seed::default();

for chunk in seed.as_mut().chunks_mut(4) {
// We advance the state first (to get away from the input value,
// in case it has low Hamming Weight).
Expand Down Expand Up @@ -96,51 +97,112 @@ pub trait RngSampler<M: MathsCore>: RngCore<M> {
#[inline]
#[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
fn sample_index(&mut self, length: NonZeroUsize) -> usize {
// attributes on expressions are experimental
// see https://github.com/rust-lang/rust/issues/15701
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let index =
M::floor(self.sample_uniform_closed_open().get() * (length.get() as f64)) as usize;
// Safety in case of f64 rounding errors
index.min(length.get() - 1)
#[cfg(target_pointer_width = "32")]
#[allow(clippy::cast_possible_truncation)]
{
self.sample_index_u32(unsafe { NonZeroU32::new_unchecked(length.get() as u32) })
as usize
}
#[cfg(target_pointer_width = "64")]
#[allow(clippy::cast_possible_truncation)]
{
self.sample_index_u64(unsafe { NonZeroU64::new_unchecked(length.get() as u64) })
as usize
}
}

#[must_use]
#[inline]
#[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
fn sample_index_u32(&mut self, length: NonZeroU32) -> u32 {
// attributes on expressions are experimental
// see https://github.com/rust-lang/rust/issues/15701
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let index =
M::floor(self.sample_uniform_closed_open().get() * f64::from(length.get())) as u32;
// Safety in case of f64 rounding errors
index.min(length.get() - 1)
// TODO: Check if delegation to `sample_index_u64` is faster

// Adapted from:
// https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single

const LOWER_MASK: u64 = !0 >> 32;

// Conservative approximation of the acceptance zone
let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1);

loop {
let raw = self.sample_u64();

let sample_check_lo = (raw & LOWER_MASK) * u64::from(length.get());

#[allow(clippy::cast_possible_truncation)]
if (sample_check_lo as u32) <= acceptance_zone {
return (sample_check_lo >> 32) as u32;
}

let sample_check_hi = (raw >> 32) * u64::from(length.get());

#[allow(clippy::cast_possible_truncation)]
if (sample_check_hi as u32) <= acceptance_zone {
return (sample_check_hi >> 32) as u32;
}
}
}

#[must_use]
#[inline]
#[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
fn sample_index_u64(&mut self, length: NonZeroU64) -> u64 {
// Adapted from:
// https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single

// Conservative approximation of the acceptance zone
let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1);

loop {
let raw = self.sample_u64();

let sample_check = u128::from(raw) * u128::from(length.get());

#[allow(clippy::cast_possible_truncation)]
if (sample_check as u64) <= acceptance_zone {
return (sample_check >> 64) as u64;
}
}
}

#[must_use]
#[inline]
#[debug_ensures(ret < length.get(), "samples U(0, length - 1)")]
fn sample_index_u128(&mut self, length: NonZeroU128) -> u128 {
// attributes on expressions are experimental
// see https://github.com/rust-lang/rust/issues/15701
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let index =
M::floor(self.sample_uniform_closed_open().get() * (length.get() as f64)) as u128;
// Safety in case of f64 rounding errors
index.min(length.get() - 1)
// Adapted from:
// https://docs.rs/rand/0.8.4/rand/distributions/uniform/trait.UniformSampler.html#method.sample_single

const LOWER_MASK: u128 = !0 >> 64;

// Conservative approximation of the acceptance zone
let acceptance_zone = (length.get() << length.leading_zeros()).wrapping_sub(1);

loop {
let raw_hi = u128::from(self.sample_u64());
let raw_lo = u128::from(self.sample_u64());

// 256-bit multiplication (hi, lo) = (raw_hi, raw_lo) * length
let mut low = raw_lo * (length.get() & LOWER_MASK);
let mut t = low >> 64;
low &= LOWER_MASK;
t += raw_hi * (length.get() & LOWER_MASK);
low += (t & LOWER_MASK) << 64;
let mut high = t >> 64;
t = low >> 64;
low &= LOWER_MASK;
t += (length.get() >> 64) * raw_lo;
low += (t & LOWER_MASK) << 64;
high += t >> 64;
high += raw_hi * (length.get() >> 64);

let sample = high;
let check = low;

if check <= acceptance_zone {
return sample;
}
}
}

#[must_use]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use alloc::{vec, vec::Vec};
use core::{
cmp::Ordering,
convert::TryFrom,
fmt,
hash::Hash,
num::{NonZeroU128, NonZeroUsize},
num::{NonZeroU128, NonZeroU64, NonZeroUsize},
};
use fnv::FnvBuildHasher;

Expand Down Expand Up @@ -191,6 +192,8 @@ impl<E: Eq + Hash + Clone> DynamicAliasMethodIndexedSampler<E> {
if let Some(total_weight) = NonZeroU128::new(self.total_weight) {
let cdf_sample = if let [_group] = &self.groups[..] {
0_u128
} else if let Ok(total_weight) = NonZeroU64::try_from(total_weight) {
u128::from(rng.sample_index_u64(total_weight))
} else {
rng.sample_index_u128(total_weight)
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use alloc::{vec, vec::Vec};
use core::{
cmp::Ordering,
convert::TryFrom,
fmt,
hash::Hash,
num::{NonZeroU128, NonZeroUsize},
num::{NonZeroU128, NonZeroU64, NonZeroUsize},
};

use necsim_core::cogs::{MathsCore, RngCore, RngSampler};
Expand Down Expand Up @@ -125,6 +126,8 @@ impl<E: Eq + Hash + Clone> DynamicAliasMethodStackSampler<E> {
if let Some(total_weight) = NonZeroU128::new(self.total_weight) {
let cdf_sample = if let [_group] = &self.groups[..] {
0_u128
} else if let Ok(total_weight) = NonZeroU64::try_from(total_weight) {
u128::from(rng.sample_index_u64(total_weight))
} else {
rng.sample_index_u128(total_weight)
};
Expand Down

0 comments on commit fcd8e45

Please sign in to comment.