diff --git a/DMBotNetwork/__init__.py b/DMBotNetwork/__init__.py index d269bf8..b229db8 100644 --- a/DMBotNetwork/__init__.py +++ b/DMBotNetwork/__init__.py @@ -3,4 +3,4 @@ from .main.utils.cl_unit import ClUnit __all__ = ["Client", "Server", "ClUnit"] -__version__ = "0.2.1" +__version__ = "0.2.2" diff --git a/DMBotNetwork/main/client.py b/DMBotNetwork/main/client.py index 653998d..352b6d4 100644 --- a/DMBotNetwork/main/client.py +++ b/DMBotNetwork/main/client.py @@ -4,7 +4,8 @@ import logging from collections.abc import Callable from pathlib import Path -from typing import Any, Dict, Optional, get_type_hints +from typing import (Any, Dict, List, Optional, Type, Union, get_args, get_origin, + get_type_hints) import aiofiles @@ -18,6 +19,9 @@ class Client: _server_handler_task: Optional[asyncio.Task] = None _disconnect_lock = asyncio.Lock() + _data_cache: Dict[str, Any] = {} + _waiting_tasks: Dict[str, asyncio.Event] = {} + _server_name: str = "dev_server" _reader: Optional[asyncio.StreamReader] = None _writer: Optional[asyncio.StreamWriter] = None @@ -31,17 +35,21 @@ class Client: _content_path: Path = Path("") @classmethod - def register_methods_from_class(cls, external_class): + def register_methods_from_class(cls, external_classes: Type | List[Type]) -> None: """Регистрация методов с префиксом 'net_' из внешнего класса.""" - for name, func in inspect.getmembers( - external_class, predicate=inspect.isfunction - ): - if name.startswith("net_"): - method_name = name[4:] - cls._network_funcs[method_name] = func - logger.debug( - f"Registered method '{name}' from {external_class.__name__} as '{method_name}'" - ) + if not isinstance(external_classes, list): + external_classes = [external_classes] + + for external_class in external_classes: + for name, func in inspect.getmembers( + external_class, predicate=inspect.isfunction + ): + if name.startswith("net_"): + method_name = name[4:] + cls._network_funcs[method_name] = func + logger.debug( + f"Registered method '{name}' from {external_class.__name__} as '{method_name}'" + ) @classmethod async def _call_func( @@ -61,18 +69,26 @@ async def _call_func( for arg_name, arg_value in valid_kwargs.items(): expected_type = type_hints.get(arg_name, Any) - if not isinstance(arg_value, expected_type) and expected_type is not Any: - logger.error( - f"Type mismatch for argument '{arg_name}': expected {expected_type}, got {type(arg_value)}." - ) + if get_origin(expected_type) is Union: + if not isinstance(arg_value, get_args(expected_type)): + logger.error( + f"Type mismatch for argument '{arg_name}': expected {expected_type}, got {type(arg_value)}." + ) + return + + else: + if not isinstance(arg_value, expected_type): + logger.error( + f"Type mismatch for argument '{arg_name}': expected {expected_type}, got {type(arg_value)}." + ) return try: if inspect.iscoroutinefunction(func): - await func(cls, **valid_kwargs) + await func(**valid_kwargs) else: - func(cls, **valid_kwargs) + func(**valid_kwargs) except Exception as e: logger.error(f"Error calling method '{func_name}' in {cls.__name__}: {e}") @@ -87,6 +103,31 @@ async def send_package(cls, code: ResponseCode, **kwargs) -> None: async def req_net_func(cls, func_name: str, **kwargs) -> None: await cls.send_package(ResponseCode.NET_REQ, net_func_name=func_name, **kwargs) + @classmethod + async def req_get_data(cls, func_name: str, get_key: str, **kwargs) -> Any: + if get_key in cls._data_cache: + return cls._data_cache.pop(get_key) + + if get_key not in cls._waiting_tasks: + cls._waiting_tasks[get_key] = asyncio.Event() + await cls.send_package( + ResponseCode.GET_REQ, + net_func_name=func_name, + net_get_key=get_key, + **kwargs, + ) + + await cls._waiting_tasks[get_key].wait() + cls._waiting_tasks.pop(get_key, None) + return cls._data_cache.pop(get_key) + + @classmethod + async def _handle_data_from_server(cls, get_key: str, data: Any) -> None: + """Обработка полученных данных от сервера.""" + cls._data_cache[get_key] = data + if get_key in cls._waiting_tasks: + cls._waiting_tasks[get_key].set() + @classmethod def is_connected(cls) -> bool: return cls._is_auth and cls._is_connected @@ -175,12 +216,18 @@ async def _server_handler(cls) -> None: logger.error(f"Receive data must has 'code' key: {receive_package}") continue - if ResponseCode.is_net(code): + if code == ResponseCode.NET_REQ: await cls._call_func( receive_package.pop("net_func_name", None), **receive_package, ) + elif code == ResponseCode.GET_REQ: + get_key = receive_package.pop("get_key", None) + data = receive_package.pop("data", None) + if get_key: + await cls._handle_data_from_server(get_key, data) + elif ResponseCode.is_log(code): cls._log_handler(code, receive_package) diff --git a/DMBotNetwork/main/server.py b/DMBotNetwork/main/server.py index acac18a..33c9b35 100644 --- a/DMBotNetwork/main/server.py +++ b/DMBotNetwork/main/server.py @@ -3,7 +3,8 @@ import logging from collections.abc import Callable from pathlib import Path -from typing import Any, Dict, Optional, get_type_hints +from typing import (Any, Dict, List, Optional, Type, Union, get_args, + get_origin, get_type_hints) from .utils import ClUnit, ResponseCode, ServerDB @@ -14,6 +15,7 @@ class Server: _network_funcs: Dict[str, Callable] = {} _cl_units: Dict[str, ClUnit] = {} _server: Optional[asyncio.AbstractServer] = None + _cl_units_lock = asyncio.Lock() _is_online: bool = False @@ -23,24 +25,29 @@ class Server: _max_players: int = -1 @classmethod - def register_methods_from_class(cls, external_class): + def register_methods_from_class(cls, external_classes: Type | List[Type]) -> None: """Регистрация методов с префиксом 'net_' из внешнего класса.""" - for name, func in inspect.getmembers( - external_class, predicate=inspect.isfunction - ): - if name.startswith("net_"): - method_name = name[4:] - cls._network_funcs[method_name] = func - logger.debug( - f"Registered method '{name}' from {external_class.__name__} as '{method_name}'" - ) + if not isinstance(external_classes, list): + external_classes = [external_classes] + + for external_class in external_classes: + for name, func in inspect.getmembers( + external_class, predicate=inspect.isfunction + ): + if name.startswith("net_"): + method_name = name[4:] + cls._network_funcs[method_name] = func + logger.debug( + f"Registered method '{name}' from {external_class.__name__} as '{method_name}'" + ) @classmethod async def _call_func( cls, func_name: str, + cl_unit: ClUnit, **kwargs, - ) -> None: + ) -> Any: func = cls._network_funcs.get(func_name) if func is None: logger.debug(f"Network func '{func_name}' not found.") @@ -53,18 +60,26 @@ async def _call_func( for arg_name, arg_value in valid_kwargs.items(): expected_type = type_hints.get(arg_name, Any) - if not isinstance(arg_value, expected_type) and expected_type is not Any: - logger.error( - f"Type mismatch for argument '{arg_name}': expected {expected_type}, got {type(arg_value)}." - ) + if get_origin(expected_type) is Union: + if not isinstance(arg_value, get_args(expected_type)): + await cl_unit.send_log_error( + f"Type mismatch for argument '{arg_name}': expected {expected_type}, got {type(arg_value)}." + ) + return + + else: + if not isinstance(arg_value, expected_type): + await cl_unit.send_log_error( + f"Type mismatch for argument '{arg_name}': expected {expected_type}, got {type(arg_value)}." + ) return try: if inspect.iscoroutinefunction(func): - await func(cls, **valid_kwargs) + return await func(**valid_kwargs) else: - func(cls, **valid_kwargs) + return func(**valid_kwargs) except Exception as e: logger.error(f"Error calling method '{func_name}' in {cls.__name__}: {e}") @@ -117,7 +132,8 @@ async def start(cls) -> None: logger.error(f"Error starting server: {err}") finally: - await cls.stop() + if cls._is_online: + await cls.stop() @classmethod async def stop(cls) -> None: @@ -126,7 +142,9 @@ async def stop(cls) -> None: cls._is_online = False - asyncio.gather(*(cl_unit.disconnect() for cl_unit in cls._cl_units.values())) + await asyncio.gather( + *(cl_unit.disconnect() for cl_unit in cls._cl_units.values()) + ) cls._cl_units.clear() if cls._server: @@ -156,6 +174,10 @@ async def _cl_handler( ) -> None: cl_unit = ClUnit("init", reader, writer) + if not cls._is_online: + await cl_unit.send_log_error("Server is shutdown") + return + try: await cls._auth(cl_unit) @@ -174,7 +196,9 @@ async def _cl_handler( await cl_unit.disconnect() return - cls._cl_units[cl_unit.login] = cl_unit + async with cls._cl_units_lock: + cls._cl_units[cl_unit.login] = cl_unit + logger.info(f"{cl_unit.login} is connected.") try: @@ -189,14 +213,29 @@ async def _cl_handler( await cl_unit.send_log_error("Receive data must has 'code' key.") continue - if ResponseCode.is_net(code): + if code == ResponseCode.NET_REQ: func_name = receive_package.pop("net_func_name", None) await cls._call_func( func_name, - cl_unit=cl_unit, + cl_unit, **receive_package, ) + elif code == ResponseCode.GET_REQ: + func_name = receive_package.pop("net_func_name", None) + get_key = receive_package.pop("net_get_key", None) + if get_key is None: + continue + + data = await cls._call_func( + func_name, + cl_unit, + **receive_package, + ) + await cl_unit.send_package( + ResponseCode.GET_REQ, get_key=get_key, data=data + ) + else: await cl_unit.send_log_error("Unknown 'code' for net type.") @@ -209,10 +248,13 @@ async def _cl_handler( pass except Exception as err: + logger.exception(f"An unexpected error occurred: {err}") await cl_unit.send_log_error(f"An unexpected error occurred: {err}") finally: - cls._cl_units.pop(cl_unit.login, None) + async with cls._cl_units_lock: + cls._cl_units.pop(cl_unit.login, None) + await cl_unit.disconnect() logger.info(f"{cl_unit.login} is disconected.") diff --git a/DMBotNetwork/main/utils/response_code.py b/DMBotNetwork/main/utils/response_code.py index 43087e6..8a25b5f 100644 --- a/DMBotNetwork/main/utils/response_code.py +++ b/DMBotNetwork/main/utils/response_code.py @@ -10,6 +10,7 @@ class ResponseCode(IntEnum): # Сетевые запросы NET_REQ = 20 # Запрос сетевого метода + GET_REQ = 21 # Запрос данных с сервера и ожидание получения # Файловые операции FIL_REQ = 30 # Запрос на отправку фрагмента файла diff --git a/Tests/ServerDB.py b/Tests/ServerDB.py index 1357f52..e4a062d 100644 --- a/Tests/ServerDB.py +++ b/Tests/ServerDB.py @@ -1,7 +1,9 @@ import unittest from pathlib import Path + from DMBotNetwork.main.utils.server_db import ServerDB + class TestServerDB(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): self.temp_db_file: Path = Path("temp") diff --git a/setup.py b/setup.py index 79927a8..22b4cfa 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="DMBotNetwork", - version="0.2.1", + version="0.2.2", packages=find_packages(), install_requires=["aiosqlite", "aiofiles", "bcrypt", "msgpack"], author="Angels And Demons dev team", @@ -16,6 +16,6 @@ "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "Operating System :: OS Independent", ], - python_requires=">=3.11", + python_requires=">=3.12", license="GPL-3.0", )