Skip to content

Commit

Permalink
Merge pull request #19 from golemcloud/multipart-fixes
Browse files Browse the repository at this point in the history
Multipart fixes
  • Loading branch information
vigoo authored Sep 17, 2024
2 parents 832f974 + 37dcb97 commit 5ae9bb7
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 38 deletions.
8 changes: 8 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ pub fn gen(
let mut known_refs = RefCache::new();
let mut models = Vec::new();

let multipart_field_file = rust::model_gen::multipart_field_module()?;
std::fs::write(
model.join(multipart_field_file.def.name.file_name()),
multipart_field_file.code,
)
.unwrap();
models.push(multipart_field_file.def);

while !ref_cache.is_empty() {
let mut next_ref_cache = RefCache::new();

Expand Down
90 changes: 52 additions & 38 deletions src/rust/client_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,35 +334,39 @@ fn request_body_params(
ReferenceOr::Reference { reference } => Err(Error::unimplemented(
format!("Unexpected ref multipart schema: '{reference}'."),
)),
ReferenceOr::Item(schema) => {
match &schema.schema_kind {
SchemaKind::Type(Type::Object(obj)) => {
fn multipart_param(
name: &str,
schema: &ReferenceOr<Box<Schema>>,
ref_cache: &mut RefCache,
) -> Result<Param> {
Ok(Param {
original_name: name.to_string(),
name: name.to_case(Case::Snake),
tpe: ref_or_box_schema_type(schema, ref_cache)?,
required: true, // TODO
kind: ParamKind::Multipart,
})
}

obj.properties
.iter()
.map(|(name, schema)| {
multipart_param(name, schema, ref_cache)
})
.collect()
ReferenceOr::Item(schema) => match &schema.schema_kind {
SchemaKind::Type(Type::Object(obj)) => {
fn multipart_param(
name: &str,
required: bool,
schema: &ReferenceOr<Box<Schema>>,
ref_cache: &mut RefCache,
) -> Result<Param> {
Ok(Param {
original_name: name.to_string(),
name: name.to_case(Case::Snake),
tpe: ref_or_box_schema_type(schema, ref_cache)?,
required,
kind: ParamKind::Multipart,
})
}
_ => Err(Error::unimplemented(
"Object schema expected for multipart request body.",
)),

obj.properties
.iter()
.map(|(name, schema)| {
multipart_param(
name,
body.required && obj.required.contains(name),
schema,
ref_cache,
)
})
.collect()
}
}
_ => Err(Error::unimplemented(
"Object schema expected for multipart request body.",
)),
},
},
}
} else {
Expand Down Expand Up @@ -657,14 +661,26 @@ fn header_setter(param: &Param) -> RustResult {
fn make_part(param: &Param) -> RustResult {
let part_type = rust_name("reqwest::multipart", "Part");

if param.tpe == DataType::Binary {
Ok(indent() + r#".part(""# + &param.original_name + r#"", "# + part_type + "::stream(" + &param.name + r#").mime_str("application/octet-stream")?)"#)
} else if param.tpe == DataType::String {
Ok(indent() + r#".part(""# + &param.original_name + r#"", "# + part_type + "::text(" + &param.name + r#".to_string()).mime_str("text/plain; charset=utf-8")?)"#)
} else if let DataType::Model(_) = param.tpe {
Ok(indent() + r#".part(""# + &param.original_name + r#"", "# + part_type + "::text(serde_json::to_string(" + &param.name + r#")?).mime_str("application/json")?)"#)
let inner =
if param.tpe == DataType::Binary {
Ok(indent() + r#"form = form.part(""# + &param.original_name + r#"", "# + part_type + "::stream(" + &param.name + r#").mime_str("application/octet-stream")?);"#)
} else if param.tpe == DataType::String {
Ok(indent() + r#"form = form.part(""# + &param.original_name + r#"", "# + part_type + "::text(" + &param.name + r#".to_string()).mime_str("text/plain; charset=utf-8")?);"#)
}
else if let DataType::Model(_) = param.tpe {
Ok(indent() + r#"form = form.part(""# + &param.original_name + r#"", "# + part_type + "::text(crate::model::MultipartField::to_multipart_field(" + &param.name + r#")).mime_str(crate::model::MultipartField::mime_type("# + &param.name + r#"))?);"#)
} else {
Err(Error::unimplemented(format!("Unsupported multipart part type {:?}", param.tpe)))
};

if param.required {
inner
} else {
Err(Error::unimplemented(format!("Unsupported multipart part type {:?}", param.tpe)))
Ok(
indent() + line(unit() + r#"if let Some("# + &param.name + r#") = "# + &param.name + " {") +
indented(inner?) +
line("}")
)
}
}

Expand Down Expand Up @@ -837,10 +853,8 @@ fn render_method_implementation(method: &Method, error_kind: &ErrorKind) -> Rust
let multipart_setter = if is_multipart {
#[rustfmt::skip]
let code = unit() +
indent() + "let form = " + rust_name("reqwest::multipart", "Form") + "::new()" +
indented(
multipart_parts?.into_iter().map(|p| unit() + NewLine + p).reduce(|acc, e| acc + e). unwrap_or_else(unit) + ";" + NewLine
) +
indent() + "let mut form = " + rust_name("reqwest::multipart", "Form") + "::new();" +
(multipart_parts?.into_iter().map(|p| unit() + NewLine + p).reduce(|acc, e| acc + e). unwrap_or_else(unit) + NewLine) +
NewLine +
line("request = request.multipart(form);");

Expand Down
35 changes: 35 additions & 0 deletions src/rust/model_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,25 @@ fn extract_enum_cases(
.collect()
}

pub fn multipart_field_module() -> Result<Module> {
let code = unit()
+ line(unit() + "pub trait MultipartField {")
+ indented(
unit()
+ line("fn to_multipart_field(&self) -> String;")
+ line("fn mime_type(&self) -> &'static str;"),
)
+ line(unit() + "}");

Ok(Module {
def: ModuleDef {
name: ModuleName::new("multipart_field"),
exports: vec!["MultipartField".to_string()],
},
code: RustContext::new().print_to_string(code),
})
}

pub fn model_gen(reference: &str, open_api: &OpenAPI, ref_cache: &mut RefCache) -> Result<Module> {
let schemas = &open_api
.components
Expand Down Expand Up @@ -438,6 +457,22 @@ pub fn model_gen(reference: &str, open_api: &OpenAPI, ref_cache: &mut RefCache)
) +
line("}")
) +
line("}") +
NewLine +
line(unit() + "impl " + rust_name("crate::model", "MultipartField") + " for " + &name + "{") +
indented(
line(unit() + "fn to_multipart_field(&self) -> String {") +
indented(
line("self.to_string()")
) +
line("}") +
NewLine +
line(unit() + "fn mime_type(&self) -> &'static str {") +
indented(
line(r#""text/plain; charset=utf-8""#)
) +
line("}")
) +
line("}");

Ok(code)
Expand Down

0 comments on commit 5ae9bb7

Please sign in to comment.