Skip to content

Commit

Permalink
Correctly generate custom type definitions for types in a custom schema
Browse files Browse the repository at this point in the history
  • Loading branch information
weiznich committed Jun 22, 2021
1 parent ce20188 commit df0cd1b
Show file tree
Hide file tree
Showing 12 changed files with 165 additions and 27 deletions.
2 changes: 1 addition & 1 deletion diesel_cli/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ pub fn build_cli() -> App<'static, 'static> {
.multiple(true)
.number_of_values(1)
.help("A list of types to import for every table, separated by commas."),
)
)
.arg(
Arg::with_name("generate-custom-type-definitions")
.long("no-generate-missing-sql-type-definitions")
Expand Down
18 changes: 13 additions & 5 deletions diesel_cli/src/infer_schema_internals/data_structures.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ use super::table_data::TableName;
pub struct ColumnInformation {
pub column_name: String,
pub type_name: String,
pub type_schema: Option<String>,
pub nullable: bool,
}

#[derive(Debug, PartialEq, Clone)]
pub struct ColumnType {
pub schema: Option<String>,
pub rust_name: String,
pub sql_name: String,
pub is_array: bool,
Expand Down Expand Up @@ -60,14 +62,20 @@ pub struct ColumnDefinition {
}

impl ColumnInformation {
pub fn new<T, U>(column_name: T, type_name: U, nullable: bool) -> Self
pub fn new<T, U>(
column_name: T,
type_name: U,
type_schema: Option<String>,
nullable: bool,
) -> Self
where
T: Into<String>,
U: Into<String>,
{
ColumnInformation {
column_name: column_name.into(),
type_name: type_name.into(),
type_schema,
nullable,
}
}
Expand All @@ -77,12 +85,12 @@ impl ColumnInformation {
impl<ST, DB> Queryable<ST, DB> for ColumnInformation
where
DB: Backend + UsesInformationSchema,
(String, String, String): FromStaticSqlRow<ST, DB>,
(String, String, Option<String>, String): FromStaticSqlRow<ST, DB>,
{
type Row = (String, String, String);
type Row = (String, String, Option<String>, String);

fn build(row: Self::Row) -> deserialize::Result<Self> {
Ok(ColumnInformation::new(row.0, row.1, row.2 == "YES"))
Ok(ColumnInformation::new(row.0, row.1, row.2, row.3 == "YES"))
}
}

Expand All @@ -94,7 +102,7 @@ where
type Row = (i32, String, String, bool, Option<String>, bool);

fn build(row: Self::Row) -> deserialize::Result<Self> {
Ok(ColumnInformation::new(row.1, row.2, !row.3))
Ok(ColumnInformation::new(row.1, row.2, None, !row.3))
}
}

Expand Down
10 changes: 7 additions & 3 deletions diesel_cli/src/infer_schema_internals/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,17 @@ fn get_column_information(

fn determine_column_type(
attr: &ColumnInformation,
conn: &InferConnection,
conn: &mut InferConnection,
) -> Result<ColumnType, Box<dyn Error + Send + Sync + 'static>> {
match *conn {
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(_) => super::sqlite::determine_column_type(attr),
#[cfg(feature = "postgres")]
InferConnection::Pg(_) => super::pg::determine_column_type(attr),
InferConnection::Pg(ref mut conn) => {
use crate::infer_schema_internals::information_schema::UsesInformationSchema;

super::pg::determine_column_type(attr, diesel::pg::Pg::default_schema(conn)?)
}
#[cfg(feature = "mysql")]
InferConnection::Mysql(_) => super::mysql::determine_column_type(attr),
}
Expand Down Expand Up @@ -206,7 +210,7 @@ pub fn load_table_data(
let column_data = get_column_information(&mut connection, &name, column_sorting)?
.into_iter()
.map(|c| {
let ty = determine_column_type(&c, &connection)?;
let ty = determine_column_type(&c, &mut connection)?;
let rust_name = rust_name_for_sql_name(&c.column_name);

Ok(ColumnDefinition {
Expand Down
49 changes: 42 additions & 7 deletions diesel_cli/src/infer_schema_internals/information_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::error::Error;
use diesel::backend::Backend;
use diesel::deserialize::{FromSql, FromSqlRow};
use diesel::dsl::*;
use diesel::expression::{is_aggregate, QueryMetadata, ValidGrouping};
use diesel::expression::{is_aggregate, MixedAggregates, QueryMetadata, ValidGrouping};
#[cfg(feature = "mysql")]
use diesel::mysql::Mysql;
#[cfg(feature = "postgres")]
Expand All @@ -24,7 +24,16 @@ pub trait UsesInformationSchema: Backend {
+ QueryId
+ QueryFragment<Self>;

type TypeSchema: SelectableExpression<
self::information_schema::columns::table,
SqlType = sql_types::Nullable<sql_types::Text>,
> + ValidGrouping<()>
+ QueryId
+ QueryFragment<Self>;

fn type_column() -> Self::TypeColumn;
fn type_schema() -> Self::TypeSchema;

fn default_schema<C>(conn: &mut C) -> QueryResult<String>
where
C: Connection<Backend = Self>,
Expand All @@ -34,11 +43,16 @@ pub trait UsesInformationSchema: Backend {
#[cfg(feature = "postgres")]
impl UsesInformationSchema for Pg {
type TypeColumn = self::information_schema::columns::udt_name;
type TypeSchema = diesel::dsl::Nullable<self::information_schema::columns::udt_schema>;

fn type_column() -> Self::TypeColumn {
self::information_schema::columns::udt_name
}

fn type_schema() -> Self::TypeSchema {
self::information_schema::columns::udt_schema.nullable()
}

fn default_schema<C>(_conn: &mut C) -> QueryResult<String> {
Ok("public".into())
}
Expand All @@ -50,11 +64,16 @@ sql_function!(fn database() -> VarChar);
#[cfg(feature = "mysql")]
impl UsesInformationSchema for Mysql {
type TypeColumn = self::information_schema::columns::column_type;
type TypeSchema = diesel::dsl::AsExprOf<Option<String>, sql_types::Nullable<sql_types::Text>>;

fn type_column() -> Self::TypeColumn {
self::information_schema::columns::column_type
}

fn type_schema() -> Self::TypeSchema {
None.into_sql()
}

fn default_schema<C>(conn: &mut C) -> QueryResult<String>
where
C: Connection<Backend = Self>,
Expand Down Expand Up @@ -85,6 +104,7 @@ mod information_schema {
__is_nullable -> VarChar,
ordinal_position -> BigInt,
udt_name -> VarChar,
udt_schema -> VarChar,
column_type -> VarChar,
}
}
Expand Down Expand Up @@ -135,11 +155,17 @@ where
SqlTypeOf<(
columns::column_name,
<Conn::Backend as UsesInformationSchema>::TypeColumn,
<Conn::Backend as UsesInformationSchema>::TypeSchema,
columns::__is_nullable,
)>,
Conn::Backend,
>,
is_aggregate::No: MixedAggregates<
<<Conn::Backend as UsesInformationSchema>::TypeSchema as ValidGrouping<()>>::IsAggregate,
Output = is_aggregate::No,
>,
String: FromSql<sql_types::Text, Conn::Backend>,
Option<String>: FromSql<sql_types::Nullable<sql_types::Text>, Conn::Backend>,
Order<
Filter<
Filter<
Expand All @@ -148,6 +174,7 @@ where
(
columns::column_name,
<Conn::Backend as UsesInformationSchema>::TypeColumn,
<Conn::Backend as UsesInformationSchema>::TypeSchema,
columns::__is_nullable,
),
>,
Expand All @@ -165,6 +192,7 @@ where
(
columns::column_name,
<Conn::Backend as UsesInformationSchema>::TypeColumn,
<Conn::Backend as UsesInformationSchema>::TypeSchema,
columns::__is_nullable,
),
>,
Expand All @@ -174,7 +202,12 @@ where
>,
columns::column_name,
>: QueryFragment<Conn::Backend>,
Conn::Backend: QueryMetadata<(sql_types::Text, sql_types::Text, sql_types::Text)>,
Conn::Backend: QueryMetadata<(
sql_types::Text,
sql_types::Text,
sql_types::Nullable<sql_types::Text>,
sql_types::Text,
)>,
{
use self::information_schema::columns::dsl::*;

Expand All @@ -184,8 +217,9 @@ where
};

let type_column = Conn::Backend::type_column();
let type_schema = Conn::Backend::type_schema();
let query = columns
.select((column_name, type_column, __is_nullable))
.select((column_name, type_column, type_schema, __is_nullable))
.filter(table_name.eq(&table.sql_name))
.filter(table_schema.eq(schema_name));
match column_sorting {
Expand Down Expand Up @@ -512,10 +546,11 @@ mod tests {

let table_1 = TableName::new("table_1", "test_schema");
let table_2 = TableName::new("table_2", "test_schema");
let id = ColumnInformation::new("id", "int4", false);
let text_col = ColumnInformation::new("text_col", "varchar", true);
let not_null = ColumnInformation::new("not_null", "text", false);
let array_col = ColumnInformation::new("array_col", "_varchar", false);
let pg_catalog = Some(String::from("pg_catalog"));
let id = ColumnInformation::new("id", "int4", pg_catalog.clone(), false);
let text_col = ColumnInformation::new("text_col", "varchar", pg_catalog.clone(), true);
let not_null = ColumnInformation::new("not_null", "text", pg_catalog.clone(), false);
let array_col = ColumnInformation::new("array_col", "_varchar", pg_catalog.clone(), false);
assert_eq!(
Ok(vec![id, text_col, not_null]),
get_table_data(&mut connection, &table_1, &ColumnSorting::OrdinalPosition)
Expand Down
1 change: 1 addition & 0 deletions diesel_cli/src/infer_schema_internals/mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub fn determine_column_type(
let unsigned = determine_unsigned(&attr.type_name);

Ok(ColumnType {
schema: None,
sql_name: tpe.trim().to_lowercase(),
rust_name: tpe.trim().to_camel_case(),
is_array: false,
Expand Down
8 changes: 8 additions & 0 deletions diesel_cli/src/infer_schema_internals/pg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::io::{stderr, Write};

pub fn determine_column_type(
attr: &ColumnInformation,
default_schema: String,
) -> Result<ColumnType, Box<dyn Error + Send + Sync + 'static>> {
let is_array = attr.type_name.starts_with('_');
let tpe = if is_array {
Expand All @@ -30,6 +31,13 @@ pub fn determine_column_type(
}

Ok(ColumnType {
schema: attr.type_schema.as_ref().and_then(|s| {
if s == &default_schema {
None
} else {
Some(s.clone())
}
}),
sql_name: tpe.to_lowercase(),
rust_name: tpe.to_camel_case(),
is_array,
Expand Down
1 change: 1 addition & 0 deletions diesel_cli/src/infer_schema_internals/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ pub fn determine_column_type(
};

Ok(ColumnType {
schema: None,
rust_name: path.clone(),
sql_name: path,
is_array: false,
Expand Down
47 changes: 36 additions & 11 deletions diesel_cli/src/print_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ fn common_diesel_types(types: &mut HashSet<&str>) {
types.insert("Double");
types.insert("Float");
types.insert("Numeric");
types.insert("Timestamp");
types.insert("Date");
types.insert("Time");

// hidden type defs
types.insert("Float4");
Expand Down Expand Up @@ -104,7 +107,6 @@ fn common_diesel_types(types: &mut HashSet<&str>) {
fn pg_diesel_types() -> HashSet<&'static str> {
let mut types = HashSet::new();
types.insert("Cidr");
types.insert("Date");
types.insert("Inet");
types.insert("Jsonb");
types.insert("MacAddr");
Expand All @@ -114,7 +116,6 @@ fn pg_diesel_types() -> HashSet<&'static str> {
types.insert("Timestamptz");
types.insert("Uuid");
types.insert("Json");
types.insert("Timestamp");
types.insert("Record");
types.insert("Interval");

Expand Down Expand Up @@ -211,11 +212,10 @@ pub fn output_schema(
import_types: config.import_types(),
};

write!(out, "{}", definitions.custom_type_defs)?;

if let Some(schema_name) = config.schema_name() {
write!(out, "{}", ModuleDefinition(schema_name, definitions))?;
} else {
write!(out, "{}", definitions.custom_type_defs)?;
write!(out, "{}", definitions)?;
}

Expand Down Expand Up @@ -259,23 +259,48 @@ impl Display for CustomTypeList {
let mut out = PadAdapter::new(f);
writeln!(out, "pub mod sql_types {{")?;
if self.with_docs {
writeln!(out, "/// The `{}` SQL type", t.sql_name)?;
if let Some(ref schema) = t.schema {
writeln!(out, "/// The `{}.{}` SQL type", schema, t.sql_name)?;
} else {
writeln!(out, "/// The `{}` SQL type", t.sql_name)?;
}
writeln!(out, "///")?;
writeln!(out, "/// (Automatically generated by Diesel.)")?;
}
writeln!(out, "#[derive(diesel::SqlType)]")?;
writeln!(out, "#[postgres(type_name = \"{}\")]", t.sql_name)?;
if let Some(ref schema) = t.schema {
writeln!(
out,
"#[postgres(type_name = \"{}\", type_schema = \"{}\")]",
t.sql_name, schema
)?;
} else {
writeln!(out, "#[postgres(type_name = \"{}\")]", t.sql_name)?;
}
writeln!(out, "pub struct {};", t.rust_name)?;
writeln!(out)?;
writeln!(f, "}}")?;
writeln!(f, "}}\n")?;
}
#[cfg(feature = "sqlite")]
Backend::Sqlite => {
let _ = (&f, self.with_docs, t);
let _ = (&f, self.with_docs);
eprintln!("Encountered unknown type for Sqlite: {}", t.sql_name);
unreachable!(
"Diesel only support a closed set of types for Sqlite. \
If you ever see this error message please open an \
issue at https://github.com/diesel-rs/diesel containing \
a dump of your schema definition."
)
}
#[cfg(feature = "mysql")]
Backend::Mysql => {
let _ = (&f, self.with_docs, t);
let _ = (&f, self.with_docs);
eprintln!("Encountered unknown type for Mysql: {}", t.sql_name);
unreachable!(
"Mysql only supports a closed set of types.
If you ever see this error message please open an \
issue at https://github.com/diesel-rs/diesel containing \
a dump of your schema definition."
)
}
}
}
Expand All @@ -289,8 +314,8 @@ impl<'a> Display for ModuleDefinition<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
{
let mut out = PadAdapter::new(f);
write!(out, "{}", self.1.custom_type_defs)?;
writeln!(out, "pub mod {} {{", self.0)?;
write!(out, "{}", self.1.custom_type_defs)?;
write!(out, "{}", self.1)?;
}
writeln!(f, "}}")?;
Expand Down
9 changes: 9 additions & 0 deletions diesel_cli/tests/print_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ fn print_schema_default_is_to_generate_custom_types() {
)
}

#[test]
#[cfg(feature = "postgres")]
fn print_schema_specifying_schema_name_with_custom_type() {
test_print_schema(
"print_schema_specifying_schema_name_with_custom_type",
vec!["--with-docs", "--schema", "custom_schema"],
)
}

#[cfg(feature = "sqlite")]
const BACKEND: &str = "sqlite";
#[cfg(feature = "postgres")]
Expand Down
Loading

0 comments on commit df0cd1b

Please sign in to comment.