Skip to content

Commit

Permalink
expose underlying schema type of statement (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhiaagarwal authored Jun 9, 2024
1 parent c175a8a commit c56e458
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
21 changes: 20 additions & 1 deletion crates/duckdb/src/column.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::str;

use arrow::datatypes::DataType;

use crate::{Error, Result, Statement};

/// Information about a column of a DuckDB query.
Expand Down Expand Up @@ -29,6 +31,9 @@ impl Statement<'_> {
/// If associated DB schema can be altered concurrently, you should make
/// sure that current statement has already been stepped once before
/// calling this method.
///
/// # Caveats
/// Panics if the query has not been [`execute`](Statement::execute)d yet.
pub fn column_names(&self) -> Vec<String> {
self.stmt
.schema()
Expand Down Expand Up @@ -87,7 +92,9 @@ impl Statement<'_> {
/// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid
/// column range for this row.
///
/// Panics when column name is not valid UTF-8.
/// # Caveats
/// Panics if the query has not been [`execute`](Statement::execute)d yet
/// or when column name is not valid UTF-8.
#[inline]
pub fn column_name(&self, col: usize) -> Result<&String> {
self.stmt.column_name(col).ok_or(Error::InvalidColumnIndex(col))
Expand All @@ -106,6 +113,9 @@ impl Statement<'_> {
///
/// Will return an `Error::InvalidColumnName` when there is no column with
/// the specified `name`.
///
/// # Caveats
/// Panics if the query has not been [`execute`](Statement::execute)d yet.
#[inline]
pub fn column_index(&self, name: &str) -> Result<usize> {
let n = self.column_count();
Expand All @@ -119,6 +129,15 @@ impl Statement<'_> {
Err(Error::InvalidColumnName(String::from(name)))
}

/// Returns the declared data type of the column.
///
/// # Caveats
/// Panics if the query has not been [`execute`](Statement::execute)d yet.
#[inline]
pub fn column_type(&self, idx: usize) -> DataType {
self.stmt.column_type(idx)
}

/// Returns a slice describing the columns of the result of the query.
///
/// If associated DB schema can be altered concurrently, you should make
Expand Down
52 changes: 45 additions & 7 deletions crates/duckdb/src/statement.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{convert, ffi::c_void, fmt, mem, os::raw::c_char, ptr, str};

use arrow::{array::StructArray, datatypes::DataType};
use arrow::{array::StructArray, datatypes::SchemaRef};

use super::{ffi, AndThenRows, Connection, Error, MappedRows, Params, RawStatement, Result, Row, Rows, ValueRef};
#[cfg(feature = "polars")]
Expand Down Expand Up @@ -452,6 +452,15 @@ impl Statement<'_> {
Rows::new(self)
}

/// Returns the underlying schema of the prepared statement.
///
/// # Caveats
/// Panics if the query has not been [`execute`](Statement::execute)d yet.
#[inline]
pub fn schema(&self) -> SchemaRef {
self.stmt.schema()
}

// generic because many of these branches can constant fold away.
fn bind_parameter<P: ?Sized + ToSql>(&self, param: &P, col: usize) -> Result<()> {
let value = param.to_sql()?;
Expand Down Expand Up @@ -542,12 +551,6 @@ impl Statement<'_> {
pub(super) fn new(conn: &Connection, stmt: RawStatement) -> Statement<'_> {
Statement { conn, stmt }
}

/// column_type
#[inline]
pub fn column_type(&self, idx: usize) -> DataType {
self.stmt.column_type(idx)
}
}

#[cfg(test)]
Expand Down Expand Up @@ -806,6 +809,41 @@ mod test {
Ok(())
}

#[test]
fn test_get_schema_of_executed_result() -> Result<()> {
use arrow::datatypes::{DataType, Field, Schema};
let db = Connection::open_in_memory()?;
let sql = "BEGIN;
CREATE TABLE foo(x STRING, y INTEGER);
INSERT INTO foo VALUES('hello', 3);
END;";
db.execute_batch(sql)?;
let mut stmt = db.prepare("SELECT x, y FROM foo")?;
let _ = stmt.execute([]);
let schema = stmt.schema();
assert_eq!(
*schema,
Schema::new(vec![
Field::new("x", DataType::Utf8, true),
Field::new("y", DataType::Int32, true)
])
);
Ok(())
}

#[test]
#[should_panic(expected = "called `Option::unwrap()` on a `None` value")]
fn test_unexecuted_schema_panics() {
let db = Connection::open_in_memory().unwrap();
let sql = "BEGIN;
CREATE TABLE foo(x STRING, y INTEGER);
INSERT INTO foo VALUES('hello', 3);
END;";
db.execute_batch(sql).unwrap();
let stmt = db.prepare("SELECT x, y FROM foo").unwrap();
let _ = stmt.schema();
}

#[test]
fn test_query_by_column_name_ignore_case() -> Result<()> {
let db = Connection::open_in_memory()?;
Expand Down

0 comments on commit c56e458

Please sign in to comment.