From c8e4fe1f6367cab3a96ca026a9ad577fac433e3b Mon Sep 17 00:00:00 2001 From: Schuyler Cebulskie Date: Sat, 2 Mar 2024 20:04:29 -0500 Subject: [PATCH] Allow creation and rolling of dice with zero sides --- benches/dice.rs | 20 ++------------ src/dice.rs | 68 ++++++++++++++++++---------------------------- src/parse.rs | 12 ++++---- src/term.rs | 2 +- src/tests/dice.rs | 52 ++++++++--------------------------- src/tests/parse.rs | 42 ++++++---------------------- 6 files changed, 55 insertions(+), 141 deletions(-) diff --git a/benches/dice.rs b/benches/dice.rs index 592da0b..49c2c16 100644 --- a/benches/dice.rs +++ b/benches/dice.rs @@ -2,8 +2,6 @@ extern crate test; -use std::num::NonZeroU8; - use test::Bencher; use dicey::{ @@ -19,11 +17,7 @@ fn roll_4d8(b: &mut Bencher) { #[bench] fn roll_8d6x(b: &mut Bencher) { - let dice = Dice::builder() - .count(8) - .sides(NonZeroU8::new(6).unwrap()) - .explode(None, true) - .build(); + let dice = Dice::builder().count(8).sides(6).explode(None, true).build(); b.iter(|| dice.roll().unwrap()); } @@ -41,11 +35,7 @@ fn roll_and_total_4d8(b: &mut Bencher) { #[bench] fn roll_and_total_8d6x(b: &mut Bencher) { - let dice = Dice::builder() - .count(8) - .sides(NonZeroU8::new(6).unwrap()) - .explode(None, true) - .build(); + let dice = Dice::builder().count(8).sides(6).explode(None, true).build(); b.iter(|| dice.roll().unwrap().total().unwrap()); } @@ -67,11 +57,7 @@ fn explain_4d8_result(b: &mut Bencher) { #[bench] fn explain_8d6x_result(b: &mut Bencher) { - let dice = Dice::builder() - .count(8) - .sides(NonZeroU8::new(6).unwrap()) - .explode(None, true) - .build(); + let dice = Dice::builder().count(8).sides(6).explode(None, true).build(); let modifier = dice.modifiers.first(); let roll = Rolled { rolls: vec![ diff --git a/src/dice.rs b/src/dice.rs index b4ac31e..d4cbace 100644 --- a/src/dice.rs +++ b/src/dice.rs @@ -1,4 +1,4 @@ -use std::{cmp, fmt, num::NonZeroU8}; +use std::{cmp, fmt}; use fastrand::Rng; @@ -11,7 +11,7 @@ pub struct Dice { pub count: u8, /// Number of sides for each die - pub sides: NonZeroU8, + pub sides: u8, /// Modifiers to apply to rolls from this set of dice pub modifiers: Vec, @@ -42,19 +42,19 @@ impl Dice { /// Rolls a single die of this type with no modifiers using the default Rng pub fn roll_single(&self) -> DieRoll { - DieRoll::new_rand(self.sides.get()) + DieRoll::new_rand(self.sides) } /// Rolls a single die of this type with no modifiers using the given Rng pub fn roll_single_with_rng(&self, rng: &mut Rng) -> DieRoll { - DieRoll::new_rand_with_rng(self.sides.get(), rng) + DieRoll::new_rand_with_rng(self.sides, rng) } - /// Creates a new set of dice with a given count and number of sides. Panics if the number of sides given is 0. + /// Creates a new set of dice with a given count and number of sides pub fn new(count: u8, sides: u8) -> Self { Self { count, - sides: NonZeroU8::new(sides).expect("dice sides must be nonzero"), + sides, modifiers: Vec::new(), } } @@ -96,10 +96,10 @@ pub enum Modifier { Explode(Option, bool), /// Keep the highest x dice, dropping the rest - KeepHigh(NonZeroU8), + KeepHigh(u8), /// Keep the lowest x dice, dropping the rest - KeepLow(NonZeroU8), + KeepLow(u8), } impl Modifier { @@ -117,7 +117,7 @@ impl Modifier { match self { Self::Explode(cond, recurse) => { // Don't allow recursively exploding dice with 1 side since that would result in infinite explosions - if *recurse && rolled.dice.sides.get() == 1 { + if *recurse && rolled.dice.sides == 1 { return Err(Error::InfiniteExplosion(rolled.dice.clone())); } @@ -166,7 +166,7 @@ impl Modifier { refs.sort(); refs.reverse(); refs.iter_mut() - .skip(count.get() as usize) + .skip(*count as usize) .for_each(|roll| roll.dropped_by = Some(self)); } @@ -174,7 +174,7 @@ impl Modifier { let mut refs = rolled.rolls.iter_mut().filter(|r| !r.is_dropped()).collect::>(); refs.sort(); refs.iter_mut() - .skip(count.get() as usize) + .skip(*count as usize) .for_each(|roll| roll.dropped_by = Some(self)); } }; @@ -190,22 +190,8 @@ impl fmt::Display for Modifier { "{}{}", match self { Self::Explode(_, recurse) => format!("x{}", recurse.then_some("").unwrap_or("o")), - Self::KeepHigh(count) => format!( - "kh{}", - if count.get() > 1 { - count.to_string() - } else { - "".to_owned() - } - ), - Self::KeepLow(count) => format!( - "kl{}", - if count.get() > 1 { - count.to_string() - } else { - "".to_owned() - } - ), + Self::KeepHigh(count) => format!("kh{}", if *count > 1 { count.to_string() } else { "".to_owned() }), + Self::KeepLow(count) => format!("kl{}", if *count > 1 { count.to_string() } else { "".to_owned() }), }, match self { Self::Explode(Some(cond), _) => cond.to_string(), @@ -218,16 +204,16 @@ impl fmt::Display for Modifier { /// Conditions that die values can be tested against #[derive(Debug, Clone, PartialEq, Eq)] pub enum Condition { - Eq(NonZeroU8), - Gt(NonZeroU8), - Gte(NonZeroU8), - Lt(NonZeroU8), - Lte(NonZeroU8), + Eq(u8), + Gt(u8), + Gte(u8), + Lt(u8), + Lte(u8), } impl Condition { /// Creates a Condition from its corresponding symbol and a value - pub fn from_symbol_and_val(symbol: &str, val: NonZeroU8) -> Result { + pub fn from_symbol_and_val(symbol: &str, val: u8) -> Result { Ok(match symbol { "=" => Self::Eq(val), ">" => Self::Gt(val), @@ -239,7 +225,7 @@ impl Condition { } /// Checks a value against the condition - pub fn check(&self, val: NonZeroU8) -> bool { + pub fn check(&self, val: u8) -> bool { match self { Self::Eq(expected) => val == *expected, Self::Gt(expected) => val > *expected, @@ -282,7 +268,7 @@ impl fmt::Display for Condition { #[derive(Debug, Clone, PartialEq, Eq)] pub struct DieRoll<'a> { /// Value that was rolled - pub val: NonZeroU8, + pub val: u8, /// Modifier that caused the addition of this die pub added_by: Option<&'a Modifier>, @@ -305,7 +291,7 @@ impl DieRoll<'_> { /// Creates a new DieRoll with the given value pub fn new(val: u8) -> Self { Self { - val: NonZeroU8::new(val).expect("roll val must be nonzero"), + val, added_by: None, dropped_by: None, } @@ -319,7 +305,7 @@ impl DieRoll<'_> { /// Creates a new DieRoll with a random value using the given Rng pub fn new_rand_with_rng(max: u8, rng: &mut Rng) -> Self { - Self::new(rng.u8(1..=max)) + Self::new(if max > 0 { rng.u8(1..=max) } else { 0 }) } } @@ -358,7 +344,7 @@ impl Rolled<'_> { // Sum all rolls that haven't been dropped for r in self.rolls.iter().filter(|r| !r.is_dropped()) { - sum = sum.checked_add(r.val.get().into()).ok_or(Error::Overflow)?; + sum = sum.checked_add(r.val as u16).ok_or(Error::Overflow)?; } Ok(sum) @@ -424,7 +410,7 @@ impl Builder { } /// Sets the number of sides per die - pub fn sides(mut self, sides: NonZeroU8) -> Self { + pub fn sides(mut self, sides: u8) -> Self { self.0.sides = sides; self } @@ -436,13 +422,13 @@ impl Builder { } /// Adds the keep highest modifier to the dice - pub fn keep_high(mut self, count: NonZeroU8) -> Self { + pub fn keep_high(mut self, count: u8) -> Self { self.0.modifiers.push(Modifier::KeepHigh(count)); self } /// Adds the keep lowest modifier to the dice - pub fn keep_low(mut self, count: NonZeroU8) -> Self { + pub fn keep_low(mut self, count: u8) -> Self { self.0.modifiers.push(Modifier::KeepLow(count)); self } diff --git a/src/parse.rs b/src/parse.rs index 8b81da1..82ea00f 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -1,7 +1,5 @@ #![cfg(feature = "parse")] -use std::num::NonZeroU8; - use chumsky::prelude::*; use crate::{ @@ -13,11 +11,11 @@ use crate::{ pub fn dice_part<'src>() -> impl Parser<'src, &'src str, Dice, extra::Err>> + Clone { // Parser for dice modifier conditions let condition = choice(( - just(">=").to(Condition::Gte as fn(NonZeroU8) -> _), - just("<=").to(Condition::Lte as fn(NonZeroU8) -> _), - just('>').to(Condition::Gt as fn(NonZeroU8) -> _), - just('<').to(Condition::Lt as fn(NonZeroU8) -> _), - just('=').to(Condition::Eq as fn(NonZeroU8) -> _), + just(">=").to(Condition::Gte as fn(u8) -> _), + just("<=").to(Condition::Lte as fn(u8) -> _), + just('>').to(Condition::Gt as fn(u8) -> _), + just('<').to(Condition::Lt as fn(u8) -> _), + just('=').to(Condition::Eq as fn(u8) -> _), )) .or_not() .then(text::int::<&'src str, _, _>(10)) diff --git a/src/term.rs b/src/term.rs index db6f144..fce3501 100644 --- a/src/term.rs +++ b/src/term.rs @@ -76,7 +76,7 @@ impl Term { pub fn is_deterministic(&self) -> bool { match self { Self::Num(..) => true, - Self::Dice(dice) => dice.sides.get() == 1, + Self::Dice(dice) => dice.sides == 1, Self::Neg(x) => x.is_deterministic(), Self::Add(a, b) | Self::Sub(a, b) | Self::Mul(a, b) | Self::DivDown(a, b) | Self::DivUp(a, b) => { a.is_deterministic() && b.is_deterministic() diff --git a/src/tests/dice.rs b/src/tests/dice.rs index e0c9d87..940ab67 100644 --- a/src/tests/dice.rs +++ b/src/tests/dice.rs @@ -1,5 +1,3 @@ -use std::num::NonZeroU8; - use crate::dice::{Dice, DieRoll, Modifier, Rolled}; #[test] @@ -64,7 +62,7 @@ fn all_dice_sides_occur() { rolls_in_range(&rolls, 20); for side in 1..20 { - assert!(rolls.iter().filter(|roll| roll.val.get() == side).count() > 0); + assert!(rolls.iter().filter(|roll| roll.val == side).count() > 0); } } @@ -86,26 +84,14 @@ fn dice_inequality() { assert_ne!(da, db); let da = Dice::new(4, 8); - let db = Dice::builder() - .count(4) - .sides(NonZeroU8::new(8).unwrap()) - .explode(None, true) - .build(); + let db = Dice::builder().count(4).sides(8).explode(None, true).build(); assert_ne!(da, db); } #[test] fn roll_equality() { - let da = Dice::builder() - .count(4) - .sides(NonZeroU8::new(8).unwrap()) - .explode(None, true) - .build(); - let db = Dice::builder() - .count(4) - .sides(NonZeroU8::new(8).unwrap()) - .explode(None, true) - .build(); + let da = Dice::builder().count(4).sides(8).explode(None, true).build(); + let db = Dice::builder().count(4).sides(8).explode(None, true).build(); let ra = Rolled { rolls: vec![ DieRoll::new(4), @@ -145,16 +131,8 @@ fn roll_equality() { #[test] fn roll_inequality() { - let da = Dice::builder() - .count(4) - .sides(NonZeroU8::new(8).unwrap()) - .explode(None, true) - .build(); - let db = Dice::builder() - .count(4) - .sides(NonZeroU8::new(8).unwrap()) - .explode(None, true) - .build(); + let da = Dice::builder().count(4).sides(8).explode(None, true).build(); + let db = Dice::builder().count(4).sides(8).explode(None, true).build(); let ra = Rolled { rolls: vec![ DieRoll::new(4), @@ -191,11 +169,7 @@ fn roll_inequality() { }; assert_ne!(ra, rb); - let da = Dice::builder() - .count(4) - .sides(NonZeroU8::new(8).unwrap()) - .explode(None, true) - .build(); + let da = Dice::builder().count(4).sides(8).explode(None, true).build(); let db = Dice::new(4, 8); let ra = Rolled { rolls: vec![ @@ -224,19 +198,15 @@ fn roll_inequality() { fn construct_plain(count: u8, sides: u8) -> Dice { let dice = Dice::new(count, sides); assert_eq!(dice.count, count); - assert_eq!(dice.sides, NonZeroU8::new(sides).unwrap()); + assert_eq!(dice.sides, sides); assert_eq!(dice.modifiers.len(), 0); dice } fn construct_exploding(count: u8, sides: u8) -> Dice { - let dice = Dice::builder() - .count(count) - .sides(NonZeroU8::new(sides).expect("sides must be nonzero")) - .explode(None, true) - .build(); + let dice = Dice::builder().count(count).sides(sides).explode(None, true).build(); assert_eq!(dice.count, count); - assert_eq!(dice.sides, NonZeroU8::new(sides).unwrap()); + assert_eq!(dice.sides, sides); assert_eq!(dice.modifiers.len(), 1); assert!(matches!(dice.modifiers.first().unwrap(), Modifier::Explode(..))); dice @@ -253,5 +223,5 @@ fn rolls_successfully_and_in_range<'a>(dice: &'a Dice) -> Rolled<'a> { } fn rolls_in_range(rolls: &[DieRoll], sides: u8) { - assert!(!rolls.iter().any(|roll| roll.val.get() < 1 || roll.val.get() > sides)); + assert!(!rolls.iter().any(|roll| roll.val < 1 || roll.val > sides)); } diff --git a/src/tests/parse.rs b/src/tests/parse.rs index a60df3c..4ccc681 100644 --- a/src/tests/parse.rs +++ b/src/tests/parse.rs @@ -1,7 +1,5 @@ use chumsky::Parser; -use std::num::NonZeroU8; - use crate::{ dice::{Condition, Dice}, parse::{dice as dice_parser, term as term_parser}, @@ -76,22 +74,14 @@ fn basic_dice_math() { #[test] fn dice_explode() { - let expected = Dice::builder() - .count(4) - .sides(NonZeroU8::new(6).unwrap()) - .explode(None, true) - .build(); + let expected = Dice::builder().count(4).sides(6).explode(None, true).build(); let ast = dice_parser().parse("4d6x").unwrap(); assert_eq!(ast, expected); } #[test] fn dice_explode_once() { - let expected = Dice::builder() - .count(4) - .sides(NonZeroU8::new(6).unwrap()) - .explode(None, false) - .build(); + let expected = Dice::builder().count(4).sides(6).explode(None, false).build(); let ast = dice_parser().parse("4d6xo").unwrap(); assert_eq!(ast, expected); } @@ -100,8 +90,8 @@ fn dice_explode_once() { fn dice_explode_condition() { let expected = Dice::builder() .count(4) - .sides(NonZeroU8::new(6).unwrap()) - .explode(Some(Condition::Gte(NonZeroU8::new(5).unwrap())), true) + .sides(6) + .explode(Some(Condition::Gte(5)), true) .build(); let ast = dice_parser().parse("4d6x>=5").unwrap(); assert_eq!(ast, expected); @@ -109,44 +99,28 @@ fn dice_explode_condition() { #[test] fn dice_keep_high() { - let expected = Dice::builder() - .count(4) - .sides(NonZeroU8::new(20).unwrap()) - .keep_high(NonZeroU8::new(1).unwrap()) - .build(); + let expected = Dice::builder().count(4).sides(20).keep_high(1).build(); let ast = dice_parser().parse("4d20kh").unwrap(); assert_eq!(ast, expected); } #[test] fn dice_keep_high_2() { - let expected = Dice::builder() - .count(4) - .sides(NonZeroU8::new(20).unwrap()) - .keep_high(NonZeroU8::new(2).unwrap()) - .build(); + let expected = Dice::builder().count(4).sides(20).keep_high(2).build(); let ast = dice_parser().parse("4d20kh2").unwrap(); assert_eq!(ast, expected); } #[test] fn dice_keep_low() { - let expected = Dice::builder() - .count(4) - .sides(NonZeroU8::new(20).unwrap()) - .keep_low(NonZeroU8::new(1).unwrap()) - .build(); + let expected = Dice::builder().count(4).sides(20).keep_low(1).build(); let ast = dice_parser().parse("4d20kl").unwrap(); assert_eq!(ast, expected); } #[test] fn dice_keep_low_2() { - let expected = Dice::builder() - .count(4) - .sides(NonZeroU8::new(20).unwrap()) - .keep_low(NonZeroU8::new(2).unwrap()) - .build(); + let expected = Dice::builder().count(4).sides(20).keep_low(2).build(); let ast = dice_parser().parse("4d20kl2").unwrap(); assert_eq!(ast, expected); }