diff --git a/src/rust/client_gen.rs b/src/rust/client_gen.rs index f203243..6ef16b3 100644 --- a/src/rust/client_gen.rs +++ b/src/rust/client_gen.rs @@ -27,7 +27,7 @@ use openapiv3::{ OpenAPI, Operation, Parameter, ParameterData, ParameterSchemaOrContent, PathItem, ReferenceOr, RequestBody, Response, Schema, SchemaKind, StatusCode, Tag, Type, }; -use std::collections::{BTreeMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; #[derive(Debug, Clone, PartialEq, Eq)] enum PathElement { @@ -186,6 +186,37 @@ struct Param { kind: ParamKind, } +#[derive(Debug, Clone)] +pub struct RequestBodyParams { + params: HashMap>, +} + +impl RequestBodyParams { + fn has_single_content_type(&self) -> bool { + self.params.len() == 1 + } + + fn get_default_request_body_param(&self) -> Option<&Vec> { + self.params + .values() + .next() + .filter(|_| self.has_single_content_type()) + } +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct ContentType(pub String); + +impl ContentType { + pub fn is_json(&self) -> bool { + self.0 == "application/json" + } + + pub fn is_yaml(&self) -> bool { + self.0 == "application/x-yaml" + } +} + impl Param { fn render_declaration(&self) -> RustPrinter { let type_name = self.tpe.render_declaration(true); @@ -268,7 +299,7 @@ fn match_tag(tag: &Option, path_item: &ReferenceOr) -> bool { fn param_data_to_type(data: &ParameterData, ref_cache: &mut RefCache) -> Result { match &data.format { ParameterSchemaOrContent::Schema(ref_or_schema) => { - ref_or_schema_type(ref_or_schema, ref_cache) + ref_or_schema_type(ref_or_schema, ref_cache, None) } ParameterSchemaOrContent::Content(_) => { Err(Error::unimplemented("Content parameter is not supported.")) @@ -295,45 +326,74 @@ fn parameter(p: &ReferenceOr, ref_cache: &mut RefCache) -> Result, ref_cache: &mut RefCache, -) -> Result> { +) -> Result { + let mut content_type_params = HashMap::new(); + match body { - ReferenceOr::Reference { reference } => Err(Error::unimplemented(format!( - "Unexpected ref request body: '{reference}'." - ))), + ReferenceOr::Reference { reference } => { + return Err(Error::unimplemented(format!( + "Unexpected ref request body: '{reference}'." + ))) + } ReferenceOr::Item(body) => { - if body.content.len() != 1 { - Err(Error::unimplemented("Content with not exactly 1 option.")) - } else { - let (content_type, media_type) = body.content.first().unwrap(); - - if content_type.starts_with("application/json") { + for (content_type, media_type) in &body.content { + if content_type.starts_with("application/json") || content_type == "*/*" { let schema = match &media_type.schema { None => Err(Error::unimplemented("JSON content without schema.")), Some(schema) => Ok(schema), }; - Ok(vec![Param { - original_name: "".to_string(), - name: "value".to_string(), - tpe: ref_or_schema_type(schema?, ref_cache)?, - required: body.required, - kind: ParamKind::Body, - }]) + content_type_params.insert( + ContentType(content_type.clone()), + vec![Param { + original_name: "".to_string(), + name: "value".to_string(), + tpe: ref_or_schema_type( + schema?, + ref_cache, + Some(content_type.clone()), + )?, + required: body.required, + kind: ParamKind::Body, + }], + ); } else if content_type == "application/octet-stream" { - Ok(vec![Param { + content_type_params.insert( + ContentType(content_type.clone()), + vec![Param { + original_name: "".to_string(), + name: "value".to_string(), + tpe: DataType::Binary, + required: body.required, + kind: ParamKind::Body, + }], + ); + } else if content_type.contains("application/x-yaml") { + let schema = match &media_type.schema { + None => Err(Error::unimplemented("YAML content without schema.")), + Some(schema) => Ok(schema), + }; + + let param = Param { original_name: "".to_string(), name: "value".to_string(), - tpe: DataType::Binary, + tpe: ref_or_schema_type(schema?, ref_cache, Some(content_type.clone()))?, required: body.required, kind: ParamKind::Body, - }]) + }; + + content_type_params.insert(ContentType(content_type.clone()), vec![param]); } else if content_type == "multipart/form-data" { match &media_type.schema { - None => Err(Error::unimplemented("Multipart content without schema.")), + None => { + return Err(Error::unimplemented("Multipart content without schema.")) + } Some(schema) => match schema { - ReferenceOr::Reference { reference } => Err(Error::unimplemented( - format!("Unexpected ref multipart schema: '{reference}'."), - )), + ReferenceOr::Reference { reference } => { + return Err(Error::unimplemented(format!( + "Unexpected ref multipart schema: '{reference}'." + ))) + } ReferenceOr::Item(schema) => match &schema.schema_kind { SchemaKind::Type(Type::Object(obj)) => { fn multipart_param( @@ -351,7 +411,8 @@ fn request_body_params( }) } - obj.properties + let params = obj + .properties .iter() .map(|(name, schema)| { multipart_param( @@ -361,41 +422,39 @@ fn request_body_params( ref_cache, ) }) - .collect() + .collect::>>()?; + + content_type_params + .insert(ContentType(content_type.clone()), params); + } + _ => { + return Err(Error::unimplemented( + "Object schema expected for multipart request body.", + )) } - _ => Err(Error::unimplemented( - "Object schema expected for multipart request body.", - )), }, }, } } else { - Err(Error::unimplemented(format!( + return Err(Error::unimplemented(format!( "Request body content type: '{content_type}'." - ))) + ))); } } } } + + Ok(RequestBodyParams { + params: content_type_params, + }) } fn parameters(op: &PathOperation, ref_cache: &mut RefCache) -> Result> { - let params: Result> = op - .op + op.op .parameters .iter() .map(|p| parameter(p, ref_cache)) - .collect(); - - let mut params = params?; - - if let Some(body) = &op.op.request_body { - for p in request_body_params(body, ref_cache)? { - params.push(p); - } - } - - Ok(params) + .collect() } fn as_code(code: &StatusCode) -> Option { @@ -449,7 +508,11 @@ fn response_type(response: &ReferenceOr, ref_cache: &mut RefCache) -> Some(schema) => Ok(schema), }; - Ok(ref_or_schema_type(schema?, ref_cache)?) + Ok(ref_or_schema_type( + schema?, + ref_cache, + Some(content_type.clone()), + )?) } else if content_type == "application/octet-stream" { Ok(DataType::Binary) } else { @@ -494,40 +557,107 @@ fn method_errors( Ok(MethodErrors { codes: codes? }) } -fn trait_method( +fn trait_methods_specific_to_content_type( op: &PathOperation, prefix_length: usize, ref_cache: &mut RefCache, -) -> Result { +) -> Result> { let (result_code, result_type) = method_result(&op.op.responses.responses, ref_cache)?; - let name = if let Some(op_id) = &op.op.operation_id { - op_id.to_case(Case::Snake) + let name = op + .op + .operation_id + .as_ref() + .map(|op_id| op_id.to_case(Case::Snake)) + .unwrap_or_else(|| op.path.strip_prefix(prefix_length).method_name(&op.method)); + + let mut main_params = parameters(op, ref_cache)?; + + if let Some(body) = &op.op.request_body { + let content_specific = request_body_params(body, ref_cache)?; + + if let Some(request_body_params) = content_specific.get_default_request_body_param() { + main_params.extend(request_body_params.iter().cloned()); + return Ok(vec![create_method( + name, + op, + &main_params, + result_type, + result_code, + ref_cache, + )?]); + } + + let mut methods = Vec::new(); + for (content_type, params) in content_specific.params { + let method_name = match_content_type(content_type, &name)?; + let new_params = [main_params.clone(), params].concat(); + methods.push(create_method( + method_name, + op, + &new_params, + result_type.clone(), + result_code.clone(), + ref_cache, + )?); + } + + Ok(methods) } else { - op.path.strip_prefix(prefix_length).method_name(&op.method) - }; + Ok(vec![create_method( + name, + op, + &main_params, + result_type, + result_code, + ref_cache, + )?]) + } +} +fn create_method( + name: String, + op: &PathOperation, + params: &[Param], + result_type: DataType, + result_code: StatusCode, + ref_cache: &mut RefCache, +) -> Result { Ok(Method { name, path: op.path.clone(), original_path: op.original_path.clone(), http_method: op.method.to_string(), - params: parameters(op, ref_cache)?, + params: params.to_vec(), result: result_type, result_status_code: result_code.clone(), errors: method_errors(&op.op.responses.responses, result_code, ref_cache)?, }) } +fn match_content_type(content_type: ContentType, base_name: &str) -> Result { + if content_type.is_json() { + Ok(format!("{}_json", base_name)) + } else if content_type.is_yaml() { + Ok(format!("{}_yaml", base_name)) + } else { + Err(Error::unimplemented( + "Multiple content types supported only for JSON and YAML", + )) + } +} + fn trait_methods( operations: &[PathOperation], prefix_length: usize, ref_cache: &mut RefCache, ) -> Result> { - operations + let res = operations .iter() - .map(|op| trait_method(op, prefix_length, ref_cache)) - .collect() + .map(|op| trait_methods_specific_to_content_type(op, prefix_length, ref_cache)) + .collect::>>>()?; + + Ok(res.into_iter().flatten().collect()) } fn render_errors(method_name: &str, error_kind: &ErrorKind, errors: &MethodErrors) -> RustResult { @@ -837,7 +967,18 @@ fn render_method_implementation(method: &Method, error_kind: &ErrorKind) -> Rust r#"request = request.header(reqwest::header::CONTENT_TYPE, "application/octet-stream");"#, ) + NewLine - } else { + } else if param.tpe == DataType::Yaml { + line( + unit() + + "request = request.body(serde_yaml::to_string(" + + ¶m.name + + ").unwrap_or_default().into_bytes());", + ) + line( + r#"request = request.header(reqwest::header::CONTENT_TYPE, "application/x-yaml");"#, + ) + NewLine + } + // Not sure why everything else is assumed to be json (previously) + else { line(unit() + "request = request.json(" + ¶m.name + ");") + NewLine } } diff --git a/src/rust/lib_gen.rs b/src/rust/lib_gen.rs index 80e9da5..f01d9ad 100644 --- a/src/rust/lib_gen.rs +++ b/src/rust/lib_gen.rs @@ -142,11 +142,11 @@ mod tests { "lib", &[ ModuleDef { - name: ModuleName::new("abc".to_string()), + name: ModuleName::new("abc"), exports: vec!["C".to_string(), "B".to_string()], }, ModuleDef { - name: ModuleName::new("xyz".to_string()), + name: ModuleName::new("xyz"), exports: vec!["A".to_string(), "Y".to_string()], }, ], diff --git a/src/rust/printer.rs b/src/rust/printer.rs index 416d91d..9ef5db7 100644 --- a/src/rust/printer.rs +++ b/src/rust/printer.rs @@ -135,6 +135,18 @@ pub fn rust_name(import: &str, name: &str) -> TreePrinter { }) } +pub fn rust_name_with_alias(import: &str, name: &str, alias: &str) -> TreePrinter { + let import_name = if name.ends_with('!') { + &name[0..name.len() - 1] + } else { + name + }; + TreePrinter::leaf(RustCode { + imports: HashSet::from([RustUse(format!("{import}::{import_name} as {alias}"))]), + code: alias.to_string(), + }) +} + impl IntoRustTree for TreePrinter { fn tree(self) -> TreePrinter { self diff --git a/src/rust/types.rs b/src/rust/types.rs index 23b451e..d0d0c8a 100644 --- a/src/rust/types.rs +++ b/src/rust/types.rs @@ -15,7 +15,7 @@ use crate::printer::TreePrinter; use crate::rust::lib_gen::ModuleName; use crate::rust::model_gen::RefCache; -use crate::rust::printer::{rust_name, unit, RustContext}; +use crate::rust::printer::{rust_name, rust_name_with_alias, unit, RustContext}; use crate::{Error, Result}; use convert_case::{Case, Casing}; use openapiv3::{ @@ -73,6 +73,7 @@ pub enum DataType { Array(Box), MapOf(Box), Json, + Yaml, } pub fn escape_keywords(name: &str) -> String { @@ -143,6 +144,11 @@ impl DataType { let res = rust_name("serde_json::value", "Value"); to_ref(res, top_param) } + + DataType::Yaml => { + let res = rust_name_with_alias("serde_yaml::value", "Value", "YamlValue"); + to_ref(res, top_param) + } } } } @@ -159,7 +165,11 @@ pub fn ref_type_name(reference: &str, ref_cache: &mut RefCache) -> Result Result { +fn schema_type( + schema: &Schema, + ref_cache: &mut RefCache, + content_type: Option, +) -> Result { match &schema.schema_kind { SchemaKind::Type(tpe) => match tpe { Type::String(string_type) => { @@ -213,7 +223,7 @@ fn schema_type(schema: &Schema, ref_cache: &mut RefCache) -> Result { "Object parameter with Any additional_properties is not supported.", )), AdditionalProperties::Schema(element_schema) => Ok(DataType::MapOf( - Box::new(ref_or_schema_type(element_schema, ref_cache)?), + Box::new(ref_or_schema_type(element_schema, ref_cache, None)?), )), } } else { @@ -238,17 +248,33 @@ fn schema_type(schema: &Schema, ref_cache: &mut RefCache) -> Result { SchemaKind::AllOf { .. } => Err(Error::unimplemented("AllOf parameter is not supported.")), SchemaKind::AnyOf { .. } => Err(Error::unimplemented("AnyOf parameter is not supported.")), SchemaKind::Not { .. } => Err(Error::unimplemented("Not parameter is not supported.")), - SchemaKind::Any(_) => Ok(DataType::Json), + SchemaKind::Any(_) => { + if let Some(content_type) = content_type { + if &content_type == "application/json" || &content_type == "*/*" { + Ok(DataType::Json) + } else if &content_type == "application/x-yaml" { + Ok(DataType::Yaml) + } else { + Err(Error::unexpected(format!( + "Cannot resolve the data type for content_type {} with `any` schema-kind", + content_type + ))) + } + } else { + Err(Error::unexpected("Cannot resolve the data type for any schema-kind with no details on content_type")) + } + } } } pub fn ref_or_schema_type( ref_or_schema: &ReferenceOr, ref_cache: &mut RefCache, + content_type: Option, ) -> Result { match ref_or_schema { ReferenceOr::Reference { reference } => ref_type_name(reference, ref_cache), - ReferenceOr::Item(schema) => schema_type(schema, ref_cache), + ReferenceOr::Item(schema) => schema_type(schema, ref_cache, content_type), } } @@ -258,6 +284,6 @@ pub fn ref_or_box_schema_type( ) -> Result { match ref_or_schema { ReferenceOr::Reference { reference } => ref_type_name(reference, ref_cache), - ReferenceOr::Item(schema) => schema_type(schema, ref_cache), + ReferenceOr::Item(schema) => schema_type(schema, ref_cache, None), } }