From 7490eb7497409328ab53469533f141b30a15174f Mon Sep 17 00:00:00 2001 From: Pat Buxton Date: Fri, 13 Dec 2024 15:37:36 +0000 Subject: [PATCH 1/2] Adds COPY INTO support --- databend_sqlalchemy/__init__.py | 23 ++ databend_sqlalchemy/databend_dialect.py | 126 +++++++- databend_sqlalchemy/dml.py | 386 ++++++++++++++++++++++++ tests/test_copy_into.py | 139 +++++++++ tests/test_merge.py | 2 +- 5 files changed, 674 insertions(+), 2 deletions(-) create mode 100644 tests/test_copy_into.py diff --git a/databend_sqlalchemy/__init__.py b/databend_sqlalchemy/__init__.py index b4787b3..208dacb 100644 --- a/databend_sqlalchemy/__init__.py +++ b/databend_sqlalchemy/__init__.py @@ -3,3 +3,26 @@ VERSION = (0, 4, 8) __version__ = ".".join(str(x) for x in VERSION) + + +from .dml import ( + Merge, + WhenMergeUnMatchedClause, + WhenMergeMatchedDeleteClause, + WhenMergeMatchedUpdateClause, + CopyIntoTable, + CopyIntoLocation, + CopyIntoTableOptions, + CopyIntoLocationOptions, + CSVFormat, + TSVFormat, + NDJSONFormat, + ParquetFormat, + ORCFormat, + AmazonS3, + AzureBlobStorage, + GoogleCloudStorage, + FileColumnClause, + StageClause, + Compression, +) \ No newline at end of file diff --git a/databend_sqlalchemy/databend_dialect.py b/databend_sqlalchemy/databend_dialect.py index 127ff14..61a3668 100644 --- a/databend_sqlalchemy/databend_dialect.py +++ b/databend_sqlalchemy/databend_dialect.py @@ -28,6 +28,8 @@ import re import operator import datetime +from types import NoneType + import sqlalchemy.types as sqltypes from typing import Any, Dict, Optional, Union from sqlalchemy import util as sa_util @@ -50,7 +52,11 @@ ) from sqlalchemy.engine import ExecutionContext, default from sqlalchemy.exc import DBAPIError, NoSuchTableError -from .dml import Merge + +from .dml import ( + Merge, StageClause, _StorageClause, GoogleCloudStorage, + AzureBlobStorage, AmazonS3 +) RESERVED_WORDS = { 'Error', 'EOI', 'Whitespace', 'Comment', 'CommentBlock', 'Ident', 'ColumnPosition', 'LiteralString', @@ -490,6 +496,124 @@ def visit_when_merge_unmatched(self, merge_unmatched, **kw): ", ".join(map(lambda e: e._compiler_dispatch(self, **kw), sets_vals)), ) + def visit_copy_into(self, copy_into, **kw): + target = ( + self.preparer.format_table(copy_into.target) + if isinstance(copy_into.target, (TableClause,)) + else copy_into.target._compiler_dispatch(self, **kw) + ) + + if isinstance(copy_into.from_, (TableClause,)): + source = self.preparer.format_table(copy_into.from_) + elif isinstance(copy_into.from_, (_StorageClause, StageClause)): + source = copy_into.from_._compiler_dispatch(self, **kw) + # elif isinstance(copy_into.from_, (FileColumnClause)): + # source = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + else: + source = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + + result = ( + f"COPY INTO {target}" + f" FROM {source}" + ) + if hasattr(copy_into, 'files') and isinstance(copy_into.files, list): + result += f"FILES = {', '.join([f for f in copy_into.files])}" + if hasattr(copy_into, 'pattern') and copy_into.pattern: + result += f" PATTERN = '{copy_into.pattern}'" + if not isinstance(copy_into.file_format, NoneType): + result += f" {copy_into.file_format._compiler_dispatch(self, **kw)}\n" + if not isinstance(copy_into.options, NoneType): + result += f" {copy_into.options._compiler_dispatch(self, **kw)}\n" + + return result + + def visit_copy_format(self, file_format, **kw): + options_list = list(file_format.options.items()) + if kw.get("deterministic", False): + options_list.sort(key=operator.itemgetter(0)) + # predefined format name + if "format_name" in file_format.options: + return f"FILE_FORMAT=(format_name = {file_format.options['format_name']})" + # format specifics + format_options = [f"TYPE = {file_format.format_type}"] + format_options.extend([ + "{} = {}".format( + option, + ( + value._compiler_dispatch(self, **kw) + if hasattr(value, "_compiler_dispatch") + else str(value) + ), + ) + for option, value in options_list + ]) + return f"FILE_FORMAT = ({', '.join(format_options)})" + + def visit_copy_into_options(self, copy_into_options, **kw): + options_list = list(copy_into_options.options.items()) + # if kw.get("deterministic", False): + # options_list.sort(key=operator.itemgetter(0)) + return "\n".join([ + f"{k} = {v}" + for k, v in options_list + ]) + + def visit_file_column(self, file_column_clause, **kw): + if isinstance(file_column_clause.from_, (TableClause,)): + source = self.preparer.format_table(file_column_clause.from_) + elif isinstance(file_column_clause.from_, (_StorageClause, StageClause)): + source = file_column_clause.from_._compiler_dispatch(self, **kw) + else: + source = f"({file_column_clause.from_._compiler_dispatch(self, **kw)})" + if isinstance(file_column_clause.columns, str): + select_str = file_column_clause.columns + else: + select_str = ",".join([col._compiler_dispatch(self, **kw) for col in file_column_clause.columns]) + return ( + f"SELECT {select_str}" + f" FROM {source}" + ) + + def visit_amazon_s3(self, amazon_s3: AmazonS3, **kw): + connection_params_str = f" ACCESS_KEY_ID = '{amazon_s3.access_key_id}' \n" + connection_params_str += f" SECRET_ACCESS_KEY = '{amazon_s3.secret_access_key}'\n" + if amazon_s3.endpoint_url: + connection_params_str += f" ENDPOINT_URL = '{amazon_s3.endpoint_url}' \n" + if amazon_s3.enable_virtual_host_style: + connection_params_str += f" ENABLE_VIRTUAL_HOST_STYLE = '{amazon_s3.enable_virtual_host_style}'\n" + if amazon_s3.master_key: + connection_params_str += f" MASTER_KEY = '{amazon_s3.master_key}'\n" + if amazon_s3.region: + connection_params_str += f" REGION = '{amazon_s3.region}'\n" + if amazon_s3.security_token: + connection_params_str += f" SECURITY_TOKEN = '{amazon_s3.security_token}'\n" + + return ( + f"'{amazon_s3.uri}' \n" + f"CONNECTION = (\n" + f"{connection_params_str}\n" + f")" + ) + + def visit_azure_blob_storage(self, azure: AzureBlobStorage, **kw): + return ( + f"'{azure.uri}' \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = 'https://{azure.account_name}.blob.core.windows.net' \n" + f" ACCOUNT_NAME = '{azure.account_name}' \n" + f" ACCOUNT_KEY = '{azure.account_key}'\n" + f")" + ) + + def visit_google_cloud_storage(self, gcs: GoogleCloudStorage, **kw): + return ( + f"'{gcs.uri}' \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = 'https://storage.googleapis.com' \n" + f" CREDENTIAL = '{gcs.credentials}' \n" + f")" + ) + class DatabendExecutionContext(default.DefaultExecutionContext): @sa_util.memoized_property diff --git a/databend_sqlalchemy/dml.py b/databend_sqlalchemy/dml.py index ab71da7..108d685 100644 --- a/databend_sqlalchemy/dml.py +++ b/databend_sqlalchemy/dml.py @@ -2,11 +2,15 @@ # # Note: parts of the file come from https://github.com/snowflakedb/snowflake-sqlalchemy # licensed under the same Apache 2.0 License +from enum import Enum +from types import NoneType +from urllib.parse import urlparse from sqlalchemy.sql.selectable import Select, Subquery, TableClause from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.expression import select +from sqlalchemy.sql.roles import FromClauseRole class _OnMergeBaseClause(ClauseElement): @@ -98,3 +102,385 @@ def when_not_matched_then_insert(self): clause = WhenMergeUnMatchedClause() self.clauses.append(clause) return clause + + +class _CopyIntoBase(UpdateBase): + __visit_name__ = "copy_into" + _bind = None + + def __init__(self, target: ['TableClause', 'StageClause', '_StorageClause'], from_, file_format: 'CopyFormat' = None, + options: ['CopyIntoLocationOptions', 'CopyIntoTableOptions'] = None): + self.target = target + self.from_ = from_ + self.file_format = file_format + self.options = options + + def __repr__(self): + """ + repr for debugging / logging purposes only. For compilation logic, see + the corresponding visitor in base.py + """ + val = f"COPY INTO {self.target} FROM {repr(self.from_)}" + return val + f" {repr(self.file_format)} ({self.options})" + + def bind(self): + return None + + +class CopyIntoLocation(_CopyIntoBase): + inherit_cache = False + def __init__(self, *, target: ['StageClause', '_StorageClause'], from_, file_format: 'CopyFormat' = None, options: 'CopyIntoLocationOptions' = None): + super().__init__(target, from_, file_format, options) + + +class CopyIntoTable(_CopyIntoBase): + inherit_cache = False + def __init__(self, *, target: [TableClause], from_: ['StageClause', '_StorageClause', 'FileColumnClause'], + files: list = None, pattern: str = None, file_format: 'CopyFormat' = None, options: 'CopyIntoTableOptions' = None): + super().__init__(target, from_, file_format, options) + self.files = files + self.pattern = pattern + + +class _CopyIntoOptions(ClauseElement): + __visit_name__ = "copy_into_options" + + def __init__(self): + self.options = dict() + + def __repr__(self): + return "\n".join([ + f"{k} = {v}" + for k, v in self.options.items() + ]) + +class CopyIntoLocationOptions(_CopyIntoOptions): + #__visit_name__ = "copy_into_location_options" + + def __init__(self, *, single: bool = None, max_file_size_bytes: int = None, overwrite: bool = None, + include_query_id: bool = None, use_raw_path: bool = None): + super().__init__() + if not isinstance(single, NoneType): + self.options['SINGLE'] = "TRUE" if single else "FALSE" + if not isinstance(max_file_size_bytes, NoneType): + self.options["MAX_FILE_SIZE "] = max_file_size_bytes + if not isinstance(overwrite, NoneType): + self.options["OVERWRITE"] = "TRUE" if overwrite else "FALSE" + if not isinstance(include_query_id, NoneType): + self.options["INCLUDE_QUERY_ID"] = "TRUE" if include_query_id else "FALSE" + if not isinstance(use_raw_path, NoneType): + self.options["USE_RAW_PATH"] = "TRUE" if use_raw_path else "FALSE" + + +class CopyIntoTableOptions(_CopyIntoOptions): + #__visit_name__ = "copy_into_table_options" + + def __init__(self, *, size_limit: int = None, purge: bool = None, force: bool = None, + disable_variant_check: bool = None, on_error: str = None, max_files: int = None, + return_failed_only: bool = None, column_match_mode: str = None): + super().__init__() + if not isinstance(size_limit, NoneType): + self.options['SIZE_LIMIT'] = size_limit + if not isinstance(purge, NoneType): + self.options["PURGE "] = "TRUE" if purge else "FALSE" + if not isinstance(force, NoneType): + self.options["FORCE"] = "TRUE" if force else "FALSE" + if not isinstance(disable_variant_check, NoneType): + self.options["DISABLE_VARIANT_CHECK"] = "TRUE" if disable_variant_check else "FALSE" + if not isinstance(on_error, NoneType): + self.options["ON_ERROR"] = on_error + if not isinstance(max_files, NoneType): + self.options["MAX_FILES"] = max_files + if not isinstance(return_failed_only, NoneType): + self.options["RETURN_FAILED_ONLY"] = return_failed_only + if not isinstance(column_match_mode, NoneType): + self.options["COLUMN_MATCH_MODE"] = column_match_mode + + + +class Compression(Enum): + NONE = "NONE" + AUTO = "AUTO" + GZIP = "GZIP" + BZ2 = "BZ2" + BROTLI = "BROTLI" + ZSTD = "ZSTD" + DEFLATE = "DEFLATE" + RAW_DEFLATE = "RAW_DEFLATE" + XZ = "XZ" + + +class CopyFormat(ClauseElement): + """ + Base class for Format specifications inside a COPY INTO statement. May also + be used to create a named format. + """ + + __visit_name__ = "copy_format" + + def __init__(self, format_name=None): + self.options = dict() + if format_name: + self.options["format_name"] = format_name + + def __repr__(self): + """ + repr for debugging / logging purposes only. For compilation logic, see + the respective visitor in the dialect + """ + return f"FILE_FORMAT=({self.options})" + + +class CSVFormat(CopyFormat): + format_type = "CSV" + + def __init__(self, *, + record_delimiter: str = None, + field_delimiter: str = None, + quote: str = None, + escape: str = None, + skip_header: int = None, + nan_display: str = None, + null_display: str = None, + error_on_column_mismatch: bool = None, + empty_field_as: str = None, + output_header: bool = None, + binary_format: str = None, + compression: Compression = None, + ): + super().__init__() + if record_delimiter: + if len(str(record_delimiter).encode().decode('unicode_escape')) != 1 and record_delimiter != '\r\n': + raise TypeError( + 'Record Delimiter should be a single character.' + ) + self.options['RECORD_DELIMITER'] = f"{repr(record_delimiter)}" + if field_delimiter: + if len(str(field_delimiter).encode().decode('unicode_escape')) != 1: + raise TypeError( + 'Field Delimiter should be a single character' + ) + self.options["FIELD_DELIMITER"] = f"{repr(field_delimiter)}" + if quote: + if quote not in ['\'', '"', '`']: + raise TypeError('Quote character must be one of [\', ", `].') + self.options["QUOTE"] = f"{repr(quote)}" + if escape: + if escape not in ['\\', '']: + raise TypeError('Escape character must be "\\" or "".') + self.options["ESCAPE"] = f"{repr(escape)}" + if skip_header: + if skip_header < 0: + raise TypeError('Skip header must be positive integer.') + self.options["SKIP_HEADER"] = skip_header + if nan_display: + if nan_display not in ['NULL', 'NaN']: + raise TypeError('NaN Display should be "NULL" or "NaN".') + self.options["NAN_DISPLAY"] = f"'{nan_display}'" + if null_display: + self.options["NULL_DISPLAY"] = f"'{null_display}'" + if error_on_column_mismatch: + self.options["ERROR_ON_COLUMN_MISMATCH"] = str(error_on_column_mismatch).upper() + if empty_field_as: + if empty_field_as not in ['NULL', 'STRING', 'FIELD_DEFAULT']: + raise TypeError('Empty Field As should be "NULL", "STRING" for "FIELD_DEFAULT".') + self.options["EMPTY_FIELD_AS"] = f"{empty_field_as}" + if output_header: + self.options["OUTPUT_HEADER"] = str(output_header).upper() + if binary_format: + if binary_format not in ['HEX', 'BASE64']: + raise TypeError('Binary Format should be "HEX" or "BASE64".') + self.options["BINARY_FORMAT"] = binary_format + if compression: + self.options["COMPRESSION"] = compression.value + + +class TSVFormat(CopyFormat): + format_type = "TSV" + + def __init__(self, *, + record_delimiter: str = None, + field_delimiter: str = None, + compression: Compression = None, + ): + super().__init__() + if record_delimiter: + if len(str(record_delimiter).encode().decode('unicode_escape')) != 1 and record_delimiter != '\r\n': + raise TypeError( + 'Record Delimiter should be a single character.' + ) + self.options['RECORD_DELIMITER'] = f"{repr(record_delimiter)}" + if field_delimiter: + if len(str(field_delimiter).encode().decode('unicode_escape')) != 1: + raise TypeError( + 'Field Delimiter should be a single character' + ) + self.options["FIELD_DELIMITER"] = f"{repr(field_delimiter)}" + if compression: + self.options["COMPRESSION"] = compression.value + + +class NDJSONFormat(CopyFormat): + format_type = "NDJSON" + + def __init__(self, *, + null_field_as: str = None, + missing_field_as: str = None, + compression: Compression = None, + ): + super().__init__() + if null_field_as: + if null_field_as not in ['NULL', 'FIELD_DEFAULT']: + raise TypeError('Null Field As should be "NULL" or "FIELD_DEFAULT".') + self.options["NULL_FIELD_AS"] = f"{null_field_as}" + if missing_field_as: + if missing_field_as not in ['ERROR', 'NULL', 'FIELD_DEFAULT', 'TYPE_DEFAULT']: + raise TypeError('Missing Field As should be "ERROR", "NULL", "FIELD_DEFAULT" or "TYPE_DEFAULT".') + self.options["MISSING_FIELD_AS"] = f"{missing_field_as}" + if compression: + self.options["COMPRESSION"] = compression.value + + +class ParquetFormat(CopyFormat): + format_type = "PARQUET" + + def __init__(self, *, + missing_field_as: str = None, + ): + super().__init__() + if missing_field_as: + if missing_field_as not in ['ERROR', 'FIELD_DEFAULT']: + raise TypeError('Missing Field As should be "ERROR" or "FIELD_DEFAULT".') + self.options["MISSING_FIELD_AS"] = f"{missing_field_as}" + + +class ORCFormat(CopyFormat): + format_type = "ORC" + +class StageClause(ClauseElement, FromClauseRole): + """Stage Clause""" + + __visit_name__ = "stage" + _hide_froms = () + + def __init__(self, *, name, path=None): + self.name = name + self.path = path + + def __repr__(self): + return f"@{self.name}/{self.path}" + + +class FileColumnClause(ClauseElement, FromClauseRole): + """Clause for selecting file columns from a Stage/Location""" + __visit_name__ = "file_column" + + def __init__(self, *, columns, from_: ['StageClause', '_StorageClause']): + # columns need to be expressions of column index, e.g. $1, IF($1 =='t', True, False), or string of these expressions that we just use + self.columns = columns + self.from_ = from_ + + def __repr__(self): + return ( + f"SELECT {self.columns if isinstance(self.columns, str) else ','.join(repr(col) for col in self.columns)}" + f" FROM {repr(self.from_)}" + ) + + +class _StorageClause(ClauseElement): + pass + + +class AmazonS3(_StorageClause): + """Amazon S3""" + + __visit_name__ = "amazon_s3" + + def __init__(self, uri: str, access_key_id: str, secret_access_key: str, endpoint_url: str = None, + enable_virtual_host_style: bool = None, master_key: str = None, + region: str = None, security_token: str = None): + r = urlparse(uri) + if r.scheme != 's3': + raise ValueError(f'Invalid S3 URI: {uri}') + + self.uri = uri + self.access_key_id = access_key_id + self.secret_access_key = secret_access_key + self.bucket = r.netloc + self.path = r.path + if endpoint_url: + self.endpoint_url = endpoint_url + if enable_virtual_host_style: + self.enable_virtual_host_style = enable_virtual_host_style + if master_key: + self.master_key = master_key + if region: + self.region = region + if security_token: + self.security_token = security_token + + def __repr__(self): + return ( + f"'{self.uri}' \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = '{self.endpoint_url}' \n" if self.endpoint_url else "" + f" ACCESS_KEY_ID = '{self.access_key_id}' \n" + f" SECRET_ACCESS_KEY = '{self.secret_access_key}'\n" + f" ENABLE_VIRTUAL_HOST_STYLE = '{self.enable_virtual_host_style}'\n" if self.enable_virtual_host_style else "" + f" MASTER_KEY = '{self.master_key}'\n" if self.master_key else "" + f" REGION = '{self.region}'\n" if self.region else "" + f" SECURITY_TOKEN = '{self.security_token}'\n" if self.security_token else "" + f")" + ) + + +class AzureBlobStorage(_StorageClause): + """Microsoft Azure Blob Storage""" + + __visit_name__ = "azure_blob_storage" + + def __init__(self, *, uri: str, account_name: str, account_key: str): + r = urlparse(uri) + if r.scheme != 'azblob': + raise ValueError(f'Invalid Azure URI: {uri}') + + self.uri = uri + self.account_name = account_name + self.account_key = account_key + self.container = r.netloc + self.path = r.path + + def __repr__(self): + return ( + f"'{self.uri}' \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = 'https://{self.account_name}.blob.core.windows.net' \n" + f" ACCOUNT_NAME = '{self.account_name}' \n" + f" ACCOUNT_KEY = '{self.account_key}'\n" + f")" + ) + + +class GoogleCloudStorage(_StorageClause): + """Google Cloud Storage""" + + __visit_name__ = "google_cloud_storage" + + def __init__(self, *, uri, credentials): + r = urlparse(uri) + if r.scheme != 'gcs': + raise ValueError(f'Invalid Google Cloud Storage URI: {uri}') + + self.uri = uri + self.credentials = credentials + + + def __repr__(self): + return ( + f"'{self.uri}' \n" + f"CONNECTION = (\n" + f" ENDPOINT_URL = 'https://storage.googleapis.com' \n" + f" CREDENTIAL = '{self.credentials}' \n" + f")" + ) + diff --git a/tests/test_copy_into.py b/tests/test_copy_into.py new file mode 100644 index 0000000..f6e7eec --- /dev/null +++ b/tests/test_copy_into.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python + +from sqlalchemy.testing import config, fixture, fixtures, util +from sqlalchemy.testing.assertions import AssertsCompiledSQL +from sqlalchemy import Table, Column, Integer, String, func, MetaData, schema, cast, literal_column + +from databend_sqlalchemy import ( + CopyIntoTable, CopyIntoLocation, CopyIntoTableOptions, CSVFormat, ParquetFormat, + GoogleCloudStorage, Compression, FileColumnClause +) + +class CompileDatabendCopyIntoTableTest(fixtures.TestBase, AssertsCompiledSQL): + + __only_on__ = "databend" + + def test_copy_into_table(self): + m = MetaData() + tbl = Table( + 'atable', m, Column("id", Integer), + schema="test_schema", + ) + + copy_into = CopyIntoTable( + target=tbl, + from_=GoogleCloudStorage( + uri='gcs://some-bucket/a/path/to/files', + credentials='XYZ', + ), + #files='', + #pattern='', + file_format=CSVFormat( + record_delimiter='\n', + field_delimiter=',', + quote='"', + #escape='\\', + #skip_header=1, + #nan_display='' + #null_display='', + error_on_column_mismatch=False, + #empty_field_as='STRING', + output_header=True, + #binary_format='', + compression=Compression.GZIP + ), + options=CopyIntoTableOptions( + size_limit=None, + purge=None, + force=None, + disable_variant_check=None, + on_error=None, + max_files=None, + return_failed_only=None, + column_match_mode=None, + ) + ) + + + self.assert_compile( + copy_into, + ("COPY INTO test_schema.atable" + " FROM 'gcs://some-bucket/a/path/to/files' " + "CONNECTION = (" + " ENDPOINT_URL = 'https://storage.googleapis.com' " + " CREDENTIAL = 'XYZ' " + ")" + " FILE_FORMAT = (TYPE = CSV, " + "RECORD_DELIMITER = '\\n', FIELD_DELIMITER = ',', QUOTE = '\"', OUTPUT_HEADER = TRUE, COMPRESSION = GZIP) " + ) + ) + + def test_copy_into_table_sub_select_string_columns(self): + m = MetaData() + tbl = Table( + 'atable', m, Column("id", Integer), + schema="test_schema", + ) + + copy_into = CopyIntoTable( + target=tbl, + from_=FileColumnClause( + columns='$1, $2, $3', + from_=GoogleCloudStorage( + uri='gcs://some-bucket/a/path/to/files', + credentials='XYZ', + ) + ), + file_format=CSVFormat(), + ) + + self.assert_compile( + copy_into, + ("COPY INTO test_schema.atable" + " FROM (SELECT $1, $2, $3" + " FROM 'gcs://some-bucket/a/path/to/files' " + "CONNECTION = (" + " ENDPOINT_URL = 'https://storage.googleapis.com' " + " CREDENTIAL = 'XYZ' " + ")" + ") FILE_FORMAT = (TYPE = CSV)" + ) + ) + + def test_copy_into_table_sub_select_column_clauses(self): + m = MetaData() + tbl = Table( + 'atable', m, Column("id", Integer), + schema="test_schema", + ) + + copy_into = CopyIntoTable( + target=tbl, + from_=FileColumnClause( + columns=[func.IF(literal_column("$1") == 'xyz', 'NULL', 'NOTNULL')], + # columns='$1, $2, $3', + from_=GoogleCloudStorage( + uri='gcs://some-bucket/a/path/to/files', + credentials='XYZ', + ) + ), + file_format=CSVFormat(), + ) + + self.assert_compile( + copy_into, + ("COPY INTO test_schema.atable" + " FROM (SELECT IF($1 = %(1_1)s, %(IF_1)s, %(IF_2)s)" + " FROM 'gcs://some-bucket/a/path/to/files' " + "CONNECTION = (" + " ENDPOINT_URL = 'https://storage.googleapis.com' " + " CREDENTIAL = 'XYZ' " + ")" + ") FILE_FORMAT = (TYPE = CSV)" + ), + checkparams={ + "1_1": "xyz", + "IF_1": "NULL", + "IF_2": "NOTNULL" + }, + ) diff --git a/tests/test_merge.py b/tests/test_merge.py index fcdedc9..074d4a2 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -15,7 +15,7 @@ from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import eq_ -from databend_sqlalchemy.databend_dialect import Merge +from databend_sqlalchemy import Merge class MergeIntoTest(fixtures.TablesTest): From be80ba9ae385e40c52d8fd6f0c5615eb7d2cb016 Mon Sep 17 00:00:00 2001 From: Pat Buxton Date: Thu, 2 Jan 2025 10:05:35 +0000 Subject: [PATCH 2/2] Add copy_into usage into ReadMe --- README.rst | 106 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/README.rst b/README.rst index e103f1c..cb54a3d 100644 --- a/README.rst +++ b/README.rst @@ -62,6 +62,112 @@ The Merge command can be used as below:: connection.execute(merge) +Copy Into Command Support +--------------------- + +Databend SQLAlchemy supports copy into operations through it's CopyIntoTable and CopyIntoLocation methods +See [CopyIntoLocation](https://docs.databend.com/sql/sql-commands/dml/dml-copy-into-location) or [CopyIntoTable](https://docs.databend.com/sql/sql-commands/dml/dml-copy-into-table) for full documentation. + +The CopyIntoTable command can be used as below:: + + from sqlalchemy.orm import sessionmaker + from sqlalchemy import MetaData, create_engine + from databend_sqlalchemy import ( + CopyIntoTable, GoogleCloudStorage, ParquetFormat, CopyIntoTableOptions, + FileColumnClause, CSVFormat, + ) + + engine = create_engine(db.url, echo=False) + session = sessionmaker(bind=engine)() + connection = engine.connect() + + meta = MetaData() + meta.reflect(bind=session.bind) + t1 = meta.tables['t1'] + t2 = meta.tables['t2'] + gcs_private_key = 'full_gcs_json_private_key' + case_sensitive_columns = True + + copy_into = CopyIntoTable( + target=t1, + from_=GoogleCloudStorage( + uri='gcs://bucket-name/path/to/file', + credentials=base64.b64encode(gcs_private_key.encode()).decode(), + ), + file_format=ParquetFormat(), + options=CopyIntoTableOptions( + force=True, + column_match_mode='CASE_SENSITIVE' if case_sensitive_columns else None, + ) + ) + result = connection.execute(copy_into) + result.fetchall() # always call fetchall() to ensure the cursor executes to completion + + # More involved example with column selection clause that can be altered to perform operations on the columns during import. + + copy_into = CopyIntoTable( + target=t2, + from_=FileColumnClause( + columns=', '.join([ + f'${index + 1}' + for index, column in enumerate(t2.columns) + ]), + from_=GoogleCloudStorage( + uri='gcs://bucket-name/path/to/file', + credentials=base64.b64encode(gcs_private_key.encode()).decode(), + ) + ), + pattern='*.*', + file_format=CSVFormat( + record_delimiter='\n', + field_delimiter=',', + quote='"', + escape='', + skip_header=1, + empty_field_as='NULL', + compression=Compression.AUTO, + ), + options=CopyIntoTableOptions( + force=True, + ) + ) + result = connection.execute(copy_into) + result.fetchall() # always call fetchall() to ensure the cursor executes to completion + +The CopyIntoLocation command can be used as below:: + + from sqlalchemy.orm import sessionmaker + from sqlalchemy import MetaData, create_engine + from databend_sqlalchemy import ( + CopyIntoLocation, GoogleCloudStorage, ParquetFormat, CopyIntoLocationOptions, + ) + + engine = create_engine(db.url, echo=False) + session = sessionmaker(bind=engine)() + connection = engine.connect() + + meta = MetaData() + meta.reflect(bind=session.bind) + t1 = meta.tables['t1'] + gcs_private_key = 'full_gcs_json_private_key' + + copy_into = CopyIntoLocation( + target=GoogleCloudStorage( + uri='gcs://bucket-name/path/to/target_file', + credentials=base64.b64encode(gcs_private_key.encode()).decode(), + ), + from_=select(t1).where(t1.c['col1'] == 1), + file_format=ParquetFormat(), + options=CopyIntoLocationOptions( + single=True, + overwrite=True, + include_query_id=False, + use_raw_path=True, + ) + ) + result = connection.execute(copy_into) + result.fetchall() # always call fetchall() to ensure the cursor executes to completion + Table Options ---------------------