Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow for lookup overflow in calibration #558

Merged
merged 8 commits into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benches/accum_matmul_relu_overflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl Circuit<Fr> for MyCircuit {
fn runmatmul(c: &mut Criterion) {
let mut group = c.benchmark_group("accum_matmul");

for &k in [8, 10, 12, 14].iter() {
for &k in [8, 10, 11, 12, 13, 14].iter() {
let len = unsafe { LEN };
unsafe {
K = k;
Expand Down
12 changes: 6 additions & 6 deletions src/circuit/ops/chip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,12 +366,12 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {

let (default_x, default_y) = table.get_first_element(col_idx);

log::debug!("---------------- col {:?} ------------------", col_idx,);
log::debug!("expr: {:?}", col_expr,);
log::debug!("multiplier: {:?}", multiplier);
log::debug!("not_expr: {:?}", not_expr);
log::debug!("default x: {:?}", default_x);
log::debug!("default y: {:?}", default_y);
log::trace!("---------------- col {:?} ------------------", col_idx,);
log::trace!("expr: {:?}", col_expr,);
log::trace!("multiplier: {:?}", multiplier);
log::trace!("not_expr: {:?}", not_expr);
log::trace!("default x: {:?}", default_x);
log::trace!("default y: {:?}", default_y);

res.extend([
(
Expand Down
8 changes: 8 additions & 0 deletions src/circuit/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ pub enum InputType {
}

impl InputType {
///
pub fn is_integer(&self) -> bool {
match self {
InputType::Bool | InputType::Int | InputType::TDim => true,
_ => false,
}
}

///
pub fn roundtrip<T: num::ToPrimitive + num::FromPrimitive + Clone>(&self, input: &mut T) {
match self {
Expand Down
3 changes: 3 additions & 0 deletions src/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ pub enum Commands {
/// Optional scales to specifically try for calibration.
#[arg(long, value_delimiter = ',')]
scales: Option<Vec<u32>>,
/// max logrows to use for calibration, 26 is the max public SRS size
#[arg(long)]
max_logrows: Option<u32>,
},

/// Generates a dummy SRS
Expand Down
20 changes: 17 additions & 3 deletions src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
data,
target,
scales,
} => calibrate(model, data, settings_path, target, scales).await,
max_logrows,
} => calibrate(model, data, settings_path, target, scales, max_logrows).await,
Commands::GenWitness {
data,
compiled_circuit,
Expand Down Expand Up @@ -561,6 +562,7 @@ pub(crate) async fn calibrate(
settings_path: PathBuf,
target: CalibrationTarget,
scales: Option<Vec<u32>>,
max_logrows: Option<u32>,
) -> Result<(), Box<dyn Error>> {
let data = GraphData::from_path(data)?;
// load the pre-generated settings
Expand Down Expand Up @@ -597,11 +599,23 @@ pub(crate) async fn calibrate(
.collect::<Vec<(u32, u32)>>();

// remove all entries where input_scale > param_scale
let range_grid = range_grid
let mut range_grid = range_grid
.into_iter()
.filter(|(a, b)| a <= b)
.collect::<Vec<(u32, u32)>>();

// if all integers
let all_scale_0 = model.graph.get_input_types().iter().all(|t| t.is_integer());
if all_scale_0 {
// set all a values to 0 then dedup
range_grid = range_grid
.iter()
.map(|(_, b)| (0, *b))
.sorted()
.dedup()
.collect::<Vec<(u32, u32)>>();
}

let range_grid = range_grid
.iter()
.cartesian_product(scale_rebase_multiplier.iter())
Expand Down Expand Up @@ -657,7 +671,7 @@ pub(crate) async fn calibrate(
.map_err(|e| format!("failed to load circuit inputs: {}", e))?;

circuit
.calibrate(&data)
.calibrate(&data, max_logrows)
.map_err(|e| format!("failed to calibrate: {}", e))?;

let settings = circuit.settings().clone();
Expand Down
38 changes: 23 additions & 15 deletions src/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ const ASSUMED_BLINDING_FACTORS: usize = 5;
pub const MIN_LOGROWS: u32 = 4;

/// 26
const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2;
pub const MAX_PUBLIC_SRS: u32 = bn256::Fr::S - 2;

use std::cell::RefCell;

Expand Down Expand Up @@ -703,12 +703,7 @@ impl GraphCircuit {
let (_, client) = setup_eth_backend(Some(&source.rpc), None).await?;
let inputs = read_on_chain_inputs(client.clone(), client.address(), &source.calls).await?;
// quantize the supplied data using the provided scale + QuantizeData.sol
let quantized_evm_inputs = evm_quantize(
client,
scales,
&inputs,
)
.await?;
let quantized_evm_inputs = evm_quantize(client, scales, &inputs).await?;
// on-chain data has already been quantized at this point. Just need to reshape it and push into tensor vector
let mut inputs: Vec<Tensor<Fp>> = vec![];
for (input, shape) in [quantized_evm_inputs].iter().zip(shapes) {
Expand Down Expand Up @@ -780,15 +775,25 @@ impl GraphCircuit {
)
}

fn calc_min_logrows(&mut self, res: &GraphWitness) -> Result<(), Box<dyn std::error::Error>> {
fn calc_min_logrows(
&mut self,
res: &GraphWitness,
max_logrows: Option<u32>,
) -> Result<(), Box<dyn std::error::Error>> {
// load the max logrows
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
let max_logrows = std::cmp::min(max_logrows, MAX_PUBLIC_SRS);
let max_logrows = std::cmp::max(max_logrows, MIN_LOGROWS);

let reserved_blinding_rows = Self::reserved_blinding_rows();
let safe_range = Self::calc_safe_range(res);

let max_col_size =
Table::<Fp>::cal_col_size(MAX_PUBLIC_SRS as usize, reserved_blinding_rows as usize);
Table::<Fp>::cal_col_size(max_logrows as usize, reserved_blinding_rows as usize);
let num_cols = Table::<Fp>::num_cols_required(safe_range, max_col_size);

if num_cols > 1 {
// empirically determined that this is when performance starts to degrade significantly
if num_cols > 3 {
let err_string = format!(
"No possible lookup range can accomodate max value min and max value ({}, {})",
safe_range.0, safe_range.1
Expand Down Expand Up @@ -834,7 +839,7 @@ impl GraphCircuit {

// ensure logrows is at least 4
logrows = std::cmp::max(logrows, MIN_LOGROWS as usize);
logrows = std::cmp::min(logrows, MAX_PUBLIC_SRS as usize);
logrows = std::cmp::min(logrows, max_logrows as usize);
let model = self.model().clone();
let settings_mut = self.settings_mut();
settings_mut.run_args.lookup_range = safe_range;
Expand All @@ -857,8 +862,7 @@ impl GraphCircuit {
settings_mut.run_args.logrows =
std::cmp::max(settings_mut.run_args.logrows, min_rows_from_constraints);

settings_mut.run_args.logrows =
std::cmp::min(MAX_PUBLIC_SRS, settings_mut.run_args.logrows);
settings_mut.run_args.logrows = std::cmp::min(max_logrows, settings_mut.run_args.logrows);

info!(
"setting lookup_range to: {:?}, setting logrows to: {}",
Expand All @@ -870,10 +874,14 @@ impl GraphCircuit {
}

/// Calibrate the circuit to the supplied data.
pub fn calibrate(&mut self, input: &[Tensor<Fp>]) -> Result<(), Box<dyn std::error::Error>> {
pub fn calibrate(
&mut self,
input: &[Tensor<Fp>],
max_logrows: Option<u32>,
) -> Result<(), Box<dyn std::error::Error>> {
let res = self.forward(&mut input.to_vec())?;

self.calc_min_logrows(&res)
self.calc_min_logrows(&res, max_logrows)
}

/// Runs the forward pass of the model / graph of computations and any associated hashing.
Expand Down
5 changes: 4 additions & 1 deletion src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@ fn gen_settings(
settings,
target,
scales = None,
max_logrows = None,
))]
fn calibrate_settings(
py: Python,
Expand All @@ -632,12 +633,14 @@ fn calibrate_settings(
settings: PathBuf,
target: Option<CalibrationTarget>,
scales: Option<Vec<u32>>,
max_logrows: Option<u32>,
) -> PyResult<&pyo3::PyAny> {
let target = target.unwrap_or(CalibrationTarget::Resources {
col_overflow: false,
});

pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::calibrate(model, data, settings, target, scales)
crate::execute::calibrate(model, data, settings, target, scales, max_logrows)
.await
.map_err(|e| {
let err_str = format!("Failed to calibrate settings: {}", e);
Expand Down
Loading