Skip to content

Commit

Permalink
Update SQL table reading functions to return maps
Browse files Browse the repository at this point in the history
Refactored the SQL table reading functions to return map data structures instead of list. This change helps to easily correlate each dataframe with its underlying table. Additionally, the function comments and test cases were updated to match this change.
  • Loading branch information
zaleslaw committed Jun 4, 2024
1 parent 55a6fe9 commit f254519
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -290,33 +290,35 @@ public fun DataFrame.Companion.readResultSet(
}

/**
* Reads all tables from the given database using the provided database configuration and limit.
* Reads all non-system tables from a database and returns them
* as a map of SQL tables and corresponding dataframes using the provided database configuration and limit.
*
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
* @param [limit] the maximum number of rows to read from each table.
* @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return a list of [AnyFrame] objects representing the non-system tables from the database.
* @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database.
*/
public fun DataFrame.Companion.readAllSqlTables(
dbConfig: DatabaseConfiguration,
catalogue: String? = null,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): List<AnyFrame> {
): Map<String, AnyFrame> {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return readAllSqlTables(connection, catalogue, limit, inferNullability)
}
}

/**
* Reads all non-system tables from a database and returns them as a list of data frames.
* Reads all non-system tables from a database and returns them
* as a map of SQL tables and corresponding dataframes.
*
* @param [connection] the database connection to read tables from.
* @param [limit] the maximum number of rows to read from each table.
* @param [catalogue] a name of the catalog from which tables will be retrieved. A null value retrieves tables from all catalogs.
* @param [inferNullability] indicates how the column nullability should be inferred.
* @return a list of [AnyFrame] objects representing the non-system tables from the database.
* @return a map of [String] to [AnyFrame] objects representing the non-system tables from the database.
*
* @see DriverManager.getConnection
*/
Expand All @@ -325,20 +327,20 @@ public fun DataFrame.Companion.readAllSqlTables(
catalogue: String? = null,
limit: Int = DEFAULT_LIMIT,
inferNullability: Boolean = true,
): List<AnyFrame> {
): Map<String, AnyFrame> {
val metaData = connection.metaData
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)

// exclude a system and other tables without data, but it looks like it supported badly for many databases
val tables = metaData.getTables(catalogue, null, null, arrayOf("TABLE"))

val dataFrames = mutableListOf<AnyFrame>()
val dataFrames = mutableMapOf<String, AnyFrame>()

