diff --git a/pb-rs/src/lib.rs b/pb-rs/src/lib.rs index c23d01e..f060207 100644 --- a/pb-rs/src/lib.rs +++ b/pb-rs/src/lib.rs @@ -56,6 +56,7 @@ pub struct ConfigBuilder { error_cycle: bool, headers: bool, dont_use_cow: bool, + custom_enum_derive: Vec, custom_struct_derive: Vec, custom_repr: Option, owned: bool, @@ -146,6 +147,12 @@ impl ConfigBuilder { self } + /// Add custom values to `#[derive(...)]` at the beginning of every enum + pub fn custom_enum_derive(mut self, val: Vec) -> Self { + self.custom_enum_derive = val; + self + } + /// Add custom values to `#[derive(...)]` at the beginning of every structure pub fn custom_struct_derive(mut self, val: Vec) -> Self { self.custom_struct_derive = val; @@ -226,6 +233,7 @@ impl ConfigBuilder { error_cycle: self.error_cycle, headers: self.headers, dont_use_cow: self.dont_use_cow, //Change this to true to not use cow with ./generate.sh for v2 and v3 tests + custom_enum_derive: self.custom_enum_derive.clone(), custom_struct_derive: self.custom_struct_derive.clone(), custom_repr: self.custom_repr.clone(), custom_rpc_generator: Box::new(|_, _| Ok(())), diff --git a/pb-rs/src/main.rs b/pb-rs/src/main.rs index 94a6305..36a7815 100644 --- a/pb-rs/src/main.rs +++ b/pb-rs/src/main.rs @@ -62,6 +62,13 @@ fn run() -> Result<(), Error> { .short("H") .required(false) .help("Do not add module comments and module attributes in generated file"), + ).arg( + Arg::with_name("CUSTOM_ENUM_DERIVE") + .long("custom_enum_derive") + .short("E") + .required(false) + .takes_value(true) + .help("The comma separated values to add to #[derive(...)] for every enum"), ).arg( Arg::with_name("CUSTOM_STRUCT_DERIVE") .long("custom_struct_derive") @@ -125,6 +132,12 @@ fn run() -> Result<(), Error> { .split(',') .map(|s| s.to_string()) .collect(); + let custom_enum_derive: Vec = matches + .value_of("CUSTOM_ENUM_DERIVE") + .unwrap_or("") + .split(',') + .map(|s| s.to_string()) + .collect(); let compiler = ConfigBuilder::new( &in_files, @@ -138,6 +151,7 @@ fn run() -> Result<(), Error> { .headers(!matches.is_present("NO_HEADERS")) .dont_use_cow(matches.is_present("DONT_USE_COW")) .custom_struct_derive(custom_struct_derive) + .custom_enum_derive(custom_enum_derive) .nostd(matches.is_present("NOSTD")) .hashbrown(matches.is_present("HASHBROWN")) .gen_info(matches.is_present("GEN_INFO")) diff --git a/pb-rs/src/types.rs b/pb-rs/src/types.rs index 5459689..2ed1327 100644 --- a/pb-rs/src/types.rs +++ b/pb-rs/src/types.rs @@ -1563,7 +1563,7 @@ impl Message { m.write(w, desc, config)?; } for e in &self.enums { - e.write(w)?; + e.write(w, config)?; } for o in &self.oneofs { o.write(w, desc, config)?; @@ -1582,23 +1582,23 @@ impl Message { desc: &FileDescriptor, config: &Config, ) -> Result<()> { - let mut custom_struct_derive = config.custom_struct_derive.join(", "); - - if !self.must_generate_impl_default(desc, config) { - custom_struct_derive += "Default"; - } - - if !custom_struct_derive.is_empty() { - custom_struct_derive += ", "; - } + let derives = if self.must_generate_impl_default(desc, config) { + vec!["Debug", "PartialEq", "Clone"] + } else { + vec!["Default", "Debug", "PartialEq", "Clone"] + }; + let derives_str = config + .custom_struct_derive + .iter() + .map(|s| s.as_str()) + .filter(|s| !s.is_empty()) + .chain(derives.into_iter()) + .collect::>() + .join(", "); writeln!(w, "#[allow(clippy::derive_partial_eq_without_eq)]")?; - writeln!( - w, - "#[derive({}Debug, PartialEq, Clone)]", - custom_struct_derive - )?; + writeln!(w, "#[derive({derives_str})]")?; if let Some(repr) = &config.custom_repr { writeln!(w, "#[repr({})]", repr)?; @@ -2163,10 +2163,10 @@ impl Enumerator { get_modules(&self.module, self.imported, desc) } - fn write(&self, w: &mut W) -> Result<()> { + fn write(&self, w: &mut W, config: &Config) -> Result<()> { println!("Writing enum {}", self.name); writeln!(w)?; - self.write_definition(w)?; + self.write_definition(w, config)?; writeln!(w)?; if self.fields.is_empty() { Ok(()) @@ -2179,8 +2179,18 @@ impl Enumerator { } } - fn write_definition(&self, w: &mut W) -> Result<()> { - writeln!(w, "#[derive(Debug, PartialEq, Eq, Clone, Copy)]")?; + fn write_definition(&self, w: &mut W, config: &Config) -> Result<()> { + let mut custom_enum_derive = config.custom_enum_derive.join(", "); + + if !custom_enum_derive.is_empty() { + custom_enum_derive += ", "; + } + + writeln!( + w, + "#[derive({}Debug, PartialEq, Eq, Clone, Copy)]", + custom_enum_derive + )?; writeln!(w, "pub enum {} {{", self.name)?; for (f, number) in &self.fields { writeln!(w, " {} = {},", f, number)?; @@ -2544,6 +2554,7 @@ pub struct Config { pub error_cycle: bool, pub headers: bool, pub dont_use_cow: bool, + pub custom_enum_derive: Vec, pub custom_struct_derive: Vec, pub custom_repr: Option, pub custom_rpc_generator: RpcGeneratorFunction, @@ -3030,7 +3041,7 @@ impl FileDescriptor { self.write_package_start(w)?; self.write_uses(w, config)?; self.write_imports(w)?; - self.write_enums(w)?; + self.write_enums(w, config)?; self.write_messages(w, config)?; self.write_rpc_services(w, config)?; self.write_package_end(w)?; @@ -3116,11 +3127,11 @@ impl FileDescriptor { Ok(()) } - fn write_enums(&self, w: &mut W) -> Result<()> { + fn write_enums(&self, w: &mut W, config: &Config) -> Result<()> { for m in self.enums.iter().filter(|e| !e.imported) { println!("Writing enum {}", m.name); writeln!(w)?; - m.write_definition(w)?; + m.write_definition(w, config)?; writeln!(w)?; m.write_impl_default(w)?; writeln!(w)?; diff --git a/perftest/build.rs b/perftest/build.rs index d8c80a6..7fa1eba 100644 --- a/perftest/build.rs +++ b/perftest/build.rs @@ -54,6 +54,7 @@ fn main() { error_cycle: false, headers: false, dont_use_cow: false, + custom_enum_derive: vec![], custom_struct_derive: vec![], custom_repr: None, custom_rpc_generator: Box::new(|rpc, writer| generate_rpc_test(rpc, writer)),