diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 4d3386ba..4921df02 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -12,6 +12,8 @@ from sqlalchemy.sql.elements import quoted_name from sqlalchemy.util.compat import string_types +from snowflake.sqlalchemy.custom_types import OBJECT + from .custom_commands import AWSBucket, AzureContainer, ExternalStage RESERVED_WORDS = frozenset( @@ -150,6 +152,36 @@ def _split_schema_by_dot(self, schema): class SnowflakeCompiler(compiler.SQLCompiler): + def visit_insert(self, stmt, **kw): + # https://github.com/sqlalchemy/sqlalchemy/discussions/7894#discussioncomment-2520337 + insert_sql = super().visit_insert(stmt, **kw) + + columns = self.column_keys + if columns is None: + columns = stmt.table.columns.keys() + + # look in the columns being inserted, see if there's + # JSON being inserted. also can just look at the INSERT string + # and look for the json function + + use_json = any(isinstance(stmt.table.c[key].type, OBJECT) for key in columns) + + if not use_json: + return insert_sql + + stmt_reg = re.match( + r"^INSERT INTO (.+?) \((.+?)\) VALUES \((.+)\)$", insert_sql + ) + if not stmt_reg: + return insert_sql + + # rewrite INSERT as per + # https://docs.snowflake.com/en/sql-reference/sql/insert.html#usage-notes + return ( + f"INSERT INTO {stmt_reg.group(1)} " + f"({stmt_reg.group(2)}) SELECT {stmt_reg.group(3)}" + ) + def visit_sequence(self, sequence, **kw): return self.dialect.identifier_preparer.format_sequence(sequence) + ".nextval"