Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add INSERT support for OBJECT #320

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.util.compat import string_types

from snowflake.sqlalchemy.custom_types import OBJECT

from .custom_commands import AWSBucket, AzureContainer, ExternalStage

RESERVED_WORDS = frozenset(
Expand Down Expand Up @@ -150,6 +152,36 @@ def _split_schema_by_dot(self, schema):


class SnowflakeCompiler(compiler.SQLCompiler):
def visit_insert(self, stmt, **kw):
# https://github.com/sqlalchemy/sqlalchemy/discussions/7894#discussioncomment-2520337
insert_sql = super().visit_insert(stmt, **kw)

columns = self.column_keys
if columns is None:
columns = stmt.table.columns.keys()

# look in the columns being inserted, see if there's
# JSON being inserted. also can just look at the INSERT string
# and look for the json function

use_json = any(isinstance(stmt.table.c[key].type, OBJECT) for key in columns)

if not use_json:
return insert_sql

stmt_reg = re.match(
r"^INSERT INTO (.+?) \((.+?)\) VALUES \((.+)\)$", insert_sql
)
if not stmt_reg:
return insert_sql

# rewrite INSERT as per
# https://docs.snowflake.com/en/sql-reference/sql/insert.html#usage-notes
return (
f"INSERT INTO {stmt_reg.group(1)} "
f"({stmt_reg.group(2)}) SELECT {stmt_reg.group(3)}"
)

def visit_sequence(self, sequence, **kw):
return self.dialect.identifier_preparer.format_sequence(sequence) + ".nextval"

Expand Down