Skip to content

Commit

Permalink
Adjust session for new bindings version
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubFrejlach committed Dec 19, 2024
1 parent 3512b4f commit 4cc3b7c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 34 deletions.
4 changes: 2 additions & 2 deletions osidb_bindings/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ def make_response_iterable(response, retrieve_list_fn, *args, **kwargs):
if param is None:
setattr(response, func_name, lambda: None)
else:
limit = re.search("limit=(\d+)", param)
limit = re.search(r"limit=(\d+)", param)
if limit is not None:
kwargs["limit"] = limit.group(1)
offset = re.search("offset=(\d+)", param)
offset = re.search(r"offset=(\d+)", param)
if offset is not None:
kwargs["offset"] = offset.group(1)

Expand Down
45 changes: 13 additions & 32 deletions osidb_bindings/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def file_trackers(self, form_data: Dict[str, Any], *args, **kwargs):
return sync_fn(
*args,
client=self.client(),
form_data=transformed_data,
multipart_data=UNSET,
json_body=UNSET,
body=transformed_data,
**kwargs,
)

Expand Down Expand Up @@ -102,9 +100,7 @@ def reject_flaw(self, id: str, form_data: Dict[str, Any], *args, **kwargs):
id,
*args,
client=self.client(),
form_data=transformed_data,
multipart_data=UNSET,
json_body=UNSET,
body=transformed_data,
**kwargs,
)

Expand Down Expand Up @@ -140,13 +136,13 @@ def get_sync_function(api_module: ModuleType) -> Callable:
)


def get_async_function(api_module: ModuleType) -> Callable:
def get_asyncio_function(api_module: ModuleType) -> Callable:
"""
Get 'sync' function from API module if available (response example is defined in schema)
or get basic 'sync_detailed' function (response example is not defined in schema)
Get 'asyncio' function from API module if available (response example is defined in schema)
or get basic 'asyncio_detailed' function (response example is not defined in schema)
"""
return double_underscores_to_single_underscores(
getattr(api_module, "async_", getattr(api_module, "async_detailed"))
getattr(api_module, "asyncio_", getattr(api_module, "asyncio_detailed"))
)


Expand Down Expand Up @@ -304,11 +300,9 @@ def __get_refresh_token(self) -> str:
if isinstance(self.auth, tuple):
response = auth_token_create.sync(
client=self.__client,
form_data=models.TokenObtainPair.from_dict(
body=models.TokenObtainPair.from_dict(
{"username": self.auth[0], "password": self.auth[1]}
),
multipart_data=UNSET,
json_body=UNSET,
)
else:
response = auth_token_retrieve.sync(
Expand All @@ -322,23 +316,19 @@ def __get_access_token(self) -> str:
try:
response = auth_token_refresh_create.sync(
client=self.__client,
form_data=models.TokenRefresh.from_dict(
body=models.TokenRefresh.from_dict(
{"refresh": self.refresh_token}
),
multipart_data=UNSET,
json_body=UNSET,
)
except requests.HTTPError:

# expired refresh token, renew it and try again
self.refresh_token = self.__get_refresh_token()
response = auth_token_refresh_create.sync(
client=self.__client,
form_data=models.TokenRefresh.from_dict(
body=models.TokenRefresh.from_dict(
{"refresh": self.refresh_token}
),
multipart_data=UNSET,
json_body=UNSET,
)

return response.access
Expand Down Expand Up @@ -454,9 +444,7 @@ def create(self, form_data: Dict[str, Any], *args, **kwargs):
return sync_fn(
*args,
client=self.client(),
form_data=serialized_data,
multipart_data=UNSET,
json_body=UNSET,
body=serialized_data,
**kwargs,
)
else:
Expand All @@ -476,8 +464,6 @@ def bulk_create(self, form_data: Dict[str, Any], *args, **kwargs):
return sync_fn(
*args,
client=self.client(),
json_body=serialized_data,
multipart_data=UNSET,
**kwargs,
)
else:
Expand All @@ -498,9 +484,7 @@ def update(self, id, form_data: Dict[str, Any], *args, **kwargs):
id,
*args,
client=self.client(),
form_data=serialized_data,
multipart_data=UNSET,
json_body=UNSET,
body=serialized_data,
**kwargs,
)
else:
Expand All @@ -520,8 +504,6 @@ def bulk_update(self, form_data: Dict[str, Any], *args, **kwargs):
return sync_fn(
*args,
client=self.client(),
json_body=serialized_data,
multipart_data=UNSET,
**kwargs,
)
else:
Expand Down Expand Up @@ -556,8 +538,7 @@ def bulk_delete(self, form_data: Dict[str, Any], *args, **kwargs):
return sync_fn(
*args,
client=self.client(),
json_body=serialized_data,
multipart_data=UNSET,
body=serialized_data,
**kwargs,
)
else:
Expand Down Expand Up @@ -623,7 +604,7 @@ async def __retrieve_list_async(
method_module = self.__get_method_module(
resource_name=self.resource_name, method="list"
)
async_fn = get_async_function(method_module)
async_fn = get_asyncio_function(method_module)

kwargs.pop("offset", None)
limit = kwargs.pop("limit", None) or DEFAULT_LIMIT
Expand Down

0 comments on commit 4cc3b7c

Please sign in to comment.