Skip to content

Commit

Permalink
Return table names in postgres for non 'public' schemas, too (#62)
Browse files Browse the repository at this point in the history
* return table names in postgres for non 'public' schemas, too

* upped version
  • Loading branch information
rishsriv authored Sep 11, 2024
1 parent 60310e6 commit 6dfbd78
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 10 deletions.
35 changes: 26 additions & 9 deletions defog/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def generate_postgres_schema(
return_format: str = "csv",
scan: bool = True,
return_tables_only: bool = False,
schemas: List[str] = ["public"],
schemas: List[str] = [],
) -> str:
# when upload is True, we send the schema to the defog servers and generate a CSV
# when its false, we return the schema as a dict
Expand All @@ -26,19 +26,36 @@ def generate_postgres_schema(

conn = psycopg2.connect(**self.db_creds)
cur = conn.cursor()
schemas = tuple(schemas)

if len(tables) == 0:
# get all tables
for schema in schemas:
if len(schemas) > 0:
for schema in schemas:
cur.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = %s;",
(schema,),
)
if schema == "public":
tables += [row[0] for row in cur.fetchall()]
else:
tables += [schema + "." + row[0] for row in cur.fetchall()]
else:
excluded_schemas = (
"information_schema",
"pg_catalog",
"pg_toast",
"pg_temp_1",
"pg_toast_temp_1",
)
cur.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = %s;",
(schema,),
"SELECT table_name, table_schema FROM information_schema.tables WHERE table_schema NOT IN %s;",
(excluded_schemas,),
)
if schema == "public":
tables += [row[0] for row in cur.fetchall()]
else:
tables += [schema + "." + row[0] for row in cur.fetchall()]
for row in cur.fetchall():
if row[1] == "public":
tables.append(row[0])
else:
tables.append(f"{row[1]}.{row[0]}")

if return_tables_only:
return tables
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def package_files(directory):
name="defog",
packages=find_packages(),
package_data={"defog": ["gcp/*", "aws/*"] + next_static_files},
version="0.65.10",
version="0.65.11",
description="Defog is a Python library that helps you generate data queries from natural language questions.",
author="Full Stack Data Pte. Ltd.",
license="MIT",
Expand Down

0 comments on commit 6dfbd78

Please sign in to comment.