Skip to content

Commit

Permalink
fix: include tol check in fwd pass (#723)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto authored Feb 23, 2024
1 parent bf69b16 commit 6c0c17c
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 52 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ jobs:
with:
crate: cargo-nextest
locked: true
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
run: cargo nextest run --release --verbose tests::mock_large_batch_public_outputs_ --test-threads 32
- name: kzg inputs
Expand Down
11 changes: 10 additions & 1 deletion src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2953,8 +2953,17 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
return enforce_equality(config, region, values);
}

let mut values = [values[0].clone(), values[1].clone()];

values[0] = region.assign(&config.inputs[0], &values[0])?;
values[1] = region.assign(&config.inputs[1], &values[1])?;
let total_assigned_0 = values[0].len();
let total_assigned_1 = values[1].len();
let total_assigned = std::cmp::max(total_assigned_0, total_assigned_1);
region.increment(total_assigned);

// Calculate the difference between the expected output and actual output
let diff = pairwise(config, region, values, BaseOp::Sub)?;
let diff = pairwise(config, region, &values, BaseOp::Sub)?;

// Calculate the reciprocal of the expected output tensor, scaling by double the scaling factor
let recip = nonlinearity(
Expand Down
8 changes: 4 additions & 4 deletions src/circuit/ops/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a),
LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a),
LookupOp::Sign => "SIGN".into(),
LookupOp::GreaterThan { .. } => "GREATER_THAN".into(),
LookupOp::GreaterThanEqual { .. } => "GREATER_THAN_EQUAL".into(),
LookupOp::LessThan { .. } => "LESS_THAN".into(),
LookupOp::LessThanEqual { .. } => "LESS_THAN_EQUAL".into(),
LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a),
LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a),
LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a),
LookupOp::LessThanEqual { a } => format!("LESS_THAN_EQUAL(a={})", a),
LookupOp::Recip {
input_scale,
output_scale,
Expand Down
2 changes: 1 addition & 1 deletion src/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ impl GraphCircuit {
}
}

let mut model_results = self.model().forward(inputs)?;
let mut model_results = self.model().forward(inputs, &self.settings().run_args)?;

if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
Expand Down
37 changes: 20 additions & 17 deletions src/graph/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,12 +556,16 @@ impl Model {
/// * `reader` - A reader for an Onnx file.
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
/// * `run_args` - [RunArgs]
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
pub fn forward(
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(&RunArgs::default(), &valtensor_inputs)?;
let res = self.dummy_layout(&run_args, &valtensor_inputs)?;
Ok(res.into())
}

Expand Down Expand Up @@ -1371,27 +1375,26 @@ impl Model {
ValType::Constant(Fp::ONE)
};

let comparator = outputs
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.map(|x| {
let mut v: ValTensor<Fp> =
vec![default_value.clone(); x.dims().iter().product::<usize>()].into();
v.reshape(x.dims())?;
Ok(v)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
.enumerate()
.map(|(i, output)| {
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();

let mut comparator: ValTensor<Fp> =
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
comparator.reshape(output.dims())?;

let _ = outputs
.iter()
.zip(comparator)
.map(|(o, c)| {
dummy_config.layout(
&mut region,
&[o.clone(), c],
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
&[output.clone(), comparator],
Box::new(HybridOp::RangeCheck(tolerance)),
)
})
.collect::<Result<Vec<_>, _>>()?;
.collect::<Result<Vec<_>, _>>();
res?;
} else if !self.visibility.output.is_private() {
for output in &outputs {
region.increment_total_constants(output.num_constants());
Expand Down
16 changes: 10 additions & 6 deletions src/pfsys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -728,12 +728,13 @@ where
let f =
File::open(path.clone()).map_err(|_| format!("failed to load vk at {}", path.display()))?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
VerifyingKey::<Scheme::Curve>::read::<_, C>(
let vk = VerifyingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)
.map_err(Box::<dyn Error>::from)
)?;
info!("done loading verification key ✅");
Ok(vk)
}

/// Loads a [ProvingKey] at `path`.
Expand All @@ -750,12 +751,13 @@ where
let f =
File::open(path.clone()).map_err(|_| format!("failed to load pk at {}", path.display()))?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
ProvingKey::<Scheme::Curve>::read::<_, C>(
let pk = ProvingKey::<Scheme::Curve>::read::<_, C>(
&mut reader,
serde_format_from_str(&EZKL_KEY_FORMAT),
params,
)
.map_err(Box::<dyn Error>::from)
)?;
info!("done loading proving key ✅");
Ok(pk)
}

/// Saves a [ProvingKey] to `path`.
Expand All @@ -772,6 +774,7 @@ where
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
pk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
writer.flush()?;
info!("done saving proving key ✅");
Ok(())
}

Expand All @@ -789,6 +792,7 @@ where
let mut writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
vk.write(&mut writer, serde_format_from_str(&EZKL_KEY_FORMAT))?;
writer.flush()?;
info!("done saving verification key ✅");
Ok(())
}

Expand Down
10 changes: 7 additions & 3 deletions src/tensor/val.rs
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
Ok(indices)
}
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
ValTensor::Instance { .. } => Ok(vec![]),
}
}

Expand All @@ -690,7 +690,7 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
Ok(indices)
}
ValTensor::Instance { .. } => Err(TensorError::WrongMethod),
ValTensor::Instance { .. } => Ok(vec![]),
}
}

Expand All @@ -709,7 +709,11 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(TensorError::WrongMethod);
if indices.is_empty() {
return Ok(());
} else {
return Err(TensorError::WrongMethod);
}
}
}
Ok(())
Expand Down
Loading

0 comments on commit 6c0c17c

Please sign in to comment.