Skip to content

Commit

Permalink
Merge pull request #64 from Ten0/diesel_master_compat_2815
Browse files Browse the repository at this point in the history
Compatibility with diesel master after 72bfb356
  • Loading branch information
adwhit authored Dec 19, 2021
2 parents e4286e3 + d6daa39 commit a339d70
Show file tree
Hide file tree
Showing 12 changed files with 252 additions and 197 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
matrix:
rust:
- stable
- 1.40.0
- 1.48.0

services:
postgres:
Expand Down
131 changes: 84 additions & 47 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,41 @@ use syn::*;
/// # Attributes
///
/// ## Type attributes
///
/// * `#[PgType = "new_enum"]` specifies postgres name for the enum type. If ommitted, uses the enum's name in snake_case.
/// * `#[PgSchema = "schema"]` specifies the postgres schema containing the enum type. If omitted, diesel uses the default search path, but this can cause problems with caching.
/// * `#[DieselType = "NewEnumMapping"]` specifies the name for the diesel type. If omitted, uses the name + `Mapping`.
///
/// * `#[DieselExistingType = "crate::schema::sql_types::NewEnum"]` specifies the name for the corresponding diesel type that was already created by the diesel CLI. If omitted, uses `crate::schema::sql_types::EnumName`.
/// * `#[DieselType = "NewEnumMapping"]` specifies the name for the diesel type to create for Mysql or Sqlite. If omitted, uses the name + `Mapping`.
/// * `#[DbValueStyle = "snake_case"]` specifies a renaming style from each of the rust enum variants to each of the database variants. Either `camelCase`, `kebab-case`, `PascalCase`, `SCREAMING_SNAKE_CASE`, `snake_case`, `verbatim`. If omitted, uses `snake_case`.
///
/// ## Variant attributes
///
/// * `#[db_rename = "variant"]` specifies the db name for a specific variant.
#[proc_macro_derive(DbEnum, attributes(PgType, PgSchema, DieselType, DbValueStyle, db_rename))]
#[proc_macro_derive(
DbEnum,
attributes(DieselType, DieselExistingType, DbValueStyle, db_rename)
)]
pub fn derive(input: TokenStream) -> TokenStream {
let input: DeriveInput = parse_macro_input!(input as DeriveInput);
let db_type =
type_from_attrs(&input.attrs, "PgType").unwrap_or(input.ident.to_string().to_snake_case());
let db_schema =
type_from_attrs(&input.attrs, "PgSchema");
let diesel_mapping =
let diesel_existing_mapping = type_from_attrs(&input.attrs, "DieselExistingType")
.unwrap_or(format!("crate::schema::sql_types::{}", input.ident));
let new_diesel_mapping =
type_from_attrs(&input.attrs, "DieselType").unwrap_or(format!("{}Mapping", input.ident));

// Maintain backwards compatibility by defaulting to snake case.
let case_style =
type_from_attrs(&input.attrs, "DbValueStyle").unwrap_or("snake_case".to_string());
let case_style = CaseStyle::from_string(&case_style);

let diesel_mapping = Ident::new(diesel_mapping.as_ref(), Span::call_site());
let diesel_existing_mapping: proc_macro2::TokenStream =
diesel_existing_mapping.parse().unwrap();
let new_diesel_mapping = Ident::new(new_diesel_mapping.as_ref(), Span::call_site());
let quoted = if let Data::Enum(syn::DataEnum {
variants: data_variants,
..
}) = input.data
{
generate_derive_enum_impls(
&db_type,
db_schema.as_deref(),
&diesel_mapping,
&diesel_existing_mapping,
&new_diesel_mapping,
case_style,
&input.ident,
&data_variants,
Expand Down Expand Up @@ -103,9 +104,8 @@ impl CaseStyle {
}

fn generate_derive_enum_impls(
db_type: &str,
db_schema: Option<&str>,
diesel_mapping: &Ident,
diesel_existing_mapping: &proc_macro2::TokenStream,
new_diesel_mapping: &Ident,
case_style: CaseStyle,
enum_ty: &Ident,
variants: &syn::punctuated::Punctuated<Variant, syn::token::Comma>,
Expand Down Expand Up @@ -137,30 +137,72 @@ fn generate_derive_enum_impls(
let variants_rs: &[proc_macro2::TokenStream] = &variant_ids;
let variants_db: &[LitByteStr] = &variants_db;

let common_impl =
generate_common_impl(db_type, db_schema, diesel_mapping, enum_ty, variants_rs, variants_db);
let (common_diesel_mapping, common_diesel_mapping_use) =
if cfg!(feature = "mysql") || cfg!(feature = "sqlite") {
let new_diesel_mapping_impl = generate_common_diesel_mapping(new_diesel_mapping);
let common_impls_on_new_diesel_mapping = generate_common_impls(
&quote! { #new_diesel_mapping },
enum_ty,
variants_rs,
variants_db,
);
(
quote! {
#new_diesel_mapping_impl
#common_impls_on_new_diesel_mapping
},
quote! {
pub use self::#modname::#new_diesel_mapping;
},
)
} else {
(quote! {}, quote! {})
};

let pg_impl = if cfg!(feature = "postgres") {
generate_postgres_impl(diesel_mapping, enum_ty, variants_rs, variants_db)
let common_impls_on_existing_diesel_mapping =
generate_common_impls(diesel_existing_mapping, enum_ty, variants_rs, variants_db);
let postgres_impl =
generate_postgres_impl(diesel_existing_mapping, enum_ty, variants_rs, variants_db);
quote! {
#common_impls_on_existing_diesel_mapping
#postgres_impl
}
} else {
quote! {}
};
let mysql_impl = if cfg!(feature = "mysql") {
generate_mysql_impl(diesel_mapping, enum_ty, variants_rs, variants_db)
generate_mysql_impl(new_diesel_mapping, enum_ty, variants_rs, variants_db)
} else {
quote! {}
};
let sqlite_impl = if cfg!(feature = "sqlite") {
generate_sqlite_impl(diesel_mapping, enum_ty, variants_rs, variants_db)
generate_sqlite_impl(new_diesel_mapping, enum_ty, variants_rs, variants_db)
} else {
quote! {}
};

let imports = quote! {
use super::*;
use diesel::Queryable;
use diesel::backend::{self, Backend};
use diesel::expression::AsExpression;
use diesel::expression::bound::Bound;
use diesel::row::Row;
use diesel::sql_types::*;
use diesel::serialize::{self, ToSql, IsNull, Output};
use diesel::deserialize::{self, FromSql};
use diesel::query_builder::QueryId;
use std::io::Write;
};

let quoted = quote! {
pub use self::#modname::#diesel_mapping;
#common_diesel_mapping_use
#[allow(non_snake_case)]
mod #modname {
#common_impl
#imports

#common_diesel_mapping
#pg_impl
#mysql_impl
#sqlite_impl
Expand All @@ -181,33 +223,22 @@ fn stylize_value(value: &str, style: CaseStyle) -> String {
}
}

fn generate_common_impl(
db_type: &str,
db_schema: Option<&str>,
diesel_mapping: &Ident,
fn generate_common_diesel_mapping(new_diesel_mapping: &Ident) -> proc_macro2::TokenStream {
quote! {
#[derive(SqlType, Clone)]
#[mysql_type = "Enum"]
#[sqlite_type = "Text"]
pub struct #new_diesel_mapping;
}
}

fn generate_common_impls(
diesel_mapping: &proc_macro2::TokenStream,
enum_ty: &Ident,
variants_rs: &[proc_macro2::TokenStream],
variants_db: &[LitByteStr],
) -> proc_macro2::TokenStream {
let db_schema = db_schema.into_iter();
quote! {
use super::*;
use diesel::Queryable;
use diesel::backend::{self, Backend};
use diesel::expression::AsExpression;
use diesel::expression::bound::Bound;
use diesel::row::Row;
use diesel::sql_types::*;
use diesel::serialize::{self, ToSql, IsNull, Output};
use diesel::deserialize::{self, FromSql};
use diesel::query_builder::QueryId;
use std::io::Write;

#[derive(SqlType, Clone)]
#[postgres(type_name = #db_type, #(type_schema = #db_schema)*)]
#[mysql_type = "Enum"]
#[sqlite_type = "Text"]
pub struct #diesel_mapping;
impl QueryId for #diesel_mapping {
type QueryId = #diesel_mapping;
const HAS_STATIC_QUERY_ID: bool = true;
Expand Down Expand Up @@ -283,7 +314,7 @@ fn generate_common_impl(
}

fn generate_postgres_impl(
diesel_mapping: &Ident,
diesel_mapping: &proc_macro2::TokenStream,
enum_ty: &Ident,
variants_rs: &[proc_macro2::TokenStream],
variants_db: &[LitByteStr],
Expand All @@ -293,6 +324,12 @@ fn generate_postgres_impl(
use super::*;
use diesel::pg::{Pg, PgValue};

impl Clone for #diesel_mapping {
fn clone(&self) -> Self {
#diesel_mapping
}
}

impl FromSql<#diesel_mapping, Pg> for #enum_ty {
fn from_sql(raw: PgValue) -> deserialize::Result<Self> {
match raw.as_bytes() {
Expand Down
Loading

0 comments on commit a339d70

Please sign in to comment.