Skip to content

Commit

Permalink
move pk population on join path add
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-cnivera committed Oct 21, 2024
1 parent 219f196 commit 938aaf4
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 28 deletions.
38 changes: 37 additions & 1 deletion journeys/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
import streamlit as st
from streamlit_extras.row import row

from app_utils.shared_utils import get_snowflake_connection
from semantic_model_generator.data_processing.cte_utils import (
fully_qualified_table_name,
)
from semantic_model_generator.protos import semantic_model_pb2
from semantic_model_generator.snowflake_utils.snowflake_connector import (
get_table_primary_keys,
)

SUPPORTED_JOIN_TYPES = [
join_type
Expand Down Expand Up @@ -167,7 +174,6 @@ def relationship_builder(

@st.experimental_dialog("Join Builder", width="large")
def joins_dialog() -> None:

if "builder_joins" not in st.session_state:
# Making a copy of the original relationships list so we can modify freely without affecting the original.
st.session_state.builder_joins = st.session_state.semantic_model.relationships[
Expand Down Expand Up @@ -210,6 +216,36 @@ def joins_dialog() -> None:
)
return

# Populate primary key information for each table in a join relationship.
left_table_object = next(
(
table
for table in st.session_state.semantic_model.tables
if table.name == relationship.left_table
)
)
right_table_object = next(
(
table
for table in st.session_state.semantic_model.tables
if table.name == relationship.right_table
)
)

if not left_table_object.primary_key.columns:
primary_keys = get_table_primary_keys(
get_snowflake_connection(),
table_fqn=fully_qualified_table_name(left_table_object.base_table),
)
left_table_object.primary_key.columns.extend(primary_keys or [""])

if not right_table_object.primary_key.columns:
primary_keys = get_table_primary_keys(
get_snowflake_connection(),
table_fqn=fully_qualified_table_name(right_table_object.base_table),
)
right_table_object.primary_key.columns.extend(primary_keys or [""])

del st.session_state.semantic_model.relationships[:]
st.session_state.semantic_model.relationships.extend(
st.session_state.builder_joins
Expand Down
1 change: 0 additions & 1 deletion semantic_model_generator/data_processing/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class Table:
id_: int
name: str
columns: List[Column]
primary_key: Optional[list[str]] = None
comment: Optional[str] = (
None # comment field's to save the table comment user specified on the table
)
Expand Down
18 changes: 1 addition & 17 deletions semantic_model_generator/generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_placeholder_joins() -> List[semantic_model_pb2.Relationship]:


def _raw_table_to_semantic_context_table(
database: str, schema: str, raw_table: data_types.Table, allow_joins: bool = False
database: str, schema: str, raw_table: data_types.Table
) -> semantic_model_pb2.Table:
"""
Converts a raw table representation to a semantic model table in protobuf format.
Expand All @@ -68,7 +68,6 @@ def _raw_table_to_semantic_context_table(
database (str): The name of the database containing the table.
schema (str): The name of the schema containing the table.
raw_table (data_types.Table): The raw table object to be transformed.
allow_joins (bool): Whether joins are enabled in the semantic model.
Returns:
semantic_model_pb2.Table: A protobuf representation of the semantic table.
Expand Down Expand Up @@ -147,18 +146,6 @@ def _raw_table_to_semantic_context_table(
f"No valid columns found for table {raw_table.name}. Please verify that this table contains column's datatypes not in {OBJECT_DATATYPES}."
)

primary_key = None
if allow_joins:
# Populate the primary key field if we were able to retrieve one during raw table construction.
# If not, leave a placeholder for the user to fill out.
primary_key = semantic_model_pb2.PrimaryKey(
columns=(
raw_table.primary_key
if raw_table.primary_key
else [_PLACEHOLDER_COMMENT]
)
)

return semantic_model_pb2.Table(
name=raw_table.name,
base_table=semantic_model_pb2.FullyQualifiedTable(
Expand All @@ -170,7 +157,6 @@ def _raw_table_to_semantic_context_table(
dimensions=dimensions,
time_dimensions=time_dimensions,
measures=measures,
primary_key=primary_key,
)


Expand Down Expand Up @@ -236,13 +222,11 @@ def raw_schema_to_semantic_context(
ndv_per_column=n_sample_values, # number of sample values to pull per column.
columns_df=valid_columns_df_this_table,
max_workers=1,
allow_joins=allow_joins,
)
table_object = _raw_table_to_semantic_context_table(
database=fqn_table.database,
schema=fqn_table.schema_name,
raw_table=raw_table,
allow_joins=allow_joins,
)
table_objects.append(table_object)
# TODO(jhilgart): Call cortex model to generate a semantically friendly name here.
Expand Down
13 changes: 4 additions & 9 deletions semantic_model_generator/snowflake_utils/snowflake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,11 @@ def _get_column_comment(
return ""


def _get_table_primary_keys(
conn: SnowflakeConnection, schema_name: str, table_name: str
def get_table_primary_keys(
conn: SnowflakeConnection,
table_fqn: str,
) -> list[str] | None:
query = f"show primary keys in table {schema_name}.{table_name};"
query = f"show primary keys in table {table_fqn};"
cursor = conn.cursor()
cursor.execute(query)
primary_keys = cursor.fetchall()
Expand All @@ -146,7 +147,6 @@ def get_table_representation(
ndv_per_column: int,
columns_df: pd.DataFrame,
max_workers: int,
allow_joins: bool = False,
) -> Table:
table_comment = _get_table_comment(conn, schema_name, table_name, columns_df)

Expand All @@ -172,16 +172,11 @@ def _get_col(col_index: int, column_row: pd.Series) -> Column:
index_and_column.append((col_index, column))
columns = [c for _, c in sorted(index_and_column, key=lambda x: x[0])]

primary_keys = (
_get_table_primary_keys(conn, schema_name, table_name) if allow_joins else None
)

return Table(
id_=table_index,
name=table_name,
comment=table_comment,
columns=columns,
primary_keys=primary_keys,
)


Expand Down

0 comments on commit 938aaf4

Please sign in to comment.