diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index d9dac47b..6b29e736 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -309,17 +309,17 @@ def wrapper(self, *args, **kwargs): class QuerysetEndpoint(Endpoint, Generic[T]): @api(version="2.0") - def all(self, *args, **kwargs) -> QuerySet[T]: + def all(self, *args, page_size: Optional[int] = None, **kwargs) -> QuerySet[T]: if args or kwargs: raise ValueError(".all method takes no arguments.") - queryset = QuerySet(self) + queryset = QuerySet(self, page_size=page_size) return queryset @api(version="2.0") - def filter(self, *_, **kwargs) -> QuerySet[T]: + def filter(self, *_, page_size: Optional[int] = None, **kwargs) -> QuerySet[T]: if _: raise RuntimeError("Only keyword arguments accepted.") - queryset = QuerySet(self).filter(**kwargs) + queryset = QuerySet(self, page_size=page_size).filter(**kwargs) return queryset @api(version="2.0") diff --git a/tableauserverclient/server/query.py b/tableauserverclient/server/query.py index 51c34d08..19513926 100644 --- a/tableauserverclient/server/query.py +++ b/tableauserverclient/server/query.py @@ -33,9 +33,9 @@ def to_camel_case(word: str) -> str: class QuerySet(Iterable[T], Sized): - def __init__(self, model: "QuerysetEndpoint[T]") -> None: + def __init__(self, model: "QuerysetEndpoint[T]", page_size: Optional[int] = None) -> None: self.model = model - self.request_options = RequestOptions() + self.request_options = RequestOptions(pagesize=page_size or 100) self._result_cache: List[T] = [] self._pagination_item = PaginationItem() @@ -134,12 +134,15 @@ def page_size(self: Self) -> int: self._fetch_all() return self._pagination_item.page_size - def filter(self: Self, *invalid, **kwargs) -> Self: + def filter(self: Self, *invalid, page_size: Optional[int] = None, **kwargs) -> Self: if invalid: raise RuntimeError("Only accepts keyword arguments.") for kwarg_key, value in kwargs.items(): field_name, operator = self._parse_shorthand_filter(kwarg_key) self.request_options.filter.add(Filter(field_name, operator, value)) + + if page_size: + self.request_options.pagesize = page_size return self def order_by(self: Self, *args) -> Self: @@ -155,11 +158,8 @@ def paginate(self: Self, **kwargs) -> Self: self.request_options.pagesize = kwargs["page_size"] return self - def with_page_size(self: Self, value: int) -> Self: - self.request_options.pagesize = value - return self - - def _parse_shorthand_filter(self: Self, key: str) -> Tuple[str, str]: + @staticmethod + def _parse_shorthand_filter(key: str) -> Tuple[str, str]: tokens = key.split("__", 1) if len(tokens) == 1: operator = RequestOptions.Operator.Equals @@ -173,7 +173,8 @@ def _parse_shorthand_filter(self: Self, key: str) -> Tuple[str, str]: raise ValueError("Field name `{}` is not valid.".format(field)) return (field, operator) - def _parse_shorthand_sort(self: Self, key: str) -> Tuple[str, str]: + @staticmethod + def _parse_shorthand_sort(key: str) -> Tuple[str, str]: direction = RequestOptions.Direction.Asc if key.startswith("-"): direction = RequestOptions.Direction.Desc diff --git a/test/test_request_option.py b/test/test_request_option.py index 5ade81ea..e48f8510 100644 --- a/test/test_request_option.py +++ b/test/test_request_option.py @@ -332,10 +332,29 @@ def test_filtering_parameters(self) -> None: self.assertIn("type", query_params) self.assertIn("tabloid", query_params["type"]) - def test_queryset_pagesize(self) -> None: + def test_queryset_endpoint_pagesize_all(self) -> None: for page_size in (1, 10, 100, 1000): with self.subTest(page_size): with requests_mock.mock() as m: m.get(f"{self.baseurl}/views?pageSize={page_size}", text=SLICING_QUERYSET_PAGE_1.read_text()) - queryset = self.server.views.all().with_page_size(page_size) + queryset = self.server.views.all(page_size=page_size) + assert queryset.request_options.pagesize == page_size + _ = list(queryset) + + def test_queryset_endpoint_pagesize_filter(self) -> None: + for page_size in (1, 10, 100, 1000): + with self.subTest(page_size): + with requests_mock.mock() as m: + m.get(f"{self.baseurl}/views?pageSize={page_size}", text=SLICING_QUERYSET_PAGE_1.read_text()) + queryset = self.server.views.filter(page_size=page_size) + assert queryset.request_options.pagesize == page_size + _ = list(queryset) + + def test_queryset_pagesize_filter(self) -> None: + for page_size in (1, 10, 100, 1000): + with self.subTest(page_size): + with requests_mock.mock() as m: + m.get(f"{self.baseurl}/views?pageSize={page_size}", text=SLICING_QUERYSET_PAGE_1.read_text()) + queryset = self.server.views.all().filter(page_size=page_size) + assert queryset.request_options.pagesize == page_size _ = list(queryset)