Skip to content

Commit

Permalink
Optimize circuit by int_div_unsafe (#8)
Browse files Browse the repository at this point in the history
* fix failure on empty chip
* optimize int_div_unsafe
  • Loading branch information
lanbones authored Aug 22, 2024
1 parent 713b34d commit 2af5b97
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 23 deletions.
15 changes: 7 additions & 8 deletions src/circuit/ecc_chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,16 +844,15 @@ pub trait EccChipBaseOps<C: CurveAffine, N: FieldExt>:
) -> Result<AssignedNonZeroPoint<C, N>, UnsafeError> {
let diff_x = self.base_integer_chip().int_sub(&a.x, &b.x);
let diff_y = self.base_integer_chip().int_sub(&a.y, &b.y);
let (x_eq, tangent) = self.base_integer_chip().int_div(&diff_y, &diff_x);
let tangent = self.base_integer_chip().int_div_unsafe(&diff_y, &diff_x);

// x cannot be same
let succeed = self.base_integer_chip().base_chip().try_assert_false(&x_eq);
let res = self.lambda_to_point_non_zero(&tangent, a, b);
match tangent {
Some(tangent) => {
let res = self.lambda_to_point_non_zero(&tangent, a, b);

if succeed {
Ok(res)
} else {
Err(UnsafeError::AddSameOrNegPoint)
Ok(res)
}
None => Err(UnsafeError::AddSameOrNegPoint),
}
}

Expand Down
7 changes: 2 additions & 5 deletions src/circuit/fq12.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ pub trait Fq2ChipOps<W: BaseExt, N: FieldExt>: EccBaseIntegerChipWrapper<W, N> {
let t0 = self.base_integer_chip().int_square(&x.0);
let t1 = self.base_integer_chip().int_square(&x.1);
let t0 = self.base_integer_chip().int_add(&t0, &t1);
let t = self.base_integer_chip().int_unsafe_invert(&t0);
let t = self.base_integer_chip().int_unsafe_invert(&t0).unwrap();
let c0 = self.base_integer_chip().int_mul(&x.0, &t);
let c1 = self.base_integer_chip().int_mul(&x.1, &t);
let c1 = self.base_integer_chip().int_neg(&c1);
Expand Down Expand Up @@ -303,10 +303,7 @@ pub trait Fq6ChipOps<W: BaseExt, N: FieldExt>: Fq2ChipOps<W, N> + Fq2BnSpecificO

pub trait Fq12ChipOps<W: BaseExt, N: FieldExt>: Fq6ChipOps<W, N> + Fq6BnSpecificOps<W, N> {
fn fq12_reduce(&mut self, x: &AssignedFq12<W, N>) -> AssignedFq12<W, N> {
(
self.fq6_reduce(&x.0),
self.fq6_reduce(&x.1),
)
(self.fq6_reduce(&x.0), self.fq6_reduce(&x.1))
}
fn fq12_assert_one(&mut self, x: &AssignedFq12<W, N>) {
let one = self.fq12_assign_one();
Expand Down
56 changes: 49 additions & 7 deletions src/circuit/integer_chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ pub trait IntegerChipOps<W: BaseExt, N: FieldExt> {
a: &AssignedInteger<W, N>,
b: &AssignedInteger<W, N>,
) -> AssignedInteger<W, N>;
fn int_unsafe_invert(&mut self, x: &AssignedInteger<W, N>) -> AssignedInteger<W, N>;
fn int_unsafe_invert(&mut self, x: &AssignedInteger<W, N>) -> Option<AssignedInteger<W, N>>;
fn int_div(
&mut self,
a: &AssignedInteger<W, N>,
b: &AssignedInteger<W, N>,
) -> (AssignedCondition<N>, AssignedInteger<W, N>);
fn int_div_unsafe(
&mut self,
a: &AssignedInteger<W, N>,
b: &AssignedInteger<W, N>,
) -> Option<AssignedInteger<W, N>>;
fn is_pure_zero(&mut self, a: &AssignedInteger<W, N>) -> AssignedCondition<N>;
fn is_pure_w_modulus(&mut self, a: &AssignedInteger<W, N>) -> AssignedCondition<N>;
fn is_int_zero(&mut self, a: &AssignedInteger<W, N>) -> AssignedCondition<N>;
Expand Down Expand Up @@ -79,7 +84,7 @@ impl<W: BaseExt, N: FieldExt> IntegerContext<W, N> {
) {
assert!(a.times < self.info().overflow_limit);
assert!(b.times < self.info().overflow_limit);
assert!(rem.times == 1);
assert!(rem.times < self.info().overflow_limit);

let info = self.info();
let one = N::one();
Expand Down Expand Up @@ -482,12 +487,49 @@ impl<W: BaseExt, N: FieldExt> IntegerChipOps<W, N> for IntegerContext<W, N> {
rem
}

fn int_unsafe_invert(&mut self, x: &AssignedInteger<W, N>) -> AssignedInteger<W, N> {
//TODO: optimize
fn int_unsafe_invert(&mut self, x: &AssignedInteger<W, N>) -> Option<AssignedInteger<W, N>> {
let one = self.assign_int_constant(W::one());
let (c, v) = self.int_div(&one, x);
self.ctx.borrow_mut().assert_false(&c);
v
self.int_div_unsafe(&one, x)
}

fn int_div_unsafe(
&mut self,
a: &AssignedInteger<W, N>,
b: &AssignedInteger<W, N>,
) -> Option<AssignedInteger<W, N>> {
let info = self.info();

let mut b = b.clone();

// Ensure b > a, so c * b > a and we can find the d that c * b = d * w + a
if b.times <= a.times {
let assigned_w = self.assign_w(&info.w_modulus);
while b.times < a.times {
b = self.int_add(&b, &assigned_w);
}
}

let a_bn = self.get_w_bn(&a);
let b_bn = self.get_w_bn(&b);

let b_inv: Option<W> = bn_to_field::<W>(&b_bn).invert().into();

match b_inv {
Some(b_inv) => {
let c = bn_to_field::<W>(&a_bn) * b_inv;
let c_bn = field_to_bn(&c);
let d_bn = (&b_bn * &c_bn - &a_bn) / &info.w_modulus;

let c = self.assign_w(&c_bn);
let d = self.assign_d(&d_bn);

self.add_constraints_for_mul_equation_on_limbs(&b, &c, &d.0, &a);
self.add_constraints_for_mul_equation_on_native(&b, &c, &d.1, &a);

Some(c)
}
None => None,
}
}

fn int_div(
Expand Down
2 changes: 2 additions & 0 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ impl<N: FieldExt> Records<N> {

let threads = 16;
let chunk_size = (self.base_height + threads - 1) / threads;
let chunk_size = if chunk_size == 0 { 1 } else { chunk_size };
let chunk_num = chunk_size * threads;
self.inner
.base_adv_record
Expand Down Expand Up @@ -419,6 +420,7 @@ impl<N: FieldExt> Records<N> {

let threads = 16;
let chunk_size = (self.base_height + threads - 1) / threads;
let chunk_size = if chunk_size == 0 { 1 } else { chunk_size };
let chunk_num = chunk_size * threads;
self.inner
.range_adv_record
Expand Down
9 changes: 6 additions & 3 deletions src/range_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ impl<W: BaseExt, N: FieldExt> RangeInfo<W, N> {
let lcm = self
.n_modulus
.lcm(&(BigUint::from(1u64) << (self.limb_bits * self.mul_check_limbs)));
let max_rem = &self.w_ceil - 1u64;
let max_rem = &self.w_ceil * (self.overflow_limit - 1) - 1u64;
assert!(lcm > &max_a * max_b);
assert!(lcm > &max_d * &self.w_modulus + &max_rem);

Expand All @@ -273,7 +273,7 @@ impl<W: BaseExt, N: FieldExt> RangeInfo<W, N> {
.iter()
.reduce(|acc, x| acc.max(x))
.unwrap();
let max_rem_i = &self.limb_modulus - 1u64;
let max_rem_i = &self.limb_modulus * (self.overflow_limit - 1) - 1u64;
assert!(
&borrow * &self.limb_modulus - &borrow
>= self.limbs * max_d_j * max_w_j + max_rem_i
Expand All @@ -285,7 +285,10 @@ impl<W: BaseExt, N: FieldExt> RangeInfo<W, N> {
let max_v = &self.limb_modulus * common_modulus - 1u64;
let max_a_j = &self.limb_modulus * (self.overflow_limit - 1);
let max_b_j = &max_a_j;
assert!(&max_v * &self.limb_modulus >= &max_a_j * max_b_j * self.limbs + &self.limb_modulus * &borrow);
assert!(
&max_v * &self.limb_modulus
>= &max_a_j * max_b_j * self.limbs + &self.limb_modulus * &borrow
);

// To avoid overflow
// max(v) * limb_modulus < n_modulus
Expand Down

0 comments on commit 2af5b97

Please sign in to comment.