Skip to content

Commit

Permalink
Add values method for repos
Browse files Browse the repository at this point in the history
  • Loading branch information
TheSuperiorStanislav committed Mar 28, 2024
1 parent e0d1bf4 commit c27e733
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
29 changes: 29 additions & 0 deletions saritasa_sqlalchemy_tools/repositories/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,35 @@ async def exists(
)
) or False

async def values(
self,
field: types.ColumnField[types.ColumnTypeT],
statement: models.SelectStatement[models.BaseModelT] | None = None,
joined_load: types.LazyLoadedSequence = (),
select_in_load: types.LazyLoadedSequence = (),
annotations: types.AnnotationSequence = (),
ordering_clauses: ordering.OrderingClauses = (),
where: filters.WhereFilters = (),
**filters_by: typing.Any,
) -> collections.abc.Sequence[types.ColumnTypeT]:
"""Get all values of field."""
return (
await self.db_session.scalars(
(
statement
if statement is not None
else self.get_fetch_statement(
joined_load=joined_load,
select_in_load=select_in_load,
annotations=annotations,
ordering_clauses=ordering_clauses,
where=where,
**filters_by,
)
).with_only_columns(field),
)
).all()


class BaseSoftDeleteRepository(
BaseRepository[models.BaseSoftDeleteModelT],
Expand Down
8 changes: 8 additions & 0 deletions saritasa_sqlalchemy_tools/repositories/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
import typing

import sqlalchemy.orm
import sqlalchemy.sql.elements
import sqlalchemy.sql.roles

ColumnTypeT = typing.TypeVar("ColumnTypeT", bound=typing.Any)
ColumnField = (
sqlalchemy.sql.roles.TypedColumnsClauseRole[ColumnTypeT]
| sqlalchemy.sql.roles.ColumnsClauseRole
)

# For some reason mypy demands that orm.QueryableAttribute has two generic args
Annotation: typing.TypeAlias = (
Expand Down
14 changes: 14 additions & 0 deletions tests/test_repositories.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,17 @@ async def test_annotation_query(
assert instance.related_models_count_query == len(
test_model.related_models,
)


async def test_values(
test_model_list: list[models.TestModel],
repository: repositories.TestModelRepository,
) -> None:
"""Test values method."""
excepted_text_values = {test_model.text for test_model in test_model_list}
actual_text_values = set(
await repository.values(
field=models.TestModel.text,
),
)
assert excepted_text_values == actual_text_values

0 comments on commit c27e733

Please sign in to comment.