Skip to content

Commit

Permalink
feat: lookupless sqrt and rsqrt
Browse files Browse the repository at this point in the history
+ no range check div and recip
  • Loading branch information
alexander-camuto committed Nov 8, 2024
1 parent 4285078 commit d457988
Show file tree
Hide file tree
Showing 13 changed files with 505 additions and 266 deletions.
10 changes: 0 additions & 10 deletions src/bindings/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,6 @@ struct PyRunArgs {
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
pub variables: Vec<(String, usize)>,
#[pyo3(get, set)]
/// bool: Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
#[pyo3(get, set)]
/// bool: Should constants with 0.0 fraction be rebased to scale 0
pub rebase_frac_zero_constants: bool,
#[pyo3(get, set)]
Expand Down Expand Up @@ -227,7 +224,6 @@ impl From<PyRunArgs> for RunArgs {
output_visibility: py_run_args.output_visibility,
param_visibility: py_run_args.param_visibility,
variables: py_run_args.variables,
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: Some(py_run_args.commitment.into()),
Expand All @@ -252,7 +248,6 @@ impl Into<PyRunArgs> for RunArgs {
output_visibility: self.output_visibility,
param_visibility: self.param_visibility,
variables: self.variables,
div_rebasing: self.div_rebasing,
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
check_mode: self.check_mode,
commitment: self.commitment.into(),
Expand Down Expand Up @@ -878,8 +873,6 @@ fn gen_settings(
/// max_logrows: int
/// Optional max logrows to use for calibration
///
/// only_range_check_rebase: bool
/// Check ranges when rebasing
///
/// Returns
/// -------
Expand All @@ -894,7 +887,6 @@ fn gen_settings(
scales = None,
scale_rebase_multiplier = DEFAULT_SCALE_REBASE_MULTIPLIERS.split(",").map(|x| x.parse().unwrap()).collect(),
max_logrows = None,
only_range_check_rebase = DEFAULT_ONLY_RANGE_CHECK_REBASE.parse().unwrap(),
))]
fn calibrate_settings(
py: Python,
Expand All @@ -906,7 +898,6 @@ fn calibrate_settings(
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
only_range_check_rebase: bool,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::calibrate(
Expand All @@ -917,7 +908,6 @@ fn calibrate_settings(
lookup_safety_margin,
scales,
scale_rebase_multiplier,
only_range_check_rebase,
max_logrows,
)
.await
Expand Down
61 changes: 42 additions & 19 deletions src/circuit/ops/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@ pub enum HybridOp {
Ln {
scale: utils::F32,
},

Rsqrt {
input_scale: utils::F32,
output_scale: utils::F32,
},
Exp {
scale: utils::F32,
},
Sqrt {
scale: utils::F32,
},
RoundHalfToEven {
scale: utils::F32,
legs: usize,
Expand All @@ -39,7 +48,6 @@ pub enum HybridOp {
},
Div {
denom: utils::F32,
use_range_check_for_int: bool,
},
ReduceMax {
axes: Vec<usize>,
Expand Down Expand Up @@ -116,6 +124,15 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid

fn as_string(&self) -> String {
match self {
HybridOp::Exp { scale } => format!("EXP(scale={})", scale),
HybridOp::Rsqrt {
input_scale,
output_scale,
} => format!(
"RSQRT (input_scale={}, output_scale={})",
input_scale, output_scale
),
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
Expand All @@ -133,13 +150,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
"RECIP (input_scale={}, output_scale={})",
input_scale, output_scale
),
HybridOp::Div {
denom,
use_range_check_for_int,
} => format!(
"DIV (denom={}, use_range_check_for_int={})",
denom, use_range_check_for_int
),
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
HybridOp::SumPool {
padding,
stride,
Expand Down Expand Up @@ -194,6 +205,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Rsqrt {
input_scale,
output_scale,
} => layouts::rsqrt(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
)?,
HybridOp::Exp { scale } => {
layouts::exp(config, region, values[..].try_into()?, *scale)?
}
HybridOp::Sqrt { scale } => {
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
}
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
Expand Down Expand Up @@ -233,13 +260,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
integer_rep_to_felt(input_scale.0 as i128),
integer_rep_to_felt(output_scale.0 as i128),
)?,
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
if denom.0.fract() == 0.0 && *use_range_check_for_int {
layouts::loop_div(
HybridOp::Div { denom, .. } => {
if denom.0.fract() == 0.0 {
layouts::div(
config,
region,
values[..].try_into()?,
Expand Down Expand Up @@ -330,9 +353,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
HybridOp::Softmax { output_scale, .. }
| HybridOp::Recip { output_scale, .. }
| HybridOp::Rsqrt { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
HybridOp::Ln {
scale: output_scale,
} => 4 * multiplier_to_scale(output_scale.0 as f64),
Expand Down
Loading

0 comments on commit d457988

Please sign in to comment.