Skip to content

Commit

Permalink
Paginate all SCIM list requests in the SDK (#440)
Browse files Browse the repository at this point in the history
## Changes
This PR incorporates two hard-coded changes for the SCIM API in the
Python SDK:
1. startIndex starts at 1 for SCIM APIs, not 0. However, the existing
.Pagination.Increment controls both the start index as well as whether
the pagination is per-page or per-resource. Later, we should replace
this extension with two independent OpenAPI options: `one_indexed`
(defaulting to `false`) and `pagination_basis` (defaulting to `resource`
but can be overridden to `page`).
2. If users don't specify a limit, the SDK will include a hard-coded
limit of 100 resources per request. We could add this to the OpenAPI
spec as an option `default_limit`, which is useful for any non-paginated
APIs that later expose pagination options and allow the SDK to
gracefully support those. However, we don't want to encourage folks to
use this pattern: all new list APIs are required to be paginated from
the start.

## Tests
<!-- 
How is this tested? Please see the checklist below and also describe any
other relevant tests
-->

- [ ] `make test` run locally
- [ ] `make fmt` applied
- [ ] relevant integration tests applied

---------

Co-authored-by: Xinjie Zheng <[email protected]>
  • Loading branch information
mgyucht and xinjiezhen-db authored Nov 14, 2023
1 parent 50c71a1 commit 9ba48cc
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 22 deletions.
11 changes: 10 additions & 1 deletion .codegen/service.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,16 @@ class {{.Name}}API:{{if .Description}}
# deduplicate items that may have been added during iteration
seen = set()
{{- end}}{{if and .Pagination.Offset (not (eq .Path "/api/2.0/clusters/events")) }}
query['{{.Pagination.Offset.Name}}'] = {{if eq .Pagination.Increment 1}}1{{else}}0{{end}}{{end}}
query['{{.Pagination.Offset.Name}}'] =
{{- if eq .Pagination.Increment 1 -}}
1
{{- else if contains .Path "/scim/v2/" -}}
1
{{- else -}}
0
{{- end}}{{end}}{{if and .Pagination.Limit (contains .Path "/scim/v2/")}}
if "{{.Pagination.Limit.Name}}" not in query: query['{{.Pagination.Limit.Name}}'] = 100
{{- end}}
while True:
json = {{template "method-do" .}}
if '{{.Pagination.Results.Name}}' not in json or not json['{{.Pagination.Results.Name}}']:
Expand Down
18 changes: 12 additions & 6 deletions databricks/sdk/service/iam.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 0 additions & 15 deletions tests/integration/test_groups.py

This file was deleted.

41 changes: 41 additions & 0 deletions tests/integration/test_iam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from databricks.sdk.core import DatabricksError


def test_filtering_groups(w, random):
all = w.groups.list(filter=f'displayName eq any-{random(12)}')
found = len(list(all))
assert found == 0


def test_scim_error_unmarshall(w, random):
with pytest.raises(DatabricksError) as exc_info:
w.groups.list(filter=random(12))
assert 'Given filter operator is not supported' in str(exc_info.value)


@pytest.mark.parametrize(
"path,call",
[("/api/2.0/preview/scim/v2/Users", lambda w: w.users.list(count=10)),
("/api/2.0/preview/scim/v2/Groups", lambda w: w.groups.list(count=4)),
("/api/2.0/preview/scim/v2/ServicePrincipals", lambda w: w.service_principals.list(count=1)), ])
def test_workspace_users_list_pagination(w, path, call):
raw = w.api_client.do('GET', path)
total = raw['totalResults']
all = call(w)
found = len(list(all))
assert found == total


@pytest.mark.parametrize(
"path,call",
[("/api/2.0/accounts/%s/scim/v2/Users", lambda a: a.users.list(count=3000)),
("/api/2.0/accounts/%s/scim/v2/Groups", lambda a: a.groups.list(count=5)),
("/api/2.0/accounts/%s/scim/v2/ServicePrincipals", lambda a: a.service_principals.list(count=1000)), ])
def test_account_users_list_pagination(a, path, call):
raw = a.api_client.do('GET', path.replace("%s", a.config.account_id))
total = raw['totalResults']
all = call(a)
found = len(list(all))
assert found == total

0 comments on commit 9ba48cc

Please sign in to comment.