Skip to content

Commit

Permalink
binfield: usize -> enum (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
jtriley-eth authored Jul 6, 2024
1 parent 586f7fc commit 1d90de6
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 27 deletions.
9 changes: 6 additions & 3 deletions src/field/binary_towers/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ pub(super) fn multiply(a: &[BinaryField], b: &[BinaryField], k: usize) -> Vec<Bi

// r1r2*X_{i-2}
// X_{i-2}: a+b*x where a, b are vectors of length K/4, with a = 0, b = 1
let mut x_i_high = vec![BinaryField::new(0); 1 << (k - 1)];
let mut x_i_high = vec![BinaryField::Zero; 1 << (k - 1)];
x_i_high[quarterlen] = BinaryField::ONE;
let r1r2_high = multiply(&x_i_high, &r1r2, k - 1);

Expand All @@ -321,9 +321,12 @@ fn add_vec(lhs: &[BinaryField], rhs: &[BinaryField]) -> Vec<BinaryField> {
pub(super) fn to_bool_vec(mut num: u64, length: usize) -> Vec<BinaryField> {
let mut result = Vec::new();
while num > 0 {
result.push(BinaryField::new(((num & 1) != 0) as u8));
result.push(match num & 1 {
0 => BinaryField::Zero,
_ => BinaryField::One,
});
num >>= 1;
}
result.extend(std::iter::repeat(BinaryField::new(0)).take(length - result.len()));
result.extend(std::iter::repeat(BinaryField::Zero).take(length - result.len()));
result
}
57 changes: 39 additions & 18 deletions src/field/binary_towers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,41 +13,51 @@ pub use extension::BinaryTowers;

/// binary field containing element `{0,1}`
#[derive(Debug, Default, Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct BinaryField(u8);

impl BinaryField {
/// create new binary field element
pub const fn new(value: u8) -> Self {
debug_assert!(value < 2, "value should be less than 2");
Self(value)
}
pub enum BinaryField {
/// binary field element `0`
#[default]
Zero,
/// binary field element `1`
One,
}

impl FiniteField for BinaryField {
const ONE: Self = BinaryField(1);
const ONE: Self = BinaryField::One;
const ORDER: usize = 2;
const PRIMITIVE_ELEMENT: Self = Self::ONE;
const ZERO: Self = BinaryField(0);
const ZERO: Self = BinaryField::Zero;

fn inverse(&self) -> Option<Self> {
if *self == Self::ZERO {
return None;
match *self {
Self::Zero => None,
Self::One => Some(Self::One),
}
Some(*self)
}

fn pow(self, _: usize) -> Self { self }
}

impl From<usize> for BinaryField {
fn from(value: usize) -> Self { Self::new(value as u8) }
fn from(value: usize) -> Self {
match value {
0 => BinaryField::Zero,
1 => BinaryField::One,
_ => panic!("Invalid `usize` value. Must be 0 or 1."),
}
}
}

impl Add for BinaryField {
type Output = Self;

#[allow(clippy::suspicious_arithmetic_impl)]
fn add(self, rhs: Self) -> Self::Output { BinaryField::new(self.0 ^ rhs.0) }
fn add(self, rhs: Self) -> Self::Output {
if self == rhs {
Self::ZERO
} else {
Self::ONE
}
}
}

impl AddAssign for BinaryField {
Expand All @@ -64,7 +74,13 @@ impl Sub for BinaryField {
type Output = Self;

#[allow(clippy::suspicious_arithmetic_impl)]
fn sub(self, rhs: Self) -> Self::Output { BinaryField(self.0 ^ rhs.0) }
fn sub(self, rhs: Self) -> Self::Output {
if self == rhs {
Self::ZERO
} else {
Self::ONE
}
}
}

impl SubAssign for BinaryField {
Expand All @@ -81,7 +97,12 @@ impl Mul for BinaryField {
type Output = Self;

#[allow(clippy::suspicious_arithmetic_impl)]
fn mul(self, rhs: Self) -> Self::Output { BinaryField(self.0 & rhs.0) }
fn mul(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Self::One, Self::One) => Self::ONE,
_ => Self::ZERO,
}
}
}

impl MulAssign for BinaryField {
Expand All @@ -98,7 +119,7 @@ impl Div for BinaryField {
type Output = Self;

#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self::Output { self * rhs.inverse().unwrap() }
fn div(self, rhs: Self) -> Self::Output { self * rhs.inverse().expect("divide by zero") }
}

impl DivAssign for BinaryField {
Expand Down
12 changes: 6 additions & 6 deletions src/field/binary_towers/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ pub(super) fn num_digits(n: u64) -> usize {
fn from_bool_vec(num: Vec<BinaryField>) -> u64 {
let mut result: u64 = 0;
for (i, &bit) in num.iter().rev().enumerate() {
if bit.0 == 1 {
if bit == BinaryField::One {
result |= 1 << (num.len() - 1 - i);
}
}
Expand All @@ -146,20 +146,20 @@ fn from_bool_vec(num: Vec<BinaryField>) -> u64 {
#[should_panic]
#[case(1, 0)]
fn binary_field_arithmetic(#[case] a: usize, #[case] b: usize) {
let arg1 = BinaryField::new(a as u8);
let arg2 = BinaryField::new(b as u8);
let arg1 = BinaryField::from(a);
let arg2 = BinaryField::from(b);
let a_test = TestBinaryField::new(a);
let b_test = TestBinaryField::new(b);

assert_eq!((arg1 + arg2).0, (a_test + b_test).value as u8);
assert_eq!((arg1 + arg2), BinaryField::from((a_test + b_test).value));
assert_eq!(arg1 - arg2, arg1 + arg2);
assert_eq!((arg1 * arg2).0, (a_test * b_test).value as u8);
assert_eq!((arg1 * arg2), BinaryField::from((a_test * b_test).value));

let inv_res = arg2.inverse();
assert!(inv_res.is_some());
assert_eq!(inv_res.unwrap(), arg2);

assert_eq!((arg1 / arg2).0, (a_test / b_test).value as u8);
assert_eq!((arg1 / arg2), BinaryField::from((a_test / b_test).value));
}

#[rstest]
Expand Down

0 comments on commit 1d90de6

Please sign in to comment.