while (tables.next()) {
val table = dbType.buildTableMetadata(tables)
if (!dbType.isSystemTable(table)) {
// we filter her second time because of specific logic with SQLite and possible issues with future databases
// we filter her a second time because of specific logic with SQLite and possible issues with future databases
val tableName = when {
catalogue != null && table.schemaName != null -> "$catalogue.${table.schemaName}.${table.name}"
catalogue != null && table.schemaName == null -> "$catalogue.${table.name}"
Expand All @@ -351,7 +353,7 @@ public fun DataFrame.Companion.readAllSqlTables(
logger.debug { "Reading table: $tableName" }

val dataFrame = readSqlTable(connection, tableName, limit, inferNullability)
dataFrames += dataFrame
dataFrames += tableName to dataFrame
logger.debug { "Finished reading table: $tableName" }
}
}
Expand Down Expand Up @@ -474,24 +476,24 @@ public fun DataFrame.Companion.getSchemaForResultSet(resultSet: ResultSet, conne
}

/**
* Retrieves the schema of all non-system tables in the database using the provided database configuration.
* Retrieves the schemas of all non-system tables in the database using the provided database configuration.
*
* @param [dbConfig] the database configuration to connect to the database, including URL, user, and password.
* @return a list of [DataFrameSchema] objects representing the schema of each non-system table.
* @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table.
*/
public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): List<DataFrameSchema> {
public fun DataFrame.Companion.getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): Map<String, DataFrameSchema> {
DriverManager.getConnection(dbConfig.url, dbConfig.user, dbConfig.password).use { connection ->
return getSchemaForAllSqlTables(connection)
}
}

/**
* Retrieves the schema of all non-system tables in the database using the provided database connection.
* Retrieves the schemas of all non-system tables in the database using the provided database connection.
*
* @param [connection] the database connection.
* @return a list of [DataFrameSchema] objects representing the schema of each non-system table.
* @return a map of [String, DataFrameSchema] objects representing the table name and its schema for each non-system table.
*/
public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): List<DataFrameSchema> {
public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection): Map<String, DataFrameSchema> {
val metaData = connection.metaData
val url = connection.metaData.url
val dbType = extractDBTypeFromUrl(url)
Expand All @@ -500,14 +502,15 @@ public fun DataFrame.Companion.getSchemaForAllSqlTables(connection: Connection):
// exclude a system and other tables without data
val tables = metaData.getTables(null, null, null, tableTypes)

val dataFrameSchemas = mutableListOf<DataFrameSchema>()
val dataFrameSchemas = mutableMapOf<String, DataFrameSchema>()

while (tables.next()) {
val jdbcTable = dbType.buildTableMetadata(tables)
if (!dbType.isSystemTable(jdbcTable)) {
// we filter her second time because of specific logic with SQLite and possible issues with future databases
val dataFrameSchema = getSchemaForSqlTable(connection, jdbcTable.name)
dataFrameSchemas += dataFrameSchema
// we filter her a second time because of specific logic with SQLite and possible issues with future databases
val tableName = jdbcTable.name
val dataFrameSchema = getSchemaForSqlTable(connection, tableName)
dataFrameSchemas += tableName to dataFrameSchema
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,11 @@ class JdbcTest {

@Test
fun `read from all tables`() {
val dataframes = DataFrame.readAllSqlTables(connection)
val dataFrameMap = DataFrame.readAllSqlTables(connection)
dataFrameMap.containsKey("Customer") shouldBe true
dataFrameMap.containsKey("Sale") shouldBe true

val dataframes = dataFrameMap.values.toList()

val customerDf = dataframes[0].cast<Customer>()

Expand All @@ -611,7 +615,7 @@ class JdbcTest {
saleDf.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0

val dataframes1 = DataFrame.readAllSqlTables(connection, limit = 1)
val dataframes1 = DataFrame.readAllSqlTables(connection, limit = 1).values.toList()

val customerDf1 = dataframes1[0].cast<Customer>()

Expand All @@ -625,7 +629,11 @@ class JdbcTest {
saleDf1.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0

val dataSchemas = DataFrame.getSchemaForAllSqlTables(connection)
val dataFrameSchemaMap = DataFrame.getSchemaForAllSqlTables(connection)
dataFrameSchemaMap.containsKey("Customer") shouldBe true
dataFrameSchemaMap.containsKey("Sale") shouldBe true

val dataSchemas = dataFrameSchemaMap.values.toList()

val customerDataSchema = dataSchemas[0]
customerDataSchema.columns.size shouldBe 3
Expand All @@ -637,7 +645,7 @@ class JdbcTest {
saleDataSchema.columns["amount"]!!.type shouldBe typeOf<BigDecimal>()

val dbConfig = DatabaseConfiguration(url = URL)
val dataframes2 = DataFrame.readAllSqlTables(dbConfig)
val dataframes2 = DataFrame.readAllSqlTables(dbConfig).values.toList()

val customerDf2 = dataframes2[0].cast<Customer>()

Expand All @@ -651,7 +659,7 @@ class JdbcTest {
saleDf2.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 3
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0

val dataframes3 = DataFrame.readAllSqlTables(dbConfig, limit = 1)
val dataframes3 = DataFrame.readAllSqlTables(dbConfig, limit = 1).values.toList()

val customerDf3 = dataframes3[0].cast<Customer>()

Expand All @@ -665,7 +673,7 @@ class JdbcTest {
saleDf3.filter { it[Sale::amount] > 40 }.rowsCount() shouldBe 1
(saleDf[0][2] as BigDecimal).compareTo(BigDecimal(100.50)) shouldBe 0

val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig)
val dataSchemas1 = DataFrame.getSchemaForAllSqlTables(dbConfig).values.toList()

val customerDataSchema1 = dataSchemas1[0]
customerDataSchema1.columns.size shouldBe 3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ class MariadbTest {

@Test
fun `read from all tables`() {
val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 1000)
val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 1000).values.toList()

val table1Df = dataframes[0].cast<Table1MariaDb>()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ interface Table1MSSSQL {
val geographyColumn: String
}

@Ignore

class MSSQLTest {
companion object {
private lateinit var connection: Connection
Expand Down Expand Up @@ -277,7 +277,7 @@ class MSSQLTest {

@Test
fun `read from all tables`() {
val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4)
val dataframes = DataFrame.readAllSqlTables(connection, TEST_DATABASE_NAME, 4).values.toList()

val table1Df = dataframes[0].cast<Table1MSSSQL>()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ class MySqlTest {

@Test
fun `read from all tables`() {
val dataframes = DataFrame.readAllSqlTables(connection)
val dataframes = DataFrame.readAllSqlTables(connection).values.toList()

val table1Df = dataframes[0].cast<Table1MySql>()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ class PostgresTest {

@Test
fun `read from all tables`() {
val dataframes = DataFrame.readAllSqlTables(connection)
val dataframes = DataFrame.readAllSqlTables(connection).values.toList()

val table1Df = dataframes[0].cast<Table1>()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class SqliteTest {

@Test
fun `read from all tables`() {
val dataframes = DataFrame.readAllSqlTables(connection)
val dataframes = DataFrame.readAllSqlTables(connection).values.toList()

val customerDf = dataframes[0].cast<CustomerSQLite>()

Expand Down
16 changes: 8 additions & 8 deletions docs/StardustDocs/topics/readSqlDatabases.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ In the second, be sure that you can establish a connection to the database.

For this, usually, you need to have three things: a URL to a database, a username and a password.

Call one of the following functions to obtain data from a database and transform it to the dataframe.
Call one of the following functions to collect data from a database and transform it to the dataframe.

For example, if you have a local PostgreSQL database named as `testDatabase` with table `Customer`,
you could read first 100 rows and print the data just copying the code below:
Expand Down Expand Up @@ -105,7 +105,7 @@ Next, import `Kotlin DataFrame` library in the cell below.
**NOTE:** The order of cell execution is important,
the dataframe library is waiting for a JDBC driver to force classloading.

Find full example Notebook [here](https://github.com/zaleslaw/KotlinDataFrame-SQL-Examples/blob/master/notebooks/imdb.ipynb).
Find a full example Notebook [here](https://github.com/zaleslaw/KotlinDataFrame-SQL-Examples/blob/master/notebooks/imdb.ipynb).


## Reading Specific Tables
Expand Down Expand Up @@ -315,9 +315,9 @@ connection.close()
These functions read all data from all tables in the connected database.
Variants with a limit parameter restrict how many rows will be read from each table.

**readAllSqlTables(connection: Connection): List\<AnyFrame>**
**readAllSqlTables(connection: Connection): Map\<String, AnyFrame>**

Retrieves data from all the non-system tables in the SQL database and returns them as a list of AnyFrame objects.
Retrieves data from all the non-system tables in the SQL database and returns them as a map of table names to AnyFrame objects.

The `dbConfig: DatabaseConfiguration` parameter represents the configuration for a database connection,
created under the hood and managed by the library. Typically, it requires a URL, username and password.
Expand All @@ -330,7 +330,7 @@ val dbConfig = DatabaseConfiguration("URL_TO_CONNECT_DATABASE", "USERNAME", "PAS
val dataframes = DataFrame.readAllSqlTables(dbConfig)
```

**readAllSqlTables(connection: Connection, limit: Int): List\<AnyFrame>**
**readAllSqlTables(connection: Connection, limit: Int): Map\<String, AnyFrame>**

A variant of the previous function,
but with an added `limit: Int` parameter that allows setting the maximum number of records to be read from each table.
Expand Down Expand Up @@ -493,10 +493,10 @@ connection.close()
These functions return a list of all [`DataFrameSchema`](schema.md) from all the non-system tables in the SQL database.
They can be called with either a database configuration or a connection.

**getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): List\<DataFrameSchema>**
**getSchemaForAllSqlTables(dbConfig: DatabaseConfiguration): Map\<String, DataFrameSchema>**

This function retrieves the schema of all tables from an SQL database
and returns them as a list of [`DataFrameSchema`](schema.md).
and returns them as a map of table names to [`DataFrameSchema`](schema.md) objects.

The `dbConfig: DatabaseConfiguration` parameter represents the configuration for a database connection,
created under the hood and managed by the library. Typically, it requires a URL, username and password.
Expand All @@ -509,7 +509,7 @@ val dbConfig = DatabaseConfiguration("URL_TO_CONNECT_DATABASE", "USERNAME", "PAS
val schemas = DataFrame.getSchemaForAllSqlTables(dbConfig)
```

**getSchemaForAllSqlTables(connection: Connection): List\<DataFrameSchema>**
**getSchemaForAllSqlTables(connection: Connection): Map\<String, DataFrameSchema>**

This function retrieves the schema of all tables using a JDBC connection: `Connection` object
and returns them as a list of [`DataFrameSchema`](schema.md).
Expand Down
2 changes: 1 addition & 1 deletion gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ kotestAsserions = "5.5.4"

jsoup = "1.17.2"
arrow = "15.0.0"
docProcessor = "0.3.6"
docProcessor = "0.3.7"
simpleGit = "2.0.3"
dependencyVersions = "0.51.0"
plugin-publish = "1.2.1"
Expand Down

0 comments on commit f254519

Please sign in to comment.