From c8876a1deaae7131fc79662c1a16392e6416f0f2 Mon Sep 17 00:00:00 2001 From: Kurt Wolf Date: Sun, 24 Dec 2023 19:02:45 -0500 Subject: [PATCH] add support for oauth2 flows. make ${service}Client much cheaper to clone now requires httpclient 0.20 --- Cargo.lock | 84 ++++---- core/Cargo.toml | 2 +- core/src/child_schemas.rs | 25 ++- core/src/extractor.rs | 152 ++++++------- core/src/extractor/record.rs | 31 +-- core/src/extractor/resolution.rs | 14 +- core/src/options.rs | 6 +- core/template/rust/src/lib.rs | 1 + hir/Cargo.toml | 2 +- hir/src/lib.rs | 88 +++++--- libninja/Cargo.toml | 2 +- libninja/src/custom.rs | 6 +- libninja/src/rust.rs | 84 +++++++- libninja/src/rust/cargo_toml.rs | 5 +- libninja/src/rust/client.rs | 340 +++++++++++++++++------------- libninja/src/rust/codegen.rs | 3 +- libninja/tests/all_of/main.rs | 10 +- libninja/tests/regression/main.rs | 4 +- mir/src/macro.rs | 28 --- 19 files changed, 499 insertions(+), 388 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d8b9229..bc3057d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -65,7 +65,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e01ed3140b2f8d422c68afa1ed2e85d996ea619c988ac834d255db32138655cb" dependencies = [ "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -178,7 +178,7 @@ dependencies = [ "actix-router", "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -298,9 +298,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.75" +version = "1.0.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +checksum = "59d2a3357dde987206219e78ecfbbb6e8dad06cbb65292758d3270e6254f7355" dependencies = [ "backtrace", ] @@ -495,7 +495,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -1047,7 +1047,7 @@ dependencies = [ "serde_json", "serde_yaml", "strum", - "syn 2.0.41", + "syn 2.0.42", "tempfile", "tera", "text_io", @@ -1080,7 +1080,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "syn 2.0.41", + "syn 2.0.42", "tera", "tracing", "tracing-ez", @@ -1228,9 +1228,9 @@ dependencies = [ [[package]] name = "object" -version = "0.32.1" +version = "0.32.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" dependencies = [ "memchr", ] @@ -1243,9 +1243,7 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "openapiv3-extended" -version = "3.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e1d7caf406151bbc00996287fcab02cc12b23069b3ab6b95fa746ce8f133e32" +version = "4.0.1" dependencies = [ "anyhow", "http", @@ -1335,7 +1333,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -1401,9 +1399,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkg-config" -version = "0.3.27" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" +checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" [[package]] name = "powerfmt" @@ -1434,14 +1432,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" dependencies = [ "proc-macro2", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] name = "proc-macro2" -version = "1.0.70" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39278fbbf5fb4f646ce651690877f89d1c5811a3d4acb27700c1cb3cdb78fd3b" +checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" dependencies = [ "unicode-ident", ] @@ -1595,7 +1593,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -1611,9 +1609,9 @@ dependencies = [ [[package]] name = "serde_spanned" -version = "0.6.4" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12022b835073e5b11e90a14f86838ceb1c8fb0325b72416845c487ac0fa95e80" +checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" dependencies = [ "serde", ] @@ -1632,9 +1630,9 @@ dependencies = [ [[package]] name = "serde_yaml" -version = "0.9.27" +version = "0.9.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cc7a1570e38322cfe4154732e5110f887ea57e22b76f4bfd32b5bdd3368666c" +checksum = "a15e0ef66bf939a7c890a0bf6d5a733c70202225f9888a89ed5c62298b019129" dependencies = [ "indexmap", "itoa", @@ -1749,9 +1747,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.41" +version = "2.0.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" +checksum = "5b7d0a2c048d661a1a59fcd7355baa232f7ed34e0ee4df2eef3c1c1c0d3852d8" dependencies = [ "proc-macro2", "quote", @@ -1825,7 +1823,7 @@ checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -1840,9 +1838,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4a34ab300f2dee6e562c10a046fc05e358b29f9bf92277f30c3c8d82275f6f5" +checksum = "f657ba42c3f86e7680e53c8cd3af8abbe56b5491790b46e22e19c0d57463583e" dependencies = [ "deranged", "itoa", @@ -1860,9 +1858,9 @@ checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ad70d68dba9e1f8aceda7aa6711965dfec1cac869f311a51bd08b3a2ccbce20" +checksum = "26197e33420244aeb70c3e8c78376ca46571bc4e701e4791c2cd9f57dcb3a43f" dependencies = [ "time-core", ] @@ -1884,9 +1882,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.35.0" +version = "1.35.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "841d45b238a16291a4e1584e61820b8ae57d696cc5015c459c229ccc6990cc1c" +checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" dependencies = [ "backtrace", "bytes", @@ -1909,7 +1907,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -1986,7 +1984,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] @@ -2195,7 +2193,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", "wasm-bindgen-shared", ] @@ -2217,7 +2215,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2402,9 +2400,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.5.28" +version = "0.5.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c830786f7720c2fd27a1a0e27a709dbd3c4d009b56d098fc742d4f4eab91fe2" +checksum = "9b5c3db89721d50d0e2a673f5043fc4722f76dcc352d7b1ab8b8288bed4ed2c5" dependencies = [ "memchr", ] @@ -2417,22 +2415,22 @@ checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" [[package]] name = "zerocopy" -version = "0.7.31" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d" +checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.31" +version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3c129550b3e6de3fd0ba67ba5c81818f9805e58b8d7fee80a3a59d2c9fc601a" +checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.41", + "syn 2.0.42", ] [[package]] diff --git a/core/Cargo.toml b/core/Cargo.toml index 858542e..9a3fc74 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -11,7 +11,7 @@ name = "ln_core" anyhow = "1.0.71" clap = { version = "4.3.11", features = ["derive"] } convert_case = "0.6.0" -openapiv3-extended = "3" +openapiv3-extended = "4" serde = { version = "1.0.166", features = ["derive"] } quote = "1.0.29" serde_json = "1.0.100" diff --git a/core/src/child_schemas.rs b/core/src/child_schemas.rs index 2fd98a3..d959ebf 100644 --- a/core/src/child_schemas.rs +++ b/core/src/child_schemas.rs @@ -8,23 +8,22 @@ pub trait ChildSchemas { impl ChildSchemas for Schema { fn add_child_schemas<'a>(&'a self, acc: &mut HashMap) { - match &self.schema_kind { + match &self.kind { SchemaKind::Type(Type::Array(a)) => { let Some(items) = &a.items else { return; }; - let Some(items) = items.as_item() else { return; }; - let items = items.as_ref(); - if let Some(title) = &items.schema_data.title { - acc.entry(title.clone()).or_insert(items); + let Some(item) = items.as_item() else { return; }; + if let Some(title) = &item.title { + acc.entry(title.clone()).or_insert(item); } - items.add_child_schemas(acc); + item.add_child_schemas(acc); } SchemaKind::Type(Type::Object(o)) => { - if let Some(title) = &self.schema_data.title { + if let Some(title) = &self.title { acc.entry(title.clone()).or_insert(self); } for (_name, prop) in &o.properties { let Some(prop) = prop.as_item() else { continue; }; - if let Some(title) = &prop.schema_data.title { + if let Some(title) = &prop.title { acc.entry(title.clone()).or_insert(prop); } prop.add_child_schemas(acc); @@ -36,7 +35,7 @@ impl ChildSchemas for Schema { | SchemaKind::AnyOf { any_of: schemas} => { for schema in schemas { let Some(schema) = schema.as_item() else { continue; }; - if let Some(title) = &schema.schema_data.title { + if let Some(title) = &schema.title { acc.entry(title.clone()).or_insert(schema); } schema.add_child_schemas(acc); @@ -57,7 +56,7 @@ impl ChildSchemas for Operation { } for par in &self.parameters { let Some(par) = par.as_item() else { continue; }; - let Some(schema) = par.parameter_data_ref().schema() else { continue; }; + let Some(schema) = par.data.schema() else { continue; }; let Some(schema) = schema.as_item() else { continue; }; schema.add_child_schemas(acc); } @@ -73,7 +72,7 @@ impl ChildSchemas for RequestBody { for (_key, content) in &self.content { let Some(schema) = &content.schema else { continue; }; let Some(schema) = schema.as_item() else { continue; }; - if let Some(title) = &schema.schema_data.title { + if let Some(title) = &schema.title { acc.entry(title.clone()).or_insert(schema); } schema.add_child_schemas(acc); @@ -86,7 +85,7 @@ impl ChildSchemas for Response { for (k, content) in &self.content { let Some(schema) = &content.schema else { continue; }; let Some(schema) = schema.as_item() else { continue; }; - if let Some(title) = &schema.schema_data.title { + if let Some(title) = &schema.title { acc.entry(title.clone()).or_insert(schema); } schema.add_child_schemas(acc); @@ -99,7 +98,7 @@ impl ChildSchemas for OpenAPI { for (_path, _method, op, _item) in self.operations() { op.add_child_schemas(acc); } - for (name, schema) in self.schemas() { + for (name, schema) in &self.schemas { let Some(schema) = schema.as_item() else { continue; }; acc.entry(name.clone()).or_insert(schema); schema.add_child_schemas(acc); diff --git a/core/src/extractor.rs b/core/src/extractor.rs index 677405e..fdec6d1 100644 --- a/core/src/extractor.rs +++ b/core/src/extractor.rs @@ -2,15 +2,16 @@ use std::collections::{BTreeMap, HashMap, HashSet}; use anyhow::{anyhow, Result}; use convert_case::{Case, Casing}; -use openapiv3::{OpenAPI, ReferenceOr, Schema}; +use openapiv3::{APIKeyLocation, OpenAPI, ReferenceOr, Schema, SecurityScheme}; use openapiv3 as oa; -use ::hir::{AuthLocation, AuthorizationParameter, AuthorizationStrategy, DocFormat, HirSpec, Language, Location, Operation, Record, Ty, Parameter, Doc}; +use ::hir::{AuthLocation, AuthParam, AuthStrategy, DocFormat, HirSpec, Language, Location, Operation, Record, Ty, Parameter, Doc}; pub use record::*; pub use resolution::{schema_ref_to_ty, schema_ref_to_ty_already_resolved, schema_to_ty}; pub use resolution::*; use mir::NewType; use tracing_ez::{warn, debug, span}; +use hir::{Oauth2Auth, TokenAuth}; mod resolution; mod record; @@ -33,7 +34,7 @@ pub fn extract_spec(spec: &OpenAPI) -> Result { } pub fn is_optional(name: &str, param: &Schema, parent: &Schema) -> bool { - param.schema_data.nullable || !parent.required(name) + param.nullable || !parent.required(name) } pub fn extract_request_schema<'a>( @@ -56,7 +57,7 @@ pub fn extract_param(param: &ReferenceOr, spec: &OpenAPI) -> Resu span!("extract_param", param = ?param); let param = param.resolve(spec)?; - let data = param.parameter_data_ref(); + let data = ¶m.data; let param_schema_ref = data .schema() .ok_or_else(|| anyhow!("No schema for parameter: {:?}", param))?; @@ -68,7 +69,7 @@ pub fn extract_param(param: &ReferenceOr, spec: &OpenAPI) -> Resu optional: !data.required, location: param.into(), ty, - example: schema.schema_data.example.clone(), + example: schema.example.clone(), }) } @@ -95,9 +96,9 @@ pub fn extract_inputs<'a>( Ok(schema) => schema, }; - if let oa::SchemaKind::Type(oa::Type::Array(oa::ArrayType { items, .. })) = &schema.schema_kind { + if let oa::SchemaKind::Type(oa::Type::Array(oa::ArrayType { items, .. })) = &schema.kind { let ty = if let Some(items) = items { - schema_ref_to_ty(&items.unbox(), spec) + schema_ref_to_ty(&items, spec) } else { Ty::Any }; @@ -108,7 +109,7 @@ pub fn extract_inputs<'a>( optional: false, doc: None, location: Location::Body, - example: schema.schema_data.example.clone(), + example: schema.example.clone(), }); } else if let Ok(props) = schema.properties_iter(spec) { let body_args = props.map(|(name, param)| { @@ -122,7 +123,7 @@ pub fn extract_inputs<'a>( optional, doc: None, location: Location::Body, - example: schema.schema_data.example.clone(), + example: schema.example.clone(), } }); for param in body_args { @@ -137,7 +138,7 @@ pub fn extract_inputs<'a>( optional: false, doc: None, location: Location::Body, - example: schema.schema_data.example.clone(), + example: schema.example.clone(), }); } Ok(inputs) @@ -200,7 +201,6 @@ pub fn extract_operation_doc(operation: &oa::Operation, format: DocFormat) -> Op pub fn extract_schema_docs(schema: &Schema) -> Option { schema - .schema_data .description .as_ref() .map(|d| Doc(d.trim().to_string())) @@ -243,7 +243,7 @@ pub fn extract_api_operations(spec: &OpenAPI, result: &mut HirSpec) -> Result<() let ret = match response_success { None => Ty::Unit, Some(ReferenceOr::Item(s)) => { - if matches!(s.schema_kind, oa::SchemaKind::Type(oa::Type::Object(_))) { + if matches!(s.kind, oa::SchemaKind::Type(oa::Type::Object(_))) { needs_response_model = Some(s); Ty::model(&format!("{}Response", name)) } else { @@ -369,85 +369,67 @@ pub fn spec_defines_auth(spec: &HirSpec) -> bool { !spec.security.is_empty() } -fn extract_security_fields(_name: &str, requirement: &oa::SecurityRequirement, spec: &OpenAPI) -> Result> { - use openapiv3::{SecurityScheme, APIKeyLocation}; - let security_schemas = &spec.components.as_ref().unwrap().security_schemes; - let mut fields = vec![]; - for (name, _scopes) in requirement { - let schema = security_schemas.get(name).unwrap().as_item().unwrap(); - let location = match schema { - SecurityScheme::APIKey { - location, - name, - description: _, - } => match location { - APIKeyLocation::Header => { - if ["bearer_auth", "bearer"].contains(&&*name.to_case(Case::Snake)) { - AuthLocation::Bearer - } else { - AuthLocation::Header { - key: name.to_string(), - } - } - } - APIKeyLocation::Query => { - AuthLocation::Query { - key: name.to_string(), - } - } - APIKeyLocation::Cookie => { - AuthLocation::Cookie { - key: name.to_string(), - } - } - }, - SecurityScheme::HTTP { - scheme, - bearer_format: _, - description: _, - } => match scheme.as_str() { - "basic" => AuthLocation::Basic, - "bearer" => AuthLocation::Bearer, - "token" => AuthLocation::Token, - _ => { - println!("{:?}", schema); - unimplemented!() - } - }, - _ => { - warn!("Skipping authorization for {:?}", schema); - return Err(anyhow!("Unsupported authorization schema")); +fn extract_key_location(loc: &APIKeyLocation, name: &str) -> AuthLocation { + match loc { + APIKeyLocation::Header => { + if ["bearer_auth", "bearer"].contains(&&*name.to_case(Case::Snake)) { + AuthLocation::Bearer + } else { + AuthLocation::Header { key: name.to_string() } } - }; - - fields.push(AuthorizationParameter { - name: name.to_string(), - env_var: name.to_case(Case::ScreamingSnake), - location, - }); + } + APIKeyLocation::Query => AuthLocation::Query { key: name.to_string() }, + APIKeyLocation::Cookie => AuthLocation::Cookie { key: name.to_string() }, } - Ok(fields) } - -pub fn extract_security_strategies(spec: &OpenAPI) -> Vec { +pub fn extract_security_strategies(spec: &OpenAPI) -> Vec { + dbg!("extracting security", &spec.security); let mut strats = vec![]; - let security = match spec.security.as_ref() { - None => return strats, - Some(s) => s, - }; - for requirement in security { - let (name, _scopes) = requirement.iter().next().unwrap(); - let fields = match extract_security_fields(name, requirement, spec) { - Ok(f) => f, - Err(_e) => { - continue; + let schemes = &spec.security_schemes; + for requirement in &spec.security { + if requirement.is_empty() { + strats.push(AuthStrategy::NoAuth); + continue; + } + let (scheme_name, _scopes) = requirement.iter().next().unwrap(); + let scheme = schemes.get(scheme_name).expect(&format!("Security scheme {} not found.", scheme_name)); + debug!("Found security scheme for {}: {:?}", scheme_name, scheme); + let scheme = scheme.as_item().expect("TODO support refs in securitySchemes"); + match scheme { + SecurityScheme::APIKey { location, name, .. } => { + let location = extract_key_location(&location, &name); + strats.push(AuthStrategy::Token(TokenAuth { + name: scheme_name.to_string(), + fields: vec![AuthParam { + name: name.to_string(), + location, + }], + })); } - }; - strats.push(AuthorizationStrategy { - name: name.clone(), - fields, - }) + SecurityScheme::OAuth2 { flows, .. } => { + if let Some(flow) = &flows.authorization_code { + strats.push(AuthStrategy::OAuth2(Oauth2Auth { + auth_url: flow.authorization_url.clone(), + exchange_url: flow.token_url.clone(), + refresh_url: flow.refresh_url.as_ref().expect("Must have refresh URL").clone(), + scopes: flow.scopes.iter().map(|(k, v)| (k.clone(), v.clone())).collect(), + })) + } + } + SecurityScheme::HTTP { scheme, bearer_format, description } => { + strats.push(AuthStrategy::Token(TokenAuth { + name: scheme_name.to_string(), + fields: vec![AuthParam { + name: scheme_name.to_string(), + // env_var: scheme_name.to_case(Case::ScreamingSnake), + location: AuthLocation::Bearer, + }], + })); + } + SecurityScheme::OpenIDConnect { .. } => {} + } } + debug!("extracted {} security: {:?}", strats.len(), strats); strats } diff --git a/core/src/extractor/record.rs b/core/src/extractor/record.rs index e7dd0e8..55f159a 100644 --- a/core/src/extractor/record.rs +++ b/core/src/extractor/record.rs @@ -4,7 +4,7 @@ use anyhow::Result; use indexmap::IndexMap; /// Records are the "model"s of the MIR world. model is a crazy overloaded word though. -use openapiv3::{ObjectType, OpenAPI, ReferenceOr, Schema, SchemaData, SchemaKind, SchemaReference, StringType, Type}; +use openapiv3::{ObjectType, OpenAPI, ReferenceOr, Schema, SchemaData, SchemaKind, SchemaReference, StringType, Type, RefOrMap}; use tracing::warn; use hir::{Doc, HirField, Record, StrEnum, Struct, NewType, HirSpec}; @@ -13,7 +13,7 @@ use crate::extractor; use crate::child_schemas::ChildSchemas; use crate::extractor::{schema_ref_to_ty_already_resolved, schema_to_ty}; -fn properties_to_fields(properties: &IndexMap>, schema: &Schema, spec: &OpenAPI) -> BTreeMap { +fn properties_to_fields(properties: &RefOrMap, schema: &Schema, spec: &OpenAPI) -> BTreeMap { properties .iter() .map(|(name, field_schema_ref)| { @@ -49,15 +49,15 @@ pub fn effective_length(all_of: &[ReferenceOr]) -> usize { pub fn create_record(name: &str, schema: &Schema, spec: &OpenAPI) -> Record { let name = name.to_string(); - match &schema.schema_kind { + match &schema.kind { // The base case, a regular object SchemaKind::Type(Type::Object(ObjectType { properties, .. })) => { let fields = properties_to_fields(properties, schema, spec); Record::Struct(Struct { name, fields, - nullable: schema.schema_data.nullable, - docs: schema.schema_data.description.as_ref().map(|d| Doc(d.trim().to_string())), + nullable: schema.nullable, + docs: schema.description.as_ref().map(|d| Doc(d.trim().to_string())), }) } // An enum @@ -70,19 +70,20 @@ pub fn create_record(name: &str, schema: &Schema, spec: &OpenAPI) -> Record { .iter() .map(|s| s.to_string()) .collect(), - docs: schema.schema_data.description.as_ref().map(|d| Doc(d.clone())), + docs: schema.description.as_ref().map(|d| Doc(d.clone())), }) } // A newtype with multiple fields SchemaKind::AllOf { all_of } => { + let all_of = all_of.as_slice(); if effective_length(all_of) == 1 { Record::TypeAlias(name, HirField { ty: schema_ref_to_ty_already_resolved(&all_of[0], spec, schema), - optional: schema.schema_data.nullable, + optional: schema.nullable, ..HirField::default() }) } else { - create_record_from_all_of(&name, all_of, &schema.schema_data, spec) + create_record_from_all_of(&name, all_of, &schema.data, spec) } } // Default case, a newtype with a single field @@ -90,12 +91,12 @@ pub fn create_record(name: &str, schema: &Schema, spec: &OpenAPI) -> Record { name, fields: vec![HirField { ty: schema_to_ty(schema, spec), - optional: schema.schema_data.nullable, + optional: schema.nullable, doc: None, example: None, flatten: false, }], - docs: schema.schema_data.description.as_ref().map(|d| Doc(d.clone())), + docs: schema.description.as_ref().map(|d| Doc(d.clone())), }), } } @@ -108,9 +109,9 @@ fn create_field(field_schema_ref: &ReferenceOr, spec: &OpenAPI) -> HirFi spec, field_schema, ); - let optional = field_schema.schema_data.nullable; - let example = field_schema.schema_data.example.clone(); - let doc = field_schema.schema_data.description.clone().map(Doc); + let optional = field_schema.nullable; + let example = field_schema.example.clone(); + let doc = field_schema.description.clone().map(Doc); HirField { ty, optional, doc, example, flatten: false } } @@ -162,7 +163,7 @@ pub fn extract_records(spec: &OpenAPI, result: &mut HirSpec) -> Result<()> { result.schemas.insert(name, rec); } - for (name, schema_ref) in spec.schemas() { + for (name, schema_ref) in &spec.schemas { let Some(reference) = schema_ref.as_ref_str() else { continue; }; result.schemas.insert(name.clone(), Record::TypeAlias(name.clone(), create_field(&schema_ref, spec))); } @@ -178,7 +179,7 @@ mod tests { #[test] fn test_all_of_required_set_correctly() { let mut additional_props: Schema = serde_yaml::from_str(include_str!("./pet_tag.yaml")).unwrap(); - let SchemaKind::AllOf { all_of } = &additional_props.schema_kind else { panic!() }; + let SchemaKind::AllOf { all_of } = &additional_props.kind else { panic!() }; let spec = OpenAPI::default(); let rec = create_record_from_all_of("PetTag", &all_of, &SchemaData::default(), &spec); let mut fields = rec.fields(); diff --git a/core/src/extractor/resolution.rs b/core/src/extractor/resolution.rs index b8f5223..81d16c3 100644 --- a/core/src/extractor/resolution.rs +++ b/core/src/extractor/resolution.rs @@ -30,7 +30,7 @@ pub fn schema_ref_to_ty_already_resolved(schema_ref: &ReferenceOr, spec: /// You probably want schema_ref_to_ty, not this method. Reason being, you want /// to use the ref'd model if one exists (e.g. User instead of resolving to Ty::Any) pub fn schema_to_ty(schema: &Schema, spec: &OpenAPI) -> Ty { - match &schema.schema_kind { + match &schema.kind { SchemaKind::Type(oa::Type::String(s)) => { match s.format.as_str() { "decimal" => Ty::Currency { @@ -46,12 +46,12 @@ pub fn schema_to_ty(schema: &Schema, spec: &OpenAPI) -> Ty { } SchemaKind::Type(oa::Type::Number(_)) => Ty::Float, SchemaKind::Type(oa::Type::Integer(_)) => { - let null_as_zero = schema.schema_data.extensions.get("x-null-as-zero") + let null_as_zero = schema.data.extensions.get("x-null-as-zero") .and_then(|v| v.as_bool()).unwrap_or(false); if null_as_zero { return Ty::Integer { serialization: hir::IntegerSerialization::NullAsZero }; } - match schema.schema_data.extensions.get("x-format").and_then(|s| s.as_str()) { + match schema.data.extensions.get("x-format").and_then(|s| s.as_str()) { Some("date") => Ty::Date { serialization: hir::DateSerialization::Integer, }, @@ -60,7 +60,7 @@ pub fn schema_to_ty(schema: &Schema, spec: &OpenAPI) -> Ty { } SchemaKind::Type(oa::Type::Boolean {}) => Ty::Boolean, SchemaKind::Type(oa::Type::Object(_)) => { - if let Some(title) = &schema.schema_data.title { + if let Some(title) = &schema.title { Ty::model(&title) } else { Ty::Any @@ -69,8 +69,7 @@ pub fn schema_to_ty(schema: &Schema, spec: &OpenAPI) -> Ty { SchemaKind::Type(oa::Type::Array(ArrayType { items: Some(item), .. })) => { - let inner = item.unbox(); - let inner = schema_ref_to_ty(&inner, spec); + let inner = schema_ref_to_ty(&item, spec); Ty::Array(Box::new(inner)) } SchemaKind::Type(oa::Type::Array(ArrayType { items: None, .. })) => { @@ -95,7 +94,7 @@ pub fn schema_to_ty(schema: &Schema, spec: &OpenAPI) -> Ty { pub fn is_primitive(schema: &Schema, spec: &OpenAPI) -> bool { use openapiv3::SchemaKind::*; use openapiv3::Type::*; - match &schema.schema_kind { + match &schema.kind { Type(String(_)) => true, Type(Number(_)) => true, Type(Integer(_)) => true, @@ -103,7 +102,6 @@ pub fn is_primitive(schema: &Schema, spec: &OpenAPI) -> bool { Type(Array(ArrayType { items: Some(inner), .. })) => { - let inner = inner.unbox(); let inner = inner.resolve(spec); is_primitive(inner, spec) } diff --git a/core/src/options.rs b/core/src/options.rs index ba81170..bd68624 100644 --- a/core/src/options.rs +++ b/core/src/options.rs @@ -47,11 +47,7 @@ impl PackageConfig { } pub fn authenticator_name(&self) -> String { - format!("{} Authentication", self.service_name) - } - - pub fn bare_client_name(&self) -> String { - "Client".to_string() + format!("{} Auth", self.service_name) } pub fn env_var(&self, name: &str) -> Literal { diff --git a/core/template/rust/src/lib.rs b/core/template/rust/src/lib.rs index 73568ae..b445906 100644 --- a/core/template/rust/src/lib.rs +++ b/core/template/rust/src/lib.rs @@ -3,4 +3,5 @@ pub mod model; pub mod request; pub use httpclient::{Error, Result, InMemoryResponseExt}; +use std::sync::{Arc, OnceLock}; use crate::model::*; diff --git a/hir/Cargo.toml b/hir/Cargo.toml index adde088..ff24591 100644 --- a/hir/Cargo.toml +++ b/hir/Cargo.toml @@ -12,5 +12,5 @@ path = "src/lib.rs" anyhow = "1.0.75" convert_case = "0.6.0" serde_json = "1.0.108" -openapiv3-extended = "3.0.0" +openapiv3-extended = "4" clap = { version = "4.4.10", features = ["derive"] } diff --git a/hir/src/lib.rs b/hir/src/lib.rs index 9cbf963..471d299 100644 --- a/hir/src/lib.rs +++ b/hir/src/lib.rs @@ -7,15 +7,14 @@ use std::string::{String, ToString}; use anyhow::Result; use convert_case::{Case, Casing}; +use openapiv3 as oa; + pub use doc::*; +pub use lang::*; mod doc; mod lang; -pub use lang::*; - -use openapiv3 as oa; - #[derive(Debug, Clone, Copy, PartialEq)] pub enum DateSerialization { Iso8601, @@ -149,11 +148,11 @@ pub enum Location { impl From<&oa::Parameter> for Location { fn from(p: &oa::Parameter) -> Self { - match p { - oa::Parameter::Query { .. } => Location::Query, - oa::Parameter::Header { .. } => Location::Header, - oa::Parameter::Path { .. } => Location::Path, - oa::Parameter::Cookie { .. } => Location::Cookie, + match p.kind { + oa::ParameterKind::Query { .. } => Location::Query, + oa::ParameterKind::Header { .. } => Location::Header, + oa::ParameterKind::Path { .. } => Location::Path, + oa::ParameterKind::Cookie { .. } => Location::Cookie, } } } @@ -174,22 +173,21 @@ impl std::fmt::Display for ParamKey { } #[derive(Debug, Clone)] -pub struct AuthorizationParameter { +pub struct AuthParam { pub name: String, - pub env_var: String, pub location: AuthLocation, } -impl AuthorizationParameter { - pub fn env_var_for_service(&self, service_name: &str) -> String { - let service = service_name.to_case(Case::ScreamingSnake); - if self.env_var.starts_with(&service) { - self.env_var.clone() - } else { - format!("{}_{}", service, self.env_var) - } - } -} +// impl AuthParam { +// pub fn env_var_for_service(&self, service_name: &str) -> String { +// let service = service_name.to_case(Case::ScreamingSnake); +// if self.env_var.starts_with(&service) { +// self.env_var.clone() +// } else { +// format!("{}_{}", service, self.env_var) +// } +// } +// } #[derive(Debug, Clone)] pub enum AuthLocation { @@ -202,11 +200,26 @@ pub enum AuthLocation { } #[derive(Debug, Clone)] -pub struct AuthorizationStrategy { +pub enum AuthStrategy { + Token(TokenAuth), + OAuth2(Oauth2Auth), + NoAuth, +} + +#[derive(Debug, Clone)] +pub struct TokenAuth { pub name: String, - pub fields: Vec, + pub fields: Vec, } +#[derive(Debug, Clone)] +pub struct Oauth2Auth { + pub auth_url: String, + pub exchange_url: String, + pub refresh_url: String, + // scope name, scope description + pub scopes: Vec<(String, String)>, +} #[derive(Debug, Default, Clone)] pub struct HirField { @@ -315,7 +328,7 @@ pub struct HirSpec { pub schemas: BTreeMap, pub servers: BTreeMap, - pub security: Vec, + pub security: Vec, pub api_docs_url: Option, } @@ -339,6 +352,10 @@ impl ServerStrategy { } } +pub fn qualified_env_var(service: &str, var_name: &str) -> String { + format!("{} {}", service, var_name).to_case(Case::ScreamingSnake) +} + impl HirSpec { pub fn get_record(&self, name: &str) -> Result<&Record> { self.schemas.get(name).ok_or_else(|| anyhow::anyhow!("No record named {}", name)) @@ -369,8 +386,18 @@ impl HirSpec { env_vars.push(env); } for strategy in &self.security { - for param in &strategy.fields { - env_vars.push(param.env_var_for_service(service_name)); + match strategy { + AuthStrategy::Token(t) => { + for f in &t.fields { + let qev = qualified_env_var(service_name, &f.name); + env_vars.push(qev); + } + } + AuthStrategy::OAuth2(_) => { + env_vars.push(qualified_env_var(service_name, "CLIENT_ID")); + env_vars.push(qualified_env_var(service_name, "CLIENT_SECRET")); + } + AuthStrategy::NoAuth => {} } } env_vars @@ -381,7 +408,14 @@ impl HirSpec { } pub fn has_basic_auth(&self) -> bool { - self.security.iter().any(|s| s.fields.iter().any(|p| matches!(p.location, AuthLocation::Basic))) + self.security.iter().any(|s| matches!(s, AuthStrategy::Token(_))) + } + + pub fn oauth2_auth(&self) -> Option<&Oauth2Auth> { + self.security.iter().filter_map(|s| match s { + AuthStrategy::OAuth2(o) => Some(o), + _ => None, + }).next() } } diff --git a/libninja/Cargo.toml b/libninja/Cargo.toml index 0c28815..ba023eb 100644 --- a/libninja/Cargo.toml +++ b/libninja/Cargo.toml @@ -22,7 +22,7 @@ serde_json = "1.0.100" serde_yaml = "0.9.22" syn = "2.0" tokio = { version = "1.29.1", features = ["full"] } -openapiv3-extended = { version = "3.1", features = ["v2"] } +openapiv3-extended = { version = "4", features = ["v2"] } convert_case = "0.6.0" prettyplease = "0.2" clap = { version = "4.4.10", features = ["derive"] } diff --git a/libninja/src/custom.rs b/libninja/src/custom.rs index 99e4cd9..adb46cc 100644 --- a/libninja/src/custom.rs +++ b/libninja/src/custom.rs @@ -23,12 +23,12 @@ pub fn modify_recurly(mut yaml: Value) -> OpenAPI { pub fn modify_openai(mut yaml: Value) -> OpenAPI { let mut spec: OpenAPI = serde_yaml::from_value(yaml).expect("Could not structure OpenAPI file."); - spec.security = Some(vec![{ + spec.security = vec![{ let mut map = indexmap::IndexMap::new(); map.insert("Bearer".to_string(), vec![]); map - }]); - spec.components.as_mut().unwrap().security_schemes.insert("Bearer".to_string(), openapiv3::ReferenceOr::Item(openapiv3::SecurityScheme::HTTP { + }]; + spec.security_schemes.insert("Bearer".to_string(), openapiv3::ReferenceOr::Item(openapiv3::SecurityScheme::HTTP { scheme: "bearer".to_string(), bearer_format: None, description: None, diff --git a/libninja/src/rust.rs b/libninja/src/rust.rs index da6231c..d443b48 100644 --- a/libninja/src/rust.rs +++ b/libninja/src/rust.rs @@ -19,11 +19,11 @@ use format::format_code; use ln_core::{copy_builtin_files, copy_builtin_templates, create_context, get_template_file, prepare_templates}; use ::mir::{Visibility, Import, File}; use ln_core::fs; -use hir::{HirSpec, IntegerSerialization, DateSerialization, Location, Parameter}; +use hir::{HirSpec, IntegerSerialization, DateSerialization, Location, Parameter, AuthStrategy, Oauth2Auth, qualified_env_var}; use mir::Ident; use crate::{add_operation_models, extract_spec, PackageConfig, OutputConfig}; -use crate::rust::client::build_Client_authenticate; +use crate::rust::client::{build_Client_authenticate, server_url}; pub use crate::rust::codegen::generate_example; use crate::rust::codegen::{codegen_function, sanitize_filename, ToRustCode}; use crate::rust::io::write_rust_file_to_path; @@ -47,6 +47,7 @@ pub struct Extras { currency: bool, integer_date_serialization: bool, basic_auth: bool, + oauth2: bool, } impl Extras { @@ -85,7 +86,8 @@ pub fn calculate_extras(spec: &HirSpec) -> Extras { } } } - let basic_auth = spec.security.iter().any(|f| f.fields.iter().any(|f| matches!(f.location, hir::AuthLocation::Basic))); + let basic_auth = spec.has_basic_auth(); + let oauth2 = spec.oauth2_auth().is_some(); Extras { null_as_zero, date_serialization, @@ -93,6 +95,7 @@ pub fn calculate_extras(spec: &HirSpec) -> Extras { currency, option_i64_str, basic_auth, + oauth2, } } @@ -227,6 +230,71 @@ fn write_model_module(spec: &HirSpec, opts: &PackageConfig) -> Result<()> { Ok(()) } +fn static_shared_http_client(spec: &HirSpec, opt: &PackageConfig) -> TokenStream { + let url = server_url(spec, opt); + quote! { + static SHARED_HTTPCLIENT: OnceLock = OnceLock::new(); + + pub fn default_http_client() -> httpclient::Client { + httpclient::Client::new() + .base_url(#url) + } + + /// Use this method if you want to add custom middleware to the httpclient. + /// Example usage: + /// + /// ``` + /// init_http_client(|| { + /// default_http_client() + /// .with_middleware(..) + /// }); + /// ``` + pub fn init_http_client(init: fn() -> httpclient::Client) { + SHARED_HTTPCLIENT.get_or_init(init); + } + + fn shared_http_client() -> &'static httpclient::Client { + SHARED_HTTPCLIENT.get_or_init(default_http_client) + } + } +} + +fn shared_oauth2_flow(auth: &Oauth2Auth, spec: &HirSpec, opts: &PackageConfig) -> TokenStream { + let service_name = opts.service_name.as_str(); + + let client_id = qualified_env_var(service_name, "client id"); + let client_id_expect = format!("{} must be set", client_id); + let client_secret = qualified_env_var(service_name, "client secret"); + let client_secret_expect = format!("{} must be set", client_secret); + let redirect_uri = qualified_env_var(service_name, "redirect uri"); + let redirect_uri_expect = format!("{} must be set", redirect_uri); + + let init_endpoint = auth.auth_url.as_str(); + let exchange_endpoint = auth.exchange_url.as_str(); + let refresh_endpoint = auth.refresh_url.as_str(); + quote! { + static SHARED_OAUTH2FLOW: OnceLock = OnceLock::new(); + + pub fn init_oauth2_flow(init: fn() -> httpclient_oauth2::OAuth2Flow) { + SHARED_OAUTH2FLOW.get_or_init(init); + } + + fn shared_oauth2_flow() -> &'static httpclient_oauth2::OAuth2Flow { + let client_id = std::env::var(#client_id).expect(#client_id_expect); + let client_secret = std::env::var(#client_secret).expect(#client_secret_expect); + let redirect_uri = std::env::var(#redirect_uri).expect(#redirect_uri_expect); + SHARED_OAUTH2FLOW.get_or_init(|| httpclient_oauth2::OAuth2Flow { + client_id, + client_secret, + init_endpoint: #init_endpoint.to_string(), + exchange_endpoint: #exchange_endpoint.to_string(), + refresh_endpoint: #refresh_endpoint.to_string(), + redirect_uri, + }) + } + } +} + /// Generates the client code for a given OpenAPI specification. fn write_lib_rs(spec: &HirSpec, extras: &Extras, opts: &PackageConfig) -> Result<()> { let src_path = opts.dest.join("src"); @@ -282,10 +350,20 @@ fn write_lib_rs(spec: &HirSpec, extras: &Extras, opts: &PackageConfig) -> Result #impl_ServiceAuthentication } }).unwrap_or_default(); + let static_shared_http_client = static_shared_http_client(spec, opts); + let oauth = spec.security.iter().filter_map(|s| match s { + AuthStrategy::OAuth2(auth) => Some(auth), + _ => None, + }).next(); + let shared_oauth2_flow = oauth.map(|auth| { + shared_oauth2_flow(auth, spec, opts) + }).unwrap_or_default(); let code = quote! { #base64_import #serde + #static_shared_http_client + #shared_oauth2_flow #fluent_request #struct_Client #impl_Client diff --git a/libninja/src/rust/cargo_toml.rs b/libninja/src/rust/cargo_toml.rs index 3615f79..4fcef98 100644 --- a/libninja/src/rust/cargo_toml.rs +++ b/libninja/src/rust/cargo_toml.rs @@ -44,7 +44,7 @@ pub fn update_cargo_toml(extras: &Extras, opts: &OutputConfig, context: &HashMap } let package_version = package.version().to_string(); - ensure_dependency(&mut m.dependencies, "httpclient", "0.19.0", &[]); + ensure_dependency(&mut m.dependencies, "httpclient", "0.20.0", &[]); ensure_dependency(&mut m.dependencies, "serde", "1.0.137", &["derive"]); ensure_dependency(&mut m.dependencies, "serde_json", "1.0.81", &[]); ensure_dependency(&mut m.dependencies, "futures", "0.3.25", &[]); @@ -76,6 +76,9 @@ pub fn update_cargo_toml(extras: &Extras, opts: &OutputConfig, context: &HashMap if extras.basic_auth { ensure_dependency(&mut m.dependencies, "base64", "0.21.0", &[]); } + if extras.oauth2 { + ensure_dependency(&mut m.dependencies, "httpclient_oauth2", "0.1.0", &[]); + } m.example = vec![]; fs::write_file(&cargo, &toml::to_string(&m).unwrap())?; Ok(package_version) diff --git a/libninja/src/rust/client.rs b/libninja/src/rust/client.rs index 7c27614..9f3596b 100644 --- a/libninja/src/rust/client.rs +++ b/libninja/src/rust/client.rs @@ -3,8 +3,8 @@ use openapiv3::OpenAPI; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; -use hir::{AuthLocation, AuthorizationStrategy, DocFormat, Location, Parameter, ServerStrategy, Doc, HirSpec, Language, Operation}; -use mir::{field, Function, Ident}; +use hir::{AuthLocation, AuthStrategy, DocFormat, Location, Parameter, ServerStrategy, Doc, HirSpec, Language, Operation, qualified_env_var}; +use mir::{ArgIdent, Function, Ident}; use mir::{Class, Field, FnArg, Visibility}; use ln_core::PackageConfig; @@ -12,40 +12,37 @@ use crate::rust::codegen::ToRustCode; use crate::rust::codegen::ToRustIdent; use crate::rust::codegen::ToRustType; -fn build_Client_from_env(spec: &HirSpec, opt: &PackageConfig) -> Function { - let declare_url = match spec.server_strategy() { - ServerStrategy::Single(url) => quote! { - .base_url(#url) - }, + +pub fn server_url(spec: &HirSpec, opt: &PackageConfig) -> TokenStream { + match spec.server_strategy() { + ServerStrategy::Single(url) => quote!(#url), ServerStrategy::Env => { - let var = opt.env_var("env").0; + let var = qualified_env_var(&opt.service_name, "env"); let error = format!("Missing environment variable {}", var); - quote! { - .base_url(std::env::var(#var).expect(#error).as_str()) - } + quote!(std::env::var(#var).expect(#error).as_str()) } ServerStrategy::BaseUrl => { - let var = opt.env_var("base_url").0; + let var = qualified_env_var(&opt.service_name, "base_url"); let error = format!("Missing environment variable {}", var); - quote! { - .base_url(std::env::var(#var).expect(#error).as_str()) - } + quote!(std::env::var(#var).expect(#error).as_str()) } - }; + } +} +fn build_Client_from_env(spec: &HirSpec, opt: &PackageConfig) -> Function { let auth_struct = opt.authenticator_name().to_rust_struct(); let body = if spec.has_security() { let auth_struct = opt.authenticator_name().to_rust_struct(); quote! { Self { - client: httpclient::Client::new()#declare_url, + client: shared_http_client(), authentication: #auth_struct::from_env(), } } } else { quote! { Self { - client: httpclient::Client::new()#declare_url + client: shared_http_client() } } }; @@ -59,17 +56,65 @@ fn build_Client_from_env(spec: &HirSpec, opt: &PackageConfig) -> Function Class { +fn build_Client_with_auth(spec: &HirSpec, opt: &PackageConfig) -> Function { + let auth_struct = opt.authenticator_name().to_rust_struct(); + let body = quote! { + Self { + client: shared_http_client(), + authentication + } + }; + Function { + name: Ident::new("with_auth"), + public: true, + ret: quote!(Self), + body, + args: vec![FnArg { + name: ArgIdent::Ident("authentication".to_string()), + ty: quote!(#auth_struct), + default: None, + treatment: None, + }], + ..Function::default() + } +} + +pub fn struct_Client(spec: &HirSpec, opt: &PackageConfig) -> Class { let auth_struct_name = opt.authenticator_name().to_rust_struct(); let mut instance_fields = vec![ - field!(pub client: quote!(httpclient::Client)), + Field { + name: "client".to_string(), + ty: quote!(&'static httpclient::Client), + ..Field::default() + } ]; - if mir_spec.has_security() { - instance_fields.push(field!(authentication: quote!(#auth_struct_name))); + if spec.has_security() { + instance_fields.push(Field { + name: "authentication".to_string(), + ty: quote!(#auth_struct_name), + ..Field::default() + }); } - let class_methods = vec![build_Client_from_env(mir_spec, opt)]; + let mut class_methods = vec![ + build_Client_from_env(spec, opt) + ]; + if spec.has_security() { + class_methods.push(build_Client_with_auth(spec, opt)); + } else { + class_methods.push(Function { + name: Ident::new("new"), + public: true, + ret: quote!(Self), + body: quote! { + Self { + client: shared_http_client() + } + }, + ..Function::default() + }); + } Class { name: opt.client_name().to_rust_struct(), instance_fields, @@ -149,46 +194,63 @@ pub fn impl_ServiceClient_paths(spec: &HirSpec) -> Vec { } pub fn authenticate_variant( - req: &AuthorizationStrategy, + req: &AuthStrategy, opt: &PackageConfig, ) -> TokenStream { + let auth_struct = opt.authenticator_name().to_rust_struct(); - let variant_name = req.name.to_rust_struct(); - let fields = req - .fields - .iter() - .map(|field| { - let field = syn::Ident::new( - &field.name.to_case(Case::Snake), - proc_macro2::Span::call_site(), - ); - quote! { #field } - }) - .collect::>(); + match req { + AuthStrategy::Token(req) => { + let variant_name = req.name.to_rust_struct(); + let fields = req + .fields + .iter() + .map(|field| { + let field = syn::Ident::new( + &field.name.to_case(Case::Snake), + proc_macro2::Span::call_site(), + ); + quote! { #field } + }) + .collect::>(); - let set_values = req - .fields - .iter() - .map(|sec_field| { - let field = syn::Ident::new( - &sec_field.name.to_case(Case::Snake), - proc_macro2::Span::call_site(), - ); - match &sec_field.location { - AuthLocation::Header { key } => quote! { r = r.header(#key, #field); }, - AuthLocation::Basic => quote! { r = r.basic_auth(#field); }, - AuthLocation::Bearer => quote! { r = r.bearer_auth(#field); }, - AuthLocation::Token => quote! { r = r.token_auth(#field); }, - AuthLocation::Query { key } => quote! { r = r.query(#key, #field); }, - AuthLocation::Cookie { key } => quote! { r = r.cookie(#key, #field); }, - } - }) - .collect::>(); + let set_values = req + .fields + .iter() + .map(|sec_field| { + let field = syn::Ident::new( + &sec_field.name.to_case(Case::Snake), + proc_macro2::Span::call_site(), + ); + match &sec_field.location { + AuthLocation::Header { key } => quote! { r = r.header(#key, #field); }, + AuthLocation::Basic => quote! { r = r.basic_auth(#field); }, + AuthLocation::Bearer => quote! { r = r.bearer_auth(#field); }, + AuthLocation::Token => quote! { r = r.token_auth(#field); }, + AuthLocation::Query { key } => quote! { r = r.query(#key, #field); }, + AuthLocation::Cookie { key } => quote! { r = r.cookie(#key, #field); }, + } + }) + .collect::>(); - quote! { - #auth_struct::#variant_name { #(#fields,)*} => { - #(#set_values)* + quote! { + #auth_struct::#variant_name { #(#fields,)*} => { + #(#set_values)* + } + } + } + AuthStrategy::OAuth2(_) => { + quote! { + #auth_struct::OAuth2 { middleware } => { + r.middlewares.insert(0, middleware.clone()); + } + } + } + AuthStrategy::NoAuth => { + quote! { + #auth_struct::NoAuth => {} + } } } } @@ -209,66 +271,18 @@ pub fn build_Client_authenticate(spec: &HirSpec, opt: &PackageConfig) -> TokenSt } } -fn build_new_fn(security: bool, opt: &PackageConfig) -> TokenStream { - if security { - let auth_struct_name = opt.authenticator_name().to_rust_struct(); - quote! { - pub fn new(url: &str, authentication: #auth_struct_name) -> Self { - let client = httpclient::Client::new() - .base_url(url); - Self { - client, - authentication, - } - } - } - } else { - quote! { - pub fn new(url: &str) -> Self { - let client = httpclient::Client::new() - .base_url(url); - Self { - client - } - } - } - } -} - pub fn impl_Client(spec: &HirSpec, opt: &PackageConfig) -> TokenStream { let client_struct_name = opt.client_name().to_rust_struct(); let path_fns = impl_ServiceClient_paths(spec); let security = spec.has_security(); - let new_fn = build_new_fn(security, opt); - let authenticate = if security { + let authenticate = security.then(|| { build_Client_authenticate(spec, opt) - } else { - TokenStream::new() - }; - let with_authentication = if security { - let auth_struct_name = opt.authenticator_name().to_rust_struct(); - quote! { - pub fn with_authentication(mut self, authentication: #auth_struct_name) -> Self { - self.authentication = authentication; - self - } - } - } else { - TokenStream::new() - }; + }).unwrap_or_default(); quote! { impl #client_struct_name { - #new_fn - #with_authentication #authenticate - - pub fn with_middleware(mut self, middleware: M) -> Self { - self.client = self.client.with_middleware(middleware); - self - } - #(#path_fns)* } } @@ -278,11 +292,25 @@ pub fn struct_Authentication(mir_spec: &HirSpec, opt: &PackageConfig) -> TokenSt let auth_struct_name = opt.authenticator_name().to_rust_struct(); let variants = mir_spec.security.iter().map(|strategy| { - let variant_name = strategy.name.to_rust_struct(); - let args = strategy.fields.iter().map(|f| f.name.to_rust_ident()); - quote! { - #variant_name { - #(#args: String),* + match strategy { + AuthStrategy::Token(strategy) => { + let variant_name = strategy.name.to_rust_struct(); + let args = strategy.fields.iter().map(|f| f.name.to_rust_ident()); + quote! { + #variant_name { + #(#args: String),* + } + } + } + AuthStrategy::OAuth2(_) => { + quote! { + OAuth2 { middleware: Arc } + } + } + AuthStrategy::NoAuth => { + quote! { + NoAuth + } } } }); @@ -293,52 +321,74 @@ pub fn struct_Authentication(mir_spec: &HirSpec, opt: &PackageConfig) -> TokenSt } } -fn build_Authentication_from_env(hir_spec: &HirSpec, service_name: &str) -> TokenStream { - let first_variant = hir_spec.security.first() - .unwrap(); - let fields = first_variant - .fields - .iter() - .map(|f| { - let basic = matches!(f.location, AuthLocation::Basic); - let field = - syn::Ident::new(&f.name.to_case(Case::Snake), proc_macro2::Span::call_site()); - let expect = format!("Environment variable {} is not set.", f.env_var); - let env_var = &f.env_var_for_service(service_name); - if basic { - quote! { - #field: { - let value = std::env::var(#env_var).expect(#expect); - STANDARD_NO_PAD.encode(value) +fn build_Authentication_from_env(spec: &HirSpec, service_name: &str) -> TokenStream { + for strat in &spec.security { + match strat { + AuthStrategy::Token(strat) => { + let fields = strat.fields + .iter() + .map(|f| { + let basic = matches!(f.location, AuthLocation::Basic); + let field = + syn::Ident::new(&f.name.to_case(Case::Snake), proc_macro2::Span::call_site()); + let env_var = qualified_env_var(service_name, &f.name); + let expect = format!("Environment variable {} is not set.", env_var); + if basic { + quote! { + #field: { + let value = std::env::var(#env_var).expect(#expect); + STANDARD_NO_PAD.encode(value) + } + } + } else { + quote! { + #field: std::env::var(#env_var).expect(#expect) + } + } + }) + .collect::>(); + let variant_name = syn::Ident::new( + &strat.name.to_case(Case::Pascal), + proc_macro2::Span::call_site(), + ); + return quote! { + pub fn from_env() -> Self { + Self::#variant_name { + #(#fields),* + } } } - } else { - quote! { - #field: std::env::var(#env_var).expect(#expect) - } } - }) - .collect::>(); - let variant_name = syn::Ident::new( - &first_variant.name.to_case(Case::Pascal), - proc_macro2::Span::call_site(), - ); - quote! { - pub fn from_env() -> Self { - Self::#variant_name { - #(#fields),* + AuthStrategy::NoAuth => { + return quote! { + pub fn from_env() -> Self { + Self::NoAuth + } + } } + _ => {} } } + TokenStream::new() } -pub fn impl_Authentication(mir_spec: &HirSpec, opt: &PackageConfig) -> TokenStream { +pub fn impl_Authentication(spec: &HirSpec, opt: &PackageConfig) -> TokenStream { let auth_struct_name = opt.authenticator_name().to_rust_struct(); - let from_env = build_Authentication_from_env(mir_spec, &opt.service_name); + let from_env = build_Authentication_from_env(spec, &opt.service_name); + let oauth2 = spec.oauth2_auth().map(|oauth| { + quote! { + pub fn oauth2(access: String, refresh: String) -> Self { + let mw = shared_oauth2_flow().middleware_from_pieces(access, refresh, httpclient_oauth2::TokenType::Bearer); + Self::OAuth2 { middleware: Arc::new(mw) } + } + } + + }).unwrap_or_default(); quote! { impl #auth_struct_name { #from_env + #oauth2 } } } diff --git a/libninja/src/rust/codegen.rs b/libninja/src/rust/codegen.rs index 7832fa4..b8394cf 100644 --- a/libninja/src/rust/codegen.rs +++ b/libninja/src/rust/codegen.rs @@ -391,12 +391,11 @@ impl ToRustCode for ParamKey { /// If you can use reference types to represent the data (e.g. &str instead of String) pub fn is_referenceable(schema: &Schema, spec: &OpenAPI) -> bool { - match &schema.schema_kind { + match &schema.kind { SchemaKind::Type(openapiv3::Type::String(_)) => true, SchemaKind::Type(openapiv3::Type::Array(ArrayType { items: Some(inner), .. })) => { - let inner = inner.unbox(); let inner = inner.resolve(spec); is_primitive(inner, spec) } diff --git a/libninja/tests/all_of/main.rs b/libninja/tests/all_of/main.rs index 819dcf6..6e1260a 100644 --- a/libninja/tests/all_of/main.rs +++ b/libninja/tests/all_of/main.rs @@ -28,10 +28,10 @@ fn formatted_code(record: Record, spec: &HirSpec) -> String { #[test] fn test_transaction() { let mut spec = OpenAPI::default(); - spec.add_schema("TransactionBase", Schema::new_object()); - spec.add_schema("TransactionCode", Schema::new_string()); - spec.add_schema("PersonalFinanceCategory", Schema::new_string()); - spec.add_schema("TransactionCounterparty", Schema::new_string()); + spec.schemas.insert("TransactionBase", Schema::new_object()); + spec.schemas.insert("TransactionCode", Schema::new_string()); + spec.schemas.insert("PersonalFinanceCategory", Schema::new_string()); + spec.schemas.insert("TransactionCounterparty", Schema::new_string()); let mut result = HirSpec::default(); extract_records(&spec, &mut result).unwrap(); @@ -44,7 +44,7 @@ fn test_transaction() { #[test] fn test_nullable_doesnt_deref() { let mut spec = OpenAPI::default(); - spec.add_schema("RecipientBACS", Schema::new_object()); + spec.schemas.insert("RecipientBACS", Schema::new_object()); let record = record_for_schema("PaymentInitiationOptionalRestrictionBacs", RESTRICTION_BACS, &spec); let code = formatted_code(record, &HirSpec::default()); diff --git a/libninja/tests/regression/main.rs b/libninja/tests/regression/main.rs index 5624bc4..35cc494 100644 --- a/libninja/tests/regression/main.rs +++ b/libninja/tests/regression/main.rs @@ -17,8 +17,8 @@ fn record_for_schema(name: &str, schema: &str, spec: &OpenAPI) -> Record { #[test] fn test_link_token_create() { let mut spec = OpenAPI::default(); - spec.add_schema("UserAddress", Schema::new_object()); - spec.add_schema("UserIDNumber", Schema::new_string()); + spec.schemas.insert("UserAddress", Schema::new_object()); + spec.schemas.insert("UserIDNumber", Schema::new_string()); let record = record_for_schema("LinkTokenCreateRequestUser", LINK_TOKEN_CREATE, &spec); let Record::Struct(struc) = record else { panic!("expected struct"); diff --git a/mir/src/macro.rs b/mir/src/macro.rs index 378cd75..1a2321b 100644 --- a/mir/src/macro.rs +++ b/mir/src/macro.rs @@ -52,34 +52,6 @@ macro_rules! arg { }; } - -#[macro_export] -macro_rules! field { - (pub(crate) $name:ident : $ty:expr) => { - ::mir::Field { - name: stringify!($name).to_string(), - ty: ($ty).into(), - visibility: ::mir::Visibility::Crate, - ..Field::default() - } - }; - (pub $name:ident : $ty:expr) => { - ::mir::Field { - name: stringify!($name).to_string(), - ty: ($ty).into(), - visibility: ::mir::Visibility::Public, - ..Field::default() - } - }; - ($name:ident : $ty:expr) => { - ::mir::Field { - name: stringify!($name).to_string(), - ty: ($ty).into(), - ..Field::default() - } - }; -} - /// A literal value. #[macro_export] macro_rules! lit {