Skip to content

Commit

Permalink
feat: FIELD_eval_h_lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
lanbones committed Nov 17, 2022
1 parent 76598b3 commit 77025cc
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
3 changes: 2 additions & 1 deletion ec-gpu-gen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ license = "MIT/Apache-2.0"
[dependencies]
bitvec = "1.0.1"
crossbeam-channel = "0.5.1"
ec-gpu = "0.2.0"
ec-gpu = { path = "../ec-gpu" }
execute = "0.2.9"
ff = { version = "0.12.0", default-features = false }
group = "0.12.0"
Expand All @@ -24,6 +24,7 @@ rust-gpu-tools = { version = "0.6.1", default-features = false, optional = true
sha2 = "0.10"
thiserror = "1.0.30"
yastl = "0.1.2"
ark-std = { version = "0.3", features = ["print-trace"] }

[dev-dependencies]
# NOTE vmx 2022-07-07: Using the `__private_bench` feature of `blstrs` is just
Expand Down
65 changes: 63 additions & 2 deletions ec-gpu-gen/src/cl/fft.cl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/*
* FFT algorithm is inspired from: http://www.bealto.com/gpu-fft_group-1.html
*/
KERNEL void FIELD_radix_fft(GLOBAL FIELD* x, // Source buffer
KERNEL void FIELD_radix_fft(const GLOBAL FIELD* x, // Source buffer
GLOBAL FIELD* y, // Destination buffer
GLOBAL FIELD* pq, // Precalculated twiddle factors
const GLOBAL FIELD* pq, // Precalculated twiddle factors
GLOBAL FIELD* omegas, // [omega, omega^2, omega^4, ...]
LOCAL FIELD* u_arg, // Local buffer to store intermediary values
uint n, // Number of elements
Expand Down Expand Up @@ -74,3 +74,64 @@ KERNEL void FIELD_mul_by_field(GLOBAL FIELD* elements,
const uint gid = GET_GLOBAL_ID();
elements[gid] = FIELD_mul(elements[gid], field);
}

KERNEL void FIELD_eval_h_lookups(
GLOBAL FIELD* value,
GLOBAL FIELD* table,
GLOBAL FIELD* permuted_input_coset,
GLOBAL FIELD* permuted_table_coset,
GLOBAL FIELD* product_coset,
GLOBAL FIELD* l0,
GLOBAL FIELD* l_last,
GLOBAL FIELD* l_active_row,
GLOBAL FIELD* y_beta_gamma,
uint32_t rot_scale,
uint32_t size
) {
uint gid = GET_GLOBAL_ID();
uint idx = gid;

uint32_t r_next = (idx + rot_scale) & (size - 1);
uint32_t r_prev = (idx + size - rot_scale) & (size - 1);

// l_0(X) * (1 - z(X)) = 0
value[idx] = FIELD_mul(value[idx], y_beta_gamma[0]);
FIELD tmp = FIELD_sub(FIELD_ONE, product_coset[idx]);
tmp = FIELD_mul(tmp, l0[idx]);
value[idx] = FIELD_add(value[idx], tmp);

// l_last(X) * (z(X)^2 - z(X)) = 0
value[idx] = FIELD_mul(value[idx], y_beta_gamma[0]);
tmp = FIELD_sqr(product_coset[idx]);
tmp = FIELD_sub(tmp, product_coset[idx]);
tmp = FIELD_mul(tmp, l_last[idx]);
value[idx] = FIELD_add(value[idx], tmp);

// (1 - (l_last(X) + l_blind(X))) * (
// z(\omega X) (a'(X) + \beta) (s'(X) + \gamma)
// - z(X) (\theta^{m-1} a_0(X) + ... + a_{m-1}(X) + \beta)
// (\theta^{m-1} s_0(X) + ... + s_{m-1}(X) + \gamma)
// ) = 0
value[idx] = FIELD_mul(value[idx], y_beta_gamma[0]);
tmp = FIELD_add(permuted_input_coset[idx], y_beta_gamma[1]);
FIELD tmp2 = FIELD_add(permuted_table_coset[idx], y_beta_gamma[2]);
tmp = FIELD_mul(tmp, tmp2);
tmp = FIELD_mul(tmp, product_coset[r_next]);
tmp2 = FIELD_mul(product_coset[idx], table[idx]);
tmp = FIELD_sub(tmp, tmp2);
tmp = FIELD_mul(tmp, l_active_row[idx]);
value[idx] = FIELD_add(value[idx], tmp);

// l_0(X) * (a'(X) - s'(X)) = 0
value[idx] = FIELD_mul(value[idx], y_beta_gamma[0]);
tmp2 = FIELD_sub(permuted_input_coset[idx], permuted_table_coset[idx]);
tmp = FIELD_mul(tmp2, l0[idx]);
value[idx] = FIELD_add(value[idx], tmp);

// (1 - (l_last + l_blind)) * (a′(X) − s′(X))⋅(a′(X) − a′(\omega^{-1} X)) = 0
value[idx] = FIELD_mul(value[idx], y_beta_gamma[0]);
tmp = FIELD_sub(permuted_input_coset[idx], permuted_input_coset[r_prev]);
tmp = FIELD_mul(tmp, tmp2);
tmp = FIELD_mul(tmp, l_active_row[idx]);
value[idx] = FIELD_add(value[idx], tmp);
}
7 changes: 5 additions & 2 deletions ec-gpu-gen/src/fft.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::cmp;
use std::sync::{Arc, RwLock};

use ark_std::{start_timer, end_timer};
use ec_gpu::GpuName;
use ff::Field;
use log::{error, info};
Expand All @@ -18,7 +19,7 @@ pub struct SingleFftKernel<'a, F>
where
F: Field + GpuName,
{
program: Program,
pub program: Program,
/// An optional function which will be called at places where it is possible to abort the FFT
/// calculations. If it returns true, the calculation will be aborted with an
/// [`EcError::Aborted`].
Expand Down Expand Up @@ -77,7 +78,9 @@ impl<'a, F: Field + GpuName> SingleFftKernel<'a, F> {
}
let omegas_buffer = program.create_buffer_from_slice(&omegas)?;

//let timer = start_timer!(|| format!("copy {}", log_n));
program.write_from_buffer(&mut src_buffer, &*input)?;
//end_timer!(timer);
// Specifies log2 of `p`, (http://www.bealto.com/gpu-fft_group-1.html)
let mut log_p = 0u32;
// Each iteration performs a FFT round
Expand Down Expand Up @@ -130,7 +133,7 @@ pub struct FftKernel<'a, F>
where
F: Field + GpuName,
{
kernels: Vec<SingleFftKernel<'a, F>>,
pub kernels: Vec<SingleFftKernel<'a, F>>,
}

impl<'a, F> FftKernel<'a, F>
Expand Down

0 comments on commit 77025cc

Please sign in to comment.