From c56e458aad72124d72b79a0924461a0a2c91a256 Mon Sep 17 00:00:00 2001 From: Abhi Agarwal Date: Sun, 9 Jun 2024 07:03:19 -0400 Subject: [PATCH] expose underlying schema type of statement (#333) --- crates/duckdb/src/column.rs | 21 +++++++++++++- crates/duckdb/src/statement.rs | 52 +++++++++++++++++++++++++++++----- 2 files changed, 65 insertions(+), 8 deletions(-) diff --git a/crates/duckdb/src/column.rs b/crates/duckdb/src/column.rs index b7ab88ae..8e898fef 100644 --- a/crates/duckdb/src/column.rs +++ b/crates/duckdb/src/column.rs @@ -1,5 +1,7 @@ use std::str; +use arrow::datatypes::DataType; + use crate::{Error, Result, Statement}; /// Information about a column of a DuckDB query. @@ -29,6 +31,9 @@ impl Statement<'_> { /// If associated DB schema can be altered concurrently, you should make /// sure that current statement has already been stepped once before /// calling this method. + /// + /// # Caveats + /// Panics if the query has not been [`execute`](Statement::execute)d yet. pub fn column_names(&self) -> Vec { self.stmt .schema() @@ -87,7 +92,9 @@ impl Statement<'_> { /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid /// column range for this row. /// - /// Panics when column name is not valid UTF-8. + /// # Caveats + /// Panics if the query has not been [`execute`](Statement::execute)d yet + /// or when column name is not valid UTF-8. #[inline] pub fn column_name(&self, col: usize) -> Result<&String> { self.stmt.column_name(col).ok_or(Error::InvalidColumnIndex(col)) @@ -106,6 +113,9 @@ impl Statement<'_> { /// /// Will return an `Error::InvalidColumnName` when there is no column with /// the specified `name`. + /// + /// # Caveats + /// Panics if the query has not been [`execute`](Statement::execute)d yet. #[inline] pub fn column_index(&self, name: &str) -> Result { let n = self.column_count(); @@ -119,6 +129,15 @@ impl Statement<'_> { Err(Error::InvalidColumnName(String::from(name))) } + /// Returns the declared data type of the column. + /// + /// # Caveats + /// Panics if the query has not been [`execute`](Statement::execute)d yet. + #[inline] + pub fn column_type(&self, idx: usize) -> DataType { + self.stmt.column_type(idx) + } + /// Returns a slice describing the columns of the result of the query. /// /// If associated DB schema can be altered concurrently, you should make diff --git a/crates/duckdb/src/statement.rs b/crates/duckdb/src/statement.rs index 1a05c2f4..a30c11c6 100644 --- a/crates/duckdb/src/statement.rs +++ b/crates/duckdb/src/statement.rs @@ -1,6 +1,6 @@ use std::{convert, ffi::c_void, fmt, mem, os::raw::c_char, ptr, str}; -use arrow::{array::StructArray, datatypes::DataType}; +use arrow::{array::StructArray, datatypes::SchemaRef}; use super::{ffi, AndThenRows, Connection, Error, MappedRows, Params, RawStatement, Result, Row, Rows, ValueRef}; #[cfg(feature = "polars")] @@ -452,6 +452,15 @@ impl Statement<'_> { Rows::new(self) } + /// Returns the underlying schema of the prepared statement. + /// + /// # Caveats + /// Panics if the query has not been [`execute`](Statement::execute)d yet. + #[inline] + pub fn schema(&self) -> SchemaRef { + self.stmt.schema() + } + // generic because many of these branches can constant fold away. fn bind_parameter(&self, param: &P, col: usize) -> Result<()> { let value = param.to_sql()?; @@ -542,12 +551,6 @@ impl Statement<'_> { pub(super) fn new(conn: &Connection, stmt: RawStatement) -> Statement<'_> { Statement { conn, stmt } } - - /// column_type - #[inline] - pub fn column_type(&self, idx: usize) -> DataType { - self.stmt.column_type(idx) - } } #[cfg(test)] @@ -806,6 +809,41 @@ mod test { Ok(()) } + #[test] + fn test_get_schema_of_executed_result() -> Result<()> { + use arrow::datatypes::{DataType, Field, Schema}; + let db = Connection::open_in_memory()?; + let sql = "BEGIN; + CREATE TABLE foo(x STRING, y INTEGER); + INSERT INTO foo VALUES('hello', 3); + END;"; + db.execute_batch(sql)?; + let mut stmt = db.prepare("SELECT x, y FROM foo")?; + let _ = stmt.execute([]); + let schema = stmt.schema(); + assert_eq!( + *schema, + Schema::new(vec![ + Field::new("x", DataType::Utf8, true), + Field::new("y", DataType::Int32, true) + ]) + ); + Ok(()) + } + + #[test] + #[should_panic(expected = "called `Option::unwrap()` on a `None` value")] + fn test_unexecuted_schema_panics() { + let db = Connection::open_in_memory().unwrap(); + let sql = "BEGIN; + CREATE TABLE foo(x STRING, y INTEGER); + INSERT INTO foo VALUES('hello', 3); + END;"; + db.execute_batch(sql).unwrap(); + let stmt = db.prepare("SELECT x, y FROM foo").unwrap(); + let _ = stmt.schema(); + } + #[test] fn test_query_by_column_name_ignore_case() -> Result<()> { let db = Connection::open_in_memory()?;