diff --git a/README.md b/README.md index 5253da3..a401334 100644 --- a/README.md +++ b/README.md @@ -30,30 +30,30 @@ without need for dynamic linking of C libraries. ## Support matrix -| | SQLite | DuckDB | PostgreSQL | MySQL | Microsoft SQL Server | -| --- | --- | --- | --- | --- | --- | -| feature | `src_sqlite` | `src_duckdb` | `src_postgres` | `src_mysql` | `src_tiberius` | -| dependency | [rusqlite](https://crates.io/crates/rusqlite) | [duckdb](https://crates.io/crates/duckdb) | [postgres](https://crates.io/crates/postgres) | [mysql](https://crates.io/crates/mysql) | [tiberius](https://crates.io/crates/tiberius) | -| query | x | x | x | x | x | -| query params | | | x | | | -| schema get | x | x | x | | | -| schema edit | x | x | x | | | -| append | x | x | x | | | -| roundtrip: null & bool | x | x | x | | | -| roundtrip: int | x | x | x | | | -| roundtrip: uint | x | x | x | | | -| roundtrip: float | x | x | x | | | -| roundtrip: decimal | x | | x | | | -| roundtrip: timestamp | x | x | x | | | -| roundtrip: date | x | | x | | | -| roundtrip: time | x | | x | | | -| roundtrip: duration | x | | x | | | -| roundtrip: interval | | | | | | -| roundtrip: utf8 | x | x | x | | | -| roundtrip: binary | x | x | x | | | -| roundtrip: empty | | x | x | | | -| containers | | | | | | -| binary fallback | x | | x | | | +| | SQLite | DuckDB | PostgreSQL | MySQL | Microsoft SQL Server | +|------------------------|-----------------------------------------------|-------------------------------------------|-----------------------------------------------|-----------------------------------------|-----------------------------------------------| +| feature | `src_sqlite` | `src_duckdb` | `src_postgres` | `src_mysql` | `src_tiberius` | +| dependency | [rusqlite](https://crates.io/crates/rusqlite) | [duckdb](https://crates.io/crates/duckdb) | [postgres](https://crates.io/crates/postgres) | [mysql](https://crates.io/crates/mysql) | [tiberius](https://crates.io/crates/tiberius) | +| query | x | x | x | x | x | +| query params | | | x | | | +| schema get | x | x | x | | | +| schema edit | x | x | x | | | +| append | x | x | x | | | +| roundtrip: null & bool | x | x | x | x | | +| roundtrip: int | x | x | x | x | | +| roundtrip: uint | x | x | x | x | | +| roundtrip: float | x | x | x | x | | +| roundtrip: decimal | x | | x | | | +| roundtrip: timestamp | x | x | x | | | +| roundtrip: date | x | | x | | | +| roundtrip: time | x | | x | | | +| roundtrip: duration | x | | x | | | +| roundtrip: interval | | | | | | +| roundtrip: utf8 | x | x | x | x | | +| roundtrip: binary | x | x | x | x | | +| roundtrip: empty | | x | x | x | | +| containers | | | | | | +| binary fallback | x | | x | | | None of the sources are enabled by default, use features to enable them. diff --git a/connector_arrow/src/mysql/append.rs b/connector_arrow/src/mysql/append.rs new file mode 100644 index 0000000..4047926 --- /dev/null +++ b/connector_arrow/src/mysql/append.rs @@ -0,0 +1,172 @@ +use arrow::datatypes::*; +use arrow::record_batch::RecordBatch; +use itertools::{zip_eq, Itertools}; +use mysql::prelude::Queryable; +use mysql::Value; + +use crate::api::Append; +use crate::types::{FixedSizeBinaryType, NullType}; +use crate::util::escape::escaped_ident_bt; +use crate::util::transport::{self, Consume, ConsumeTy}; +use crate::util::ArrayCellRef; +use crate::{impl_consume_unsupported, ConnectorError}; + +pub struct MySQLAppender<'conn, C: Queryable> { + table: String, + client: &'conn mut C, +} + +impl<'conn, C: Queryable> MySQLAppender<'conn, C> { + pub fn new(client: &'conn mut C, table_name: &str) -> Result { + client.query_drop("START TRANSACTION;")?; + Ok(Self { + table: table_name.to_owned(), + client, + }) + } +} + +impl<'conn, C: Queryable> Append<'conn> for MySQLAppender<'conn, C> { + fn append(&mut self, batch: RecordBatch) -> Result<(), ConnectorError> { + // TODO: 30 is a guess, we need benchmarking to find the optimum value + const BATCH_SIZE: usize = 30; + + let last_batch_size = batch.num_rows() % BATCH_SIZE; + + let batch_query = insert_query(&self.table, batch.num_columns(), BATCH_SIZE); + for batch_number in 0..(batch.num_rows() / BATCH_SIZE) { + let rows_range = (batch_number * BATCH_SIZE)..((batch_number + 1) * BATCH_SIZE); + + let params: Vec = collect_args(&batch, rows_range); + self.client.exec_iter(&batch_query, params)?; + } + + if last_batch_size > 0 { + let rows_range = (batch.num_rows() - last_batch_size)..batch.num_rows(); + + let last_query = insert_query(&self.table, batch.num_columns(), last_batch_size); + let params: Vec = collect_args(&batch, rows_range); + self.client.exec_iter(&last_query, params)?; + } + + Ok(()) + } + + fn finish(self) -> Result<(), ConnectorError> { + self.client.query_drop("COMMIT;")?; + Ok(()) + } +} + +fn insert_query(table_name: &str, cols: usize, rows: usize) -> String { + let values = (0..rows) + .map(|_| { + let row = (0..cols).map(|_| "?").join(","); + format!("({row})") + }) + .join(","); + + format!( + "INSERT INTO {} VALUES {values}", + escaped_ident_bt(table_name) + ) +} + +fn collect_args(batch: &RecordBatch, rows_range: std::ops::Range) -> Vec { + let mut res = Vec::with_capacity(rows_range.len() * batch.num_columns()); + + let schema = batch.schema(); + let mut row = zip_eq(batch.columns(), schema.fields()) + .map(|(array, field)| ArrayCellRef { + array, + field, + row_number: 0, + }) + .collect_vec(); + + for row_number in rows_range { + for cell in &mut row { + cell.row_number = row_number; + transport::transport(cell.field, cell as &_, &mut res).unwrap(); + } + } + res +} + +impl Consume for Vec {} + +macro_rules! impl_consume_ty { + ($ArrTy: ty, $value_kind: expr) => { + impl_consume_ty!($ArrTy, $value_kind, std::convert::identity); + }; + + ($ArrTy: ty, $value_kind: expr, $conversion: expr) => { + impl ConsumeTy<$ArrTy> for Vec { + fn consume( + &mut self, + _ty: &DataType, + value: <$ArrTy as crate::types::ArrowType>::Native, + ) { + let value: Value = $value_kind(($conversion)(value)); + self.push(value); + } + + fn consume_null(&mut self) { + self.push(Value::NULL); + } + } + }; +} + +impl_consume_ty!(BooleanType, Value::Int, i64::from); +impl_consume_ty!(Int8Type, Value::Int, i64::from); +impl_consume_ty!(Int16Type, Value::Int, i64::from); +impl_consume_ty!(Int32Type, Value::Int, i64::from); +impl_consume_ty!(Int64Type, Value::Int); +impl_consume_ty!(UInt8Type, Value::UInt, u64::from); +impl_consume_ty!(UInt16Type, Value::UInt, u64::from); +impl_consume_ty!(UInt32Type, Value::UInt, u64::from); +impl_consume_ty!(UInt64Type, Value::UInt); +impl_consume_ty!(Float16Type, Value::Float, f32::from); +impl_consume_ty!(Float32Type, Value::Float); +impl_consume_ty!(Float64Type, Value::Double); +impl_consume_ty!(Utf8Type, Value::Bytes, String::into_bytes); +impl_consume_ty!(BinaryType, Value::Bytes); +impl_consume_ty!(LargeBinaryType, Value::Bytes); +impl_consume_ty!(FixedSizeBinaryType, Value::Bytes); + +impl ConsumeTy for Vec { + fn consume(&mut self, _ty: &DataType, _value: ()) { + self.push(Value::NULL); + } + + fn consume_null(&mut self) { + self.push(Value::NULL); + } +} + +impl_consume_unsupported!( + Vec, + ( + TimestampSecondType, + TimestampMillisecondType, + TimestampMicrosecondType, + TimestampNanosecondType, + Date32Type, + Date64Type, + Time32SecondType, + Time32MillisecondType, + Time64MicrosecondType, + Time64NanosecondType, + DurationSecondType, + DurationMillisecondType, + DurationMicrosecondType, + DurationNanosecondType, + IntervalDayTimeType, + IntervalMonthDayNanoType, + IntervalYearMonthType, + Decimal128Type, + Decimal256Type, + LargeUtf8Type, + ) +); diff --git a/connector_arrow/src/mysql/mod.rs b/connector_arrow/src/mysql/mod.rs index 1eac4da..ad8cb45 100644 --- a/connector_arrow/src/mysql/mod.rs +++ b/connector_arrow/src/mysql/mod.rs @@ -1,10 +1,12 @@ +mod append; mod query; +mod schema; mod types; use arrow::datatypes::*; use mysql::prelude::*; -use crate::api::{unimplemented, Connector}; +use crate::api::Connector; use crate::ConnectorError; pub struct MySQLConnection { @@ -24,7 +26,7 @@ impl MySQLConnection { impl Connector for MySQLConnection { type Stmt<'conn> = query::MySQLStatement<'conn, C> where Self: 'conn; - type Append<'conn> = unimplemented::Appender where Self: 'conn; + type Append<'conn> = append::MySQLAppender<'conn, C> where Self: 'conn; fn query<'a>(&'a mut self, query: &str) -> Result, ConnectorError> { let stmt = self.conn.prep(query)?; @@ -34,37 +36,82 @@ impl Connector for MySQLConnection { }) } - fn append<'a>(&'a mut self, _table_name: &str) -> Result, ConnectorError> { - Ok(unimplemented::Appender {}) + fn append<'a>(&'a mut self, table_name: &str) -> Result, ConnectorError> { + append::MySQLAppender::new(&mut self.conn, table_name) } fn type_db_into_arrow(ty: &str) -> Option { - Some(match ty { - "null" => DataType::Null, + dbg!(ty); - "tinyint" | "bool" | "boolean" => DataType::Int8, - "smallint" => DataType::Int16, - "integer" | "int" => DataType::Int32, - "bigint" => DataType::Int64, + let (ty, unsigned) = ty + .strip_suffix(" unsigned") + .map(|p| (p, true)) + .unwrap_or((ty, false)); - "tinyint unsigned" => DataType::UInt8, - "smallint unsigned" => DataType::UInt16, - "integer unsigned" | "int unsigned" => DataType::UInt32, - "bigint unsigned" => DataType::UInt64, + // strip size suffix and anything following it + let ty = if let Some(open_parent) = ty.find('(') { + &ty[0..open_parent] + } else { + ty + }; - "real" | "float4" => DataType::Float32, - "double" | "float8" => DataType::Float64, + Some(match (ty, unsigned) { + ("null", _) => DataType::Null, - "bytea" => DataType::Binary, - "bit" | "tiny_blob" | "medium_blob" | "long_blob" | "blob" => DataType::Binary, + ("tinyint" | "bool" | "boolean", false) => DataType::Int8, + ("smallint", false) => DataType::Int16, + ("integer" | "int", false) => DataType::Int32, + ("bigint", false) => DataType::Int64, - "varchar" | "var_string" | "string" => DataType::Utf8, + ("tinyint", true) => DataType::UInt8, + ("smallint", true) => DataType::UInt16, + ("integer" | "int", true) => DataType::UInt32, + ("bigint", true) => DataType::UInt64, + + ("real" | "float" | "float4", _) => DataType::Float32, + ("double" | "float8", _) => DataType::Float64, + + ("bit" | "tinyblob" | "mediumblob" | "longblob" | "blob" | "binary", _) => { + DataType::Binary + } + + ("tinytext" | "mediumtext" | "longtext" | "text" | "varchar", _) => DataType::Utf8, _ => return None, }) } - fn type_arrow_into_db(_ty: &DataType) -> Option { - None + fn type_arrow_into_db(ty: &DataType) -> Option { + Some( + match ty { + DataType::Null => "tinyint", + DataType::Boolean => "tinyint", + DataType::Int8 => "tinyint", + DataType::Int16 => "smallint", + DataType::Int32 => "integer", + DataType::Int64 => "bigint", + DataType::UInt8 => "tinyint unsigned", + DataType::UInt16 => "smallint unsigned", + DataType::UInt32 => "integer unsigned", + DataType::UInt64 => "bigint unsigned", + DataType::Float16 => "float", + DataType::Float32 => "float", + DataType::Float64 => "double", + + DataType::Binary => "longblob", + DataType::FixedSizeBinary(1) => "binary", + DataType::FixedSizeBinary(2) => "blob", + DataType::FixedSizeBinary(3) => "mediumblob", + DataType::FixedSizeBinary(4) => "longblob", + DataType::FixedSizeBinary(_) => return None, + DataType::LargeBinary => return None, + + DataType::Utf8 => "longtext", + DataType::LargeUtf8 => return None, + + _ => return None, + } + .to_string(), + ) } } diff --git a/connector_arrow/src/mysql/query.rs b/connector_arrow/src/mysql/query.rs index 6370e79..ad0e82e 100644 --- a/connector_arrow/src/mysql/query.rs +++ b/connector_arrow/src/mysql/query.rs @@ -93,6 +93,7 @@ impl<'a> util::CellReader<'a> for MySQLCellReader { cell: self.cell, }; self.cell += 1; + Some(r) } } @@ -113,7 +114,12 @@ macro_rules! impl_produce_ty { Ok(self.row.take(self.cell).unwrap()) } fn produce_opt(self) -> Result::Native>, ConnectorError> { - Ok(self.row.take(self.cell)) + let res = self.row.take_opt(self.cell).unwrap(); + match res { + Ok(v) => Ok(Some(v)), + Err(mysql::FromValueError(mysql::Value::NULL)) => Ok(None), + Err(mysql::FromValueError(v)) => Err(ConnectorError::from(mysql::Error::FromValueError(v))) + } } } )+ diff --git a/connector_arrow/src/mysql/schema.rs b/connector_arrow/src/mysql/schema.rs new file mode 100644 index 0000000..e12d295 --- /dev/null +++ b/connector_arrow/src/mysql/schema.rs @@ -0,0 +1,101 @@ +use std::sync::Arc; + +use arrow::datatypes::{DataType, Schema}; +use itertools::Itertools; +use mysql::prelude::Queryable; + +use crate::{ + api::{Connector, SchemaEdit, SchemaGet}, + mysql::MySQLConnection, + util::escape::escaped_ident_bt, + ConnectorError, TableCreateError, TableDropError, +}; + +impl SchemaGet for super::MySQLConnection { + fn table_list(&mut self) -> Result, crate::ConnectorError> { + let mut results = self.conn.exec_iter("SHOW TABLES;", ())?; + let result = results.iter().ok_or(crate::ConnectorError::NoResultSets)?; + + let table_names = result + .into_iter() + .map(|r_row| r_row.map(|row| row.get::(0).unwrap())) + .collect::, _>>()?; + + Ok(table_names) + } + + fn table_get( + &mut self, + name: &str, + ) -> Result { + let mut results = self + .conn + .exec_iter(format!("DESCRIBE {};", escaped_ident_bt(name)), ())?; + let result = results.iter().ok_or(crate::ConnectorError::NoResultSets)?; + + let fields = result + .into_iter() + .map(|r_row| { + r_row.map(|row| { + let name = row.get::(0).unwrap(); + let ty = row.get::(1).unwrap(); + let nullable = row.get::(2).unwrap() == "YES"; + + super::types::create_field(name, &ty, nullable) + }) + }) + .collect::, _>>()?; + + Ok(Arc::new(Schema::new(fields))) + } +} + +impl SchemaEdit for super::MySQLConnection { + fn table_create( + &mut self, + name: &str, + schema: arrow::datatypes::SchemaRef, + ) -> Result<(), TableCreateError> { + let column_defs = schema + .fields() + .iter() + .map(|field| { + let ty = MySQLConnection::::type_arrow_into_db(field.data_type()) + .unwrap_or_else(|| { + unimplemented!("cannot store arrow type {} in MySQL", field.data_type()); + }); + + let is_nullable = + field.is_nullable() || matches!(field.data_type(), DataType::Null); + let not_null = if is_nullable { "" } else { " NOT NULL" }; + + let name = escaped_ident_bt(field.name()); + format!("{name} {ty}{not_null}",) + }) + .join(","); + + let ddl = format!("CREATE TABLE {} ({column_defs});", escaped_ident_bt(name)); + + let res = self.conn.query_drop(&ddl); + match res { + Ok(_) => Ok(()), + Err(mysql::Error::MySqlError(e)) if e.code == 1050 => { + Err(TableCreateError::TableExists) + } + Err(e) => Err(TableCreateError::Connector(ConnectorError::MySQL(e))), + } + } + + fn table_drop(&mut self, name: &str) -> Result<(), TableDropError> { + let res = self + .conn + .query_drop(format!("DROP TABLE {}", escaped_ident_bt(name))); + match res { + Ok(_) => Ok(()), + Err(mysql::Error::MySqlError(e)) if e.code == 1051 => { + Err(TableDropError::TableNonexistent) + } + Err(e) => Err(TableDropError::Connector(ConnectorError::MySQL(e))), + } + } +} diff --git a/connector_arrow/src/mysql/types.rs b/connector_arrow/src/mysql/types.rs index 403783d..650b3be 100644 --- a/connector_arrow/src/mysql/types.rs +++ b/connector_arrow/src/mysql/types.rs @@ -10,12 +10,21 @@ use crate::ConnectorError; pub fn get_result_schema<'a, P: Protocol>( result: &mysql::ResultSet<'a, 'a, 'a, 'a, P>, ) -> Result { + dbg!("get_result_schema"); + let mut fields = Vec::new(); for column in result.columns().as_ref() { let is_unsigned = !(column.flags() & ColumnFlags::UNSIGNED_FLAG).is_empty(); let is_not_null = !(column.flags() & ColumnFlags::NOT_NULL_FLAG).is_empty(); + let is_blob = !(column.flags() & ColumnFlags::BLOB_FLAG).is_empty(); + let is_binary = !(column.flags() & ColumnFlags::BINARY_FLAG).is_empty(); + + dbg!(column.name_str()); + dbg!(is_blob); + dbg!(is_binary); - let db_ty = get_name_of_column_type(&column.column_type(), is_unsigned); + let db_ty = get_name_of_column_type(&column.column_type(), is_unsigned, is_binary); + dbg!(db_ty); fields.push(create_field( column.name_str().to_string(), db_ty, @@ -26,67 +35,75 @@ pub fn get_result_schema<'a, P: Protocol>( Ok(Arc::new(Schema::new(fields))) } -fn create_field(name: String, db_ty: &str, nullable: bool) -> Field { +pub fn create_field(name: String, db_ty: &str, nullable: bool) -> Field { let data_type = super::MySQLConnection::::type_db_into_arrow(db_ty); - let data_type = data_type.unwrap_or_else(|| todo!()); + let data_type = data_type.unwrap_or_else(|| todo!("db type: {db_ty}")); Field::new(name, data_type, nullable) } -fn get_name_of_column_type(col_ty: &ColumnType, unsigned: bool) -> &'static str { +fn get_name_of_column_type(col_ty: &ColumnType, unsigned: bool, binary: bool) -> &'static str { use ColumnType::*; - match (col_ty, unsigned) { - (MYSQL_TYPE_NULL, _) => "null", + match (col_ty, unsigned, binary) { + (MYSQL_TYPE_NULL, _, _) => "null", + + (MYSQL_TYPE_TINY, false, _) => "tinyint", + (MYSQL_TYPE_TINY, true, _) => "tinyint unsigned", + + (MYSQL_TYPE_SHORT, false, _) => "smallint", + (MYSQL_TYPE_SHORT, true, _) => "smallint unsigned", - (MYSQL_TYPE_TINY, false) => "tinyint", - (MYSQL_TYPE_TINY, true) => "tinyint unsigned", + (MYSQL_TYPE_INT24, false, _) => "mediumint", + (MYSQL_TYPE_INT24, true, _) => "mediumint unsigned", - (MYSQL_TYPE_SHORT, false) => "smallint", - (MYSQL_TYPE_SHORT, true) => "smallint unsigned", + (MYSQL_TYPE_LONG, false, _) => "int", + (MYSQL_TYPE_LONG, true, _) => "int unsigned", - (MYSQL_TYPE_INT24, false) => "mediumint", - (MYSQL_TYPE_INT24, true) => "mediumint unsigned", + (MYSQL_TYPE_LONGLONG, false, _) => "bigint", + (MYSQL_TYPE_LONGLONG, true, _) => "bigint unsigned", - (MYSQL_TYPE_LONG, false) => "int", - (MYSQL_TYPE_LONG, true) => "int unsigned", + (MYSQL_TYPE_FLOAT, _, _) => "float", + (MYSQL_TYPE_DOUBLE, _, _) => "double", - (MYSQL_TYPE_LONGLONG, false) => "bigint", - (MYSQL_TYPE_LONGLONG, true) => "bigint unsigned", + (MYSQL_TYPE_TIMESTAMP, _, _) => "timestamp", + (MYSQL_TYPE_DATE, _, _) => "date", + (MYSQL_TYPE_TIME, _, _) => "time", + (MYSQL_TYPE_DATETIME, _, _) => "datetime", + (MYSQL_TYPE_YEAR, _, _) => "year", + (MYSQL_TYPE_NEWDATE, _, _) => "newdate", - (MYSQL_TYPE_FLOAT, _) => "float", - (MYSQL_TYPE_DOUBLE, _) => "double", + (MYSQL_TYPE_TIMESTAMP2, _, _) => "timestamp2", + (MYSQL_TYPE_DATETIME2, _, _) => "datetime2", + (MYSQL_TYPE_TIME2, _, _) => "time2", + (MYSQL_TYPE_TYPED_ARRAY, _, _) => "typed_array", - (MYSQL_TYPE_TIMESTAMP, _) => "timestamp", - (MYSQL_TYPE_DATE, _) => "date", - (MYSQL_TYPE_TIME, _) => "time", - (MYSQL_TYPE_DATETIME, _) => "datetime", - (MYSQL_TYPE_YEAR, _) => "year", - (MYSQL_TYPE_NEWDATE, _) => "newdate", + (MYSQL_TYPE_NEWDECIMAL, _, _) => "newdecimal", + (MYSQL_TYPE_DECIMAL, _, _) => "decimal", - (MYSQL_TYPE_TIMESTAMP2, _) => "timestamp2", - (MYSQL_TYPE_DATETIME2, _) => "datetime2", - (MYSQL_TYPE_TIME2, _) => "time2", - (MYSQL_TYPE_TYPED_ARRAY, _) => "typed_array", + (MYSQL_TYPE_VARCHAR, _, _) => "varchar", + (MYSQL_TYPE_JSON, _, _) => "json", - (MYSQL_TYPE_NEWDECIMAL, _) => "newdecimal", - (MYSQL_TYPE_DECIMAL, _) => "decimal", + (MYSQL_TYPE_ENUM, _, _) => "enum", + (MYSQL_TYPE_SET, _, _) => "set", - (MYSQL_TYPE_VARCHAR, _) => "varchar", - (MYSQL_TYPE_VAR_STRING, _) => "var_string", - (MYSQL_TYPE_STRING, _) => "string", - (MYSQL_TYPE_JSON, _) => "json", + (MYSQL_TYPE_BIT, _, _) => "bit", - (MYSQL_TYPE_ENUM, _) => "enum", - (MYSQL_TYPE_SET, _) => "set", + (MYSQL_TYPE_TINY_BLOB, _, true) => "tinyblob", + (MYSQL_TYPE_MEDIUM_BLOB, _, true) => "mediumblob", + (MYSQL_TYPE_LONG_BLOB, _, true) => "longblob", + (MYSQL_TYPE_BLOB, _, true) => "blob", + (MYSQL_TYPE_VAR_STRING, _, true) => "varbinary", + (MYSQL_TYPE_STRING, _, true) => "binary", - (MYSQL_TYPE_BIT, _) => "bit", - (MYSQL_TYPE_TINY_BLOB, _) => "tiny_blob", - (MYSQL_TYPE_MEDIUM_BLOB, _) => "medium_blob", - (MYSQL_TYPE_LONG_BLOB, _) => "long_blob", - (MYSQL_TYPE_BLOB, _) => "blob", + (MYSQL_TYPE_TINY_BLOB, _, false) => "tinytext", + (MYSQL_TYPE_MEDIUM_BLOB, _, false) => "mediumtext", + (MYSQL_TYPE_LONG_BLOB, _, false) => "longtext", + (MYSQL_TYPE_BLOB, _, false) => "text", + (MYSQL_TYPE_VAR_STRING, _, false) => "varchar", + (MYSQL_TYPE_STRING, _, false) => "char", - (MYSQL_TYPE_GEOMETRY, _) => "geometry", - (MYSQL_TYPE_UNKNOWN, _) => "unknown", + (MYSQL_TYPE_GEOMETRY, _, _) => "geometry", + (MYSQL_TYPE_UNKNOWN, _, _) => "unknown", } } diff --git a/connector_arrow/src/util/escape.rs b/connector_arrow/src/util/escape.rs index b1226a0..f429109 100644 --- a/connector_arrow/src/util/escape.rs +++ b/connector_arrow/src/util/escape.rs @@ -8,6 +8,11 @@ pub fn escaped_ident(ident: &str) -> EscapedIdent<'_> { EscapedIdent { ident, quote: '"' } } +#[allow(dead_code)] +pub fn escaped_ident_bt(ident: &str) -> EscapedIdent<'_> { + EscapedIdent { ident, quote: '`' } +} + pub static VALID_IDENT: Lazy = Lazy::new(|| { // An ident starting with `a-z_` and containing other characters `a-z0-9_$` // diff --git a/connector_arrow/tests/it/spec.rs b/connector_arrow/tests/it/spec.rs index ef257ee..329f98d 100644 --- a/connector_arrow/tests/it/spec.rs +++ b/connector_arrow/tests/it/spec.rs @@ -204,6 +204,10 @@ pub fn interval() -> Vec { } pub fn utf8() -> Vec { + domains_to_batch_spec(&[DataType::Utf8], &[false, true], &VALUE_GEN_PROCESS_ALL) +} + +pub fn utf8_large() -> Vec { domains_to_batch_spec( &[DataType::Utf8, DataType::LargeUtf8], &[false, true], @@ -212,11 +216,30 @@ pub fn utf8() -> Vec { } pub fn binary() -> Vec { + domains_to_batch_spec( + &[ + DataType::Binary, + DataType::FixedSizeBinary(1), + DataType::FixedSizeBinary(2), + DataType::FixedSizeBinary(3), + DataType::FixedSizeBinary(4), + ], + &[false, true], + &VALUE_GEN_PROCESS_ALL, + ) +} + +pub fn binary_large() -> Vec { domains_to_batch_spec( &[ DataType::Binary, DataType::LargeBinary, DataType::FixedSizeBinary(15), + DataType::FixedSizeBinary(1), + DataType::FixedSizeBinary(2), + DataType::FixedSizeBinary(3), + DataType::FixedSizeBinary(4), + DataType::FixedSizeBinary(5), DataType::FixedSizeBinary(0), ], &[false, true], diff --git a/connector_arrow/tests/it/test_duckdb.rs b/connector_arrow/tests/it/test_duckdb.rs index 92dd89c..f6c53c2 100644 --- a/connector_arrow/tests/it/test_duckdb.rs +++ b/connector_arrow/tests/it/test_duckdb.rs @@ -26,11 +26,11 @@ fn query_01() { // #[case::time("roundtrip::time", spec::time())] // #[case::duration("roundtrip::duration", spec::duration())] // #[case::interval("roundtrip::interval", spec::interval())] -#[case::utf8("roundtrip::utf8", spec::utf8())] -#[case::binary("roundtrip::binary", spec::binary())] +#[case::utf8("roundtrip::utf8", spec::utf8_large())] +#[case::binary("roundtrip::binary", spec::binary_large())] fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { let mut conn = init(); - super::tests::roundtrip(&mut conn, table_name, spec); + super::tests::roundtrip(&mut conn, table_name, spec, '"', false); } #[test] diff --git a/connector_arrow/tests/it/test_mysql.rs b/connector_arrow/tests/it/test_mysql.rs index 5ef0582..71ce53a 100644 --- a/connector_arrow/tests/it/test_mysql.rs +++ b/connector_arrow/tests/it/test_mysql.rs @@ -1,4 +1,7 @@ use connector_arrow::mysql::MySQLConnection; +use rstest::*; + +use crate::spec; fn init() -> MySQLConnection { let _ = env_logger::builder().is_test(true).try_init(); @@ -13,3 +16,49 @@ fn query_01() { let mut conn = init(); super::tests::query_01(&mut conn); } + +#[test] +fn schema_get() { + let table_name = "schema_get"; + + let mut conn = init(); + let column_spec = super::spec::basic_types(); + super::tests::schema_get(&mut conn, table_name, column_spec); +} + +#[test] +fn schema_edit() { + let table_name = "schema_edit"; + + let mut conn = init(); + let column_spec = super::spec::basic_types(); + super::tests::schema_edit(&mut conn, table_name, column_spec); +} + +#[test] +fn ident_escaping() { + // https://github.com/blackbeam/rust_mysql_common/issues/129 + let table_name = "ident_escaping"; + + let mut conn = init(); + super::tests::ident_escaping(&mut conn, table_name); +} + +#[rstest] +#[case::empty("roundtrip__empty", spec::empty())] +#[case::null_bool("roundtrip__null_bool", spec::null_bool())] +#[case::int("roundtrip__int", spec::int())] +#[case::uint("roundtrip__uint", spec::uint())] +#[case::float("roundtrip__float", spec::float())] +// #[case::decimal("roundtrip__decimal", spec::decimal())] +// #[case::timestamp("roundtrip__timestamp", spec::timestamp())] +// #[case::date("roundtrip__date", spec::date())] +// #[case::time("roundtrip__time", spec::time())] +// #[case::duration("roundtrip__duration", spec::duration())] +// #[case::interval("roundtrip__interval", spec::interval())] +#[case::utf8("roundtrip__utf8", spec::utf8())] +#[case::binary("roundtrip__binary", spec::binary())] +fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { + let mut conn = init(); + super::tests::roundtrip(&mut conn, table_name, spec, '`', true); +} diff --git a/connector_arrow/tests/it/test_postgres_extended.rs b/connector_arrow/tests/it/test_postgres_extended.rs index ce29af5..d5f803f 100644 --- a/connector_arrow/tests/it/test_postgres_extended.rs +++ b/connector_arrow/tests/it/test_postgres_extended.rs @@ -43,12 +43,12 @@ fn query_03() { #[case::time("roundtrip::time", spec::time())] #[case::duration("roundtrip::duration", spec::duration())] // #[case::interval("roundtrip::interval", spec::interval())] -#[case::utf8("roundtrip::utf8", spec::utf8())] -#[case::binary("roundtrip::binary", spec::binary())] +#[case::utf8("roundtrip::utf8", spec::utf8_large())] +#[case::binary("roundtrip::binary", spec::binary_large())] fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { let mut conn = init(); let table_name = format!("extended::{table_name}"); - super::tests::roundtrip(&mut conn, &table_name, spec); + super::tests::roundtrip(&mut conn, &table_name, spec, '"', false); } #[rstest] diff --git a/connector_arrow/tests/it/test_postgres_simple.rs b/connector_arrow/tests/it/test_postgres_simple.rs index 31310d9..3c3cdc4 100644 --- a/connector_arrow/tests/it/test_postgres_simple.rs +++ b/connector_arrow/tests/it/test_postgres_simple.rs @@ -37,12 +37,12 @@ fn query_02() { #[case::time("roundtrip::time", spec::time())] #[case::duration("roundtrip::duration", spec::duration())] // #[case::interval("roundtrip::interval", spec::interval())] -#[case::utf8("roundtrip::utf8", spec::utf8())] -#[case::binary("roundtrip::binary", spec::binary())] +#[case::utf8("roundtrip::utf8", spec::utf8_large())] +#[case::binary("roundtrip::binary", spec::binary_large())] fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { let mut conn = init(); let table_name = format!("simple::{table_name}"); - super::tests::roundtrip(&mut conn, &table_name, spec); + super::tests::roundtrip(&mut conn, &table_name, spec, '"', false); } #[rstest] diff --git a/connector_arrow/tests/it/test_sqlite.rs b/connector_arrow/tests/it/test_sqlite.rs index 6628c07..5031877 100644 --- a/connector_arrow/tests/it/test_sqlite.rs +++ b/connector_arrow/tests/it/test_sqlite.rs @@ -26,11 +26,11 @@ fn query_01() { #[case::time("roundtrip::time", spec::time())] #[case::duration("roundtrip::duration", spec::duration())] // #[case::interval("roundtrip::interval", spec::interval())] -#[case::utf8("roundtrip::utf8", spec::utf8())] -#[case::binary("roundtrip::binary", spec::binary())] +#[case::utf8("roundtrip::utf8", spec::utf8_large())] +#[case::binary("roundtrip::binary", spec::binary_large())] fn roundtrip(#[case] table_name: &str, #[case] spec: spec::ArrowGenSpec) { let mut conn = init(); - super::tests::roundtrip(&mut conn, table_name, spec); + super::tests::roundtrip(&mut conn, table_name, spec, '"', false); } #[test] diff --git a/connector_arrow/tests/it/tests.rs b/connector_arrow/tests/it/tests.rs index a946754..3040222 100644 --- a/connector_arrow/tests/it/tests.rs +++ b/connector_arrow/tests/it/tests.rs @@ -77,8 +77,13 @@ pub fn query_03(conn: &mut C) { ); } -pub fn roundtrip(conn: &mut C, table_name: &str, spec: ArrowGenSpec) -where +pub fn roundtrip( + conn: &mut C, + table_name: &str, + spec: ArrowGenSpec, + ident_quote_char: char, + nullable_results: bool, +) where C: Connector + SchemaEdit, { let mut rng = rand_chacha::ChaCha8Rng::from_seed([0; 32]); @@ -86,10 +91,11 @@ where load_into_table(conn, schema.clone(), &batches, table_name).unwrap(); + let override_nullable = if !nullable_results { Some(true) } else { None }; let (schema_coerced, batches_coerced) = - coerce::coerce_batches(schema, &batches, coerce_type::, Some(true)).unwrap(); + coerce::coerce_batches(schema, &batches, coerce_type::, override_nullable).unwrap(); - let (schema_query, batches_query) = query_table(conn, table_name).unwrap(); + let (schema_query, batches_query) = query_table(conn, table_name, ident_quote_char).unwrap(); similar_asserts::assert_eq!(schema_coerced, schema_query); similar_asserts::assert_eq!(batches_coerced, batches_query); diff --git a/connector_arrow/tests/it/util.rs b/connector_arrow/tests/it/util.rs index 94e4301..74c44fb 100644 --- a/connector_arrow/tests/it/util.rs +++ b/connector_arrow/tests/it/util.rs @@ -59,9 +59,12 @@ where pub fn query_table( conn: &mut C, table_name: &str, + ident_quote_char: char, ) -> Result<(SchemaRef, Vec), ConnectorError> { let mut stmt = conn - .query(&format!("SELECT * FROM \"{table_name}\"")) + .query(&format!( + "SELECT * FROM {ident_quote_char}{table_name}{ident_quote_char}" + )) .unwrap(); let mut reader = stmt.start([])?;