From 399dd4b13d9f5fc64f0b5817f504e460f4e376a5 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 22 Aug 2024 14:06:35 +0100 Subject: [PATCH 1/6] feat: Add `fpow` and `fround` float operations --- hugr-core/src/std_extensions/arithmetic/float_ops.rs | 8 ++++++-- .../src/std_extensions/arithmetic/float_ops/const_fold.rs | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index 9fe86ee1f..d87a5cf20 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -39,8 +39,10 @@ pub enum FloatOps { fabs, fmul, fdiv, + fpow, ffloor, fceil, + fround, ftostring, } @@ -60,10 +62,10 @@ impl MakeOpDef for FloatOps { feq | fne | flt | fgt | fle | fge => { Signature::new(type_row![FLOAT64_TYPE; 2], type_row![BOOL_T]) } - fmax | fmin | fadd | fsub | fmul | fdiv => { + fmax | fmin | fadd | fsub | fmul | fdiv | fpow => { Signature::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]) } - fneg | fabs | ffloor | fceil => Signature::new_endo(type_row![FLOAT64_TYPE]), + fneg | fabs | ffloor | fceil | fround => Signature::new_endo(type_row![FLOAT64_TYPE]), ftostring => Signature::new(type_row![FLOAT64_TYPE], STRING_TYPE), } .into() @@ -86,8 +88,10 @@ impl MakeOpDef for FloatOps { fabs => "absolute value", fmul => "multiplication", fdiv => "division", + fpow => "exponentiation", ffloor => "floor", fceil => "ceiling", + fround => "round", ftostring => "string representation", } .to_string() diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops/const_fold.rs b/hugr-core/src/std_extensions/arithmetic/float_ops/const_fold.rs index 974dbe9b6..11310622f 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops/const_fold.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops/const_fold.rs @@ -12,9 +12,11 @@ pub(super) fn set_fold(op: &FloatOps, def: &mut OpDef) { use FloatOps::*; match op { - fmax | fmin | fadd | fsub | fmul | fdiv => def.set_constant_folder(BinaryFold::from_op(op)), + fmax | fmin | fadd | fsub | fmul | fdiv | fpow => { + def.set_constant_folder(BinaryFold::from_op(op)) + } feq | fne | flt | fgt | fle | fge => def.set_constant_folder(CmpFold::from_op(*op)), - fneg | fabs | ffloor | fceil => def.set_constant_folder(UnaryFold::from_op(op)), + fneg | fabs | ffloor | fceil | fround => def.set_constant_folder(UnaryFold::from_op(op)), ftostring => def.set_constant_folder(ToStringFold::from_op(op)), } } @@ -43,6 +45,7 @@ impl BinaryFold { fsub => std::ops::Sub::sub, fmul => std::ops::Mul::mul, fdiv => std::ops::Div::div, + fpow => f64::powf, _ => panic!("not binary op"), })) } @@ -106,6 +109,7 @@ impl UnaryFold { fabs => f64::abs, ffloor => f64::floor, fceil => f64::ceil, + fround => f64::round, _ => panic!("not unary op."), })) } From 52cad200e14189763c479f9980c33cb3873922b6 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 22 Aug 2024 14:46:03 +0100 Subject: [PATCH 2/6] feat: `ipow`, `iu_to_s`, `is_to_u` operations --- .../src/std_extensions/arithmetic/int_ops.rs | 12 ++- .../arithmetic/int_ops/const_fold.rs | 74 +++++++++++++++++++ 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index c3785ca9a..c1ae5bf7a 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -89,6 +89,7 @@ pub enum IntOpDef { idiv_s, imod_checked_s, imod_s, + ipow, iabs, iand, ior, @@ -98,6 +99,8 @@ pub enum IntOpDef { ishr, irotl, irotr, + iu_to_s, + is_to_u, itostring_u, itostring_s, } @@ -130,10 +133,10 @@ impl MakeOpDef for IntOpDef { ieq | ine | ilt_u | ilt_s | igt_u | igt_s | ile_u | ile_s | ige_u | ige_s => { int_polytype(1, vec![tv0; 2], type_row![BOOL_T]).into() } - imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor => { + imax_u | imax_s | imin_u | imin_s | iadd | isub | imul | iand | ior | ixor | ipow => { ibinop_sig().into() } - ineg | iabs | inot => iunop_sig().into(), + ineg | iabs | inot | iu_to_s | is_to_u => iunop_sig().into(), idivmod_checked_u | idivmod_checked_s => { let intpair: TypeRowRV = vec![tv0; 2].into(); int_polytype( @@ -209,6 +212,7 @@ impl MakeOpDef for IntOpDef { idiv_s => "as idivmod_s but discarding the second output", imod_checked_s => "as idivmod_checked_s but discarding the first output", imod_s => "as idivmod_s but discarding the first output", + ipow => "raise first input to the power of second input", iabs => "convert signed to unsigned by taking absolute value", iand => "bitwise AND", ior => "bitwise OR", @@ -222,6 +226,8 @@ impl MakeOpDef for IntOpDef { (leftmost bits replace rightmost bits)", irotr => "rotate first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits replace leftmost bits)", + is_to_u => "convert signed to unsigned by taking absolute value. Panics if the input is negative", + iu_to_s => "convert unsigned to signed by taking absolute value. Panics if the input is too large", itostring_s => "convert a signed integer to its string representation", itostring_u => "convert an unsigned integer to its string representation", }.into() @@ -378,7 +384,7 @@ mod test { fn test_int_ops_extension() { assert_eq!(EXTENSION.name() as &str, "arithmetic.int"); assert_eq!(EXTENSION.types().count(), 0); - assert_eq!(EXTENSION.operations().count(), 47); + assert_eq!(EXTENSION.operations().count(), 50); for (name, _) in EXTENSION.operations() { assert!(name.starts_with('i')); } diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs b/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs index 34725677f..76b683bd7 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops/const_fold.rs @@ -587,6 +587,36 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { }, ), }, + IntOpDef::ipow => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let (n0, n1): (&ConstInt, &ConstInt) = get_pair_of_input_values(consts)?; + if n0.log_width() != logwidth || n1.log_width() != logwidth { + None + } else { + Some(vec![( + 0.into(), + Value::extension( + ConstInt::new_u( + logwidth, + n0.value_u() + .overflowing_pow( + n1.value_u().try_into().unwrap_or(u32::MAX), + ) + .0 + & bitmask_from_logwidth(logwidth), + ) + .unwrap(), + ), + )]) + } + }, + ), + }, IntOpDef::idivmod_checked_u => Folder { folder: Box::new( |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { @@ -1154,6 +1184,50 @@ pub(super) fn set_fold(op: &IntOpDef, def: &mut OpDef) { }, ), }, + IntOpDef::is_to_u => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + if n0.log_width() != logwidth { + None + } else { + if n0.value_s() < 0 { + panic!( + "Cannot convert negative integer {} to unsigned.", + n0.value_s() + ); + } + Some(vec![(0.into(), Value::extension(n0.clone()))]) + } + }, + ), + }, + IntOpDef::iu_to_s => Folder { + folder: Box::new( + |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { + let [arg] = type_args else { + return None; + }; + let logwidth: u8 = get_log_width(arg).ok()?; + let n0: &ConstInt = get_single_input_value(consts)?; + if n0.log_width() != logwidth { + None + } else { + if n0.value_s() < 0 { + panic!( + "Unsigned integer {} is too large to be converted to signed.", + n0.value_u() + ); + } + Some(vec![(0.into(), Value::extension(n0.clone()))]) + } + }, + ), + }, IntOpDef::itostring_u => Folder { folder: Box::new( |type_args: &[TypeArg], consts: &[(IncomingPort, Value)]| -> ConstFoldResult { From 1f817e1aa7ceecaf75573f482878aecc05611207 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 22 Aug 2024 14:59:46 +0100 Subject: [PATCH 3/6] doc: Mention that `ifrombool` / `itobool` only work for 1-bit ints --- hugr-core/src/std_extensions/arithmetic/int_ops.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index c1ae5bf7a..c848598f3 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -119,12 +119,12 @@ impl MakeOpDef for IntOpDef { let tv0 = int_tv(0); match self { iwiden_s | iwiden_u => CustomValidator::new( - int_polytype(2, vec![tv0.clone()], vec![int_tv(1)]), + int_polytype(2, vec![tv0], vec![int_tv(1)]), IOValidator { f_ge_s: false }, ) .into(), inarrow_s | inarrow_u => CustomValidator::new( - int_polytype(2, tv0.clone(), sum_ty_with_err(int_tv(1))), + int_polytype(2, tv0, sum_ty_with_err(int_tv(1))), IOValidator { f_ge_s: true }, ) .into(), @@ -176,8 +176,8 @@ impl MakeOpDef for IntOpDef { iwiden_s => "widen a signed integer to a wider one with the same value", inarrow_u => "narrow an unsigned integer to a narrower one with the same value if possible", inarrow_s => "narrow a signed integer to a narrower one with the same value if possible", - itobool => "convert to bool (1 is true, 0 is false)", - ifrombool => "convert from bool (1 is true, 0 is false)", + itobool => "convert a 1-bit integer to bool (1 is true, 0 is false)", + ifrombool => "convert from bool into a 1-bit integer (1 is true, 0 is false)", ieq => "equality test", ine => "inequality test", ilt_u => "\"less than\" as unsigned integers", @@ -226,8 +226,8 @@ impl MakeOpDef for IntOpDef { (leftmost bits replace rightmost bits)", irotr => "rotate first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits replace leftmost bits)", - is_to_u => "convert signed to unsigned by taking absolute value. Panics if the input is negative", - iu_to_s => "convert unsigned to signed by taking absolute value. Panics if the input is too large", + is_to_u => "convert signed to unsigned by taking absolute value", + iu_to_s => "convert unsigned to signed by taking absolute value", itostring_s => "convert a signed integer to its string representation", itostring_u => "convert an unsigned integer to its string representation", }.into() From 6100512119e99b4bc30919afcc84efe705c11f05 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 22 Aug 2024 15:07:19 +0100 Subject: [PATCH 4/6] Update extension defs --- .../hugr/std/_json_defs/arithmetic/float.json | 67 +++++++ .../hugr/std/_json_defs/arithmetic/int.json | 179 +++++++++++++++++- .../std_extensions/arithmetic/float.json | 67 +++++++ .../std_extensions/arithmetic/int.json | 179 +++++++++++++++++- 4 files changed, 488 insertions(+), 4 deletions(-) diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json index 8563fe57b..8bd9f3268 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json @@ -522,6 +522,73 @@ }, "binary": false }, + "fpow": { + "extension": "arithmetic.float", + "name": "fpow", + "description": "exponentiation", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "arithmetic.float.types", + "id": "float64", + "args": [], + "bound": "C" + }, + { + "t": "Opaque", + "extension": "arithmetic.float.types", + "id": "float64", + "args": [], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "arithmetic.float.types", + "id": "float64", + "args": [], + "bound": "C" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, + "fround": { + "extension": "arithmetic.float", + "name": "fround", + "description": "round", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "arithmetic.float.types", + "id": "float64", + "args": [], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "arithmetic.float.types", + "id": "float64", + "args": [], + "bound": "C" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "fsub": { "extension": "arithmetic.float", "name": "fsub", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json index 6eb546736..a6cc862f6 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json @@ -959,7 +959,7 @@ "ifrombool": { "extension": "arithmetic.int", "name": "ifrombool", - "description": "convert from bool (1 is true, 0 is false)", + "description": "convert from bool into a 1-bit integer (1 is true, 0 is false)", "signature": { "params": [], "body": { @@ -2489,6 +2489,75 @@ }, "binary": false }, + "ipow": { + "extension": "arithmetic.int", + "name": "ipow", + "description": "raise first input to the power of second input", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": 7 + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + }, + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "irotl": { "extension": "arithmetic.int", "name": "irotl", @@ -2627,6 +2696,59 @@ }, "binary": false }, + "is_to_u": { + "extension": "arithmetic.int", + "name": "is_to_u", + "description": "convert signed to unsigned by taking absolute value", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": 7 + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "ishl": { "extension": "arithmetic.int", "name": "ishl", @@ -2837,7 +2959,7 @@ "itobool": { "extension": "arithmetic.int", "name": "itobool", - "description": "convert to bool (1 is true, 0 is false)", + "description": "convert a 1-bit integer to bool (1 is true, 0 is false)", "signature": { "params": [], "body": { @@ -2955,6 +3077,59 @@ }, "binary": false }, + "iu_to_s": { + "extension": "arithmetic.int", + "name": "iu_to_s", + "description": "convert unsigned to signed by taking absolute value", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": 7 + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "iwiden_s": { "extension": "arithmetic.int", "name": "iwiden_s", diff --git a/specification/std_extensions/arithmetic/float.json b/specification/std_extensions/arithmetic/float.json index 8563fe57b..8bd9f3268 100644 --- a/specification/std_extensions/arithmetic/float.json +++ b/specification/std_extensions/arithmetic/float.json @@ -522,6 +522,73 @@ }, "binary": false }, + "fpow": { + "extension": "arithmetic.float", + "name": "fpow", + "description": "exponentiation", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "arithmetic.float.types", + "id": "float64", + "args": [], + "bound": "C" + }, + { + "t": "Opaque", + "extension": "arithmetic.float.types", + "id": "float64", + "args": [], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "arithmetic.float.types", + "id": "float64", + "args": [], + "bound": "C" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, + "fround": { + "extension": "arithmetic.float", + "name": "fround", + "description": "round", + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "arithmetic.float.types", + "id": "float64", + "args": [], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "arithmetic.float.types", + "id": "float64", + "args": [], + "bound": "C" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "fsub": { "extension": "arithmetic.float", "name": "fsub", diff --git a/specification/std_extensions/arithmetic/int.json b/specification/std_extensions/arithmetic/int.json index 6eb546736..a6cc862f6 100644 --- a/specification/std_extensions/arithmetic/int.json +++ b/specification/std_extensions/arithmetic/int.json @@ -959,7 +959,7 @@ "ifrombool": { "extension": "arithmetic.int", "name": "ifrombool", - "description": "convert from bool (1 is true, 0 is false)", + "description": "convert from bool into a 1-bit integer (1 is true, 0 is false)", "signature": { "params": [], "body": { @@ -2489,6 +2489,75 @@ }, "binary": false }, + "ipow": { + "extension": "arithmetic.int", + "name": "ipow", + "description": "raise first input to the power of second input", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": 7 + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + }, + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "irotl": { "extension": "arithmetic.int", "name": "irotl", @@ -2627,6 +2696,59 @@ }, "binary": false }, + "is_to_u": { + "extension": "arithmetic.int", + "name": "is_to_u", + "description": "convert signed to unsigned by taking absolute value", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": 7 + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "ishl": { "extension": "arithmetic.int", "name": "ishl", @@ -2837,7 +2959,7 @@ "itobool": { "extension": "arithmetic.int", "name": "itobool", - "description": "convert to bool (1 is true, 0 is false)", + "description": "convert a 1-bit integer to bool (1 is true, 0 is false)", "signature": { "params": [], "body": { @@ -2955,6 +3077,59 @@ }, "binary": false }, + "iu_to_s": { + "extension": "arithmetic.int", + "name": "iu_to_s", + "description": "convert unsigned to signed by taking absolute value", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": 7 + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "arithmetic.int.types", + "id": "int", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": 7 + } + } + ], + "bound": "C" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "iwiden_s": { "extension": "arithmetic.int", "name": "iwiden_s", From 9da92fdc1337e1edb3c6a7b08cf0af08f6bdf749 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 22 Aug 2024 15:44:57 +0100 Subject: [PATCH 5/6] Add constant folding tests --- .../std_extensions/arithmetic/float_ops.rs | 46 +++++++++++++++++ .../src/std_extensions/arithmetic/int_ops.rs | 50 +++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index d87a5cf20..8ef8850a8 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -137,6 +137,9 @@ impl MakeRegisteredOp for FloatOps { #[cfg(test)] mod test { + use cgmath::AbsDiffEq; + use rstest::rstest; + use super::*; #[test] @@ -148,4 +151,47 @@ mod test { assert!(name.as_str().starts_with('f')); } } + + #[rstest] + #[case::fadd(FloatOps::fadd, &[0.1, 0.2], &[0.30000000000000004])] + #[case::fsub(FloatOps::fsub, &[1., 2.], &[-1.])] + #[case::fmul(FloatOps::fmul, &[2., 3.], &[6.])] + #[case::fdiv(FloatOps::fdiv, &[7., 2.], &[3.5])] + #[case::fpow(FloatOps::fpow, &[0.5, 3.], &[0.125])] + #[case::ffloor(FloatOps::ffloor, &[42.42], &[42.])] + #[case::fceil(FloatOps::fceil, &[42.42], &[43.])] + #[case::fround(FloatOps::fround, &[42.42], &[42.])] + fn float_fold(#[case] op: FloatOps, #[case] inputs: &[f64], #[case] outputs: &[f64]) { + use crate::ops::Value; + use crate::std_extensions::arithmetic::float_types::ConstF64; + + let consts: Vec<_> = inputs + .iter() + .enumerate() + .map(|(i, &x)| (i.into(), Value::extension(ConstF64::new(x)))) + .collect(); + + let res = op + .to_extension_op() + .unwrap() + .constant_fold(&consts) + .unwrap(); + + for (i, expected) in outputs.iter().enumerate() { + let res_val: f64 = res + .get(i) + .unwrap() + .1 + .get_custom_value::() + .expect("This function assumes all incoming constants are floats.") + .value(); + + assert!( + res_val.abs_diff_eq(expected, f64::EPSILON), + "expected {:?}, got {:?}", + expected, + res_val + ); + } + } } diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index c848598f3..0bfdafb17 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -372,6 +372,8 @@ fn sum_ty_with_err(t: Type) -> Type { #[cfg(test)] mod test { + use rstest::rstest; + use crate::{ ops::{dataflow::DataflowOpTrait, ExtensionOp}, std_extensions::arithmetic::int_types::int_type, @@ -456,4 +458,52 @@ mod test { assert_eq!(ConcreteIntOp::from_op(&ext_op).unwrap(), o); assert_eq!(IntOpDef::from_op(&ext_op).unwrap(), IntOpDef::itobool); } + + #[rstest] + #[case::iadd(IntOpDef::iadd.with_log_width(5), &[1, 2], &[3], 5)] + #[case::isub(IntOpDef::isub.with_log_width(5), &[5, 2], &[3], 5)] + #[case::imul(IntOpDef::imul.with_log_width(5), &[2, 8], &[16], 5)] + #[case::idiv(IntOpDef::idiv_u.with_log_width(5), &[37, 8], &[4], 5)] + #[case::imod(IntOpDef::imod_u.with_log_width(5), &[43, 8], &[3], 5)] + #[case::ipow(IntOpDef::ipow.with_log_width(5), &[2, 8], &[256], 5)] + #[case::iu_to_s(IntOpDef::iu_to_s.with_log_width(5), &[42], &[42], 5)] + #[case::is_to_u(IntOpDef::is_to_u.with_log_width(5), &[42], &[42], 5)] + fn int_fold( + #[case] op: ConcreteIntOp, + #[case] inputs: &[u64], + #[case] outputs: &[u64], + #[case] log_width: u8, + ) { + use crate::ops::Value; + use crate::std_extensions::arithmetic::int_types::ConstInt; + + let consts: Vec<_> = inputs + .iter() + .enumerate() + .map(|(i, &x)| { + ( + i.into(), + Value::extension(ConstInt::new_u(log_width, x).unwrap()), + ) + }) + .collect(); + + let res = op + .to_extension_op() + .unwrap() + .constant_fold(&consts) + .unwrap(); + + for (i, &expected) in outputs.iter().enumerate() { + let res_val: u64 = res + .get(i) + .unwrap() + .1 + .get_custom_value::() + .expect("This function assumes all incoming constants are floats.") + .value_u(); + + assert_eq!(res_val, expected); + } + } } From dce1f7cbf0e99dfcc61d8d49e79190796de045c7 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 22 Aug 2024 16:05:01 +0100 Subject: [PATCH 6/6] Test panic paths too --- hugr-core/src/std_extensions/arithmetic/int_ops.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 0bfdafb17..cae7627eb 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -468,6 +468,10 @@ mod test { #[case::ipow(IntOpDef::ipow.with_log_width(5), &[2, 8], &[256], 5)] #[case::iu_to_s(IntOpDef::iu_to_s.with_log_width(5), &[42], &[42], 5)] #[case::is_to_u(IntOpDef::is_to_u.with_log_width(5), &[42], &[42], 5)] + #[should_panic(expected = "too large to be converted to signed")] + #[case::iu_to_s_panic(IntOpDef::iu_to_s.with_log_width(5), &[u32::MAX as u64], &[], 5)] + #[should_panic(expected = "Cannot convert negative integer")] + #[case::is_to_u_panic(IntOpDef::is_to_u.with_log_width(5), &[(0u32.wrapping_sub(42)) as u64], &[], 5)] fn int_fold( #[case] op: ConcreteIntOp, #[case] inputs: &[u64],