Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bulk operations #51

Merged
merged 3 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

## [4.1.0] - 2024-07-01
### Added
- support for bulk create and bulk update for Affects (OSIDB-3124)

## [4.0.0] - 2024-06-17
### Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ...types import UNSET, Response, Unset

QUERY_PARAMS = {}
REQUEST_BODY_TYPE = List[AffectPost]


def _get_kwargs(
Expand Down Expand Up @@ -47,7 +48,7 @@ def _get_kwargs(
return {
"url": url,
"headers": headers,
"json": form_data.to_dict(),
"json": json_json_body,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ...types import UNSET, Response, Unset

QUERY_PARAMS = {}
REQUEST_BODY_TYPE = List[AffectBulkPut]


def _get_kwargs(
Expand Down Expand Up @@ -47,7 +48,7 @@ def _get_kwargs(
return {
"url": url,
"headers": headers,
"json": form_data.to_dict(),
"json": json_json_body,
}


Expand Down
96 changes: 90 additions & 6 deletions osidb_bindings/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import importlib
import types
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, get_args, get_origin

import aiohttp
import requests
Expand All @@ -27,7 +27,11 @@
OSIDB_BINDINGS_PLACEHOLDER_FIELD,
OSIDB_BINDINGS_USERAGENT,
)
from .exceptions import OperationUnsupported, UndefinedRequestBody
from .exceptions import (
OperationUnsupported,
OSIDBBindingsException,
UndefinedRequestBody,
)
from .helpers import get_env
from .iterators import Paginator

Expand Down Expand Up @@ -108,6 +112,19 @@ def get_async_function(api_module: ModuleType) -> Callable:
)


def serialize_data(data, model):
"""
Serialize data into bindings model or list of bindings models
"""
if hasattr(model, "from_dict"):
return model.from_dict(data)
elif get_origin(model) is list:
inner_model = get_args(model)[0]
return [serialize_data(item, inner_model) for item in data]
else:
raise OSIDBBindingsException(f"Unserializable model '{model}'")


def new_session(
osidb_server_uri,
password=None,
Expand Down Expand Up @@ -207,6 +224,10 @@ def __init__(self, base_url, auth=None, verify_ssl=True):
"list",
"create",
"destroy",
"bulk_create",
"bulk_update",
# TODO: currently blocked by OSIDB-2996
# "bulk_delete",
),
subresources={
"cvss_scores": {
Expand Down Expand Up @@ -389,19 +410,40 @@ def create(self, form_data: Dict[str, Any], *args, **kwargs):
if model is None:
self.__raise_undefined_request_body("create")

transformed_data = model.from_dict(form_data)
serialized_data = serialize_data(form_data, model)
sync_fn = get_sync_function(method_module)
return sync_fn(
*args,
client=self.client(),
form_data=transformed_data,
form_data=serialized_data,
multipart_data=UNSET,
json_body=UNSET,
**kwargs,
)
else:
self.__raise_operation_unsupported("create")

def bulk_create(self, form_data: Dict[str, Any], *args, **kwargs):
if "bulk_create" in self.allowed_operations:
method_module = self.__get_method_module(
resource_name=self.resource_name, method="bulk_create"
)
model = getattr(method_module, "REQUEST_BODY_TYPE", None)
if model is None:
self.__raise_undefined_request_body("bulk_create")

serialized_data = serialize_data(form_data, model)
sync_fn = get_sync_function(method_module)
return sync_fn(
*args,
client=self.client(),
json_body=serialized_data,
multipart_data=UNSET,
**kwargs,
)
else:
self.__raise_operation_unsupported("bulk_create")

def update(self, id, form_data: Dict[str, Any], *args, **kwargs):
if "update" in self.allowed_operations:
method_module = self.__get_method_module(
Expand All @@ -411,20 +453,41 @@ def update(self, id, form_data: Dict[str, Any], *args, **kwargs):
if model is None:
self.__raise_undefined_request_body("update")

transformed_data = model.from_dict(form_data)
serialized_data = serialize_data(form_data, model)
sync_fn = get_sync_function(method_module)
return sync_fn(
id,
*args,
client=self.client(),
form_data=transformed_data,
form_data=serialized_data,
multipart_data=UNSET,
json_body=UNSET,
**kwargs,
)
else:
self.__raise_operation_unsupported("update")

def bulk_update(self, form_data: Dict[str, Any], *args, **kwargs):
if "bulk_update" in self.allowed_operations:
method_module = self.__get_method_module(
resource_name=self.resource_name, method="bulk_update"
)
model = getattr(method_module, "REQUEST_BODY_TYPE", None)
if model is None:
self.__raise_undefined_request_body("bulk_update")

serialized_data = serialize_data(form_data, model)
sync_fn = get_sync_function(method_module)
return sync_fn(
*args,
client=self.client(),
json_body=serialized_data,
multipart_data=UNSET,
**kwargs,
)
else:
self.__raise_operation_unsupported("bulk_update")

def delete(self, id, *args, **kwargs):
if "destroy" in self.allowed_operations:
method_module = self.__get_method_module(
Expand All @@ -440,6 +503,27 @@ def delete(self, id, *args, **kwargs):
else:
self.__raise_operation_unsupported("delete")

def bulk_delete(self, form_data: Dict[str, Any], *args, **kwargs):
if "bulk_delete" in self.allowed_operations:
method_module = self.__get_method_module(
resource_name=self.resource_name, method="bulk_destroy"
)
model = getattr(method_module, "REQUEST_BODY_TYPE", None)
if model is None:
self.__raise_undefined_request_body("bulk_delete")

serialized_data = serialize_data(form_data, model)
sync_fn = get_sync_function(method_module)
return sync_fn(
*args,
client=self.client(),
json_body=serialized_data,
multipart_data=UNSET,
**kwargs,
)
else:
self.__raise_operation_unsupported("bulk_delete")

# Extra operations

def count(self, *args, **kwargs):
Expand Down
8 changes: 5 additions & 3 deletions osidb_bindings/templates/endpoint_module.py.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ QUERY_PARAMS = {
}
{% if endpoint.form_body_class %}
REQUEST_BODY_TYPE = {{ endpoint.form_body_class.name }}
{% elif endpoint.json_body %}
REQUEST_BODY_TYPE = {{ endpoint.json_body.get_type_string() }}
{% endif %}

def _get_kwargs(
Expand Down Expand Up @@ -58,10 +60,10 @@ def _get_kwargs(
{% elif endpoint.json_body %}
"json": {{ "json_" + endpoint.json_body.python_name }},
{% endif %} #}
{% if endpoint.json_body %}
{% if endpoint.form_body_class %}
"json": form_data.to_dict(),
{% elif endpoint.form_body_class %}
"data": form_data.to_dict(),
{% elif endpoint.json_body %}
"json": {{ "json_" + endpoint.json_body.python_name }},
{% elif endpoint.multipart_body %}
"files": {{ "multipart_" + endpoint.multipart_body.python_name }},
{% endif %}
Expand Down