diff --git a/src/sqlalchemy_cratedb/__init__.py b/src/sqlalchemy_cratedb/__init__.py index 36198beb..968a23c9 100644 --- a/src/sqlalchemy_cratedb/__init__.py +++ b/src/sqlalchemy_cratedb/__init__.py @@ -23,8 +23,9 @@ from .dialect import CrateDialect from .sa_version import SA_1_4, SA_2_0, SA_VERSION from .support import insert_bulk -from .types import Geopoint, Geoshape, ObjectArray, ObjectType - +from .type.array import ObjectArray +from .type.geo import Geopoint, Geoshape +from .type.object import ObjectType if SA_VERSION < SA_1_4: import textwrap diff --git a/src/sqlalchemy_cratedb/compiler.py b/src/sqlalchemy_cratedb/compiler.py index 767ad638..07106b87 100644 --- a/src/sqlalchemy_cratedb/compiler.py +++ b/src/sqlalchemy_cratedb/compiler.py @@ -27,7 +27,8 @@ from sqlalchemy.dialects.postgresql.base import PGCompiler from sqlalchemy.sql import compiler from sqlalchemy.types import String -from .types import MutableDict, ObjectTypeImpl, Geopoint, Geoshape +from .type.geo import Geopoint, Geoshape +from .type.object import MutableDict, ObjectTypeImpl from .sa_version import SA_VERSION, SA_1_4 diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 3f1197df..aebad9c2 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -33,7 +33,7 @@ ) from crate.client.exceptions import TimezoneUnawareException from .sa_version import SA_VERSION, SA_1_4, SA_2_0 -from .types import ObjectType, ObjectArray +from .type import ObjectArray, ObjectType TYPES_MAP = { "boolean": sqltypes.Boolean, diff --git a/src/sqlalchemy_cratedb/type/__init__.py b/src/sqlalchemy_cratedb/type/__init__.py new file mode 100644 index 00000000..8e78f7da --- /dev/null +++ b/src/sqlalchemy_cratedb/type/__init__.py @@ -0,0 +1,3 @@ +from .array import ObjectArray +from .geo import Geopoint, Geoshape +from .object import ObjectType diff --git a/src/sqlalchemy_cratedb/types.py b/src/sqlalchemy_cratedb/type/array.py similarity index 51% rename from src/sqlalchemy_cratedb/types.py rename to src/sqlalchemy_cratedb/type/array.py index f9899d92..ae68d4b4 100644 --- a/src/sqlalchemy_cratedb/types.py +++ b/src/sqlalchemy_cratedb/type/array.py @@ -18,15 +18,12 @@ # However, if you have executed another commercial license agreement # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. -import warnings import sqlalchemy.types as sqltypes from sqlalchemy.sql import operators, expression from sqlalchemy.sql import default_comparator from sqlalchemy.ext.mutable import Mutable -import geojson - class MutableList(Mutable, list): @@ -74,91 +71,6 @@ def remove(self, item): self.changed() -class MutableDict(Mutable, dict): - - @classmethod - def coerce(cls, key, value): - "Convert plain dictionaries to MutableDict." - - if not isinstance(value, MutableDict): - if isinstance(value, dict): - return MutableDict(value) - - # this call will raise ValueError - return Mutable.coerce(key, value) - else: - return value - - def __init__(self, initval=None, to_update=None, root_change_key=None): - initval = initval or {} - self._changed_keys = set() - self._deleted_keys = set() - self._overwrite_key = root_change_key - self.to_update = self if to_update is None else to_update - for k in initval: - initval[k] = self._convert_dict(initval[k], - overwrite_key=k if self._overwrite_key is None else self._overwrite_key - ) - dict.__init__(self, initval) - - def __setitem__(self, key, value): - value = self._convert_dict(value, key if self._overwrite_key is None else self._overwrite_key) - dict.__setitem__(self, key, value) - self.to_update.on_key_changed( - key if self._overwrite_key is None else self._overwrite_key - ) - - def __delitem__(self, key): - dict.__delitem__(self, key) - # add the key to the deleted keys if this is the root object - # otherwise update on root object - if self._overwrite_key is None: - self._deleted_keys.add(key) - self.changed() - else: - self.to_update.on_key_changed(self._overwrite_key) - - def on_key_changed(self, key): - self._deleted_keys.discard(key) - self._changed_keys.add(key) - self.changed() - - def _convert_dict(self, value, overwrite_key): - if isinstance(value, dict) and not isinstance(value, MutableDict): - return MutableDict(value, self.to_update, overwrite_key) - return value - - def __eq__(self, other): - return dict.__eq__(self, other) - - -class ObjectTypeImpl(sqltypes.UserDefinedType, sqltypes.JSON): - - __visit_name__ = "OBJECT" - - cache_ok = False - none_as_null = False - - -# Designated name to refer to. `Object` is too ambiguous. -ObjectType = MutableDict.as_mutable(ObjectTypeImpl) - -# Backward-compatibility aliases. -_deprecated_Craty = ObjectType -_deprecated_Object = ObjectType - -# https://www.lesinskis.com/deprecating-module-scope-variables.html -deprecated_names = ["Craty", "Object"] - - -def __getattr__(name): - if name in deprecated_names: - warnings.warn(f"{name} is deprecated and will be removed in future releases. " - f"Please use ObjectType instead.", DeprecationWarning) - return globals()[f"_deprecated_{name}"] - raise AttributeError(f"module {__name__} has no attribute {name}") - - class Any(expression.ColumnElement): """Represent the clause ``left operator ANY (right)``. ``right`` must be an array expression. @@ -230,48 +142,3 @@ def get_col_spec(self, **kws): ObjectArray = MutableList.as_mutable(_ObjectArray) - - -class Geopoint(sqltypes.UserDefinedType): - cache_ok = True - - class Comparator(sqltypes.TypeEngine.Comparator): - - def __getitem__(self, key): - return default_comparator._binary_operate(self.expr, - operators.getitem, - key) - - def get_col_spec(self): - return 'GEO_POINT' - - def bind_processor(self, dialect): - def process(value): - if isinstance(value, geojson.Point): - return value.coordinates - return value - return process - - def result_processor(self, dialect, coltype): - return tuple - - comparator_factory = Comparator - - -class Geoshape(sqltypes.UserDefinedType): - cache_ok = True - - class Comparator(sqltypes.TypeEngine.Comparator): - - def __getitem__(self, key): - return default_comparator._binary_operate(self.expr, - operators.getitem, - key) - - def get_col_spec(self): - return 'GEO_SHAPE' - - def result_processor(self, dialect, coltype): - return geojson.GeoJSON.to_instance - - comparator_factory = Comparator diff --git a/src/sqlalchemy_cratedb/type/geo.py b/src/sqlalchemy_cratedb/type/geo.py new file mode 100644 index 00000000..31abd279 --- /dev/null +++ b/src/sqlalchemy_cratedb/type/geo.py @@ -0,0 +1,48 @@ +import geojson +from sqlalchemy import types as sqltypes +from sqlalchemy.sql import default_comparator, operators + + +class Geopoint(sqltypes.UserDefinedType): + cache_ok = True + + class Comparator(sqltypes.TypeEngine.Comparator): + + def __getitem__(self, key): + return default_comparator._binary_operate(self.expr, + operators.getitem, + key) + + def get_col_spec(self): + return 'GEO_POINT' + + def bind_processor(self, dialect): + def process(value): + if isinstance(value, geojson.Point): + return value.coordinates + return value + return process + + def result_processor(self, dialect, coltype): + return tuple + + comparator_factory = Comparator + + +class Geoshape(sqltypes.UserDefinedType): + cache_ok = True + + class Comparator(sqltypes.TypeEngine.Comparator): + + def __getitem__(self, key): + return default_comparator._binary_operate(self.expr, + operators.getitem, + key) + + def get_col_spec(self): + return 'GEO_SHAPE' + + def result_processor(self, dialect, coltype): + return geojson.GeoJSON.to_instance + + comparator_factory = Comparator diff --git a/src/sqlalchemy_cratedb/type/object.py b/src/sqlalchemy_cratedb/type/object.py new file mode 100644 index 00000000..32d36463 --- /dev/null +++ b/src/sqlalchemy_cratedb/type/object.py @@ -0,0 +1,92 @@ +import warnings + +from sqlalchemy import types as sqltypes +from sqlalchemy.ext.mutable import Mutable + + +class MutableDict(Mutable, dict): + + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." + + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + + def __init__(self, initval=None, to_update=None, root_change_key=None): + initval = initval or {} + self._changed_keys = set() + self._deleted_keys = set() + self._overwrite_key = root_change_key + self.to_update = self if to_update is None else to_update + for k in initval: + initval[k] = self._convert_dict(initval[k], + overwrite_key=k if self._overwrite_key is None else self._overwrite_key + ) + dict.__init__(self, initval) + + def __setitem__(self, key, value): + value = self._convert_dict(value, key if self._overwrite_key is None else self._overwrite_key) + dict.__setitem__(self, key, value) + self.to_update.on_key_changed( + key if self._overwrite_key is None else self._overwrite_key + ) + + def __delitem__(self, key): + dict.__delitem__(self, key) + # add the key to the deleted keys if this is the root object + # otherwise update on root object + if self._overwrite_key is None: + self._deleted_keys.add(key) + self.changed() + else: + self.to_update.on_key_changed(self._overwrite_key) + + def on_key_changed(self, key): + self._deleted_keys.discard(key) + self._changed_keys.add(key) + self.changed() + + def _convert_dict(self, value, overwrite_key): + if isinstance(value, dict) and not isinstance(value, MutableDict): + return MutableDict(value, self.to_update, overwrite_key) + return value + + def __eq__(self, other): + return dict.__eq__(self, other) + + +class ObjectTypeImpl(sqltypes.UserDefinedType, sqltypes.JSON): + + __visit_name__ = "OBJECT" + + cache_ok = False + none_as_null = False + + +# Designated name to refer to. `Object` is too ambiguous. +ObjectType = MutableDict.as_mutable(ObjectTypeImpl) + +# Backward-compatibility aliases. +_deprecated_Craty = ObjectType +_deprecated_Object = ObjectType + +# https://www.lesinskis.com/deprecating-module-scope-variables.html +deprecated_names = ["Craty", "Object"] + + +def __getattr__(name): + if name in deprecated_names: + warnings.warn(f"{name} is deprecated and will be removed in future releases. " + f"Please use ObjectType instead.", DeprecationWarning) + return globals()[f"_deprecated_{name}"] + raise AttributeError(f"module {__name__} has no attribute {name}") + + +__all__ = deprecated_names diff --git a/tests/compiler_test.py b/tests/compiler_test.py index e280d6c6..2e6609cd 100644 --- a/tests/compiler_test.py +++ b/tests/compiler_test.py @@ -37,8 +37,7 @@ except ImportError: from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_cratedb import SA_VERSION, SA_1_4, SA_2_0 -from sqlalchemy_cratedb import ObjectType +from sqlalchemy_cratedb import SA_VERSION, SA_1_4, SA_2_0, ObjectType from crate.client.test_util import ParametrizedTestCase diff --git a/tests/dialect_test.py b/tests/dialect_test.py index e797f0b5..d3b5c364 100644 --- a/tests/dialect_test.py +++ b/tests/dialect_test.py @@ -26,9 +26,8 @@ import sqlalchemy as sa from crate.client.cursor import Cursor -from sqlalchemy_cratedb import SA_VERSION +from sqlalchemy_cratedb import SA_VERSION, ObjectType from sqlalchemy_cratedb import SA_1_4, SA_2_0 -from sqlalchemy_cratedb import ObjectType from sqlalchemy import inspect from sqlalchemy.orm import Session try: diff --git a/tests/query_caching.py b/tests/query_caching.py index 4381f61c..16a7582f 100644 --- a/tests/query_caching.py +++ b/tests/query_caching.py @@ -26,7 +26,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.operators import eq -from sqlalchemy_cratedb import SA_VERSION, SA_1_4 +from sqlalchemy_cratedb import SA_VERSION, SA_1_4, ObjectArray, ObjectType from crate.testing.settings import crate_host try: @@ -34,8 +34,6 @@ except ImportError: from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy_cratedb import ObjectType, ObjectArray - class SqlAlchemyQueryCompilationCaching(TestCase): diff --git a/tests/warnings_test.py b/tests/warnings_test.py index ede78709..b74b8b30 100644 --- a/tests/warnings_test.py +++ b/tests/warnings_test.py @@ -44,7 +44,7 @@ def test_craty_object_deprecation_warning(self): with warnings.catch_warnings(record=True) as w: # Import the deprecated symbol. - from sqlalchemy_cratedb.types import Craty # noqa: F401 + from sqlalchemy_cratedb.type.object import Craty # noqa: F401 # Verify details of the deprecation warning. self.assertEqual(len(w), 1) @@ -55,7 +55,7 @@ def test_craty_object_deprecation_warning(self): with warnings.catch_warnings(record=True) as w: # Import the deprecated symbol. - from sqlalchemy_cratedb.types import Object # noqa: F401 + from sqlalchemy_cratedb.type.object import Object # noqa: F401 # Verify details of the deprecation warning. self.assertEqual(len(w), 1)