Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1871175: Add support for specifying a schema string for DataFrame.create_dataframe #2828

Merged
merged 5 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
- `try_to_binary`

- Added `Catalog` class to manage snowflake objects. It can be accessed via `Session.catalog`.
- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`.

#### Improvements

Expand Down
173 changes: 173 additions & 0 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,17 @@ def get_data_type_string_object_mappings(
STRING_RE = re.compile(r"^\s*(varchar|string|text)\s*\(\s*(\d*)\s*\)\s*$")
# support type string format like " string ( 23 ) "

ARRAY_RE = re.compile(r"(?i)^\s*array\s*<")
# support type string format like starting with "array<..."

MAP_RE = re.compile(r"(?i)^\s*map\s*<")
# support type string format like starting with "map<..."

STRUCT_RE = re.compile(r"(?i)^\s*struct\s*<")
# support type string format like starting with "struct<..."

_NOT_NULL_PATTERN = re.compile(r"^(?P<base>.*?)\s+not\s+null\s*$", re.IGNORECASE)


def get_number_precision_scale(type_str: str) -> Optional[Tuple[int, int]]:
decimal_matches = DECIMAL_RE.match(type_str)
Expand All @@ -982,7 +993,169 @@ def get_string_length(type_str: str) -> Optional[int]:
return int(string_matches.group(2))


def extract_bracket_content(type_str: str, keyword: str) -> str:
"""
Given a string that starts with e.g. "array<", returns the content inside the top-level <...>.
e.g., "array<int>" => "int". It also parses the nested array like "array<array<...>>".
Raises ValueError on mismatched or missing bracket.
"""
type_str = type_str.strip()
prefix_pattern = rf"(?i)^\s*{keyword}\s*<"
match = re.match(prefix_pattern, type_str)
if not match:
raise ValueError(
f"'{type_str}' does not match expected '{keyword}<...>' syntax."
)

start_index = match.end() - 1 # position at '<'
bracket_depth = 0
inside_chars: List[str] = []
i = start_index
while i < len(type_str):
sfc-gh-jrose marked this conversation as resolved.
Show resolved Hide resolved
sfc-gh-jrose marked this conversation as resolved.
Show resolved Hide resolved
c = type_str[i]
if c == "<":
bracket_depth += 1
# we don't store the opening bracket in 'inside_chars'
# if bracket_depth was 0 -> 1, to skip the outer bracket
if bracket_depth > 1:
inside_chars.append(c)
sfc-gh-aling marked this conversation as resolved.
Show resolved Hide resolved
elif c == ">":
bracket_depth -= 1
if bracket_depth < 0:
raise ValueError(f"Mismatched '>' in '{type_str}'")
if bracket_depth == 0:
if i != len(type_str) - 1:
raise ValueError(
f"Unexpected characters after closing '>' in '{type_str}'"
)
# done
return "".join(inside_chars).strip()
sfc-gh-jrose marked this conversation as resolved.
Show resolved Hide resolved
inside_chars.append(c)
else:
inside_chars.append(c)
i += 1

raise ValueError(f"Missing closing '>' in '{type_str}'.")


def extract_nullable_keyword(type_str: str) -> Tuple[str, bool]:
"""
Checks if `type_str` ends with something like 'NOT NULL' (ignoring
case and allowing arbitrary space between NOT and NULL). If found,
return the type substring minus that part, along with nullable=False.
Otherwise, return (type_str, True).
"""
trimmed = type_str.strip()
match = _NOT_NULL_PATTERN.match(trimmed)
if match:
# Group 'base' is everything before 'not null'
base_type_str = match.group("base").strip()
return base_type_str, False

# By default, the field is nullable
return trimmed, True


def parse_struct_field_list(fields_str: str) -> StructType:
"""
Parse something like "a: int, b: string, c: array<int>"
into StructType([StructField('a', IntegerType()), ...]).
"""
fields = []
field_defs = split_top_level_comma_fields(fields_str)
for field_def in field_defs:
# Try splitting on colon first, else whitespace
if ":" in field_def:
left, right = field_def.split(":", 1)
sfc-gh-aling marked this conversation as resolved.
Show resolved Hide resolved
else:
parts = field_def.split(None, 1)
if len(parts) != 2:
raise ValueError(f"Cannot parse struct field definition: '{field_def}'")
left, right = parts[0], parts[1]

field_name = left.strip()
type_part = right.strip()
if not field_name:
raise ValueError(f"Struct field missing name in '{field_def}'")

# 1) Check for trailing "NOT NULL" => sets nullable=False
base_type_str, nullable = extract_nullable_keyword(type_part)
# 2) Parse the base type
field_type = type_string_to_type_object(base_type_str)
fields.append(StructField(field_name, field_type, nullable=nullable))

return StructType(fields)


def split_top_level_comma_fields(s: str) -> List[str]:
"""
Splits 's' by commas not enclosed in matching brackets.
Example: "int, array<long>, decimal(10,2)" => ["int", "array<long>", "decimal(10,2)"].
"""
parts = []
bracket_depth = 0
start_idx = 0
for i, c in enumerate(s):
sfc-gh-jrose marked this conversation as resolved.
Show resolved Hide resolved
if c in ["<", "("]:
bracket_depth += 1
elif c in [">", ")"]:
bracket_depth -= 1
if bracket_depth < 0:
raise ValueError(f"Mismatched bracket in '{s}'.")
Comment on lines +1098 to +1104
Copy link
Contributor

@sfc-gh-aling sfc-gh-aling Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this bracket check logic has repeated multiple times
do you think it's possible to check the bracket match as the initial step for only one time for the whole input string, and then in the downstream logic we can only focus on extracting the names and types

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to parse bracket to split fields, and extract names and types anyway. There is indeed a duplicate of validating whether the bracket expression is valid or not, maybe we can remove it. But to make the function self-contained, maybe let's still keep it? They are also covered in the test.

elif c == "," and bracket_depth == 0:
parts.append(s[start_idx:i].strip())
start_idx = i + 1
parts.append(s[start_idx:].strip())
return parts


def is_likely_struct(s: str) -> bool:
"""
Heuristic: If there's a top-level comma or colon outside brackets,
treat it like a struct with multiple fields, e.g. "a: int, b: string".
"""
bracket_depth = 0
for c in s:
if c in ["<", "("]:
bracket_depth += 1
elif c in [">", ")"]:
bracket_depth -= 1
elif (c in [":", ","]) and bracket_depth == 0:
return True
return False


def type_string_to_type_object(type_str: str) -> DataType:
type_str = type_str.strip()
if not type_str:
raise ValueError("Empty type string")

# First check if this might be a top-level multi-field struct
# (e.g. "a: int, b: string") even if not written as "struct<...>"
if is_likely_struct(type_str):
return parse_struct_field_list(type_str)

# Check for array<...>
if ARRAY_RE.match(type_str):
inner = extract_bracket_content(type_str, "array")
element_type = type_string_to_type_object(inner)
return ArrayType(element_type)

# Check for map<key, value>
if MAP_RE.match(type_str):
inner = extract_bracket_content(type_str, "map")
parts = split_top_level_comma_fields(inner)
if len(parts) != 2:
raise ValueError(f"Invalid map type definition: '{type_str}'")
key_type = type_string_to_type_object(parts[0])
val_type = type_string_to_type_object(parts[1])
return MapType(key_type, val_type)

# Check for explicit struct<...>
if STRUCT_RE.match(type_str):
inner = extract_bracket_content(type_str, "struct")
return parse_struct_field_list(inner)

precision_scale = get_number_precision_scale(type_str)
if precision_scale:
return DecimalType(*precision_scale)
Expand Down
26 changes: 22 additions & 4 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
infer_schema,
infer_type,
merge_type,
type_string_to_type_object,
)
from snowflake.snowpark._internal.udf_utils import generate_call_python_sp_sql
from snowflake.snowpark._internal.utils import (
Expand Down Expand Up @@ -3029,7 +3030,7 @@ def write_pandas(
def create_dataframe(
self,
data: Union[List, Tuple, "pandas.DataFrame"],
schema: Optional[Union[StructType, Iterable[str]]] = None,
schema: Optional[Union[StructType, Iterable[str], str]] = None,
_emit_ast: bool = True,
) -> DataFrame:
"""Creates a new DataFrame containing the specified values from the local data.
Expand All @@ -3046,9 +3047,15 @@ def create_dataframe(
``data`` will constitute a row in the DataFrame.
schema: A :class:`~snowflake.snowpark.types.StructType` containing names and
data types of columns, or a list of column names, or ``None``.
When ``schema`` is a list of column names or ``None``, the schema of the
DataFrame will be inferred from the data across all rows. To improve
performance, provide a schema. This avoids the need to infer data types

- When passing a **string**, it can be either an *explicit* struct
(e.g. ``"struct<a: int, b: string>"``) or an *implicit* struct
(e.g. ``"a: int, b: string"``). Internally, the string is parsed and
converted into a :class:`StructType` using Snowpark's type parsing.
Comment on lines +3051 to +3054
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean a valid struct must contain col_name: data_type pair?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

- When ``schema`` is a list of column names or ``None``, the schema of the
DataFrame will be inferred from the data across all rows.

To improve performance, provide a schema. This avoids the need to infer data types
with large data sets.

Examples::
Expand Down Expand Up @@ -3078,6 +3085,10 @@ def create_dataframe(
>>> session.create_dataframe(pd.DataFrame([(1, 2, 3, 4)], columns=["a", "b", "c", "d"])).collect()
[Row(a=1, b=2, c=3, d=4)]

>>> # create a dataframe using an implicit struct schema string
>>> session.create_dataframe([[10, 20], [30, 40]], schema="x: int, y: int").collect()
[Row(X=10, Y=20), Row(X=30, Y=40)]

Note:
When `data` is a pandas DataFrame, `snowflake.connector.pandas_tools.write_pandas` is called, which
requires permission to (1) CREATE STAGE (2) CREATE TABLE and (3) CREATE FILE FORMAT under the current
Expand Down Expand Up @@ -3156,6 +3167,13 @@ def create_dataframe(
# infer the schema based on the data
names = None
schema_query = None
if isinstance(schema, str):
schema = type_string_to_type_object(schema)
if not isinstance(schema, StructType):
raise ValueError(
f"Invalid schema string: {schema}. "
f"You should provide a valid schema string representing a struct type."
)
if isinstance(schema, StructType):
new_schema = schema
# SELECT query has an undefined behavior for nullability, so if the schema requires non-nullable column and
Expand Down
Loading
Loading