From 38397d346aac4ec19026ede2776e776fdceb854c Mon Sep 17 00:00:00 2001 From: jfecher Date: Tue, 6 Aug 2024 14:55:37 -0500 Subject: [PATCH] feat: Derive `Ord` and `Hash` in the stdlib; add `std::meta::make_impl` helper (#5683) # Description ## Problem\* ## Summary\* `Ord` and `Hash` are the last two traits in the stdlib that can be derived - and now we can. I've also added `std::meta::make_impl` so that there's not so much repeated code for each of this function. This also makes writing these derive functions somewhat easier for users. ## Additional Context ## Documentation\* Check one: - [ ] No documentation needed. - [ ] Documentation included in this PR. - [x] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: Michael J Klein --- noir_stdlib/src/cmp.nr | 42 +++++++++---------- noir_stdlib/src/default.nr | 27 +++--------- noir_stdlib/src/hash/mod.nr | 11 ++++- noir_stdlib/src/meta/mod.nr | 42 +++++++++++++++++++ .../execution_success/derive/src/main.nr | 30 ++++++++++++- 5 files changed, 105 insertions(+), 47 deletions(-) diff --git a/noir_stdlib/src/cmp.nr b/noir_stdlib/src/cmp.nr index 94cd284e238..10182ca83b0 100644 --- a/noir_stdlib/src/cmp.nr +++ b/noir_stdlib/src/cmp.nr @@ -8,28 +8,10 @@ trait Eq { // docs:end:eq-trait comptime fn derive_eq(s: StructDefinition) -> Quoted { - let typ = s.as_type(); - - let impl_generics = s.generics().map(|g| quote { $g }).join(quote {,}); - - let where_clause = s.generics().map(|name| quote { $name: Eq }).join(quote {,}); - - // `(self.a == other.a) & (self.b == other.b) & ...` - let equalities = s.fields().map( - |f: (Quoted, Type)| { - let name = f.0; - quote { (self.$name == other.$name) } - } - ); - let body = equalities.join(quote { & }); - - quote { - impl<$impl_generics> Eq for $typ where $where_clause { - fn eq(self, other: Self) -> bool { - $body - } - } - } + let signature = quote { fn eq(_self: Self, _other: Self) -> bool }; + let for_each_field = |name| quote { (_self.$name == _other.$name) }; + let body = |fields| fields; + crate::meta::make_trait_impl(s, quote { Eq }, signature, for_each_field, quote { & }, body) } impl Eq for Field { fn eq(self, other: Field) -> bool { self == other } } @@ -127,12 +109,28 @@ impl Ordering { } } +#[derive_via(derive_ord)] // docs:start:ord-trait trait Ord { fn cmp(self, other: Self) -> Ordering; } // docs:end:ord-trait +comptime fn derive_ord(s: StructDefinition) -> Quoted { + let signature = quote { fn cmp(_self: Self, _other: Self) -> std::cmp::Ordering }; + let for_each_field = |name| quote { + if result == std::cmp::Ordering::equal() { + result = _self.$name.cmp(_other.$name); + } + }; + let body = |fields| quote { + let mut result = std::cmp::Ordering::equal(); + $fields + result + }; + crate::meta::make_trait_impl(s, quote { Ord }, signature, for_each_field, quote {}, body) +} + // Note: Field deliberately does not implement Ord impl Ord for u64 { diff --git a/noir_stdlib/src/default.nr b/noir_stdlib/src/default.nr index 4fbde09b512..f9399bfb865 100644 --- a/noir_stdlib/src/default.nr +++ b/noir_stdlib/src/default.nr @@ -8,28 +8,11 @@ trait Default { // docs:end:default-trait comptime fn derive_default(s: StructDefinition) -> Quoted { - let typ = s.as_type(); - - let impl_generics = s.generics().map(|g| quote { $g }).join(quote {,}); - - let where_clause = s.generics().map(|name| quote { $name: Default }).join(quote {,}); - - // `foo: Default::default(), bar: Default::default(), ...` - let fields = s.fields().map( - |f: (Quoted, Type)| { - let name = f.0; - quote { $name: Default::default() } - } - ); - let fields = fields.join(quote {,}); - - quote { - impl<$impl_generics> Default for $typ where $where_clause { - fn default() -> Self { - Self { $fields } - } - } - } + let name = quote { Default }; + let signature = quote { fn default() -> Self }; + let for_each_field = |name| quote { $name: Default::default() }; + let body = |fields| quote { Self { $fields } }; + crate::meta::make_trait_impl(s, name, signature, for_each_field, quote { , }, body) } impl Default for Field { fn default() -> Field { 0 } } diff --git a/noir_stdlib/src/hash/mod.nr b/noir_stdlib/src/hash/mod.nr index 8e9fe75d982..84ad7c22bb3 100644 --- a/noir_stdlib/src/hash/mod.nr +++ b/noir_stdlib/src/hash/mod.nr @@ -8,6 +8,7 @@ use crate::uint128::U128; use crate::sha256::{digest, sha256_var}; use crate::collections::vec::Vec; use crate::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul, multi_scalar_mul_slice}; +use crate::meta::derive_via; #[foreign(sha256)] // docs:start:sha256 @@ -141,10 +142,18 @@ pub fn sha256_compression(_input: [u32; 16], _state: [u32; 8]) -> [u32; 8] {} // Partially ported and impacted by rust. // Hash trait shall be implemented per type. -trait Hash{ +#[derive_via(derive_hash)] +trait Hash { fn hash(self, state: &mut H) where H: Hasher; } +comptime fn derive_hash(s: StructDefinition) -> Quoted { + let name = quote { Hash }; + let signature = quote { fn hash(_self: Self, _state: &mut H) where H: std::hash::Hasher }; + let for_each_field = |name| quote { _self.$name.hash(_state); }; + crate::meta::make_trait_impl(s, name, signature, for_each_field, quote {}, |fields| fields) +} + // Hasher trait shall be implemented by algorithms to provide hash-agnostic means. // TODO: consider making the types generic here ([u8], [Field], etc.) trait Hasher{ diff --git a/noir_stdlib/src/meta/mod.nr b/noir_stdlib/src/meta/mod.nr index 615b4e5aa14..ec264ed8022 100644 --- a/noir_stdlib/src/meta/mod.nr +++ b/noir_stdlib/src/meta/mod.nr @@ -42,3 +42,45 @@ pub comptime fn derive(s: StructDefinition, traits: [TraitDefinition]) -> Quoted unconstrained pub comptime fn derive_via(t: TraitDefinition, f: DeriveFunction) { HANDLERS.insert(t, f); } + +/// `make_impl` is a helper function to make a simple impl, usually while deriving a trait. +/// This impl has a couple assumptions: +/// 1. The impl only has one function, with the signature `function_signature` +/// 2. The trait itself does not have any generics. +/// +/// While these assumptions are met, `make_impl` will create an impl from a StructDefinition, +/// automatically filling in the required generics from the struct, along with the where clause. +/// The function body is created by mapping each field with `for_each_field` and joining the +/// results with `join_fields_with`. The result of this is passed to the `body` function for +/// any final processing - e.g. wrapping each field in a `StructConstructor { .. }` expression. +/// +/// See `derive_eq` and `derive_default` for example usage. +pub comptime fn make_trait_impl( + s: StructDefinition, + trait_name: Quoted, + function_signature: Quoted, + for_each_field: fn[Env1](Quoted) -> Quoted, + join_fields_with: Quoted, + body: fn[Env2](Quoted) -> Quoted +) -> Quoted { + let typ = s.as_type(); + let impl_generics = s.generics().map(|g| quote { $g }).join(quote {,}); + let where_clause = s.generics().map(|name| quote { $name: $trait_name }).join(quote {,}); + + // `for_each_field(field1) $join_fields_with for_each_field(field2) $join_fields_with ...` + let fields = s.fields().map( + |f: (Quoted, Type)| { + let name = f.0; + for_each_field(name) + } + ); + let body = body(fields.join(join_fields_with)); + + quote { + impl<$impl_generics> $trait_name for $typ where $where_clause { + $function_signature { + $body + } + } + } +} diff --git a/test_programs/execution_success/derive/src/main.nr b/test_programs/execution_success/derive/src/main.nr index f344defe41e..5ec2fb32a79 100644 --- a/test_programs/execution_success/derive/src/main.nr +++ b/test_programs/execution_success/derive/src/main.nr @@ -1,3 +1,5 @@ +use std::hash::Hash; + #[derive_via(derive_do_nothing)] trait DoNothing { fn do_nothing(self); @@ -20,14 +22,15 @@ comptime fn derive_do_nothing(s: StructDefinition) -> Quoted { } // Test stdlib derive fns & multiple traits -#[derive(Eq, Default)] +// - We can derive Ord and Hash even though std::cmp::Ordering and std::hash::Hasher aren't imported +#[derive(Eq, Default, Hash, Ord)] struct MyOtherStruct { field1: A, field2: B, field3: MyOtherOtherStruct, } -#[derive(Eq, Default)] +#[derive(Eq, Default, Hash, Ord)] struct MyOtherOtherStruct { x: T, } @@ -41,4 +44,27 @@ fn main() { let o: MyOtherStruct]> = MyOtherStruct::default(); assert_eq(o, o); + + // Field & str<2> above don't implement Ord + let o1 = MyOtherStruct { field1: 12 as u32, field2: 24 as i8, field3: MyOtherOtherStruct { x: 54 as i8 } }; + let o2 = MyOtherStruct { field1: 12 as u32, field2: 24 as i8, field3: MyOtherOtherStruct { x: 55 as i8 } }; + assert(o1 < o2); + + let mut hasher = TestHasher { result: 0 }; + o1.hash(&mut hasher); + assert_eq(hasher.finish(), 12 + 24 + 54); +} + +struct TestHasher { + result: Field, +} + +impl std::hash::Hasher for TestHasher { + fn finish(self) -> Field { + self.result + } + + fn write(&mut self, input: Field) { + self.result += input; + } }