diff --git a/src/rust/lib_gen.rs b/src/rust/lib_gen.rs index 0c8abef..80e9da5 100644 --- a/src/rust/lib_gen.rs +++ b/src/rust/lib_gen.rs @@ -15,6 +15,7 @@ use crate::printer::*; use crate::rust::printer::*; use crate::rust::types::{escape_keywords, RustPrinter}; +use convert_case::{Case, Casing}; use itertools::Itertools; #[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)] @@ -39,20 +40,24 @@ pub struct ModuleName { } impl ModuleName { + pub fn name(&self) -> String { + self.name.clone() + } + fn code(&self) -> RustPrinter { line(unit() + self.verbosity.render() + "mod " + escape_keywords(&self.name) + ";") } - pub fn new>(s: S) -> ModuleName { + pub fn new(s: impl AsRef) -> ModuleName { ModuleName { - name: s.into(), + name: Self::escape_type_params(s.as_ref()).to_case(Case::Snake), verbosity: Verbosity::Default, } } - pub fn new_pub>(s: S) -> ModuleName { + pub fn new_pub(s: impl AsRef) -> ModuleName { ModuleName { - name: s.into(), + name: Self::escape_type_params(s.as_ref()).to_case(Case::Snake), verbosity: Verbosity::Pub, } } @@ -60,6 +65,13 @@ impl ModuleName { pub fn file_name(&self) -> String { format!("{}.rs", &self.name) } + + fn escape_type_params(s: &str) -> String { + s.replace("<", "_") + .replace(",", "_") + .replace(">", "_") + .replace(" ", "") + } } #[derive(Debug, Clone)] diff --git a/src/rust/model_gen.rs b/src/rust/model_gen.rs index 3056fb3..05f65cc 100644 --- a/src/rust/model_gen.rs +++ b/src/rust/model_gen.rs @@ -385,7 +385,8 @@ pub fn model_gen( "Unexpected reference format: {reference}." )))?; - let name = original_name.to_case(Case::UpperCamel); + let mod_name = ModuleName::new(original_name); + let name = mod_name.name().to_case(Case::UpperCamel); let schema = schemas.get(original_name).ok_or(Error::unexpected(format!( "Can't find schema by reference {original_name}" @@ -584,10 +585,11 @@ pub fn model_gen( } }; + let name = ModuleName::new(name); Ok(Module { def: ModuleDef { - name: ModuleName::new(name.to_case(Case::Snake)), - exports: vec![name], + name: name.clone(), + exports: vec![name.name().to_case(Case::Pascal)], }, code: RustContext::new().print_to_string(code?), }) diff --git a/src/rust/types.rs b/src/rust/types.rs index 93a0e2e..23b451e 100644 --- a/src/rust/types.rs +++ b/src/rust/types.rs @@ -13,6 +13,7 @@ // limitations under the License. use crate::printer::TreePrinter; +use crate::rust::lib_gen::ModuleName; use crate::rust::model_gen::RefCache; use crate::rust::printer::{rust_name, unit, RustContext}; use crate::{Error, Result}; @@ -120,7 +121,8 @@ impl DataType { } } DataType::Model(ModelType { name }) => { - let model_type = rust_name("crate::model", name); + let name = ModuleName::new(name); + let model_type = rust_name("crate::model", &name.name().to_case(Case::Pascal)); to_ref(model_type, top_param) } DataType::Array(item) => {