diff --git a/src/lib.rs b/src/lib.rs index 191839d..da08b92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::fs::File; use std::io::BufReader; use std::path::{Path, PathBuf}; @@ -68,7 +69,10 @@ pub fn gen( version: &str, overwrite_cargo: bool, disable_clippy: bool, + mapping: &[(&str, &str)], ) -> Result<()> { + let mapping: HashMap<&str, &str> = HashMap::from_iter(mapping.iter().cloned()); + let open_api = merge_all_openapi_specs(openapi_specs)?; let src = target.join("src"); @@ -113,7 +117,7 @@ pub fn gen( for ref_str in ref_cache.refs { if !known_refs.refs.contains(&ref_str) { let model_file = - rust::model_gen::model_gen(&ref_str, &open_api, &mut next_ref_cache)?; + rust::model_gen::model_gen(&ref_str, &open_api, &mapping, &mut next_ref_cache)?; std::fs::write(model.join(model_file.def.name.file_name()), model_file.code) .unwrap(); models.push(model_file.def); diff --git a/src/main.rs b/src/main.rs index 5f439d4..06db5fb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -64,6 +64,7 @@ fn main() { &args.client_version, true, false, + &[], ) .unwrap(); } diff --git a/src/rust/model_gen.rs b/src/rust/model_gen.rs index 1d69bbe..0d930fa 100644 --- a/src/rust/model_gen.rs +++ b/src/rust/model_gen.rs @@ -24,7 +24,7 @@ use convert_case::{Case, Casing}; use openapiv3::{ AnySchema, Discriminator, OpenAPI, ReferenceOr, Schema, SchemaData, SchemaKind, Type, }; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; #[derive(Debug, Clone)] pub struct RefCache { @@ -367,7 +367,12 @@ pub fn multipart_field_module() -> Result { }) } -pub fn model_gen(reference: &str, open_api: &OpenAPI, ref_cache: &mut RefCache) -> Result { +pub fn model_gen( + reference: &str, + open_api: &OpenAPI, + mapping: &HashMap<&str, &str>, + ref_cache: &mut RefCache, +) -> Result { let schemas = &open_api .components .as_ref() @@ -391,191 +396,208 @@ pub fn model_gen(reference: &str, open_api: &OpenAPI, ref_cache: &mut RefCache) "Direct cross reference in {reference}" )))?; - let code = match &schema.schema_kind { - SchemaKind::Type(tpe) => match tpe { - Type::String(string_type) => { - if string_type.enumeration.is_empty() { - Err(Error::unimplemented(format!( - "String schema without enum {reference}" - ))) - } else if string_type.enumeration.contains(&None) { - Err(Error::unimplemented(format!( - "String schema enum with empty string {reference}" - ))) - } else { - fn make_case(name: &str) -> RustPrinter { - let rust_name = name.to_case(Case::UpperCamel); + let code = if let Some(mapped_type) = mapping.get(name.as_str()) { + Ok(unit() + line(unit() + "pub type " + &name + " = " + *mapped_type + ";")) + } else { + match &schema.schema_kind { + SchemaKind::Type(tpe) => match tpe { + Type::String(string_type) => { + if string_type.enumeration.is_empty() { + Err(Error::unimplemented(format!( + "String schema without enum {reference}" + ))) + } else if string_type.enumeration.contains(&None) { + Err(Error::unimplemented(format!( + "String schema enum with empty string {reference}" + ))) + } else { + fn make_case(name: &str) -> RustPrinter { + let rust_name = name.to_case(Case::UpperCamel); - let rename = if name == rust_name { - unit() - } else { - rename_line(name) - }; + let rename = if name == rust_name { + unit() + } else { + rename_line(name) + }; - rename + line(unit() + rust_name + ",") - } + rename + line(unit() + rust_name + ",") + } - let cases = string_type - .enumeration - .iter() - .map(|n| make_case(n.as_ref().unwrap())) - .reduce(|acc, e| acc + e) - .unwrap_or_else(unit); + let cases = string_type + .enumeration + .iter() + .map(|n| make_case(n.as_ref().unwrap())) + .reduce(|acc, e| acc + e) + .unwrap_or_else(unit); - #[rustfmt::skip] - fn make_match_case(enum_name: &str, name: &str) -> RustPrinter { - let rust_name = name.to_case(Case::UpperCamel); + #[rustfmt::skip] + fn make_match_case(enum_name: &str, name: &str) -> RustPrinter { + let rust_name = name.to_case(Case::UpperCamel); - line(unit() + enum_name + "::" + rust_name + r#" => write!(f, ""# + name + r#""),"#) - } + line(unit() + enum_name + "::" + rust_name + r#" => write!(f, ""# + name + r#""),"#) + } - let match_cases = string_type - .enumeration - .iter() - .map(|n| make_match_case(&name, n.as_ref().unwrap())) - .reduce(|acc, e| acc + e) - .unwrap_or_else(unit); - - #[rustfmt::skip] - let code = unit() + - derive_line() + - line(unit() + "pub enum " + &name + " {") + - indented( - cases - ) + - line(unit() + "}") + - NewLine + - line(unit() + "impl " + rust_name("std::fmt", "Display") + " for " + &name + "{") + - indented( - line(unit() + "fn fmt(&self, f: &mut " + rust_name("std::fmt", "Formatter") + "<'_>) -> " + rust_name("std::fmt", "Result") + " {") + + let match_cases = string_type + .enumeration + .iter() + .map(|n| make_match_case(&name, n.as_ref().unwrap())) + .reduce(|acc, e| acc + e) + .unwrap_or_else(unit); + + #[rustfmt::skip] + let code = unit() + + derive_line() + + line(unit() + "pub enum " + &name + " {") + indented( - line("match self {") + - indented( - match_cases - ) + - line("}") + cases ) + - line("}") - ) + - line("}") + - NewLine + - line(unit() + "impl " + rust_name("crate::model", "MultipartField") + " for " + &name + "{") + - indented( - line(unit() + "fn to_multipart_field(&self) -> String {") + - indented( - line("self.to_string()") - ) + - line("}") + - NewLine + - line(unit() + "fn mime_type(&self) -> &'static str {") + + line(unit() + "}") + + NewLine + + line(unit() + "impl " + rust_name("std::fmt", "Display") + " for " + &name + "{") + + indented( + line(unit() + "fn fmt(&self, f: &mut " + rust_name("std::fmt", "Formatter") + "<'_>) -> " + rust_name("std::fmt", "Result") + " {") + indented( - line(r#""text/plain; charset=utf-8""#) + line("match self {") + + indented( + match_cases + ) + + line("}") ) + line("}") - ) + - line("}"); + ) + + line("}") + + NewLine + + line(unit() + "impl " + rust_name("crate::model", "MultipartField") + " for " + &name + "{") + + indented( + line(unit() + "fn to_multipart_field(&self) -> String {") + + indented( + line("self.to_string()") + ) + + line("}") + + NewLine + + line(unit() + "fn mime_type(&self) -> &'static str {") + + indented( + line(r#""text/plain; charset=utf-8""#) + ) + + line("}") + ) + + line("}"); - Ok(code) + Ok(code) + } } - } - Type::Number(_) => Err(Error::unimplemented(format!("Number schema {reference}"))), - Type::Integer(_) => Err(Error::unimplemented(format!("Integer schema {reference}"))), - Type::Boolean(_) => Err(Error::unimplemented(format!("Boolean schema {reference}"))), - Type::Array(_) => Err(Error::unimplemented(format!("Array schema {reference}"))), - Type::Object(obj) => { - let required: HashSet = obj.required.iter().map(|s| s.to_owned()).collect(); - - fn make_field( - name: &str, - schema: &ReferenceOr>, - required: &HashSet, - ref_cache: &mut RefCache, - ) -> RustResult { - let rust_name = name.to_case(Case::Snake); - - let rename = if rust_name == name { - unit() - } else { - rename_line(name) - }; + Type::Number(_) => Err(Error::unimplemented(format!("Number schema {reference}"))), + Type::Integer(_) => { + Err(Error::unimplemented(format!("Integer schema {reference}"))) + } + Type::Boolean(_) => { + Err(Error::unimplemented(format!("Boolean schema {reference}"))) + } + Type::Array(_) => Err(Error::unimplemented(format!("Array schema {reference}"))), + Type::Object(obj) => { + let required: HashSet = + obj.required.iter().map(|s| s.to_owned()).collect(); + + fn make_field( + name: &str, + schema: &ReferenceOr>, + required: &HashSet, + ref_cache: &mut RefCache, + ) -> RustResult { + let rust_name = name.to_case(Case::Snake); + + let rename = if rust_name == name { + unit() + } else { + rename_line(name) + }; - let tpe = ref_or_box_schema_type(schema, ref_cache)?.render_declaration(false); + let tpe = + ref_or_box_schema_type(schema, ref_cache)?.render_declaration(false); - let tpe = if required.contains(name) { - tpe - } else { - unit() + "Option<" + tpe + ">" - }; + let tpe = if required.contains(name) { + tpe + } else { + unit() + "Option<" + tpe + ">" + }; + + Ok(rename + line(unit() + "pub " + rust_name + ": " + tpe + ",")) + } + + let fields: Result> = obj + .properties + .iter() + .map(|(name, schema)| make_field(name, schema, &required, ref_cache)) + .collect(); + + let fields = + fields.map_err(|e| e.extend(format!("In reference {reference}.")))?; + + let code = unit() + + derive_line() + + line(unit() + "pub struct " + &name + " {") + + indented( + fields + .into_iter() + .reduce(|acc, e| acc + e) + .unwrap_or_else(unit), + ) + + line(unit() + "}") + + NewLine + + line( + unit() + + "impl " + + rust_name("crate::model", "MultipartField") + + " for " + + &name + + "{", + ) + + indented( + line(unit() + "fn to_multipart_field(&self) -> String {") + + indented(line("serde_json::to_string(self).unwrap()")) + + line("}") + + NewLine + + line(unit() + "fn mime_type(&self) -> &'static str {") + + indented(line(r#""application/json""#)) + + line("}"), + ) + + line("}"); - Ok(rename + line(unit() + "pub " + rust_name + ": " + tpe + ",")) + Ok(code) } + }, + SchemaKind::OneOf { .. } => { + Err(Error::unimplemented(format!("OneOf schema {reference}"))) + } + SchemaKind::AllOf { .. } => { + Err(Error::unimplemented(format!("AllOf schema {reference}"))) + } + SchemaKind::AnyOf { .. } => { + Err(Error::unimplemented(format!("AnyOf schema {reference}"))) + } + SchemaKind::Not { .. } => Err(Error::unimplemented(format!("Not schema {reference}"))), + SchemaKind::Any(any) => { + enum_schema_sanity_check(any, &schema.schema_data)?; - let fields: Result> = obj - .properties - .iter() - .map(|(name, schema)| make_field(name, schema, &required, ref_cache)) - .collect(); + let discriminator = schema.schema_data.discriminator.as_ref().unwrap(); - let fields = fields.map_err(|e| e.extend(format!("In reference {reference}.")))?; + let cases = extract_enum_cases(open_api, discriminator, ref_cache); + + let cases = cases? + .iter() + .map(|c| c.render(reference, open_api)) + .reduce(|acc, e| acc + e) + .unwrap_or_else(unit); let code = unit() - + derive_line() - + line(unit() + "pub struct " + &name + " {") - + indented( - fields - .into_iter() - .reduce(|acc, e| acc + e) - .unwrap_or_else(unit), - ) - + line(unit() + "}") - + NewLine - + line( - unit() - + "impl " - + rust_name("crate::model", "MultipartField") - + " for " - + &name - + "{", - ) - + indented( - line(unit() + "fn to_multipart_field(&self) -> String {") - + indented(line("serde_json::to_string(self).unwrap()")) - + line("}") - + NewLine - + line(unit() + "fn mime_type(&self) -> &'static str {") - + indented(line(r#""application/json""#)) - + line("}"), - ) - + line("}"); + + derive_line_simple() + + line(unit() + r#"#[serde(tag = ""# + &discriminator.property_name + r#"")]"#) + + line(unit() + "pub enum " + &name + " {") + + indented(cases) + + line(unit() + "}"); Ok(code) } - }, - SchemaKind::OneOf { .. } => Err(Error::unimplemented(format!("OneOf schema {reference}"))), - SchemaKind::AllOf { .. } => Err(Error::unimplemented(format!("AllOf schema {reference}"))), - SchemaKind::AnyOf { .. } => Err(Error::unimplemented(format!("AnyOf schema {reference}"))), - SchemaKind::Not { .. } => Err(Error::unimplemented(format!("Not schema {reference}"))), - SchemaKind::Any(any) => { - enum_schema_sanity_check(any, &schema.schema_data)?; - - let discriminator = schema.schema_data.discriminator.as_ref().unwrap(); - - let cases = extract_enum_cases(open_api, discriminator, ref_cache); - - let cases = cases? - .iter() - .map(|c| c.render(reference, open_api)) - .reduce(|acc, e| acc + e) - .unwrap_or_else(unit); - - let code = unit() - + derive_line_simple() - + line(unit() + r#"#[serde(tag = ""# + &discriminator.property_name + r#"")]"#) - + line(unit() + "pub enum " + &name + " {") - + indented(cases) - + line(unit() + "}"); - - Ok(code) } };