diff --git a/src/lib.rs b/src/lib.rs index 0c36e21..191839d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -99,6 +99,14 @@ pub fn gen( let mut known_refs = RefCache::new(); let mut models = Vec::new(); + let multipart_field_file = rust::model_gen::multipart_field_module()?; + std::fs::write( + model.join(multipart_field_file.def.name.file_name()), + multipart_field_file.code, + ) + .unwrap(); + models.push(multipart_field_file.def); + while !ref_cache.is_empty() { let mut next_ref_cache = RefCache::new(); diff --git a/src/rust/client_gen.rs b/src/rust/client_gen.rs index ce1ac2f..f203243 100644 --- a/src/rust/client_gen.rs +++ b/src/rust/client_gen.rs @@ -334,35 +334,39 @@ fn request_body_params( ReferenceOr::Reference { reference } => Err(Error::unimplemented( format!("Unexpected ref multipart schema: '{reference}'."), )), - ReferenceOr::Item(schema) => { - match &schema.schema_kind { - SchemaKind::Type(Type::Object(obj)) => { - fn multipart_param( - name: &str, - schema: &ReferenceOr>, - ref_cache: &mut RefCache, - ) -> Result { - Ok(Param { - original_name: name.to_string(), - name: name.to_case(Case::Snake), - tpe: ref_or_box_schema_type(schema, ref_cache)?, - required: true, // TODO - kind: ParamKind::Multipart, - }) - } - - obj.properties - .iter() - .map(|(name, schema)| { - multipart_param(name, schema, ref_cache) - }) - .collect() + ReferenceOr::Item(schema) => match &schema.schema_kind { + SchemaKind::Type(Type::Object(obj)) => { + fn multipart_param( + name: &str, + required: bool, + schema: &ReferenceOr>, + ref_cache: &mut RefCache, + ) -> Result { + Ok(Param { + original_name: name.to_string(), + name: name.to_case(Case::Snake), + tpe: ref_or_box_schema_type(schema, ref_cache)?, + required, + kind: ParamKind::Multipart, + }) } - _ => Err(Error::unimplemented( - "Object schema expected for multipart request body.", - )), + + obj.properties + .iter() + .map(|(name, schema)| { + multipart_param( + name, + body.required && obj.required.contains(name), + schema, + ref_cache, + ) + }) + .collect() } - } + _ => Err(Error::unimplemented( + "Object schema expected for multipart request body.", + )), + }, }, } } else { @@ -657,14 +661,26 @@ fn header_setter(param: &Param) -> RustResult { fn make_part(param: &Param) -> RustResult { let part_type = rust_name("reqwest::multipart", "Part"); - if param.tpe == DataType::Binary { - Ok(indent() + r#".part(""# + ¶m.original_name + r#"", "# + part_type + "::stream(" + ¶m.name + r#").mime_str("application/octet-stream")?)"#) - } else if param.tpe == DataType::String { - Ok(indent() + r#".part(""# + ¶m.original_name + r#"", "# + part_type + "::text(" + ¶m.name + r#".to_string()).mime_str("text/plain; charset=utf-8")?)"#) - } else if let DataType::Model(_) = param.tpe { - Ok(indent() + r#".part(""# + ¶m.original_name + r#"", "# + part_type + "::text(serde_json::to_string(" + ¶m.name + r#")?).mime_str("application/json")?)"#) + let inner = + if param.tpe == DataType::Binary { + Ok(indent() + r#"form = form.part(""# + ¶m.original_name + r#"", "# + part_type + "::stream(" + ¶m.name + r#").mime_str("application/octet-stream")?);"#) + } else if param.tpe == DataType::String { + Ok(indent() + r#"form = form.part(""# + ¶m.original_name + r#"", "# + part_type + "::text(" + ¶m.name + r#".to_string()).mime_str("text/plain; charset=utf-8")?);"#) + } + else if let DataType::Model(_) = param.tpe { + Ok(indent() + r#"form = form.part(""# + ¶m.original_name + r#"", "# + part_type + "::text(crate::model::MultipartField::to_multipart_field(" + ¶m.name + r#")).mime_str(crate::model::MultipartField::mime_type("# + ¶m.name + r#"))?);"#) + } else { + Err(Error::unimplemented(format!("Unsupported multipart part type {:?}", param.tpe))) + }; + + if param.required { + inner } else { - Err(Error::unimplemented(format!("Unsupported multipart part type {:?}", param.tpe))) + Ok( + indent() + line(unit() + r#"if let Some("# + ¶m.name + r#") = "# + ¶m.name + " {") + + indented(inner?) + + line("}") + ) } } @@ -837,10 +853,8 @@ fn render_method_implementation(method: &Method, error_kind: &ErrorKind) -> Rust let multipart_setter = if is_multipart { #[rustfmt::skip] let code = unit() + - indent() + "let form = " + rust_name("reqwest::multipart", "Form") + "::new()" + - indented( - multipart_parts?.into_iter().map(|p| unit() + NewLine + p).reduce(|acc, e| acc + e). unwrap_or_else(unit) + ";" + NewLine - ) + + indent() + "let mut form = " + rust_name("reqwest::multipart", "Form") + "::new();" + + (multipart_parts?.into_iter().map(|p| unit() + NewLine + p).reduce(|acc, e| acc + e). unwrap_or_else(unit) + NewLine) + NewLine + line("request = request.multipart(form);"); diff --git a/src/rust/model_gen.rs b/src/rust/model_gen.rs index 016935c..28e4e6e 100644 --- a/src/rust/model_gen.rs +++ b/src/rust/model_gen.rs @@ -348,6 +348,25 @@ fn extract_enum_cases( .collect() } +pub fn multipart_field_module() -> Result { + let code = unit() + + line(unit() + "pub trait MultipartField {") + + indented( + unit() + + line("fn to_multipart_field(&self) -> String;") + + line("fn mime_type(&self) -> &'static str;"), + ) + + line(unit() + "}"); + + Ok(Module { + def: ModuleDef { + name: ModuleName::new("multipart_field"), + exports: vec!["MultipartField".to_string()], + }, + code: RustContext::new().print_to_string(code), + }) +} + pub fn model_gen(reference: &str, open_api: &OpenAPI, ref_cache: &mut RefCache) -> Result { let schemas = &open_api .components @@ -438,6 +457,22 @@ pub fn model_gen(reference: &str, open_api: &OpenAPI, ref_cache: &mut RefCache) ) + line("}") ) + + line("}") + + NewLine + + line(unit() + "impl " + rust_name("crate::model", "MultipartField") + " for " + &name + "{") + + indented( + line(unit() + "fn to_multipart_field(&self) -> String {") + + indented( + line("self.to_string()") + ) + + line("}") + + NewLine + + line(unit() + "fn mime_type(&self) -> &'static str {") + + indented( + line(r#""text/plain; charset=utf-8""#) + ) + + line("}") + ) + line("}"); Ok(code)