diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..3dbfbb4 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,2 @@ +[run] +omit = tests/* \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 748a892..c825fe9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -40,6 +40,20 @@ category = "dev" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +[[package]] +name = "coverage" +version = "7.3.2" +description = "Code coverage measurement for Python" +category = "dev" +optional = false +python-versions = ">=3.8" + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + [[package]] name = "exceptiongroup" version = "1.2.0" @@ -175,6 +189,43 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-cov" +version = "4.1.0" +description = "Pytest plugin for measuring coverage." +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] + +[[package]] +name = "pytest-cover" +version = "3.0.0" +description = "Pytest plugin for measuring coverage. Forked from `pytest-cov`." +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +pytest-cov = ">=2.0" + +[[package]] +name = "pytest-coverage" +version = "0.0" +description = "Pytest plugin for measuring coverage. Forked from `pytest-cov`." +category = "dev" +optional = false +python-versions = "*" + +[package.dependencies] +pytest-cover = "*" + [[package]] name = "python-dateutil" version = "2.8.2" @@ -250,7 +301,7 @@ python-versions = ">=3.8" [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "505d32a18e15dd4aa9d80f95409cb93a192976d897856579c8bb0de5ece39f9f" +content-hash = "dec73ae351f76ed7d60336dc9bbcdb3a18832d27ff41567c7c2dd03000330e81" [metadata.files] black = [ @@ -281,6 +332,60 @@ colorama = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +coverage = [ + {file = "coverage-7.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d872145f3a3231a5f20fd48500274d7df222e291d90baa2026cc5152b7ce86bf"}, + {file = "coverage-7.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:310b3bb9c91ea66d59c53fa4989f57d2436e08f18fb2f421a1b0b6b8cc7fffda"}, + {file = "coverage-7.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f47d39359e2c3779c5331fc740cf4bce6d9d680a7b4b4ead97056a0ae07cb49a"}, + {file = "coverage-7.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa72dbaf2c2068404b9870d93436e6d23addd8bbe9295f49cbca83f6e278179c"}, + {file = "coverage-7.3.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:beaa5c1b4777f03fc63dfd2a6bd820f73f036bfb10e925fce067b00a340d0f3f"}, + {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:dbc1b46b92186cc8074fee9d9fbb97a9dd06c6cbbef391c2f59d80eabdf0faa6"}, + {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:315a989e861031334d7bee1f9113c8770472db2ac484e5b8c3173428360a9148"}, + {file = "coverage-7.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d1bc430677773397f64a5c88cb522ea43175ff16f8bfcc89d467d974cb2274f9"}, + {file = "coverage-7.3.2-cp310-cp310-win32.whl", hash = "sha256:a889ae02f43aa45032afe364c8ae84ad3c54828c2faa44f3bfcafecb5c96b02f"}, + {file = "coverage-7.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:c0ba320de3fb8c6ec16e0be17ee1d3d69adcda99406c43c0409cb5c41788a611"}, + {file = "coverage-7.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ac8c802fa29843a72d32ec56d0ca792ad15a302b28ca6203389afe21f8fa062c"}, + {file = "coverage-7.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:89a937174104339e3a3ffcf9f446c00e3a806c28b1841c63edb2b369310fd074"}, + {file = "coverage-7.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e267e9e2b574a176ddb983399dec325a80dbe161f1a32715c780b5d14b5f583a"}, + {file = "coverage-7.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2443cbda35df0d35dcfb9bf8f3c02c57c1d6111169e3c85fc1fcc05e0c9f39a3"}, + {file = "coverage-7.3.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4175e10cc8dda0265653e8714b3174430b07c1dca8957f4966cbd6c2b1b8065a"}, + {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0cbf38419fb1a347aaf63481c00f0bdc86889d9fbf3f25109cf96c26b403fda1"}, + {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5c913b556a116b8d5f6ef834038ba983834d887d82187c8f73dec21049abd65c"}, + {file = "coverage-7.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1981f785239e4e39e6444c63a98da3a1db8e971cb9ceb50a945ba6296b43f312"}, + {file = "coverage-7.3.2-cp311-cp311-win32.whl", hash = "sha256:43668cabd5ca8258f5954f27a3aaf78757e6acf13c17604d89648ecc0cc66640"}, + {file = "coverage-7.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10c39c0452bf6e694511c901426d6b5ac005acc0f78ff265dbe36bf81f808a2"}, + {file = "coverage-7.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4cbae1051ab791debecc4a5dcc4a1ff45fc27b91b9aee165c8a27514dd160836"}, + {file = "coverage-7.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:12d15ab5833a997716d76f2ac1e4b4d536814fc213c85ca72756c19e5a6b3d63"}, + {file = "coverage-7.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c7bba973ebee5e56fe9251300c00f1579652587a9f4a5ed8404b15a0471f216"}, + {file = "coverage-7.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fe494faa90ce6381770746077243231e0b83ff3f17069d748f645617cefe19d4"}, + {file = "coverage-7.3.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6e9589bd04d0461a417562649522575d8752904d35c12907d8c9dfeba588faf"}, + {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:d51ac2a26f71da1b57f2dc81d0e108b6ab177e7d30e774db90675467c847bbdf"}, + {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:99b89d9f76070237975b315b3d5f4d6956ae354a4c92ac2388a5695516e47c84"}, + {file = "coverage-7.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:fa28e909776dc69efb6ed975a63691bc8172b64ff357e663a1bb06ff3c9b589a"}, + {file = "coverage-7.3.2-cp312-cp312-win32.whl", hash = "sha256:289fe43bf45a575e3ab10b26d7b6f2ddb9ee2dba447499f5401cfb5ecb8196bb"}, + {file = "coverage-7.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:7dbc3ed60e8659bc59b6b304b43ff9c3ed858da2839c78b804973f613d3e92ed"}, + {file = "coverage-7.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f94b734214ea6a36fe16e96a70d941af80ff3bfd716c141300d95ebc85339738"}, + {file = "coverage-7.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:af3d828d2c1cbae52d34bdbb22fcd94d1ce715d95f1a012354a75e5913f1bda2"}, + {file = "coverage-7.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:630b13e3036e13c7adc480ca42fa7afc2a5d938081d28e20903cf7fd687872e2"}, + {file = "coverage-7.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9eacf273e885b02a0273bb3a2170f30e2d53a6d53b72dbe02d6701b5296101c"}, + {file = "coverage-7.3.2-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8f17966e861ff97305e0801134e69db33b143bbfb36436efb9cfff6ec7b2fd9"}, + {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b4275802d16882cf9c8b3d057a0839acb07ee9379fa2749eca54efbce1535b82"}, + {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:72c0cfa5250f483181e677ebc97133ea1ab3eb68645e494775deb6a7f6f83901"}, + {file = "coverage-7.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:cb536f0dcd14149425996821a168f6e269d7dcd2c273a8bff8201e79f5104e76"}, + {file = "coverage-7.3.2-cp38-cp38-win32.whl", hash = "sha256:307adb8bd3abe389a471e649038a71b4eb13bfd6b7dd9a129fa856f5c695cf92"}, + {file = "coverage-7.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:88ed2c30a49ea81ea3b7f172e0269c182a44c236eb394718f976239892c0a27a"}, + {file = "coverage-7.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b631c92dfe601adf8f5ebc7fc13ced6bb6e9609b19d9a8cd59fa47c4186ad1ce"}, + {file = "coverage-7.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d3d9df4051c4a7d13036524b66ecf7a7537d14c18a384043f30a303b146164e9"}, + {file = "coverage-7.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f7363d3b6a1119ef05015959ca24a9afc0ea8a02c687fe7e2d557705375c01f"}, + {file = "coverage-7.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2f11cc3c967a09d3695d2a6f03fb3e6236622b93be7a4b5dc09166a861be6d25"}, + {file = "coverage-7.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:149de1d2401ae4655c436a3dced6dd153f4c3309f599c3d4bd97ab172eaf02d9"}, + {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:3a4006916aa6fee7cd38db3bfc95aa9c54ebb4ffbfc47c677c8bba949ceba0a6"}, + {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9028a3871280110d6e1aa2df1afd5ef003bab5fb1ef421d6dc748ae1c8ef2ebc"}, + {file = "coverage-7.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9f805d62aec8eb92bab5b61c0f07329275b6f41c97d80e847b03eb894f38d083"}, + {file = "coverage-7.3.2-cp39-cp39-win32.whl", hash = "sha256:d1c88ec1a7ff4ebca0219f5b1ef863451d828cccf889c173e1253aa84b1e07ce"}, + {file = "coverage-7.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b4767da59464bb593c07afceaddea61b154136300881844768037fd5e859353f"}, + {file = "coverage-7.3.2-pp38.pp39.pp310-none-any.whl", hash = "sha256:ae97af89f0fbf373400970c0a21eef5aa941ffeed90aee43650b81f7d7f47637"}, + {file = "coverage-7.3.2.tar.gz", hash = "sha256:be32ad29341b0170e795ca590e1c07e81fc061cb5b10c74ce7203491484404ef"}, +] exceptiongroup = [ {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"}, @@ -384,6 +489,18 @@ pytest = [ {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"}, {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"}, ] +pytest-cov = [ + {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, + {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, +] +pytest-cover = [ + {file = "pytest-cover-3.0.0.tar.gz", hash = "sha256:5bdb6c1cc3dd75583bb7bc2c57f5e1034a1bfcb79d27c71aceb0b16af981dbf4"}, + {file = "pytest_cover-3.0.0-py2.py3-none-any.whl", hash = "sha256:578249955eb3b5f3991209df6e532bb770b647743b7392d3d97698dc02f39ebb"}, +] +pytest-coverage = [ + {file = "pytest-coverage-0.0.tar.gz", hash = "sha256:db6af2cbd7e458c7c9fd2b4207cee75258243c8a81cad31a7ee8cfad5be93c05"}, + {file = "pytest_coverage-0.0-py2.py3-none-any.whl", hash = "sha256:dedd084c5e74d8e669355325916dc011539b190355021b037242514dee546368"}, +] python-dateutil = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, diff --git a/pyproject.toml b/pyproject.toml index c82c629..a6c798c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ sqlalchemy = "^2.0" isort = "^5.12.0" black = "^23.11.0" pytest = "^7.4.3" +pytest-coverage = "^0.0" [build-system] requires = ["poetry>=0.12"] diff --git a/rqlalchemy/query.py b/rqlalchemy/query.py index f856f5d..59a15b3 100644 --- a/rqlalchemy/query.py +++ b/rqlalchemy/query.py @@ -29,6 +29,7 @@ from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.sql import _typing from sqlalchemy.sql import elements +from sqlalchemy.sql.sqltypes import JSON ArgsType = List[Any] BinaryOperator = Callable[[Any, Any], Any] @@ -80,9 +81,9 @@ def _rql_select_limit(self): def _rql_select_offset(self): return self._offset_clause.value if self._offset_clause is not None else None - def rql(self, query: str = "", limit: Optional[int] = None) -> "RQLSelect": + def rql(self, query: str = "", limit: Optional[int] = None) -> "RQLSelect": # noqa: C901 if len(self._rql_select_entities) > 1: - raise NotImplementedError("Select must have only one entity") + raise self._rql_error_cls("Select must have only one entity") if not query: self.rql_parsed = None @@ -124,7 +125,9 @@ def rql(self, query: str = "", limit: Optional[int] = None) -> "RQLSelect": return select_ - def execute(self, session: Session) -> Sequence[Union[Union[Row, RowMapping], Any]]: + def execute( # noqa: C901 + self, session: Session + ) -> Sequence[Union[Union[Row, RowMapping], Any]]: # noqa: C901 """ Executes the sql expression differently based on which clauses included: - For single aggregates a scalar is returned @@ -256,9 +259,10 @@ def _rql_apply(self, node: Dict[str, Any]) -> Any: return node - def _rql_attr(self, attr): - model = self._rql_select_entities[0] + def _rql_attr(self, attr, model=None): + model = model or self._rql_select_entities[0] + # if it's just a plain attribute name, return it if isinstance(attr, str): try: return getattr(model, attr) @@ -266,16 +270,37 @@ def _rql_attr(self, attr): raise self._rql_error_cls(f"Invalid query attribute: {attr}") from e elif isinstance(attr, tuple): - # Every entry in attr but the last should be a relationship name. - for name in attr[:-1]: - if name not in inspect(model).relationships: - raise AttributeError(f'{model} has no relationship "{name}"') + # if it's an one-item tuple resolve it recursively + if len(attr) == 1: + return self._rql_attr(attr[0], model) + # if there are more than one item in the tuple, resolve the first + # item + + name = attr[0] + # if it's a relationship, resolve it, add a join, and resolve the + # rest recursively + if name in inspect(model).relationships: relation = getattr(model, name) self._rql_joins.append(relation) model = relation.mapper.class_ - return getattr(model, attr[-1]) - raise NotImplementedError + return self._rql_attr(attr[1:], model) + + # if it's a JSON column, build a path to the value using the + # remaining entries, set the field name as key to be used in RQL + # select clauses, and return the result immediately. + if isinstance(inspect(model).columns[name].type, JSON): + json_column = getattr(model, name) + json_path = reduce(operator.getitem, attr[1:], json_column) # noqa: E203 + json_path.key = attr[-1] + return json_path + + # if it's neither, something is wrong. + raise self._rql_error_cls(f"Invalid nested query attribute: {name}") + + # Parsed RQL attributes are either strings or tuples. We should never + # get here. + raise TypeError(f"Invalid attribute type: {attr}") def _rql_value(self, value: Any) -> Any: if isinstance(value, dict): @@ -283,11 +308,46 @@ def _rql_value(self, value: Any) -> Any: return value + def _rql_set_attr_type_for_json_value(self, attr: Any, value: Any) -> Any: + # if it's not a JSON column, return it unchanged + if not isinstance(attr.type, JSON): + return attr + + # if value is a list of values, they must all be of the same type + if isinstance(value, list): + if not value: + return attr + + value_type = type(value[0]) + if not all(isinstance(v, value_type) for v in value): + raise self._rql_error_cls( + "Cannot compare JSON column against multiple values of different types" + ) + + value = value[0] + + # if it's a JSON column, cast the value to the appropriate type + if isinstance(value, str): + return attr.as_string() + if isinstance(value, bool): + return attr.as_boolean() + elif isinstance(value, int): + return attr.as_integer() + elif isinstance(value, float): + return attr.as_float() + else: + # NOTE: we might have to add support for all pyrql types here + raise self._rql_error_cls( + f"Cannot cast to type {type(value)} for comparison with JSON column" + ) + def _rql_compare(self, args: ArgsType, op: BinaryOperator) -> elements.BinaryExpression: attr, value = args attr = self._rql_attr(attr=attr) value = self._rql_value(value) + attr = self._rql_set_attr_type_for_json_value(attr, value) + return op(attr, value) def _rql_and(self, args: ArgsType) -> Optional[elements.BooleanClauseList]: @@ -305,6 +365,8 @@ def _rql_in(self, args: ArgsType) -> elements.BinaryExpression: attr = self._rql_attr(attr=attr) value = self._rql_value([str(v) for v in value]) + attr = self._rql_set_attr_type_for_json_value(attr, value) + return attr.in_(value) def _rql_out(self, args: ArgsType) -> elements.BinaryExpression: @@ -312,6 +374,8 @@ def _rql_out(self, args: ArgsType) -> elements.BinaryExpression: attr = self._rql_attr(attr=attr) value = self._rql_value([str(v) for v in value]) + attr = self._rql_set_attr_type_for_json_value(attr, value) + return sql.not_(attr.in_(value)) def _rql_like(self, args: ArgsType) -> elements.BinaryExpression: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py index d6bcb19..3efc3a9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,15 +2,18 @@ import json import os +import re import pytest -from fixtures import Base -from fixtures import Blog -from fixtures import Post -from fixtures import User from sqlalchemy import create_engine +from sqlalchemy import select from sqlalchemy.orm import sessionmaker +from .fixtures import Base +from .fixtures import Blog +from .fixtures import Post +from .fixtures import User + @pytest.fixture(scope="session") def engine(): @@ -45,8 +48,24 @@ def session(engine): state=raw["state"], tags=raw["tags"], balance=raw["balance"], + raw=raw, + misc={ + "eye_color": raw["eyeColor"], + "likes_apples": raw["favoriteFruit"] == "apple", + "unread_messages": int( + re.search(r"You have (\d+) unread messages", raw["greeting"]).group(1) + ), + "latitude": raw["latitude"], + "preferences": { + "favorite_fruit": raw["favoriteFruit"], + }, + "location": { + "type": "Point", + "coordinates": [raw["longitude"], raw["latitude"]], + }, + "balance": float(raw["balance"].strip("$").replace(",", "")), + }, ) - localsession.add(obj) localsession.commit() @@ -83,3 +102,8 @@ def posts(blogs, session): posts.append(post) session.commit() yield (posts) + + +@pytest.fixture(name="users") +def _users(session): + return session.scalars(select(User)).all() diff --git a/tests/fixtures.py b/tests/fixtures.py index e327217..4486b22 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -38,6 +38,9 @@ class User(Base): state = sa.Column(sa.String(2)) balance = sa.Column(sa.Numeric(9, 2)) + raw = sa.Column(sa.JSON) + misc = sa.Column(sa.JSON) + _tags = relationship("Tag") tags = association_proxy("_tags", "name", creator=lambda name: Tag(name=name)) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index 0ffd3c7..403a960 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -1,9 +1,10 @@ import pytest -from fixtures import User from rqlalchemy import RQLSelectError from rqlalchemy.query import select +from .fixtures import User + class TestPagination: def test_pagination_no_limit_raises_error(self, session): diff --git a/tests/test_query.py b/tests/test_query.py index 843dc66..f00571f 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,164 +1,160 @@ import pytest -from fixtures import Blog -from fixtures import Post -from fixtures import User from sqlalchemy import func -from sqlalchemy import not_ from rqlalchemy import RQLSelectError from rqlalchemy import select +from .fixtures import User + def to_dict(it): return [row._asdict() for row in it] class TestQuery: - def test_simple_sort(self, session): + def test_simple_sort(self, session, users): res = select(User).rql("sort(balance)").execute(session) - exp = session.scalars(select(User).order_by(User.balance)).all() + exp = sorted(users, key=lambda u: u.balance) assert res assert res == exp - def test_simple_sort_desc(self, session): + def test_simple_sort_desc(self, session, users): res = select(User).rql("sort(-balance)").execute(session) - exp = session.scalars(select(User).order_by(User.balance.desc())).all() + exp = sorted(users, key=lambda u: u.balance, reverse=True) assert res assert res == exp - def test_complex_sort(self, session): + def test_complex_sort(self, session, users): res = select(User).rql("sort(balance,registered,birthdate)").execute(session) - exp = session.scalars( - select(User).order_by(User.balance, User.registered, User.birthdate) - ).all() + exp = sorted(users, key=lambda u: (u.balance, u.registered, u.birthdate)) assert res assert res == exp - def test_in_operator(self, session): + def test_in_operator(self, session, users): res = select(User).rql("in(state,(FL,TX))").execute(session) - exp = session.scalars(select(User).filter(User.state.in_(["FL", "TX"]))).all() + exp = [u for u in users if u.state in ("FL", "TX")] assert res assert res == exp - def test_out_operator(self, session): + def test_out_operator(self, session, users): res = select(User).rql("out(state,(FL,TX))").execute(session) - exp = session.scalars(select(User).filter(not_(User.state.in_(["FL", "TX"])))).all() + exp = [u for u in users if u.state not in ("FL", "TX")] assert res assert res == exp - def test_contains_string(self, session): + def test_contains_string(self, session, users): res = select(User).rql("contains(email,besto.com)").execute(session) - exp = session.scalars(select(User).filter(User.email.contains("besto.com"))).all() + exp = [u for u in users if "besto.com" in u.email] assert res assert res == exp - def test_excludes_string(self, session): + def test_excludes_string(self, session, users): res = select(User).rql("excludes(email,besto.com)").execute(session) - exp = session.scalars(select(User).filter(not_(User.email.contains("besto.com")))).all() + exp = [u for u in users if "besto.com" not in u.email] assert res assert res == exp - def test_contains_array(self, session): + def test_contains_array(self, session, users): res = select(User).rql("contains(tags,aliqua)").execute(session) - exp = session.scalars(select(User).filter(User.tags.contains("aliqua"))).all() + exp = [u for u in users if "aliqua" in u.tags] assert res assert res == exp - def test_excludes_array(self, session): + def test_excludes_array(self, session, users): res = select(User).rql("excludes(tags,aliqua)").execute(session) - exp = session.scalars(select(User).filter(not_(User.tags.contains("aliqua")))).all() + exp = [u for u in users if "aliqua" not in u.tags] assert res assert res == exp - def test_limit(self, session): + def test_limit(self, session, users): res = select(User).rql("limit(2)").execute(session) - exp = session.scalars(select(User).limit(2)).all() + exp = [u for u in users][:2] assert res assert res == exp - def test_select(self, session): + def test_select(self, session, users): rql_res = select(User).rql("select(user_id,state)").execute(session) - res = to_dict(session.execute(select(User.user_id, User.state))) + res = [{"user_id": u.user_id, "state": u.state} for u in users] assert res assert rql_res == res - def test_values(self, session): + def test_values(self, session, users): res = select(User).rql("values(state)").execute(session) - exp = [v[0] for v in session.execute(select(User.state))] + exp = [u.state for u in users] assert res assert res == exp - def test_sum(self, session): + def test_sum(self, session, users): res = select(User).rql("sum(balance)").execute(session) - exp = session.scalar(select(func.sum(User.balance))) + exp = sum([u.balance for u in users]) assert res == exp - def test_mean(self, session): + def test_mean(self, session, users): res = select(User).rql("mean(balance)").execute(session) - exp = session.scalar(select(func.avg(User.balance))) + # SQLAlchemy average is cast to float instead of Decimal? + exp = sum([float(u.balance) for u in users]) / len(users) assert res == exp - def test_max(self, session): + def test_max(self, session, users): res = select(User).rql("max(balance)").execute(session) - exp = session.scalar(select(func.max(User.balance))) + exp = max([u.balance for u in users]) assert res == exp - def test_min(self, session): + def test_min(self, session, users): res = select(User).rql("min(balance)").execute(session) - exp = session.scalar(select(func.min(User.balance))) + exp = min([u.balance for u in users]) assert res == exp - def test_first(self, session): + def test_first(self, session, users): res = select(User).rql("first()").execute(session) - exp = [session.scalars(select(User)).first()] + exp = [users[0]] assert len(res) == 1 assert res == exp - def test_one(self, session): - res = select(User).rql("guid=658c407c-6c19-470e-9aa6-8c2b86cddb4b&one()").execute(session) - exp = [ - session.scalars( - select(User).filter(User.guid == "658c407c-6c19-470e-9aa6-8c2b86cddb4b") - ).one() - ] + def test_one(self, session, users): + guid = "658c407c-6c19-470e-9aa6-8c2b86cddb4b" + res = select(User).rql(f"guid={guid}&one()").execute(session) + exp = [u for u in users if u.guid == guid] assert len(res) == 1 assert res == exp - def test_one_no_results_found(self, session): + def test_one_no_results_found(self, session, users): with pytest.raises(RQLSelectError) as exc: select(User).rql("guid=lero&one()").execute(session) assert exc.value.args[0] == "No result found for one()" - def test_one_multiple_results_found(self, session): + def test_one_multiple_results_found(self, session, users): with pytest.raises(RQLSelectError) as exc: select(User).rql("state=FL&one()").execute(session) assert exc.value.args[0] == "Multiple results found for one()" - def test_distinct(self, session): + def test_distinct(self, session, users): res = select(User).rql("select(gender)&distinct()").execute(session) - exp = to_dict(session.execute(select(User.gender).distinct())) + exp = [{"gender": gender} for gender in {u.gender for u in users}] assert len(res) == 2 - assert res == exp + assert res == exp or res == exp[::-1] - def test_count(self, session): + def test_count(self, session, users): res = select(User).rql("count()").execute(session) - exp = session.scalar(select(func.count()).select_from(User)) + exp = len(users) assert res == exp @pytest.mark.parametrize("user_id", (1, 2, 3)) - def test_eq_operator(self, session, user_id): + def test_eq_operator(self, session, user_id, users): res = select(User).rql("user_id={}".format(user_id)).execute(session) + exp = [u for u in users if u.user_id == user_id] assert res - assert session.scalar(select(User).filter_by(user_id=user_id)).name == res[0].name + assert res == exp @pytest.mark.parametrize("balance", (1000, 2000, 3000)) - def test_gt_operator(self, session, balance): + def test_gt_operator(self, session, balance, users): res = select(User).rql("gt(balance,{})".format(balance)).execute(session) + exp = [u for u in users if u.balance > balance] assert res - assert all([u.balance > balance for u in res]) + assert res == exp def test_aggregate(self, session): res = select(User).rql("aggregate(state,sum(balance))").execute(session) @@ -195,14 +191,12 @@ def test_aggregate_with_filter(self, session): assert res assert res == exp - def test_like_with_relationship_1_deep(self, session, blogs): + def test_like_with_relationship_1_deep(self, session, blogs, users): res = select(User).rql("like((blogs, title), *1*)").execute(session) - exp = session.scalars(select(User).join(Blog).filter(Blog.title.like("%1%"))).all() + exp = [b.user for b in blogs if "1" in b.title] assert res == exp def test_like_with_relationship_2_deep(self, session, posts): res = select(User).rql("like((blogs, posts, title), *Post 1*)").execute(session) - exp = session.scalars( - select(User).join(Blog).join(Post).filter(Post.title.like("%Post 1%")) - ).all() + exp = [p.blog.user for p in posts if "Post 1" in p.title] assert res == exp diff --git a/tests/test_query_defaults.py b/tests/test_query_defaults.py index 8602d35..58d726c 100644 --- a/tests/test_query_defaults.py +++ b/tests/test_query_defaults.py @@ -2,10 +2,10 @@ from unittest.mock import patch -from fixtures import User - from rqlalchemy.query import select +from .fixtures import User + class TestQueryDefaults: @patch("rqlalchemy.RQLSelect._rql_default_limit", 10) diff --git a/tests/test_query_json.py b/tests/test_query_json.py new file mode 100644 index 0000000..eecdc47 --- /dev/null +++ b/tests/test_query_json.py @@ -0,0 +1,188 @@ +from sqlalchemy import func + +from rqlalchemy import select + +from .fixtures import User +from .test_query import to_dict + + +class TestQueryJSON: + def test_simple_sort(self, session, users): + res = select(User).rql("sort((raw,balance))").execute(session) + exp = sorted(users, key=lambda u: u.raw["balance"]) + assert res + assert res == exp + + def test_simple_sort_desc(self, session, users): + res = select(User).rql("sort(-(raw,balance))").execute(session) + exp = sorted(users, key=lambda u: u.raw["balance"], reverse=True) + assert res + assert res == exp + + def test_complex_sort(self, session, users): + res = ( + select(User) + .rql("sort((raw,balance),(raw,registered),(raw,birthdate))") + .execute(session) + ) + exp = sorted( + users, + key=lambda u: (u.raw["balance"], u.raw["registered"], u.raw["birthdate"]), + ) + assert res + assert res == exp + + def test_in_operator(self, session, users): + res = select(User).rql("in((raw,state),(FL,TX))").execute(session) + exp = [u for u in users if u.raw["state"] in ("FL", "TX")] + assert res + assert res == exp + + def test_out_operator(self, session, users): + res = select(User).rql("out((raw,state),(FL,TX))").execute(session) + exp = [u for u in users if u.raw["state"] not in ("FL", "TX")] + assert res + assert res == exp + + def test_contains_string(self, session, users): + res = select(User).rql("contains((raw,email),besto.com)").execute(session) + exp = [u for u in users if "besto.com" in u.raw["email"]] + assert res + assert res == exp + + def test_excludes_string(self, session, users): + res = select(User).rql("excludes((raw,email),besto.com)").execute(session) + exp = [u for u in users if "besto.com" not in u.raw["email"]] + assert res + assert res == exp + + def test_contains_array(self, session, users): + res = select(User).rql("contains((raw,tags),aliqua)").execute(session) + exp = [u for u in users if "aliqua" in u.raw["tags"]] + assert res + assert res == exp + + def test_excludes_array(self, session, users): + res = select(User).rql("excludes((raw,tags),aliqua)").execute(session) + exp = [u for u in users if "aliqua" not in u.raw["tags"]] + assert res + assert res == exp + + def test_select_1_deep(self, session, users): + res = select(User).rql("select((raw,guid),(raw,state),(raw,isActive))").execute(session) + exp = [ + {"guid": u.raw["guid"], "state": u.raw["state"], "isActive": u.raw["isActive"]} + for u in users + ] + assert res + assert res == exp + + def test_select_2_deep(self, session, users): + res = select(User).rql("select((misc,preferences,favorite_fruit))").execute(session) + exp = [{"favorite_fruit": u.misc["preferences"]["favorite_fruit"]} for u in users] + assert res + assert res == exp + + def test_values(self, session, users): + res = select(User).rql("values((raw,state))").execute(session) + exp = [u.raw["state"] for u in users] + assert res + assert res == exp + + def test_filter_by_json_key_1_deep_string(self, session): + res = select(User).rql("eq((misc,eye_color),blue)").execute(session) + exp = [u for u in session.scalars(select(User)) if u.misc["eye_color"] == "blue"] + assert res + assert res == exp + + def test_filter_by_json_key_1_deep_bool(self, session): + res = select(User).rql("eq((misc,likes_apples),true)").execute(session) + exp = [u for u in session.scalars(select(User)) if u.misc["likes_apples"]] + assert res + assert res == exp + + def test_filter_by_json_key_1_deep_integer(self, session): + res = select(User).rql("eq((misc,unread_messages),8)").execute(session) + exp = [u for u in session.scalars(select(User)) if u.misc["unread_messages"] == 8] + assert res + assert res == exp + + def test_filter_by_json_key_1_deep_float(self, session): + res = select(User).rql("gt((misc,balance),1000.0)").execute(session) + exp = [u for u in session.scalars(select(User)) if u.misc["balance"] > 1000] + assert res + assert res == exp + + def test_filter_by_json_key_2_deep(self, session, users): + res = select(User).rql("eq((misc,preferences,favorite_fruit),banana)").execute(session) + exp = [u for u in users if u.misc["preferences"]["favorite_fruit"] == "banana"] + assert res + assert res == exp + + def test_filter_by_json_key_3_deep_and_index(self, session, users): + res = select(User).rql("lt((misc,location,coordinates,1),-18.4)").execute(session) + exp = [u for u in users if u.misc["location"]["coordinates"][1] < -18.4] + assert res + assert res == exp + + def test_aggregate(self, session): + res = ( + select(User).rql("aggregate((raw,state),sum((misc,unread_messages)))").execute(session) + ) + exp = to_dict( + session.execute( + select( + User.raw["state"].label("state"), + func.sum(User.misc["unread_messages"].as_integer()).label("sum"), + ).group_by(User.raw["state"].label("state")) + ).all() + ) + + assert res + assert res == exp + + def test_aggregate_count(self, session): + res = select(User).rql("aggregate((raw,gender),count((raw,user_id)))").execute(session) + exp = to_dict( + session.execute( + select( + User.raw["gender"].label("gender"), + func.count(User.raw["user_id"]).label("count"), + ).group_by(User.raw["gender"].label("gender")) + ).all() + ) + + assert res + assert res == exp + + def test_aggregate_with_filter(self, session): + res = ( + select(User) + .rql("aggregate((raw,state),sum((misc,unread_messages)))&eq((raw,isActive),true)") + .execute(session) + ) + exp = to_dict( + session.execute( + select( + User.raw["state"].label("state"), + func.sum(User.misc["unread_messages"].as_integer()).label("sum"), + ) + .filter(User.raw["isActive"].as_boolean()) + .group_by(User.raw["state"].label("state")) + ).all() + ) + + assert res + assert res == exp + + def test_like_with_relationship_1_deep(self, session, users): + res = select(User).rql("like((raw,name),*Jackson*)").execute(session) + exp = [u for u in users if "Jackson" in u.raw["name"]] + assert res + assert res == exp + + def test_like_with_relationship_2_deep(self, session, users): + res = select(User).rql("like((misc,preferences,favorite_fruit),*ana*)").execute(session) + exp = [u for u in users if "ana" in u.misc["preferences"]["favorite_fruit"]] + assert res + assert res == exp