diff --git a/pyathena/sqlalchemy/base.py b/pyathena/sqlalchemy/base.py index fe0fef6f..2a654dbb 100644 --- a/pyathena/sqlalchemy/base.py +++ b/pyathena/sqlalchemy/base.py @@ -34,8 +34,9 @@ import pyathena from pyathena.model import AthenaFileFormat, AthenaRowFormatSerde -from pyathena.sqlalchemy.types import DOUBLE, STRUCT, AthenaDate, AthenaTimestamp -from pyathena.sqlalchemy.util import _HashableDict + +from . import types as athena_types +from .util import _HashableDict if TYPE_CHECKING: from types import ModuleType @@ -352,7 +353,7 @@ ischema_names: Dict[str, Type[Any]] = { "boolean": types.BOOLEAN, "float": types.FLOAT, - "double": DOUBLE, + "double": athena_types.DOUBLE, "real": types.REAL, "tinyint": types.INTEGER, "smallint": types.INTEGER, @@ -369,8 +370,8 @@ "varbinary": types.BINARY, "array": types.ARRAY, "map": types.String, - "struct": STRUCT, - "row": STRUCT, + "struct": athena_types.STRUCT, + "row": athena_types.STRUCT, "json": types.String, } @@ -928,9 +929,9 @@ class AthenaDialect(DefaultDialect): ] colspecs = { - types.DATE: AthenaDate, - types.DATETIME: AthenaTimestamp, - types.TIMESTAMP: AthenaTimestamp, + types.DATE: athena_types.AthenaDate, + types.DATETIME: athena_types.AthenaTimestamp, + types.TIMESTAMP: athena_types.AthenaTimestamp, } ischema_names: Dict[str, Type[Any]] = ischema_names @@ -1132,7 +1133,7 @@ def _get_column_type(self, type_: str): args = [int(column_type_args)] elif col_type is types.ARRAY: args = [self._get_column_type(column_type_args.strip())] - elif col_type is STRUCT: + elif col_type is athena_types.STRUCT: args = self._parse_struct(column_type_args) return col_type(*args) diff --git a/pyathena/sqlalchemy/types.py b/pyathena/sqlalchemy/types.py index 100e15b6..37379d5a 100644 --- a/pyathena/sqlalchemy/types.py +++ b/pyathena/sqlalchemy/types.py @@ -10,8 +10,8 @@ from sqlalchemy.sql.sqltypes import Float, Indexable from sqlalchemy.sql.type_api import TypeEngine, UserDefinedType -from pyathena.sqlalchemy import base -from pyathena.sqlalchemy.constants import sqlalchemy_1_4_or_more +from . import base +from .constants import sqlalchemy_1_4_or_more if TYPE_CHECKING: from sqlalchemy import Dialect