diff --git a/cl/runtime/__main__.py b/cl/runtime/__main__.py index cba03d91..16207bc9 100644 --- a/cl/runtime/__main__.py +++ b/cl/runtime/__main__.py @@ -22,7 +22,6 @@ from starlette.responses import JSONResponse from starlette.staticfiles import StaticFiles from cl.runtime import Context -from cl.runtime.routers.context_middleware import ContextMiddleware from cl.runtime.context.process_context import ProcessContext from cl.runtime.log.exceptions.user_error import UserError from cl.runtime.log.log_entry import LogEntry @@ -30,6 +29,7 @@ from cl.runtime.log.user_log_entry import UserLogEntry from cl.runtime.routers.app import app_router from cl.runtime.routers.auth import auth_router +from cl.runtime.routers.context_middleware import ContextMiddleware from cl.runtime.routers.entity import entity_router from cl.runtime.routers.health import health_router from cl.runtime.routers.schema import schema_router diff --git a/cl/runtime/context/context.py b/cl/runtime/context/context.py index b95dff9e..b8cb66b7 100644 --- a/cl/runtime/context/context.py +++ b/cl/runtime/context/context.py @@ -15,7 +15,8 @@ import logging from contextvars import ContextVar from dataclasses import dataclass -from typing import Iterable, Dict +from typing import Dict +from typing import Iterable from typing import List from typing import Optional from typing import Type diff --git a/cl/runtime/context/context_util.py b/cl/runtime/context/context_util.py index 37f7460c..e2043fd7 100644 --- a/cl/runtime/context/context_util.py +++ b/cl/runtime/context/context_util.py @@ -14,7 +14,8 @@ import base64 from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import padding from cl.runtime import Context @@ -45,19 +46,13 @@ def decrypt_secret(cls, key: str) -> str | None: # Load the private key private_key = serialization.load_pem_private_key( - private_key_pem.encode(), - password=None, - backend=default_backend() + private_key_pem.encode(), password=None, backend=default_backend() ) # Decrypt the value decrypted_value_bytes = private_key.decrypt( encrypted_value_bytes, - padding.OAEP( - mgf=padding.MGF1(algorithm=hashes.SHA256()), - algorithm=hashes.SHA256(), - label=None - ) + padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None), ) - return decrypted_value_bytes.decode('utf-8') + return decrypted_value_bytes.decode("utf-8") diff --git a/cl/runtime/exceptions/error_util.py b/cl/runtime/exceptions/error_util.py index cf7adb49..c93b4a47 100644 --- a/cl/runtime/exceptions/error_util.py +++ b/cl/runtime/exceptions/error_util.py @@ -14,7 +14,6 @@ from typing import Any from typing import Type - from cl.runtime.log.exceptions.user_error import UserError from cl.runtime.primitive.case_util import CaseUtil from cl.runtime.primitive.string_util import StringUtil diff --git a/cl/runtime/plots/heat_map_plot.py b/cl/runtime/plots/heat_map_plot.py index 1883c267..0dc136c9 100644 --- a/cl/runtime/plots/heat_map_plot.py +++ b/cl/runtime/plots/heat_map_plot.py @@ -77,4 +77,4 @@ def _create_figure(self) -> plt.Figure: fig.tight_layout() - return fig \ No newline at end of file + return fig diff --git a/cl/runtime/routers/entity/list_panels_response_item.py b/cl/runtime/routers/entity/list_panels_response_item.py index b6589ffd..8f022857 100644 --- a/cl/runtime/routers/entity/list_panels_response_item.py +++ b/cl/runtime/routers/entity/list_panels_response_item.py @@ -15,14 +15,12 @@ from __future__ import annotations from typing import List from pydantic import BaseModel - from cl.runtime import Context from cl.runtime.primitive.case_util import CaseUtil from cl.runtime.routers.entity.list_panels_request import ListPanelsRequest from cl.runtime.schema.handler_declare_block_decl import HandlerDeclareBlockDecl from cl.runtime.schema.handler_declare_decl import HandlerDeclareDecl from cl.runtime.schema.schema import Schema - from cl.runtime.serialization.string_serializer import StringSerializer @@ -63,16 +61,15 @@ def list_panels(cls, request: ListPanelsRequest) -> List[ListPanelsResponseItem] if handlers_block is not None and handlers_block: return [ - ListPanelsResponseItem( - name=handler.label, - type=cls.get_type(handler) - ) for handler in handlers_block if handler.type_ == "Viewer" + ListPanelsResponseItem(name=handler.label, type=cls.get_type(handler)) + for handler in handlers_block + if handler.type_ == "Viewer" ] return [] @classmethod def get_type(cls, handler: HandlerDeclareDecl) -> str | None: """Get type of the handler.""" - - if handler.type_ == 'Viewer' and handler.name == 'view_self': - return 'Primary' + + if handler.type_ == "Viewer" and handler.name == "view_self": + return "Primary" diff --git a/cl/runtime/routers/storage/record_response.py b/cl/runtime/routers/storage/record_response.py index dee2418f..1a28ffaf 100644 --- a/cl/runtime/routers/storage/record_response.py +++ b/cl/runtime/routers/storage/record_response.py @@ -13,14 +13,12 @@ # limitations under the License. from __future__ import annotations - import dataclasses -from pydantic import BaseModel -from pydantic import Field from typing import Any from typing import Dict from typing import List - +from pydantic import BaseModel +from pydantic import Field from cl.runtime import Context from cl.runtime.backend.core.ui_app_state import UiAppState from cl.runtime.backend.core.ui_app_state_key import UiAppStateKey @@ -147,8 +145,7 @@ def get_record(cls, request: RecordRequest) -> RecordResponse: deserialized_key = UiTypeStateKey( user=UserKey(username=username or "root"), type_=TypeDeclKey( - name=type_state_record_type_name, - module=ModuleDeclKey(module_name=type_state_record_module) + name=type_state_record_type_name, module=ModuleDeclKey(module_name=type_state_record_module) ), ) else: @@ -174,7 +171,8 @@ def get_record(cls, request: RecordRequest) -> RecordResponse: all_handlers.extend( [ - handler_name for handler_decl in handlers_block + handler_name + for handler_decl in handlers_block if (handler_name := handler_decl.get("Name")) not in all_handlers ] ) diff --git a/cl/runtime/routers/storage/save_permanently_request.py b/cl/runtime/routers/storage/save_permanently_request.py index 09a174e0..a9186a04 100644 --- a/cl/runtime/routers/storage/save_permanently_request.py +++ b/cl/runtime/routers/storage/save_permanently_request.py @@ -29,4 +29,3 @@ class SavePermanentlyRequest(BaseModel): with_dependencies: bool = False """Flag that indicated whether to include nested dependencies for Dag objects.""" - diff --git a/cl/runtime/routers/storage/save_permanently_response.py b/cl/runtime/routers/storage/save_permanently_response.py index 8148af13..5316e541 100644 --- a/cl/runtime/routers/storage/save_permanently_response.py +++ b/cl/runtime/routers/storage/save_permanently_response.py @@ -14,13 +14,12 @@ from collections import defaultdict from pathlib import Path -from typing import Iterable, Type, DefaultDict - +from typing import DefaultDict +from typing import Iterable +from typing import Type from urllib import parse - import pandas as pd from pydantic import BaseModel - from cl.runtime import Context from cl.runtime.db.protocols import TRecord from cl.runtime.file.file_util import FileUtil @@ -65,7 +64,7 @@ def _get_extension(cls) -> str: # TODO (Bohdan): Check if it makes sense to have a config which format/extension to use. # If not - simplify the code. - return 'csv' + return "csv" @classmethod def _get_path_to_save_permanently_folder(cls) -> Path: @@ -83,9 +82,9 @@ def _write_records(cls, file_path: Path, records: Iterable[TRecord]) -> None: serializer = FlatDictSerializer() # TODO (Bohdan): Provide a proper serializer serialized_records = [serializer.serialize_data(record) for record in records] - if file_extension == 'csv': + if file_extension == "csv": df = pd.DataFrame([serialized_records]) - df.to_csv(file_path, mode='w', index=False, header=True) + df.to_csv(file_path, mode="w", index=False, header=True) else: raise ValueError(f"File extension {file_extension} is not supported.") @@ -94,7 +93,7 @@ def save_permanently(cls, request: SavePermanentlyRequest) -> "SavePermanentlyRe """Save records to the database on the disk.""" for record_type, records in get_type_to_records_map(request).items(): - filename = f'{record_type.__name__}.{cls._get_extension()}' + filename = f"{record_type.__name__}.{cls._get_extension()}" FileUtil.check_valid_filename(filename) file_path = cls._get_path_to_save_permanently_folder() / filename file_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/cl/runtime/routers/tasks/run_response_item.py b/cl/runtime/routers/tasks/run_response_item.py index f4f4c425..9fc90383 100644 --- a/cl/runtime/routers/tasks/run_response_item.py +++ b/cl/runtime/routers/tasks/run_response_item.py @@ -16,7 +16,6 @@ import traceback from typing import List from pydantic import BaseModel - from cl.runtime import Context from cl.runtime.primitive.case_util import CaseUtil from cl.runtime.records.dataclasses_extensions import missing diff --git a/cl/runtime/settings/api_settings.py b/cl/runtime/settings/api_settings.py index 2f0fe66c..9f87537c 100644 --- a/cl/runtime/settings/api_settings.py +++ b/cl/runtime/settings/api_settings.py @@ -14,7 +14,6 @@ from dataclasses import dataclass from typing import List - from cl.runtime.settings.settings import Settings @@ -113,8 +112,9 @@ def init(self) -> None: elif isinstance(self.allow_headers, str) or not hasattr(self.allow_headers, "__iter__"): raise RuntimeError(f"{type(self).__name__} field 'allow_headers' must be a list or None.") - if self.expose_headers is not None and (isinstance(self.expose_headers, str) or - not hasattr(self.expose_headers, "__iter__")): + if self.expose_headers is not None and ( + isinstance(self.expose_headers, str) or not hasattr(self.expose_headers, "__iter__") + ): raise RuntimeError(f"{type(self).__name__} field 'expose_headers' must be a list or None.") if self.max_age is not None and not isinstance(self.max_age, int): diff --git a/cl/runtime/tasks/celery/celery_queue.py b/cl/runtime/tasks/celery/celery_queue.py index d43f5e8b..8be7ecbc 100644 --- a/cl/runtime/tasks/celery/celery_queue.py +++ b/cl/runtime/tasks/celery/celery_queue.py @@ -161,4 +161,3 @@ def submit_task(self, task: TaskKey): retry=False, # Do not retry in case the task fails ignore_result=True, # TODO: Do not publish to the Celery result backend ) - diff --git a/cl/runtime/tasks/function_task.py b/cl/runtime/tasks/function_task.py index ec8c9d1d..e6f58f32 100644 --- a/cl/runtime/tasks/function_task.py +++ b/cl/runtime/tasks/function_task.py @@ -39,11 +39,11 @@ def _execute(self) -> None: @classmethod def create( - cls, - *, - queue: TaskQueueKey, - record_type: Type, - method: Callable, + cls, + *, + queue: TaskQueueKey, + record_type: Type, + method: Callable, ) -> Self: """Create from static or class handler method callable.""" raise NotImplementedError() diff --git a/cl/runtime/tasks/process_queue.py b/cl/runtime/tasks/process_queue.py index 3acf6f93..e391f5da 100644 --- a/cl/runtime/tasks/process_queue.py +++ b/cl/runtime/tasks/process_queue.py @@ -46,14 +46,10 @@ def run_start_queue(self) -> None: # TODO: Use DB queries with filter by queue field all_tasks = context.load_all(Task) awaiting_tasks = [ - task for task in all_tasks - if task.queue.queue_id == queue_id - and task.status == TaskStatusEnum.AWAITING + task for task in all_tasks if task.queue.queue_id == queue_id and task.status == TaskStatusEnum.AWAITING ] pending_tasks = [ - task for task in all_tasks - if task.queue.queue_id == queue_id - and task.status == TaskStatusEnum.PENDING + task for task in all_tasks if task.queue.queue_id == queue_id and task.status == TaskStatusEnum.PENDING ] # Awaiting tasks have priority over pending tasks diff --git a/cl/runtime/tasks/static_method_task.py b/cl/runtime/tasks/static_method_task.py index 9d72166f..afc561ae 100644 --- a/cl/runtime/tasks/static_method_task.py +++ b/cl/runtime/tasks/static_method_task.py @@ -51,11 +51,11 @@ def _execute(self) -> None: @classmethod def create( - cls, - *, - queue: TaskQueueKey, - record_type: Type, - method_callable: Callable, + cls, + *, + queue: TaskQueueKey, + record_type: Type, + method_callable: Callable, ) -> Self: """Create from @staticmethod callable and record type.""" diff --git a/cl/runtime/tasks/task.py b/cl/runtime/tasks/task.py index ce2b9b61..cf38f35a 100644 --- a/cl/runtime/tasks/task.py +++ b/cl/runtime/tasks/task.py @@ -75,11 +75,7 @@ def init(self) -> None: self.task_id = Timestamp.create() else: # Otherwise validate - Timestamp.validate( - self.task_id, - value_name="task_id", - data_type="TaskKey" - ) + Timestamp.validate(self.task_id, value_name="task_id", data_type="TaskKey") # Set status and progress_pct if not yet set if self.status is None: diff --git a/cl/runtime/tasks/task_key.py b/cl/runtime/tasks/task_key.py index a35336bf..a642f0da 100644 --- a/cl/runtime/tasks/task_key.py +++ b/cl/runtime/tasks/task_key.py @@ -14,7 +14,6 @@ from dataclasses import dataclass from typing import Type - from cl.runtime.primitive.timestamp import Timestamp from cl.runtime.records.dataclasses_extensions import missing from cl.runtime.records.key_mixin import KeyMixin @@ -40,11 +39,7 @@ class TaskKey(KeyMixin): def init(self) -> None: # Check only if inside a key, will be set automatically if inside a record if is_key(self): - Timestamp.validate( - self.task_id, - value_name="task_id", - data_type="TaskKey" - ) + Timestamp.validate(self.task_id, value_name="task_id", data_type="TaskKey") @classmethod def get_key_type(cls) -> Type: diff --git a/cl/runtime/tasks/task_queue.py b/cl/runtime/tasks/task_queue.py index c874f498..4ad58d4f 100644 --- a/cl/runtime/tasks/task_queue.py +++ b/cl/runtime/tasks/task_queue.py @@ -44,4 +44,3 @@ def run_start_queue(self) -> None: @abstractmethod def run_stop_queue(self) -> None: """Exit after completing all currently executing tasks.""" - diff --git a/stubs/cl/runtime/configs/stub_runtime_config.py b/stubs/cl/runtime/configs/stub_runtime_config.py index 7b4ea9bc..ac9772a7 100644 --- a/stubs/cl/runtime/configs/stub_runtime_config.py +++ b/stubs/cl/runtime/configs/stub_runtime_config.py @@ -16,22 +16,24 @@ from cl.runtime.configs.config import Config from cl.runtime.context.context import Context from cl.runtime.plots.group_bar_plot import GroupBarPlot -from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord, StubFileViewers, StubDagViewers +from stubs.cl.runtime import StubDagViewers +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord from stubs.cl.runtime import StubDataclassDerivedRecord from stubs.cl.runtime import StubDataclassDictFields from stubs.cl.runtime import StubDataclassDictListFields from stubs.cl.runtime import StubDataclassListDictFields from stubs.cl.runtime import StubDataclassListFields from stubs.cl.runtime import StubDataclassNestedFields -from stubs.cl.runtime import StubDataclassComposite from stubs.cl.runtime import StubDataclassOptionalFields from stubs.cl.runtime import StubDataclassOtherDerivedRecord from stubs.cl.runtime import StubDataclassPrimitiveFields from stubs.cl.runtime import StubDataclassRecord from stubs.cl.runtime import StubDataclassSingleton +from stubs.cl.runtime import StubDataViewers +from stubs.cl.runtime import StubFileViewers from stubs.cl.runtime import StubHandlers from stubs.cl.runtime import StubPlotViewers -from stubs.cl.runtime import StubDataViewers @dataclass(slots=True, kw_only=True) diff --git a/stubs/cl/runtime/records/for_dataclasses/stub_dataclass_nested_fields.py b/stubs/cl/runtime/records/for_dataclasses/stub_dataclass_nested_fields.py index 9ff5c6bc..1e522f5d 100644 --- a/stubs/cl/runtime/records/for_dataclasses/stub_dataclass_nested_fields.py +++ b/stubs/cl/runtime/records/for_dataclasses/stub_dataclass_nested_fields.py @@ -49,4 +49,3 @@ class StubDataclassNestedFields(StubDataclassRecord): record_as_key_field: StubDataclassRecordKey = field(default_factory=lambda: StubDataclassRecord()) """Stub field with key type initialized to record type instance.""" - diff --git a/stubs/cl/runtime/views/stub_data_viewers.py b/stubs/cl/runtime/views/stub_data_viewers.py index b0fd1153..65e47901 100644 --- a/stubs/cl/runtime/views/stub_data_viewers.py +++ b/stubs/cl/runtime/views/stub_data_viewers.py @@ -15,9 +15,12 @@ from dataclasses import dataclass from typing import List from typing_extensions import Self - -from stubs.cl.runtime import StubDataclassRecordKey, StubDataclassRecord, StubDataclassNestedFields, \ - StubDataclassListFields, StubDataclassComposite, StubDataclassCompositeKey +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassCompositeKey +from stubs.cl.runtime import StubDataclassListFields +from stubs.cl.runtime import StubDataclassNestedFields +from stubs.cl.runtime import StubDataclassRecord +from stubs.cl.runtime import StubDataclassRecordKey from stubs.cl.runtime.views.stub_viewers import StubViewers from stubs.cl.runtime.views.stub_viewers_key import StubViewersKey diff --git a/tests/cl/runtime/context/test_context_util.py b/tests/cl/runtime/context/test_context_util.py index 64529be9..3393c592 100644 --- a/tests/cl/runtime/context/test_context_util.py +++ b/tests/cl/runtime/context/test_context_util.py @@ -12,31 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import base64 import pytest +import base64 from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives import serialization, hashes -from cryptography.hazmat.primitives.asymmetric import rsa, padding +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import padding +from cryptography.hazmat.primitives.asymmetric import rsa from cl.runtime.context.context_util import ContextUtil from cl.runtime.context.testing_context import TestingContext def _generate_rsa_private_cert() -> str: # Generate private key - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=default_backend() - ) + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) # Convert private key to PEM format pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) # Convert PEM bytes to string - pem_str = pem.decode('utf-8') + pem_str = pem.decode("utf-8") return pem_str @@ -47,11 +45,7 @@ def _encrypt_value(value: str) -> str: public_key = serialization.load_pem_public_key(public_key_pem.encode()) encrypted = public_key.encrypt( value.encode(), - padding=padding.OAEP( - mgf=padding.MGF1(algorithm=hashes.SHA256()), - algorithm=hashes.SHA256(), - label=None - ) + padding=padding.OAEP(mgf=padding.MGF1(algorithm=hashes.SHA256()), algorithm=hashes.SHA256(), label=None), ) return base64.b64encode(encrypted).decode() @@ -62,7 +56,7 @@ def test_decrypt_secret(): with TestingContext() as context: - key = 'test_key' + key = "test_key" value = "secret_value" encrypted_value = _encrypt_value(value) context.secrets[key] = encrypted_value diff --git a/tests/cl/runtime/db/sql/test_sqlite_db.py b/tests/cl/runtime/db/sql/test_sqlite_db.py index febe3ced..8ed0e5ad 100644 --- a/tests/cl/runtime/db/sql/test_sqlite_db.py +++ b/tests/cl/runtime/db/sql/test_sqlite_db.py @@ -19,7 +19,8 @@ from cl.runtime.context.testing_context import TestingContext from cl.runtime.db.sql.sqlite_db import SqliteDb from cl.runtime.records.class_info import ClassInfo -from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord, StubDataclassComposite +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord from stubs.cl.runtime import StubDataclassDerivedRecord from stubs.cl.runtime import StubDataclassDictFields from stubs.cl.runtime import StubDataclassDictListFields diff --git a/tests/cl/runtime/db/sql/test_sqlite_schema_manager.py b/tests/cl/runtime/db/sql/test_sqlite_schema_manager.py index 65add65f..ceda0ff3 100644 --- a/tests/cl/runtime/db/sql/test_sqlite_schema_manager.py +++ b/tests/cl/runtime/db/sql/test_sqlite_schema_manager.py @@ -18,12 +18,13 @@ from cl.runtime.db.sql.sqlite_schema_manager import SqliteSchemaManager from cl.runtime.schema.schema import Schema from cl.runtime.testing.regression_guard import RegressionGuard -from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord, StubDataclassNestedFields +from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord from stubs.cl.runtime import StubDataclassDerivedRecord from stubs.cl.runtime import StubDataclassDictFields from stubs.cl.runtime import StubDataclassDictListFields from stubs.cl.runtime import StubDataclassListDictFields from stubs.cl.runtime import StubDataclassListFields +from stubs.cl.runtime import StubDataclassNestedFields from stubs.cl.runtime import StubDataclassOtherDerivedRecord from stubs.cl.runtime import StubDataclassRecord from stubs.cl.runtime import StubDataclassRecordKey diff --git a/tests/cl/runtime/exceptions/test_error_util.py b/tests/cl/runtime/exceptions/test_error_util.py index 669181bd..e624193a 100644 --- a/tests/cl/runtime/exceptions/test_error_util.py +++ b/tests/cl/runtime/exceptions/test_error_util.py @@ -23,15 +23,26 @@ def test_value_error(): guard.write(ErrorUtil.value_error(value=123)) guard.write(ErrorUtil.value_error(value=123, value_name="sample_value_name")) guard.write( - ErrorUtil.value_error(value=123, value_name="sample_value_name", method_name="sample_function",) + ErrorUtil.value_error( + value=123, + value_name="sample_value_name", + method_name="sample_function", + ) ) guard.write(ErrorUtil.value_error(value=123, method_name="sample_function")) guard.write( - ErrorUtil.value_error(value=123, method_name="sample_method", data_type="SampleRecord",) + ErrorUtil.value_error( + value=123, + method_name="sample_method", + data_type="SampleRecord", + ) ) guard.write( ErrorUtil.value_error( - value=123, value_name="sample_value_name", method_name="sample_method", data_type="SampleRecord", + value=123, + value_name="sample_value_name", + method_name="sample_method", + data_type="SampleRecord", ) ) guard.verify() @@ -55,9 +66,7 @@ def test_of_param(): guard.write(ErrorUtil._of_param(param_name="sample_param", method_name="sample_function")) guard.write(ErrorUtil._of_param(method_name="sample_function")) guard.write(ErrorUtil._of_param(method_name="sample_method", data_type="SampleRecord")) - guard.write( - ErrorUtil._of_param(param_name="sample_param", method_name="sample_method", data_type="SampleRecord") - ) + guard.write(ErrorUtil._of_param(param_name="sample_param", method_name="sample_method", data_type="SampleRecord")) guard.verify() diff --git a/tests/cl/runtime/file/test_csv_file_reader.py b/tests/cl/runtime/file/test_csv_file_reader.py index 12f87c8a..21727c40 100644 --- a/tests/cl/runtime/file/test_csv_file_reader.py +++ b/tests/cl/runtime/file/test_csv_file_reader.py @@ -17,7 +17,8 @@ from cl.runtime.context.env_util import EnvUtil from cl.runtime.context.testing_context import TestingContext from cl.runtime.file.csv_file_reader import CsvFileReader -from stubs.cl.runtime import StubDataclassDerivedRecord, StubDataclassComposite +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassDerivedRecord from stubs.cl.runtime import StubDataclassNestedFields from stubs.cl.runtime import StubDataclassRecord from stubs.cl.runtime import StubDataclassRecordKey diff --git a/tests/cl/runtime/file/test_csv_record_roundtrip.py b/tests/cl/runtime/file/test_csv_record_roundtrip.py index dad7b0e8..b4f81597 100644 --- a/tests/cl/runtime/file/test_csv_record_roundtrip.py +++ b/tests/cl/runtime/file/test_csv_record_roundtrip.py @@ -25,7 +25,8 @@ from cl.runtime.records.protocols import is_key from cl.runtime.serialization.flat_dict_serializer import FlatDictSerializer from cl.runtime.serialization.string_serializer import StringSerializer -from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord, StubDataclassComposite +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord from stubs.cl.runtime import StubDataclassDerivedRecord from stubs.cl.runtime import StubDataclassDictFields from stubs.cl.runtime import StubDataclassDictListFields diff --git a/tests/cl/runtime/plots/test_successor_dag_plot.py b/tests/cl/runtime/plots/test_successor_dag_plot.py index 8f9720a5..8d65941e 100644 --- a/tests/cl/runtime/plots/test_successor_dag_plot.py +++ b/tests/cl/runtime/plots/test_successor_dag_plot.py @@ -13,21 +13,23 @@ # limitations under the License. import pytest -from cl.runtime.context.testing_context import TestingContext from dataclasses import dataclass -from typing import List, Optional -import networkx as nx +from typing import List +from typing import Optional import matplotlib.pyplot as plt +import networkx as nx from matplotlib.patches import Rectangle +from cl.runtime.context.testing_context import TestingContext from cl.runtime.testing.pytest.pytest_fixtures import local_dir_fixture + def test_smoke(local_dir_fixture): with TestingContext() as context: @dataclass class Node: title: str - successors: Optional[List['Node']] = None + successors: Optional[List["Node"]] = None # Define the nodes with successors staff_a = Node(title="Staff A") @@ -68,8 +70,9 @@ def add_nodes_recursive(graph, node, current_id, x, y, pos, labels) -> int: graph.add_edge(current_id, successor_id) # Position each successor progressively lower y = y - (i + 1) * y_offset # Adjust vertical spacing between successors - y = add_nodes_recursive(graph, successor, successor_id, x + x_offset, y, pos, - labels) # Adjust horizontal spacing + y = add_nodes_recursive( + graph, successor, successor_id, x + x_offset, y, pos, labels + ) # Adjust horizontal spacing return y # Add CEO node and its successors recursively @@ -93,9 +96,12 @@ def draw_edges_with_custom_arrows(graph, pos, ax): # Draw the arrow ax.annotate( - '', xy=(entry_x, end_y), xytext=(exit_x, start_y), - arrowprops=dict(arrowstyle='-|>', lw=1.5, color='black', - connectionstyle='arc3,rad=0.0') # Straight arrow + "", + xy=(entry_x, end_y), + xytext=(exit_x, start_y), + arrowprops=dict( + arrowstyle="-|>", lw=1.5, color="black", connectionstyle="arc3,rad=0.0" + ), # Straight arrow ) # Call the function to draw custom arrows @@ -110,8 +116,9 @@ def draw_edges_with_custom_arrows(graph, pos, ax): width = 3 height = 1.5 # Draw a rectangle centered on the node's position - rect = Rectangle((x - width / 2, y - height / 2), width, height, - linewidth=1, edgecolor='black', facecolor='lightblue') + rect = Rectangle( + (x - width / 2, y - height / 2), width, height, linewidth=1, edgecolor="black", facecolor="lightblue" + ) ax.add_patch(rect) # Dynamically calculate plot limits to ensure all boxes fit @@ -131,5 +138,6 @@ def draw_edges_with_custom_arrows(graph, pos, ax): # Add a title and display the plot # plt.savefig("test_successor_dag.png") + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/cl/runtime/records/test_record_util.py b/tests/cl/runtime/records/test_record_util.py index f14be15d..f6fce478 100644 --- a/tests/cl/runtime/records/test_record_util.py +++ b/tests/cl/runtime/records/test_record_util.py @@ -16,7 +16,8 @@ from cl.runtime.db.protocols import TKey from cl.runtime.records.record_util import RecordUtil from cl.runtime.testing.regression_guard import RegressionGuard -from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord, StubDataclassComposite +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord from stubs.cl.runtime import StubDataclassDerivedRecord from stubs.cl.runtime import StubDataclassDictFields from stubs.cl.runtime import StubDataclassDictListFields diff --git a/tests/cl/runtime/routers/entity/test_list_panels.py b/tests/cl/runtime/routers/entity/test_list_panels.py index f626ff3d..147f36a8 100644 --- a/tests/cl/runtime/routers/entity/test_list_panels.py +++ b/tests/cl/runtime/routers/entity/test_list_panels.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -from typing import Type, List - import pytest +import inspect +from typing import List +from typing import Type from fastapi import FastAPI from fastapi.testclient import TestClient - from cl.runtime.context.testing_context import TestingContext from cl.runtime.primitive.case_util import CaseUtil from cl.runtime.routers.entity import entity_router @@ -77,7 +76,11 @@ def test_api(): for request in requests: # Split request headers and query request_headers = {"user": request.get("user")} - request_params = {"type": request.get("type"), "key": request.get("key"), "dataset": request.get("dataset")} + request_params = { + "type": request.get("type"), + "key": request.get("key"), + "dataset": request.get("dataset"), + } # Eliminate empty keys request_headers = {k: v for k, v in request_headers.items() if v is not None} diff --git a/tests/cl/runtime/routers/tasks/test_run.py b/tests/cl/runtime/routers/tasks/test_run.py index 726888d8..32a4c1cf 100644 --- a/tests/cl/runtime/routers/tasks/test_run.py +++ b/tests/cl/runtime/routers/tasks/test_run.py @@ -85,10 +85,7 @@ def test_method(celery_test_queue_fixture): request_object = RunRequest(**request) response_items = RunResponseItem.run_tasks(request_object) - [ - Task.wait_for_completion(TaskKey(task_id=response_item.task_run_id)) - for response_item in response_items - ] + [Task.wait_for_completion(TaskKey(task_id=response_item.task_run_id)) for response_item in response_items] actual_records = list(context.load_many(StubDataclassRecord, expected_keys)) assert actual_records == expected_records diff --git a/tests/cl/runtime/routers/tasks/test_run_result.py b/tests/cl/runtime/routers/tasks/test_run_result.py index c0451014..956d4008 100644 --- a/tests/cl/runtime/routers/tasks/test_run_result.py +++ b/tests/cl/runtime/routers/tasks/test_run_result.py @@ -13,7 +13,8 @@ # limitations under the License. import pytest -from typing import List, Dict +from typing import Dict +from typing import List from fastapi import FastAPI from starlette.testclient import TestClient from cl.runtime import Context diff --git a/tests/cl/runtime/routers/tasks/test_run_status.py b/tests/cl/runtime/routers/tasks/test_run_status.py index dd5d618b..f61ee142 100644 --- a/tests/cl/runtime/routers/tasks/test_run_status.py +++ b/tests/cl/runtime/routers/tasks/test_run_status.py @@ -13,7 +13,8 @@ # limitations under the License. import pytest -from typing import List, Dict +from typing import Dict +from typing import List from fastapi import FastAPI from starlette.testclient import TestClient from cl.runtime import Context diff --git a/tests/cl/runtime/schema/test_get_schema.py b/tests/cl/runtime/schema/test_get_schema.py index 84aecd82..48b44c45 100644 --- a/tests/cl/runtime/schema/test_get_schema.py +++ b/tests/cl/runtime/schema/test_get_schema.py @@ -15,7 +15,8 @@ import pytest from cl.runtime.schema.for_dataclasses.dataclass_type_decl import DataclassTypeDecl from cl.runtime.testing.regression_guard import RegressionGuard -from stubs.cl.runtime import StubDataclassListFields, StubDataclassComposite +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassListFields from stubs.cl.runtime import StubDataclassNestedFields from stubs.cl.runtime import StubDataclassPrimitiveFields from stubs.cl.runtime import StubDataclassRecord diff --git a/tests/cl/runtime/serialization/test_dict_serializer.py b/tests/cl/runtime/serialization/test_dict_serializer.py index 1b6251db..5dc9e79a 100644 --- a/tests/cl/runtime/serialization/test_dict_serializer.py +++ b/tests/cl/runtime/serialization/test_dict_serializer.py @@ -14,7 +14,8 @@ import pytest from cl.runtime.serialization.dict_serializer import DictSerializer -from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord, StubDataclassComposite +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord from stubs.cl.runtime import StubDataclassDerivedRecord from stubs.cl.runtime import StubDataclassDictFields from stubs.cl.runtime import StubDataclassDictListFields diff --git a/tests/cl/runtime/serialization/test_flat_dict_serialization.py b/tests/cl/runtime/serialization/test_flat_dict_serialization.py index 47188a22..61621a46 100644 --- a/tests/cl/runtime/serialization/test_flat_dict_serialization.py +++ b/tests/cl/runtime/serialization/test_flat_dict_serialization.py @@ -14,7 +14,8 @@ import pytest from cl.runtime.serialization.flat_dict_serializer import FlatDictSerializer -from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord, StubDataclassComposite +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord from stubs.cl.runtime import StubDataclassDerivedRecord from stubs.cl.runtime import StubDataclassDictFields from stubs.cl.runtime import StubDataclassDictListFields diff --git a/tests/cl/runtime/serialization/test_string_serializer.py b/tests/cl/runtime/serialization/test_string_serializer.py index d66a52d1..6e16e578 100644 --- a/tests/cl/runtime/serialization/test_string_serializer.py +++ b/tests/cl/runtime/serialization/test_string_serializer.py @@ -14,7 +14,8 @@ import pytest from cl.runtime.serialization.string_serializer import StringSerializer -from stubs.cl.runtime import StubDataclassListFields, StubDataclassComposite +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassListFields from stubs.cl.runtime import StubDataclassNestedFields from stubs.cl.runtime import StubDataclassOptionalFields from stubs.cl.runtime import StubDataclassPrimitiveFields diff --git a/tests/cl/runtime/serialization/test_ui_dict_serialization.py b/tests/cl/runtime/serialization/test_ui_dict_serialization.py index a88593cf..14187255 100644 --- a/tests/cl/runtime/serialization/test_ui_dict_serialization.py +++ b/tests/cl/runtime/serialization/test_ui_dict_serialization.py @@ -14,7 +14,8 @@ import pytest from cl.runtime.serialization.ui_dict_serializer import UiDictSerializer -from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord, StubDataclassComposite +from stubs.cl.runtime import StubDataclassComposite +from stubs.cl.runtime import StubDataclassDerivedFromDerivedRecord from stubs.cl.runtime import StubDataclassDerivedRecord from stubs.cl.runtime import StubDataclassDictFields from stubs.cl.runtime import StubDataclassDictListFields diff --git a/tests/cl/runtime/tasks/celery/test_celery_queue.py b/tests/cl/runtime/tasks/celery/test_celery_queue.py index 55ac23f5..eb090ede 100644 --- a/tests/cl/runtime/tasks/celery/test_celery_queue.py +++ b/tests/cl/runtime/tasks/celery/test_celery_queue.py @@ -13,7 +13,6 @@ # limitations under the License. import pytest - from cl.runtime import Context from cl.runtime.context.testing_context import TestingContext from cl.runtime.serialization.dict_serializer import DictSerializer @@ -34,10 +33,7 @@ def _create_task(queue: TaskQueueKey) -> TaskKey: """Create a test task.""" method_callable = StubHandlers.run_static_method_1a - task = StaticMethodTask.create( - queue=queue, - record_type=StubHandlers, - method_callable=method_callable) + task = StaticMethodTask.create(queue=queue, record_type=StubHandlers, method_callable=method_callable) Context.current().save_one(task) return task.get_key() diff --git a/tests/cl/runtime/tasks/test_static_method_task.py b/tests/cl/runtime/tasks/test_static_method_task.py index 6fda56ef..6c47df43 100644 --- a/tests/cl/runtime/tasks/test_static_method_task.py +++ b/tests/cl/runtime/tasks/test_static_method_task.py @@ -32,9 +32,7 @@ def test_create(): record_type = sample_input[0] method_callable = sample_input[1] task = StaticMethodTask.create( - queue=TaskQueueKey(queue_id="Sample Queue"), - record_type=record_type, - method_callable=method_callable + queue=TaskQueueKey(queue_id="Sample Queue"), record_type=record_type, method_callable=method_callable ) task.run_task()