From df0cd1b91c54d41cd40e7f493d67a8ba3eee0471 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Tue, 22 Jun 2021 15:26:29 +0200 Subject: [PATCH] Correctly generate custom type definitions for types in a custom schema --- diesel_cli/src/cli.rs | 2 +- .../infer_schema_internals/data_structures.rs | 18 +++++-- .../src/infer_schema_internals/inference.rs | 10 ++-- .../information_schema.rs | 49 ++++++++++++++++--- .../src/infer_schema_internals/mysql.rs | 1 + diesel_cli/src/infer_schema_internals/pg.rs | 8 +++ .../src/infer_schema_internals/sqlite.rs | 1 + diesel_cli/src/print_schema.rs | 47 +++++++++++++----- diesel_cli/tests/print_schema.rs | 9 ++++ .../diesel.toml | 4 ++ .../postgres/expected.rs | 38 ++++++++++++++ .../postgres/schema.sql | 5 ++ 12 files changed, 165 insertions(+), 27 deletions(-) create mode 100644 diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/diesel.toml create mode 100644 diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/postgres/expected.rs create mode 100644 diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/postgres/schema.sql diff --git a/diesel_cli/src/cli.rs b/diesel_cli/src/cli.rs index e4ca5d4b8ded..dd7174c7b8ea 100644 --- a/diesel_cli/src/cli.rs +++ b/diesel_cli/src/cli.rs @@ -218,7 +218,7 @@ pub fn build_cli() -> App<'static, 'static> { .multiple(true) .number_of_values(1) .help("A list of types to import for every table, separated by commas."), - ) + ) .arg( Arg::with_name("generate-custom-type-definitions") .long("no-generate-missing-sql-type-definitions") diff --git a/diesel_cli/src/infer_schema_internals/data_structures.rs b/diesel_cli/src/infer_schema_internals/data_structures.rs index 54f4d8bf1a73..735f141edb23 100644 --- a/diesel_cli/src/infer_schema_internals/data_structures.rs +++ b/diesel_cli/src/infer_schema_internals/data_structures.rs @@ -12,11 +12,13 @@ use super::table_data::TableName; pub struct ColumnInformation { pub column_name: String, pub type_name: String, + pub type_schema: Option, pub nullable: bool, } #[derive(Debug, PartialEq, Clone)] pub struct ColumnType { + pub schema: Option, pub rust_name: String, pub sql_name: String, pub is_array: bool, @@ -60,7 +62,12 @@ pub struct ColumnDefinition { } impl ColumnInformation { - pub fn new(column_name: T, type_name: U, nullable: bool) -> Self + pub fn new( + column_name: T, + type_name: U, + type_schema: Option, + nullable: bool, + ) -> Self where T: Into, U: Into, @@ -68,6 +75,7 @@ impl ColumnInformation { ColumnInformation { column_name: column_name.into(), type_name: type_name.into(), + type_schema, nullable, } } @@ -77,12 +85,12 @@ impl ColumnInformation { impl Queryable for ColumnInformation where DB: Backend + UsesInformationSchema, - (String, String, String): FromStaticSqlRow, + (String, String, Option, String): FromStaticSqlRow, { - type Row = (String, String, String); + type Row = (String, String, Option, String); fn build(row: Self::Row) -> deserialize::Result { - Ok(ColumnInformation::new(row.0, row.1, row.2 == "YES")) + Ok(ColumnInformation::new(row.0, row.1, row.2, row.3 == "YES")) } } @@ -94,7 +102,7 @@ where type Row = (i32, String, String, bool, Option, bool); fn build(row: Self::Row) -> deserialize::Result { - Ok(ColumnInformation::new(row.1, row.2, !row.3)) + Ok(ColumnInformation::new(row.1, row.2, None, !row.3)) } } diff --git a/diesel_cli/src/infer_schema_internals/inference.rs b/diesel_cli/src/infer_schema_internals/inference.rs index ccf8b2de5e44..cde7762dc48b 100644 --- a/diesel_cli/src/infer_schema_internals/inference.rs +++ b/diesel_cli/src/infer_schema_internals/inference.rs @@ -105,13 +105,17 @@ fn get_column_information( fn determine_column_type( attr: &ColumnInformation, - conn: &InferConnection, + conn: &mut InferConnection, ) -> Result> { match *conn { #[cfg(feature = "sqlite")] InferConnection::Sqlite(_) => super::sqlite::determine_column_type(attr), #[cfg(feature = "postgres")] - InferConnection::Pg(_) => super::pg::determine_column_type(attr), + InferConnection::Pg(ref mut conn) => { + use crate::infer_schema_internals::information_schema::UsesInformationSchema; + + super::pg::determine_column_type(attr, diesel::pg::Pg::default_schema(conn)?) + } #[cfg(feature = "mysql")] InferConnection::Mysql(_) => super::mysql::determine_column_type(attr), } @@ -206,7 +210,7 @@ pub fn load_table_data( let column_data = get_column_information(&mut connection, &name, column_sorting)? .into_iter() .map(|c| { - let ty = determine_column_type(&c, &connection)?; + let ty = determine_column_type(&c, &mut connection)?; let rust_name = rust_name_for_sql_name(&c.column_name); Ok(ColumnDefinition { diff --git a/diesel_cli/src/infer_schema_internals/information_schema.rs b/diesel_cli/src/infer_schema_internals/information_schema.rs index 07842db356d4..698f604eb285 100644 --- a/diesel_cli/src/infer_schema_internals/information_schema.rs +++ b/diesel_cli/src/infer_schema_internals/information_schema.rs @@ -4,7 +4,7 @@ use std::error::Error; use diesel::backend::Backend; use diesel::deserialize::{FromSql, FromSqlRow}; use diesel::dsl::*; -use diesel::expression::{is_aggregate, QueryMetadata, ValidGrouping}; +use diesel::expression::{is_aggregate, MixedAggregates, QueryMetadata, ValidGrouping}; #[cfg(feature = "mysql")] use diesel::mysql::Mysql; #[cfg(feature = "postgres")] @@ -24,7 +24,16 @@ pub trait UsesInformationSchema: Backend { + QueryId + QueryFragment; + type TypeSchema: SelectableExpression< + self::information_schema::columns::table, + SqlType = sql_types::Nullable, + > + ValidGrouping<()> + + QueryId + + QueryFragment; + fn type_column() -> Self::TypeColumn; + fn type_schema() -> Self::TypeSchema; + fn default_schema(conn: &mut C) -> QueryResult where C: Connection, @@ -34,11 +43,16 @@ pub trait UsesInformationSchema: Backend { #[cfg(feature = "postgres")] impl UsesInformationSchema for Pg { type TypeColumn = self::information_schema::columns::udt_name; + type TypeSchema = diesel::dsl::Nullable; fn type_column() -> Self::TypeColumn { self::information_schema::columns::udt_name } + fn type_schema() -> Self::TypeSchema { + self::information_schema::columns::udt_schema.nullable() + } + fn default_schema(_conn: &mut C) -> QueryResult { Ok("public".into()) } @@ -50,11 +64,16 @@ sql_function!(fn database() -> VarChar); #[cfg(feature = "mysql")] impl UsesInformationSchema for Mysql { type TypeColumn = self::information_schema::columns::column_type; + type TypeSchema = diesel::dsl::AsExprOf, sql_types::Nullable>; fn type_column() -> Self::TypeColumn { self::information_schema::columns::column_type } + fn type_schema() -> Self::TypeSchema { + None.into_sql() + } + fn default_schema(conn: &mut C) -> QueryResult where C: Connection, @@ -85,6 +104,7 @@ mod information_schema { __is_nullable -> VarChar, ordinal_position -> BigInt, udt_name -> VarChar, + udt_schema -> VarChar, column_type -> VarChar, } } @@ -135,11 +155,17 @@ where SqlTypeOf<( columns::column_name, ::TypeColumn, + ::TypeSchema, columns::__is_nullable, )>, Conn::Backend, >, + is_aggregate::No: MixedAggregates< + <::TypeSchema as ValidGrouping<()>>::IsAggregate, + Output = is_aggregate::No, + >, String: FromSql, + Option: FromSql, Conn::Backend>, Order< Filter< Filter< @@ -148,6 +174,7 @@ where ( columns::column_name, ::TypeColumn, + ::TypeSchema, columns::__is_nullable, ), >, @@ -165,6 +192,7 @@ where ( columns::column_name, ::TypeColumn, + ::TypeSchema, columns::__is_nullable, ), >, @@ -174,7 +202,12 @@ where >, columns::column_name, >: QueryFragment, - Conn::Backend: QueryMetadata<(sql_types::Text, sql_types::Text, sql_types::Text)>, + Conn::Backend: QueryMetadata<( + sql_types::Text, + sql_types::Text, + sql_types::Nullable, + sql_types::Text, + )>, { use self::information_schema::columns::dsl::*; @@ -184,8 +217,9 @@ where }; let type_column = Conn::Backend::type_column(); + let type_schema = Conn::Backend::type_schema(); let query = columns - .select((column_name, type_column, __is_nullable)) + .select((column_name, type_column, type_schema, __is_nullable)) .filter(table_name.eq(&table.sql_name)) .filter(table_schema.eq(schema_name)); match column_sorting { @@ -512,10 +546,11 @@ mod tests { let table_1 = TableName::new("table_1", "test_schema"); let table_2 = TableName::new("table_2", "test_schema"); - let id = ColumnInformation::new("id", "int4", false); - let text_col = ColumnInformation::new("text_col", "varchar", true); - let not_null = ColumnInformation::new("not_null", "text", false); - let array_col = ColumnInformation::new("array_col", "_varchar", false); + let pg_catalog = Some(String::from("pg_catalog")); + let id = ColumnInformation::new("id", "int4", pg_catalog.clone(), false); + let text_col = ColumnInformation::new("text_col", "varchar", pg_catalog.clone(), true); + let not_null = ColumnInformation::new("not_null", "text", pg_catalog.clone(), false); + let array_col = ColumnInformation::new("array_col", "_varchar", pg_catalog.clone(), false); assert_eq!( Ok(vec![id, text_col, not_null]), get_table_data(&mut connection, &table_1, &ColumnSorting::OrdinalPosition) diff --git a/diesel_cli/src/infer_schema_internals/mysql.rs b/diesel_cli/src/infer_schema_internals/mysql.rs index 8418c8403af1..5b4172ad8ccf 100644 --- a/diesel_cli/src/infer_schema_internals/mysql.rs +++ b/diesel_cli/src/infer_schema_internals/mysql.rs @@ -92,6 +92,7 @@ pub fn determine_column_type( let unsigned = determine_unsigned(&attr.type_name); Ok(ColumnType { + schema: None, sql_name: tpe.trim().to_lowercase(), rust_name: tpe.trim().to_camel_case(), is_array: false, diff --git a/diesel_cli/src/infer_schema_internals/pg.rs b/diesel_cli/src/infer_schema_internals/pg.rs index 03a2bdd62c38..cc4bf346e242 100644 --- a/diesel_cli/src/infer_schema_internals/pg.rs +++ b/diesel_cli/src/infer_schema_internals/pg.rs @@ -5,6 +5,7 @@ use std::io::{stderr, Write}; pub fn determine_column_type( attr: &ColumnInformation, + default_schema: String, ) -> Result> { let is_array = attr.type_name.starts_with('_'); let tpe = if is_array { @@ -30,6 +31,13 @@ pub fn determine_column_type( } Ok(ColumnType { + schema: attr.type_schema.as_ref().and_then(|s| { + if s == &default_schema { + None + } else { + Some(s.clone()) + } + }), sql_name: tpe.to_lowercase(), rust_name: tpe.to_camel_case(), is_array, diff --git a/diesel_cli/src/infer_schema_internals/sqlite.rs b/diesel_cli/src/infer_schema_internals/sqlite.rs index 4e4c41b38bb7..1f2009d71a1d 100644 --- a/diesel_cli/src/infer_schema_internals/sqlite.rs +++ b/diesel_cli/src/infer_schema_internals/sqlite.rs @@ -172,6 +172,7 @@ pub fn determine_column_type( }; Ok(ColumnType { + schema: None, rust_name: path.clone(), sql_name: path, is_array: false, diff --git a/diesel_cli/src/print_schema.rs b/diesel_cli/src/print_schema.rs index 554ce6444d66..ac0d39c40708 100644 --- a/diesel_cli/src/print_schema.rs +++ b/diesel_cli/src/print_schema.rs @@ -76,6 +76,9 @@ fn common_diesel_types(types: &mut HashSet<&str>) { types.insert("Double"); types.insert("Float"); types.insert("Numeric"); + types.insert("Timestamp"); + types.insert("Date"); + types.insert("Time"); // hidden type defs types.insert("Float4"); @@ -104,7 +107,6 @@ fn common_diesel_types(types: &mut HashSet<&str>) { fn pg_diesel_types() -> HashSet<&'static str> { let mut types = HashSet::new(); types.insert("Cidr"); - types.insert("Date"); types.insert("Inet"); types.insert("Jsonb"); types.insert("MacAddr"); @@ -114,7 +116,6 @@ fn pg_diesel_types() -> HashSet<&'static str> { types.insert("Timestamptz"); types.insert("Uuid"); types.insert("Json"); - types.insert("Timestamp"); types.insert("Record"); types.insert("Interval"); @@ -211,11 +212,10 @@ pub fn output_schema( import_types: config.import_types(), }; - write!(out, "{}", definitions.custom_type_defs)?; - if let Some(schema_name) = config.schema_name() { write!(out, "{}", ModuleDefinition(schema_name, definitions))?; } else { + write!(out, "{}", definitions.custom_type_defs)?; write!(out, "{}", definitions)?; } @@ -259,23 +259,48 @@ impl Display for CustomTypeList { let mut out = PadAdapter::new(f); writeln!(out, "pub mod sql_types {{")?; if self.with_docs { - writeln!(out, "/// The `{}` SQL type", t.sql_name)?; + if let Some(ref schema) = t.schema { + writeln!(out, "/// The `{}.{}` SQL type", schema, t.sql_name)?; + } else { + writeln!(out, "/// The `{}` SQL type", t.sql_name)?; + } writeln!(out, "///")?; writeln!(out, "/// (Automatically generated by Diesel.)")?; } writeln!(out, "#[derive(diesel::SqlType)]")?; - writeln!(out, "#[postgres(type_name = \"{}\")]", t.sql_name)?; + if let Some(ref schema) = t.schema { + writeln!( + out, + "#[postgres(type_name = \"{}\", type_schema = \"{}\")]", + t.sql_name, schema + )?; + } else { + writeln!(out, "#[postgres(type_name = \"{}\")]", t.sql_name)?; + } writeln!(out, "pub struct {};", t.rust_name)?; - writeln!(out)?; - writeln!(f, "}}")?; + writeln!(f, "}}\n")?; } #[cfg(feature = "sqlite")] Backend::Sqlite => { - let _ = (&f, self.with_docs, t); + let _ = (&f, self.with_docs); + eprintln!("Encountered unknown type for Sqlite: {}", t.sql_name); + unreachable!( + "Diesel only support a closed set of types for Sqlite. \ + If you ever see this error message please open an \ + issue at https://github.com/diesel-rs/diesel containing \ + a dump of your schema definition." + ) } #[cfg(feature = "mysql")] Backend::Mysql => { - let _ = (&f, self.with_docs, t); + let _ = (&f, self.with_docs); + eprintln!("Encountered unknown type for Mysql: {}", t.sql_name); + unreachable!( + "Mysql only supports a closed set of types. + If you ever see this error message please open an \ + issue at https://github.com/diesel-rs/diesel containing \ + a dump of your schema definition." + ) } } } @@ -289,8 +314,8 @@ impl<'a> Display for ModuleDefinition<'a> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { { let mut out = PadAdapter::new(f); - write!(out, "{}", self.1.custom_type_defs)?; writeln!(out, "pub mod {} {{", self.0)?; + write!(out, "{}", self.1.custom_type_defs)?; write!(out, "{}", self.1)?; } writeln!(f, "}}")?; diff --git a/diesel_cli/tests/print_schema.rs b/diesel_cli/tests/print_schema.rs index bb5dc27ceae5..dc5dacf285a3 100644 --- a/diesel_cli/tests/print_schema.rs +++ b/diesel_cli/tests/print_schema.rs @@ -189,6 +189,15 @@ fn print_schema_default_is_to_generate_custom_types() { ) } +#[test] +#[cfg(feature = "postgres")] +fn print_schema_specifying_schema_name_with_custom_type() { + test_print_schema( + "print_schema_specifying_schema_name_with_custom_type", + vec!["--with-docs", "--schema", "custom_schema"], + ) +} + #[cfg(feature = "sqlite")] const BACKEND: &str = "sqlite"; #[cfg(feature = "postgres")] diff --git a/diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/diesel.toml b/diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/diesel.toml new file mode 100644 index 000000000000..c976c2ccdd93 --- /dev/null +++ b/diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/diesel.toml @@ -0,0 +1,4 @@ +[print_schema] +file = "src/schema.rs" +with_docs = true +schema = "custom_schema" diff --git a/diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/postgres/expected.rs b/diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/postgres/expected.rs new file mode 100644 index 000000000000..cbab01b9de03 --- /dev/null +++ b/diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/postgres/expected.rs @@ -0,0 +1,38 @@ +// @generated automatically by Diesel CLI. + +pub mod custom_schema { + /// A module containing custom SQL type definitions + /// + /// (Automatically generated by Diesel.) + pub mod sql_types { + /// The `custom_schema.my_enum` SQL type + /// + /// (Automatically generated by Diesel.) + #[derive(diesel::SqlType)] + #[postgres(type_name = "my_enum", type_schema = "custom_schema")] + pub struct MyEnum; + } + + diesel::table! { + use diesel::sql_types::*; + use super::sql_types::MyEnum; + + /// Representation of the `custom_schema.in_schema` table. + /// + /// (Automatically generated by Diesel.) + custom_schema.in_schema (id) { + /// The `id` column of the `custom_schema.in_schema` table. + /// + /// Its SQL type is `Int4`. + /// + /// (Automatically generated by Diesel.) + id -> Int4, + /// The `custom_type` column of the `custom_schema.in_schema` table. + /// + /// Its SQL type is `Nullable`. + /// + /// (Automatically generated by Diesel.) + custom_type -> Nullable, + } + } +} diff --git a/diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/postgres/schema.sql b/diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/postgres/schema.sql new file mode 100644 index 000000000000..21337d277288 --- /dev/null +++ b/diesel_cli/tests/print_schema/print_schema_specifying_schema_name_with_custom_type/postgres/schema.sql @@ -0,0 +1,5 @@ +CREATE SCHEMA custom_schema; +CREATE TABLE in_public (id SERIAL PRIMARY KEY); +CREATE TYPE my_public_enum AS ENUM('A', 'B'); +CREATE TYPE custom_schema.my_enum AS ENUM ('A', 'B'); +CREATE TABLE custom_schema.in_schema (id SERIAL PRIMARY KEY, custom_type custom_schema.MY_ENUM);