diff --git a/src/middlewared/middlewared/api/base/decorator.py b/src/middlewared/middlewared/api/base/decorator.py index c120f4ccf5742..5b4fe10c551be 100644 --- a/src/middlewared/middlewared/api/base/decorator.py +++ b/src/middlewared/middlewared/api/base/decorator.py @@ -15,6 +15,7 @@ def api_method( audit: str | None = None, audit_callback: bool = False, audit_extended: Callable[..., str] | None = None, + rate_limit=True, roles: list[str] | None = None, private: bool = False, ): @@ -33,6 +34,8 @@ def api_method( `audit_extended` is the function that takes the same arguments as the decorated function and returns the string that will be appended to the audit message to be logged. + `rate_limit` specifies whether the method calls should be rate limited when calling without authentication. + `roles` is a list of user roles that will gain access to this method. `private` is `True` when the method should not be exposed in the public API. By default, the method is public. @@ -63,6 +66,7 @@ def wrapped(*args): wrapped.audit = audit wrapped.audit_callback = audit_callback wrapped.audit_extended = audit_extended + wrapped.rate_limit = rate_limit wrapped.roles = roles or [] wrapped._private = private diff --git a/src/middlewared/middlewared/api/base/model.py b/src/middlewared/middlewared/api/base/model.py index d68142240de60..0b51b443c6f14 100644 --- a/src/middlewared/middlewared/api/base/model.py +++ b/src/middlewared/middlewared/api/base/model.py @@ -1,4 +1,6 @@ import copy +import inspect +from types import NoneType import typing from pydantic import BaseModel as PydanticBaseModel, ConfigDict, create_model, Field, model_serializer @@ -86,10 +88,19 @@ def wrapper(klass): return wrapper -def single_argument_result(klass): +def single_argument_result(klass, klass_name=None): + if klass is None: + klass = NoneType + + if klass.__module__ == "builtins": + if klass_name is None: + raise TypeError("You must specify class name when using `single_argument_result` for built-in types") + else: + klass_name = klass_name or klass.__name__ + return create_model( - klass.__name__, + klass_name, __base__=(BaseModel,), - __module__=klass.__module__, + __module__=inspect.getmodule(inspect.stack()[1][0]), **{"result": Annotated[klass, Field()]}, ) diff --git a/src/middlewared/middlewared/api/base/server/__init__.py b/src/middlewared/middlewared/api/base/server/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/middlewared/middlewared/api/base/server/app.py b/src/middlewared/middlewared/api/base/server/app.py new file mode 100644 index 0000000000000..16472cd188270 --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/app.py @@ -0,0 +1,18 @@ +import logging +import uuid + +from middlewared.auth import SessionManagerCredentials +from middlewared.utils.origin import Origin + +logger = logging.getLogger(__name__) + + +class App: + def __init__(self, origin: Origin): + self.origin = origin + self.session_id = str(uuid.uuid4()) + self.authenticated = False + self.authenticated_credentials: SessionManagerCredentials | None = None + self.py_exceptions = False + self.websocket = False + self.rest = False diff --git a/src/middlewared/middlewared/api/base/server/method.py b/src/middlewared/middlewared/api/base/server/method.py new file mode 100644 index 0000000000000..fa02b1a9ba66c --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/method.py @@ -0,0 +1,30 @@ +import types + +from middlewared.job import Job + + +class Method: + def __init__(self, middleware: "Middleware", name: str): + self.middleware = middleware + self.name = name + + async def call(self, app: "RpcWebSocketApp", params): + serviceobj, methodobj = self.middleware.get_method(self.name) + + await self.middleware.authorize_method_call(app, self.name, methodobj, params) + + if mock := self.middleware._mock_method(self.name, params): + methodobj = mock + + result = await self.middleware.call_with_audit(self.name, serviceobj, methodobj, params, app) + if isinstance(result, Job): + result = result.id + elif isinstance(result, types.GeneratorType): + result = list(result) + elif isinstance(result, types.AsyncGeneratorType): + result = [i async for i in result] + + return result + + def dump_args(self, params): + return self.middleware.dump_args(params, method_name=self.name) diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/__init__.py b/src/middlewared/middlewared/api/base/server/ws_handler/__init__.py new file mode 100644 index 0000000000000..d3b4c09c780ef --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/ws_handler/__init__.py @@ -0,0 +1,6 @@ +# -*- coding=utf-8 -*- +import logging + +logger = logging.getLogger(__name__) + +__all__ = [] diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/base.py b/src/middlewared/middlewared/api/base/server/ws_handler/base.py new file mode 100644 index 0000000000000..975369abe0b7f --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/ws_handler/base.py @@ -0,0 +1,72 @@ +import socket +import struct + +from aiohttp.http_websocket import WSCloseCode +from aiohttp.web import Request, WebSocketResponse + +from middlewared.auth import is_ha_connection +from middlewared.utils.nginx import get_remote_addr_port +from middlewared.utils.origin import Origin, UnixSocketOrigin, TCPIPOrigin +from middlewared.webui_auth import addr_in_allowlist + + +class BaseWebSocketHandler: + def __init__(self, middleware: "Middleware"): + self.middleware = middleware + + async def __call__(self, request: Request): + ws = WebSocketResponse() + try: + await ws.prepare(request) + except ConnectionResetError: + # Happens when we're preparing a new session, and during the time we prepare, the server is + # stopped/killed/restarted etc. Ignore these to prevent log spam. + return ws + + origin = await self.get_origin(request) + if origin is None: + await ws.close() + return ws + if not await self.can_access(origin): + await ws.close( + code=WSCloseCode.POLICY_VIOLATION, + message=b"You are not allowed to access this resource", + ) + return ws + + await self.process(origin, ws) + return ws + + async def get_origin(self, request: Request) -> Origin | None: + try: + sock = request.transport.get_extra_info("socket") + except AttributeError: + # request.transport can be None by the time this is called on HA systems because remote node could have been + # rebooted + return + + if sock.family == socket.AF_UNIX: + peercred = sock.getsockopt(socket.SOL_SOCKET, socket.SO_PEERCRED, struct.calcsize("3i")) + pid, uid, gid = struct.unpack("3i", peercred) + return UnixSocketOrigin(pid, uid, gid) + + remote_addr, remote_port = await self.middleware.run_in_thread(get_remote_addr_port, request) + return TCPIPOrigin(remote_addr, remote_port) + + async def can_access(self, origin: Origin | None) -> bool: + if not isinstance(origin, TCPIPOrigin): + return True + + if not (ui_allowlist := await self.middleware.call("system.general.get_ui_allowlist")): + return True + + if is_ha_connection(origin.addr, origin.port): + return True + + if addr_in_allowlist(origin.addr, ui_allowlist): + return True + + return False + + async def process(self, origin: Origin, ws: WebSocketResponse): + raise NotImplementedError diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/rpc.py b/src/middlewared/middlewared/api/base/server/ws_handler/rpc.py new file mode 100644 index 0000000000000..989057fbcf7a8 --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/ws_handler/rpc.py @@ -0,0 +1,301 @@ +import asyncio +import binascii +from collections import defaultdict +import enum +import errno +import pickle +import sys +import traceback +from typing import Any, Callable + +from aiohttp.http_websocket import WSCloseCode, WSMessage +from aiohttp.web import WebSocketResponse, WSMsgType +import jsonschema + +from truenas_api_client import json +from truenas_api_client.jsonrpc import JSONRPCError + +from middlewared.schema import Error +from middlewared.service_exception import (CallException, CallError, ValidationError, ValidationErrors, adapt_exception, + get_errname) +from middlewared.utils.debug import get_frame_details +from middlewared.utils.lock import SoftHardSemaphore, SoftHardSemaphoreLimit +from middlewared.utils.origin import Origin +from .base import BaseWebSocketHandler +from ..app import App +from ..method import Method + +REQUEST_SCHEMA = { + "type": "object", + "additionalProperties": False, + "required": ["jsonrpc", "method"], + "properties": { + "jsonrpc": {"enum": ["2.0"]}, + "method": {"type": "string"}, + "params": {"type": "array"}, + "id": {"type": ["null", "number", "string"]}, + } +} + + +class RpcWebSocketAppEvent(enum.Enum): + MESSAGE = enum.auto() + CLOSE = enum.auto() + + +class RpcWebSocketApp(App): + def __init__(self, middleware: "Middleware", origin: Origin, ws: WebSocketResponse): + super().__init__(origin) + + self.websocket = True + + self.middleware = middleware + self.ws = ws + self.softhardsemaphore = SoftHardSemaphore(10, 20) + self.callbacks = defaultdict(list) + self.subscriptions = {} + + def send(self, data): + fut = asyncio.run_coroutine_threadsafe(self.ws.send_str(json.dumps(data)), self.middleware.loop) + + def send_error(self, id_: Any, code: int, message: str, data: Any = None): + error = { + "jsonrpc": "2.0", + "error": { + "code": code, + "message": message, + }, + "id": id_, + } + if data is not None: + error["error"]["data"] = data + + self.send(error) + + def send_truenas_error(self, id_: Any, code: int, message: str, errno_: int, reason: str, + exc_info=None, extra: list | None = None): + self.send_error(id_, code, message, self.format_truenas_error(errno_, reason, exc_info, extra)) + + def format_truenas_error(self, errno_: int, reason: str, exc_info=None, extra: list | None = None): + return { + "error": errno_, + "errname": get_errname(errno_), + "reason": reason, + "trace": self.truenas_error_traceback(exc_info) if exc_info else None, + "extra": extra, + **({"py_exception": binascii.b2a_base64(pickle.dumps(exc_info[1])).decode()} + if self.py_exceptions and exc_info else {}), + } + + def truenas_error_traceback(self, exc_info): + etype, value, tb = exc_info + + frames = [] + cur_tb = tb + while cur_tb: + tb_frame = cur_tb.tb_frame + cur_tb = cur_tb.tb_next + + cur_frame = get_frame_details(tb_frame, self.middleware.logger) + if cur_frame: + frames.append(cur_frame) + + return { + "class": etype.__name__, + "frames": frames, + "formatted": "".join(traceback.format_exception(*exc_info)), + "repr": repr(value), + } + + def send_truenas_validation_error(self, id_: Any, exc_info, errors: list): + self.send_error(id_, JSONRPCError.INVALID_PARAMS.value, "Invalid params", + self.format_truenas_validation_error(exc_info[1], exc_info, errors)) + + def format_truenas_validation_error(self, exception, exc_info=None, errors: list | None = None): + return self.format_truenas_error(errno.EINVAL, str(exception), exc_info, errors) + + def register_callback(self, event: RpcWebSocketAppEvent, callback: Callable): + self.callbacks[event.value].append(callback) + + def run_callback(self, event, *args, **kwargs): + for callback in self.callbacks[event.value]: + try: + callback(self, *args, **kwargs) + except Exception: + logger.error(f"Failed to run {event} callback", exc_info=True) + + async def subscribe(self, ident: str, name: str): + shortname, arg = self.middleware.event_source_manager.short_name_arg(name) + if shortname in self.middleware.event_source_manager.event_sources: + await self.middleware.event_source_manager.subscribe_app(self, self.__esm_ident(ident), shortname, arg) + else: + self.subscriptions[ident] = name + + async def unsubscribe(self, ident: str): + if ident in self.subscriptions: + self.subscriptions.pop(ident) + elif self.__esm_ident(ident) in self.middleware.event_source_manager.idents: + await self.middleware.event_source_manager.unsubscribe(self.__esm_ident(ident)) + + def __esm_ident(self, ident): + return self.session_id + ident + + def send_event(self, name: str, event_type: str, **kwargs): + if ( + not any(i in [name, "*"] for i in self.subscriptions.values()) and + ( + self.middleware.event_source_manager.short_name_arg(name)[0] not in + self.middleware.event_source_manager.event_sources + ) + ): + return + + event = { + "msg": event_type.lower(), + "collection": name, + } + kwargs = kwargs.copy() + if "id" in kwargs: + event["id"] = kwargs.pop("id") + if event_type in ("ADDED", "CHANGED"): + if "fields" in kwargs: + event["fields"] = kwargs.pop("fields") + if kwargs: + event["extra"] = kwargs + + self.send_notification("collection_update", event) + + def notify_unsubscribed(self, collection: str, error: Exception | None): + params = {"collection": collection, "error": None} + if error: + if isinstance(error, ValidationErrors): + params["error"] = self.format_truenas_validation_error(error, extra=list(error)) + elif isinstance(error, CallError): + params["error"] = self.format_truenas_error(error.errno, str(error), extra=error.extra) + else: + params["error"] = self.format_truenas_error(errno.EINVAL, str(error)) + + self.send_notification("notify_unsubscribed", params) + + def send_notification(self, method, params): + self.send({ + "jsonrpc": "2.0", + "method": method, + "params": params, + }) + + +class RpcWebSocketHandler(BaseWebSocketHandler): + def __init__(self, middleware: "Middleware", methods: {str: Method}): + super().__init__(middleware) + self.methods = methods + + async def process(self, origin: Origin, ws: WebSocketResponse): + app = RpcWebSocketApp(self.middleware, origin, ws) + + self.middleware.register_wsclient(app) + try: + # aiohttp can cancel tasks if a request take too long to finish. + # It is desired to prevent that in this stage in case we are debugging middlewared via gdb (which makes the + # program execution a lot slower) + await asyncio.shield(self.middleware.call_hook("core.on_connect", app)) + + msg: WSMessage + async for msg in ws: + if msg.type == WSMsgType.ERROR: + self.middleware.logger.error("Websocket error: %r", msg.data) + break + + if msg.type != WSMsgType.TEXT: + await ws.close( + code=WSCloseCode.UNSUPPORTED_DATA, + message=f"Invalid websocket message type: {msg.type!r}".encode("utf-8"), + ) + break + + if not app.authenticated and len(msg.data) > 8192: + await ws.close( + code=WSCloseCode.INVALID_TEXT, + message=b"Anonymous connection max message length is 8 kB", + ) + break + + try: + message = json.loads(msg.data) + except ValueError as e: + app.send_error(None, JSONRPCError.INVALID_JSON.value, str(e)) + continue + + app.run_callback(RpcWebSocketAppEvent.MESSAGE, message) + + try: + await self.process_message(app, message) + except Exception as e: + self.middleware.logger.error("Unhandled exception in JSON-RPC message handler", exc_info=True) + await ws.close( + code=WSCloseCode.INTERNAL_ERROR, + message=str(e).encode("utf-8"), + ) + break + finally: + app.run_callback(RpcWebSocketAppEvent.CLOSE) + + await self.middleware.event_source_manager.unsubscribe_app(app) + + self.middleware.unregister_wsclient(app) + + async def process_message(self, app: RpcWebSocketApp, message: Any): + try: + jsonschema.validate(message, REQUEST_SCHEMA) + except jsonschema.ValidationError as e: + app.send_error(app, None, JSONRPCError.INVALID_REQUEST.value, str(e)) + return + + id_ = message.get("id") + method = self.methods.get(message["method"]) + if method is None: + app.send_error(id_, JSONRPCError.METHOD_NOT_FOUND.value, "Method does not exist") + return + + asyncio.ensure_future(self.process_method_call(app, id_, method, message["params"])) + + async def process_method_call(self, app: RpcWebSocketApp, id_: Any, method: Method, params: {str: Any}): + try: + async with app.softhardsemaphore: + result = await method.call(app, params) + except SoftHardSemaphoreLimit as e: + app.send_error(id_, JSONRPCError.TRUENAS_TOO_MANY_CONCURRENT_CALLS.value, + f"Maximum number of concurrent calls ({e.args[0]}) has exceeded") + except ValidationError as e: + app.send_truenas_validation_error(id_, sys.exc_info(), [ + (e.attribute, e.errmsg, e.errno), + ]) + except ValidationErrors as e: + app.send_truenas_validation_error(id_, sys.exc_info(), list(e)) + except (CallException, Error) as e: + # CallException and subclasses are the way to gracefully send errors to the client + app.send_truenas_error(id_, JSONRPCError.TRUENAS_CALL_ERROR.value, "Method call error", e.errno, str(e), + sys.exc_info(), e.extra) + except Exception as e: + adapted = adapt_exception(e) + if adapted: + errno_ = adapted.errno + error = adapted + extra = adapted.extra + else: + errno_ = errno.EINVAL + error = e + extra = None + + app.send_truenas_error(id_, JSONRPCError.TRUENAS_CALL_ERROR.value, "Method call error", errno_, + str(error) or repr(error), sys.exc_info(), extra) + + if not adapted and not app.py_exceptions: + self.middleware.logger.warning(f"Exception while calling {method.name}(*{method.dump_args(params)!r})", + exc_info=True) + else: + app.send({ + "jsonrpc": "2.0", + "result": result, + "id": id_, + }) diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/rpc_factory.py b/src/middlewared/middlewared/api/base/server/ws_handler/rpc_factory.py new file mode 100644 index 0000000000000..c0390794b20fb --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/ws_handler/rpc_factory.py @@ -0,0 +1,19 @@ +from ..method import Method +from .rpc import RpcWebSocketHandler + + +def create_rpc_ws_handler(middleware: "Middleware"): + methods = {} + for service_name, service in middleware.get_services().items(): + for attribute in dir(service): + if attribute.startswith("_"): + continue + + if not callable(getattr(service, attribute)): + continue + + method_name = f"{service_name}.{attribute}" + + methods[method_name] = Method(middleware, method_name) + + return RpcWebSocketHandler(middleware, methods) diff --git a/src/middlewared/middlewared/api/v25_04_0/__init__.py b/src/middlewared/middlewared/api/v25_04_0/__init__.py index c78e0dd7d2f29..4e51c596d68b0 100644 --- a/src/middlewared/middlewared/api/v25_04_0/__init__.py +++ b/src/middlewared/middlewared/api/v25_04_0/__init__.py @@ -1,5 +1,6 @@ from .api_key import * # noqa from .cloud_sync import * # noqa from .common import * # noqa +from .core import * # noqa from .user import * # noqa from .vendor import * # noqa diff --git a/src/middlewared/middlewared/api/v25_04_0/core.py b/src/middlewared/middlewared/api/v25_04_0/core.py new file mode 100644 index 0000000000000..202f3f292e814 --- /dev/null +++ b/src/middlewared/middlewared/api/v25_04_0/core.py @@ -0,0 +1,29 @@ +from middlewared.api.base import BaseModel, ForUpdateMetaclass, single_argument_result + +__all__ = ["CoreSetOptionsArgs", "CoreSetOptionsResult", "CoreSubscribeArgs", "CoreSubscribeResult", + "CoreUnsubscribeArgs", "CoreUnsubscribeResult"] + + +class CoreSetOptionsOptions(BaseModel, metaclass=ForUpdateMetaclass): + py_exceptions: bool + + +class CoreSetOptionsArgs(BaseModel): + options: CoreSetOptionsOptions + + +CoreSetOptionsResult = single_argument_result(None, "CoreSetOptionsResult") + + +class CoreSubscribeArgs(BaseModel): + event: str + + +CoreSubscribeResult = single_argument_result(str, "CoreSubscribeResult") + + +class CoreUnsubscribeArgs(BaseModel): + id_: str + + +CoreUnsubscribeResult = single_argument_result(None, "CoreUnsubscribeResult") diff --git a/src/middlewared/middlewared/api/v25_04_0/user.py b/src/middlewared/middlewared/api/v25_04_0/user.py index 2da4fb7aff5bb..782b649532155 100644 --- a/src/middlewared/middlewared/api/v25_04_0/user.py +++ b/src/middlewared/middlewared/api/v25_04_0/user.py @@ -103,4 +103,4 @@ class UserRenew2faSecretArgs(BaseModel): twofactor_options: TwofactorOptions -UserRenew2faSecretResult = single_argument_result(UserEntry) +UserRenew2faSecretResult = single_argument_result(UserEntry, "UserRenew2faSecretResult") diff --git a/src/middlewared/middlewared/common/event_source/manager.py b/src/middlewared/middlewared/common/event_source/manager.py index 56df54efee0b6..bc5e4efbc6c60 100644 --- a/src/middlewared/middlewared/common/event_source/manager.py +++ b/src/middlewared/middlewared/common/event_source/manager.py @@ -1,12 +1,10 @@ import asyncio from collections import defaultdict, namedtuple -import errno import functools from uuid import uuid4 from middlewared.event import EventSource from middlewared.schema import ValidationErrors -from middlewared.service_exception import CallError IdentData = namedtuple("IdentData", ["subscriber", "name", "arg"]) @@ -28,20 +26,7 @@ def send_event(self, event_type, **kwargs): self.app.send_event(self.collection, event_type, **kwargs) def terminate(self, error): - error_dict = {} - if error: - if isinstance(error, ValidationErrors): - error_dict['error'] = self.app.get_error_dict( - errno.EAGAIN, str(error), etype='VALIDATION', extra=list(error) - ) - elif isinstance(error, CallError): - error_dict['error'] = self.app.get_error_dict( - error.errno, str(error), extra=error.extra - ) - else: - error_dict['error'] = self.app.get_error_dict(errno.EINVAL, str(error)) - - self.app._send({'msg': 'nosub', 'collection': self.collection, **error_dict}) + self.app.notify_unsubscribed(self.collection, error) class InternalSubscriber(Subscriber): diff --git a/src/middlewared/middlewared/etc_files/local/nginx/nginx.conf.mako b/src/middlewared/middlewared/etc_files/local/nginx/nginx.conf.mako index 99ab6584e3935..f953ff5a11953 100644 --- a/src/middlewared/middlewared/etc_files/local/nginx/nginx.conf.mako +++ b/src/middlewared/middlewared/etc_files/local/nginx/nginx.conf.mako @@ -218,6 +218,17 @@ http { report_uploads proxied; } + location /api { + allow all; # This is handled by `Middleware.ws_can_access` because if we return HTTP 403, browser security + # won't allow us to understand that connection error was due to client IP not being allowlisted. + proxy_pass http://127.0.0.1:6000/api; + proxy_http_version 1.1; + proxy_set_header X-Real-Remote-Addr $remote_addr; + proxy_set_header X-Real-Remote-Port $remote_port; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + } + location /api/docs { proxy_pass http://127.0.0.1:6000/api/docs; } diff --git a/src/middlewared/middlewared/main.py b/src/middlewared/middlewared/main.py index 3d9cad504b855..43d8eb1f63b9b 100644 --- a/src/middlewared/middlewared/main.py +++ b/src/middlewared/middlewared/main.py @@ -1,7 +1,9 @@ from .api.base.handler.dump_params import dump_params from .api.base.handler.result import serialize_result +from .api.base.server.ws_handler.base import BaseWebSocketHandler +from .api.base.server.ws_handler.rpc import RpcWebSocketApp, RpcWebSocketAppEvent +from .api.base.server.ws_handler.rpc_factory import create_rpc_ws_handler from .apidocs import routes as apidocs_routes -from .auth import is_ha_connection from .common.event_source.manager import EventSourceManager from .event import Events from .job import Job, JobsQueue @@ -18,8 +20,7 @@ from .utils import MIDDLEWARE_RUN_DIR, sw_version from .utils.debug import get_frame_details, get_threads_stacks from .utils.lock import SoftHardSemaphore, SoftHardSemaphoreLimit -from .utils.nginx import get_remote_addr_port -from .utils.origin import UnixSocketOrigin, TCPIPOrigin +from .utils.origin import Origin from .utils.os import close_fds from .utils.plugins import LoadPluginsMixin from .utils.privilege import credential_has_full_admin @@ -29,7 +30,7 @@ from .utils.syslog import syslog_message from .utils.threading import set_thread_name, IoThreadPoolExecutor, io_thread_pool_executor from .utils.type import copy_function_metadata -from .webui_auth import addr_in_allowlist, WebUIAuth +from .webui_auth import WebUIAuth from .worker import main_worker, worker_init from aiohttp import web from aiohttp.http_websocket import WSCloseCode @@ -59,7 +60,6 @@ import queue import setproctitle import signal -import socket import struct import sys import termios @@ -101,56 +101,23 @@ def real_crud_method(method): return child_method -class Application: +class Application(RpcWebSocketApp): + def __init__(self, middleware: 'Middleware', origin: Origin, loop: asyncio.AbstractEventLoop, request, response): + super().__init__(middleware, origin, response) + self.websocket = True - def __init__(self, middleware: 'Middleware', loop: asyncio.AbstractEventLoop, request, response): - self.middleware = middleware self.loop = loop self.request = request self.response = response - self.authenticated = False - self.authenticated_credentials = None self.handshake = False self.logger = logger.Logger('application').getLogger() - self.session_id = str(uuid.uuid4()) - self.rest = False - self.websocket = True # Allow at most 10 concurrent calls and only queue up until 20 self._softhardsemaphore = SoftHardSemaphore(10, 20) self._py_exceptions = False - """ - Callback index registered by services. They are blocking. - - Currently the following events are allowed: - on_message(app, message) - on_close(app) - """ - self.__callbacks = defaultdict(list) self.__subscribed = {} - @functools.cached_property - def origin(self) -> typing.Union[UnixSocketOrigin, TCPIPOrigin, None]: - try: - sock = self.request.transport.get_extra_info("socket") - except AttributeError: - # self.request.transport can be None by the time this is called - # on HA systems because remote node could have been rebooted - return - - if sock.family == socket.AF_UNIX: - peercred = sock.getsockopt(socket.SOL_SOCKET, socket.SO_PEERCRED, struct.calcsize('3i')) - pid, uid, gid = struct.unpack('3i', peercred) - return UnixSocketOrigin(pid, uid, gid) - - remote_addr, remote_port = get_remote_addr_port(self.request) - return TCPIPOrigin(remote_addr, remote_port) - - def register_callback(self, name: str, method): - assert name in ('on_message', 'on_close') - self.__callbacks[name].append(method) - def _send(self, data: typing.Dict[str, typing.Any]): serialized = json.dumps(data) asyncio.run_coroutine_threadsafe(self.response.send_str(serialized), loop=self.loop) @@ -244,20 +211,6 @@ async def call_method(self, message, serviceobj, methodobj): self.middleware.dump_args(message.get('params', []), method_name=message['method']) ), exc_info=True) - def can_subscribe(self, name): - if event := self.middleware.events.get_event(name): - if event['no_auth_required']: - return True - - if not self.authenticated: - return False - - if event: - if event['no_authz_required']: - return True - - return self.authenticated_credentials.authorize('SUBSCRIBE', name) - async def subscribe(self, ident, name): shortname, arg = self.middleware.event_source_manager.short_name_arg(name) if shortname in self.middleware.event_source_manager.event_sources: @@ -301,8 +254,21 @@ def send_event(self, name, event_type, **kwargs): event['extra'] = kwargs self._send(event) - async def log_audit_message(self, event, event_data, success): - return await self.middleware.log_audit_message(self, event, event_data, success) + def notify_unsubscribed(self, collection, error): + error_dict = {} + if error: + if isinstance(error, ValidationErrors): + error_dict['error'] = self.get_error_dict( + errno.EAGAIN, str(error), etype='VALIDATION', extra=list(error) + ) + elif isinstance(error, CallError): + error_dict['error'] = self.get_error_dict( + error.errno, str(error), extra=error.extra + ) + else: + error_dict['error'] = self.get_error_dict(errno.EINVAL, str(error)) + + self._send({'msg': 'nosub', 'collection': collection, **error_dict}) async def __log_audit_message_for_method(self, message, methodobj, authenticated, authorized, success): return await self.middleware.log_audit_message_for_method( @@ -312,25 +278,15 @@ async def __log_audit_message_for_method(self, message, methodobj, authenticated def on_open(self): self.middleware.register_wsclient(self) - async def on_close(self, *args, **kwargs): - # Run callbacks registered in plugins for on_close - for method in self.__callbacks['on_close']: - try: - method(self) - except Exception: - self.logger.error('Failed to run on_close callback.', exc_info=True) + async def on_close(self): + self.run_callback(RpcWebSocketAppEvent.CLOSE) await self.middleware.event_source_manager.unsubscribe_app(self) self.middleware.unregister_wsclient(self) async def on_message(self, message: typing.Dict[str, typing.Any]): - # Run callbacks registered in plugins for on_message - for method in self.__callbacks['on_message']: - try: - method(self, message) - except Exception: - self.logger.error('Failed to run on_message callback.', exc_info=True) + self.run_callback(RpcWebSocketAppEvent.MESSAGE, message) if message['msg'] == 'connect': if message.get('version') != '1': @@ -348,68 +304,26 @@ async def on_message(self, message: typing.Dict[str, typing.Any]): elif not self.handshake: self._send({'msg': 'failed', 'version': '1'}) elif message['msg'] == 'method': - error = False if 'method' not in message: self.send_error(message, errno.EINVAL, "Message is malformed: 'method' is absent.") - error = True else: try: - serviceobj, methodobj = self.middleware._method_lookup(message['method']) + serviceobj, methodobj = self.middleware.get_method(message['method']) + + await self.middleware.authorize_method_call( + self, message['method'], methodobj, message.get('params') or [], + ) except CallError as e: self.send_error(message, e.errno, str(e), sys.exc_info(), extra=e.extra) - error = True - - if not error: - auth_required = not hasattr(methodobj, '_no_auth_required') - if not auth_required: - ip_added = await RateLimitCache.add(message['method'], self.origin) - if ip_added is not None: - if any(( - RateLimitCache.max_entries_reached, - RateLimitCache.rate_limit_exceeded(message['method'], ip_added), - )): - # 1 of 2 things happened: - # 1. we've hit maximum amount of entries for global rate limit - # cache (this is an edge-case and something bad is going on) - # 2. OR this endpoint has been hit too many times by the same - # origin IP address - # In either scenario, sleep a random delay and send an error - await self.__log_audit_message_for_method(message, methodobj, False, True, False) - await RateLimitCache.random_sleep() - self.send_error(message, errno.EBUSY, 'Rate Limit Exceeded') - error = True - else: - # was added to rate limit cache but rate limit thresholds haven't - # been met so no error - error = False - else: - # the origin of the request for the unauthenticated method is an - # internal call or comes from the other controller on an HA system - error = False - elif auth_required and not self.authenticated: - await self.__log_audit_message_for_method(message, methodobj, False, False, False) - self.send_error(message, ErrnoMixin.ENOTAUTHENTICATED, 'Not authenticated') - error = True - elif self.authenticated_credentials.is_user_session and hasattr(methodobj, '_no_authz_required'): - # Some methods require authentication to the NAS (a valid account) - # but not explicit authorization. In this case the authorization - # check is bypassed as long as it is a user session. API keys - # explicitly whitelist particular methods and are used for targeted - # purposes, and so authorization is _always_ enforced. - error = False - elif not self.authenticated_credentials.authorize('CALL', message['method']): - await self.__log_audit_message_for_method(message, methodobj, True, False, False) - self.send_error(message, errno.EACCES, 'Not authorized') - error = True - if not error: - self.middleware.create_task(self.call_method(message, serviceobj, methodobj)) + else: + self.middleware.create_task(self.call_method(message, serviceobj, methodobj)) elif message['msg'] == 'ping': pong = {'msg': 'pong'} if 'id' in message: pong['id'] = message['id'] self._send(pong) elif message['msg'] == 'sub': - if not self.can_subscribe(message['name'].split(':', 1)[0]): + if not self.middleware.can_subscribe(self, message['name'].split(':', 1)[0]): self.send_error(message, errno.EACCES, 'Not authorized') else: await self.subscribe(message['id'], message['name']) @@ -570,7 +484,7 @@ async def upload(self, request): return resp try: - serviceobj, methodobj = self.middleware._method_lookup(data['method']) + serviceobj, methodobj = self.middleware.get_method(data['method']) if authenticated_credentials.authorize('CALL', data['method']): job = await self.middleware.call_with_audit(data['method'], serviceobj, methodobj, data.get('params') or [], app, @@ -753,7 +667,7 @@ class ShellConnectionData(object): t_worker = None -class ShellApplication(object): +class ShellApplication: shells = {} def __init__(self, middleware): @@ -764,14 +678,16 @@ async def ws_handler(self, request): if not prepared: return ws - if not await self.middleware.ws_can_access(request, ws): + handler = BaseWebSocketHandler(self.middleware) + origin = await handler.get_origin(request) + if not await self.middleware.ws_can_access(ws, origin): return ws conndata = ShellConnectionData() conndata.id = str(uuid.uuid4()) try: - await self.run(ws, request, conndata) + await self.run(ws, origin, conndata) except Exception: if conndata.t_worker: await self.worker_kill(conndata.t_worker) @@ -779,7 +695,7 @@ async def ws_handler(self, request): self.shells.pop(conndata.id, None) return ws - async def run(self, ws, request, conndata): + async def run(self, ws, origin, conndata): # Each connection will have its own input queue input_queue = queue.Queue() @@ -799,7 +715,6 @@ async def run(self, ws, request, conndata): if not token: continue - origin = TCPIPOrigin(*await self.middleware.run_in_thread(get_remote_addr_port, request)) token = await self.middleware.call('auth.get_token_for_shell_application', token, origin) if not token: await ws.send_json({ @@ -1195,9 +1110,6 @@ def __notify_startup_complete(self): def plugin_route_add(self, plugin_name, route, method): self.app.router.add_route('*', f'/_plugins/{plugin_name}/{route}', method) - def get_wsclients(self): - return self.__wsclients - def register_wsclient(self, client): self.__wsclients[client.session_id] = client @@ -1444,7 +1356,7 @@ def dump_args(self, args, method=None, method_name=None): if method is None: if method_name is not None: try: - method = self._method_lookup(method_name)[1] + method = self.get_method(method_name)[1] except Exception: return args @@ -1480,6 +1392,70 @@ def dump_result(self, method, result, expose_secrets): return result + async def authorize_method_call(self, app, method_name, methodobj, params): + if hasattr(methodobj, '_no_auth_required'): + if app.authenticated: + # Do not rate limit authenticated users + return + + if not getattr(methodobj, 'rate_limit', True): + # The method is not subjected to rate limit. + return + + ip_added = await RateLimitCache.add(method_name, app.origin) + if ip_added is None: + # the origin of the request for the unauthenticated method is an + # internal call or comes from the other controller on an HA system + return + + if any(( + RateLimitCache.max_entries_reached, + RateLimitCache.rate_limit_exceeded(method_name, ip_added), + )): + # 1 of 2 things happened: + # 1. we've hit maximum amount of entries for global rate limit + # cache (this is an edge-case and something bad is going on) + # 2. OR this endpoint has been hit too many times by the same + # origin IP address + # In either scenario, sleep a random delay and send an error + await self.log_audit_message_for_method(method_name, methodobj, params, app, False, False, False) + await RateLimitCache.random_sleep() + raise CallError('Rate Limit Exceeded', errno.EBUSY) + + # was added to rate limit cache but rate limit thresholds haven't + # been met so no error + return + + if not app.authenticated: + await self.log_audit_message_for_method(method_name, methodobj, params, app, False, False, False) + raise CallError('Not authenticated', ErrnoMixin.ENOTAUTHENTICATED) + + # Some methods require authentication to the NAS (a valid account) + # but not explicit authorization. In this case the authorization + # check is bypassed as long as it is a user session. API keys + # explicitly whitelist particular methods and are used for targeted + # purposes, and so authorization is _always_ enforced. + if app.authenticated_credentials.is_user_session and hasattr(methodobj, '_no_authz_required'): + return + + if not app.authenticated_credentials.authorize('CALL', method_name): + await self.log_audit_message_for_method(method_name, methodobj, params, app, True, False, False) + raise CallError('Not authorized', errno.EACCES) + + def can_subscribe(self, app, name): + if event := self.events.get_event(name): + if event['no_auth_required']: + return True + + if not app.authenticated: + return False + + if event: + if event['no_authz_required']: + return True + + return app.authenticated_credentials.authorize('SUBSCRIBE', name) + async def call_with_audit(self, method, serviceobj, methodobj, params, app, **kwargs): audit_callback_messages = [] success = False @@ -1572,7 +1548,7 @@ async def log_audit_message(self, app, event, event_data, success): async def call(self, name, *params, app=None, audit_callback=None, job_on_progress_cb=None, pipes=None, profile=False): - serviceobj, methodobj = self._method_lookup(name) + serviceobj, methodobj = self.get_method(name) if mock := self._mock_method(name, params): methodobj = mock @@ -1592,7 +1568,7 @@ def call_sync(self, name, *params, job_on_progress_cb=None, app=None, audit_call if background: return self.loop.call_soon_threadsafe(lambda: self.create_task(self.call(name, *params, app=app))) - serviceobj, methodobj = self._method_lookup(name) + serviceobj, methodobj = self.get_method(name) if mock := self._mock_method(name, params): methodobj = mock @@ -1788,7 +1764,7 @@ def set_mock(self, name, args, mock): if args == _args: raise ValueError(f'{name!r} is already mocked with {args!r}') - serviceobj, methodobj = self._method_lookup(name) + serviceobj, methodobj = self.get_method(name) if inspect.iscoroutinefunction(mock): async def f(*args, **kwargs): @@ -1838,10 +1814,12 @@ async def ws_handler(self, request): if not prepared: return ws - if not await self.ws_can_access(request, ws): + handler = BaseWebSocketHandler(self) + origin = await handler.get_origin(request) + if not await self.ws_can_access(ws, origin): return ws - connection = Application(self, self.loop, request, ws) + connection = Application(self, origin, self.loop, request, ws) connection.on_open() try: @@ -1887,26 +1865,14 @@ async def ws_handler(self, request): return ws - async def ws_can_access(self, request, ws): - if not (ui_allowlist := await self.call('system.general.get_ui_allowlist')): - return True - - sock = request.transport.get_extra_info('socket') - if sock.family == socket.AF_UNIX: - return True - - remote_addr, remote_port = await self.run_in_thread(get_remote_addr_port, request) - if is_ha_connection(remote_addr, remote_port): - return True - - if addr_in_allowlist(remote_addr, ui_allowlist): - return True - - await ws.close( - code=WSCloseCode.POLICY_VIOLATION, - message='You are not allowed to access this resource'.encode('utf-8'), - ) - return False + async def ws_can_access(self, ws, origin): + if not await BaseWebSocketHandler(self).can_access(origin): + await ws.close( + code=WSCloseCode.POLICY_VIOLATION, + message='You are not allowed to access this resource'.encode('utf-8'), + ) + return False + return True _loop_monitor_ignore_frames = ( LoopMonitorIgnoreFrame( @@ -2011,6 +1977,10 @@ async def __initialize(self): self.loop.add_signal_handler(signal.SIGUSR1, self.pdb) self.loop.add_signal_handler(signal.SIGUSR2, self.log_threads_stacks) + rpc_ws_handler = create_rpc_ws_handler(self) + app.router.add_route('GET', '/api/current', rpc_ws_handler) + app.router.add_route('GET', '/api/v25.04.0', rpc_ws_handler) + app.router.add_route('GET', '/websocket', self.ws_handler) app.router.add_routes(apidocs_routes) diff --git a/src/middlewared/middlewared/plugins/auth.py b/src/middlewared/middlewared/plugins/auth.py index fe06f526628ff..00b1fc8639892 100644 --- a/src/middlewared/middlewared/plugins/auth.py +++ b/src/middlewared/middlewared/plugins/auth.py @@ -7,6 +7,7 @@ import psutil +from middlewared.api.base.server.ws_handler.rpc import RpcWebSocketAppEvent from middlewared.auth import (SessionManagerCredentials, UserSessionManagerCredentials, UnixSocketSessionManagerCredentials, RootTcpSocketSessionManagerCredentials, LoginPasswordSessionManagerCredentials, ApiKeySessionManagerCredentials, @@ -108,12 +109,12 @@ async def login(self, app, credentials): app.authenticated = True app.authenticated_credentials = credentials - app.register_callback("on_message", self._app_on_message) - app.register_callback("on_close", self._app_on_close) + app.register_callback(RpcWebSocketAppEvent.MESSAGE, self._app_on_message) + app.register_callback(RpcWebSocketAppEvent.CLOSE, self._app_on_close) if not is_internal_session(session): self.middleware.send_event("auth.sessions", "ADDED", fields=dict(id=app.session_id, **session.dump())) - await app.log_audit_message("AUTHENTICATION", { + await self.middleware.log_audit_message(app, "AUTHENTICATION", { "credentials": dump_credentials(credentials), "error": None, }, True) @@ -205,7 +206,6 @@ def dump(self): return data - def is_internal_session(session): if isinstance(session.app.origin, UnixSocketOrigin) and session.app.origin.uid == 0: return True @@ -311,7 +311,7 @@ async def terminate_session(self, id_): self.token_manager.destroy_by_session_id(id_) - await session.app.response.close() + await session.app.ws.close() @accepts(roles=['AUTH_SESSIONS_WRITE']) @returns() @@ -454,7 +454,7 @@ async def login(self, app, username, password, otp_token): """ user = await self.get_login_user(username, password, otp_token) if user is None: - await app.log_audit_message("AUTHENTICATION", { + await self.middleware.log_audit_message(app, "AUTHENTICATION", { "credentials": { "credentials": "LOGIN_PASSWORD", "credentials_data": {"username": username}, @@ -494,7 +494,7 @@ async def login_with_api_key(self, app, api_key): await self.session_manager.login(app, ApiKeySessionManagerCredentials(api_key_object)) return True - await app.log_audit_message("AUTHENTICATION", { + await self.middleware.log_audit_message(app, "AUTHENTICATION", { "credentials": { "credentials": "API_KEY", "credentials_data": { @@ -516,7 +516,7 @@ async def login_with_token(self, app, token_str): """ token = self.token_manager.get(token_str, app.origin) if token is None: - await app.log_audit_message("AUTHENTICATION", { + await self.middleware.log_audit_message(app, "AUTHENTICATION", { "credentials": { "credentials": "TOKEN", "credentials_data": { @@ -528,7 +528,7 @@ async def login_with_token(self, app, token_str): return False if token.attributes: - await app.log_audit_message("AUTHENTICATION", { + await self.middleware.log_audit_message(app, "AUTHENTICATION", { "credentials": { "credentials": "TOKEN", "credentials_data": { diff --git a/src/middlewared/middlewared/plugins/failover.py b/src/middlewared/middlewared/plugins/failover.py index a343b012a19bc..703b6177df2ed 100644 --- a/src/middlewared/middlewared/plugins/failover.py +++ b/src/middlewared/middlewared/plugins/failover.py @@ -6,7 +6,6 @@ import logging import os import shutil -import socket import stat import textwrap import time @@ -21,13 +20,14 @@ import middlewared.sqlalchemy as sa from middlewared.plugins.auth import AuthService from middlewared.plugins.config import FREENAS_DATABASE -from middlewared.utils.contextlib import asyncnullcontext from middlewared.plugins.failover_.zpool_cachefile import ZPOOL_CACHE_FILE, ZPOOL_CACHE_FILE_OVERWRITE from middlewared.plugins.failover_.configure import HA_LICENSE_CACHE_KEY from middlewared.plugins.failover_.remote import NETWORK_ERRORS from middlewared.plugins.update_.install import STARTING_INSTALLER from middlewared.plugins.update_.utils import DOWNLOAD_UPDATE_FILE, can_update from middlewared.plugins.update_.utils_linux import mount_update +from middlewared.utils.contextlib import asyncnullcontext +from middlewared.utils.origin import TCPIPOrigin ENCRYPTION_CACHE_LOCK = asyncio.Lock() @@ -1121,18 +1121,9 @@ async def ha_permission(middleware, app): return # We only care for remote connections (IPv4), in the interlink - try: - sock = app.request.transport.get_extra_info('socket') - except AttributeError: - # app.request or app.request.transport can be None - return - - if sock.family != socket.AF_INET: - return - - remote_addr, remote_port = app.request.transport.get_extra_info('peername') - if is_ha_connection(remote_addr, remote_port): - await AuthService.session_manager.login(app, TrueNasNodeSessionManagerCredentials()) + if isinstance(app.origin, TCPIPOrigin): + if is_ha_connection(app.origin.addr, app.origin.port): + await AuthService.session_manager.login(app, TrueNasNodeSessionManagerCredentials()) async def interface_pre_sync_hook(middleware): diff --git a/src/middlewared/middlewared/plugins/failover_/remote.py b/src/middlewared/middlewared/plugins/failover_/remote.py index 7d47d05358b83..fce58d3a09380 100644 --- a/src/middlewared/middlewared/plugins/failover_/remote.py +++ b/src/middlewared/middlewared/plugins/failover_/remote.py @@ -59,7 +59,7 @@ def run(self): def connect_and_wait(self): try: - with Client(f'ws://{self.remote_ip}:6000/websocket', reserved_ports=True) as c: + with Client(f'ws://{self.remote_ip}:6000/api/current', reserved_ports=True) as c: self.client = c self.connected.set() # Subscribe to all events on connection diff --git a/src/middlewared/middlewared/plugins/pool_/dataset.py b/src/middlewared/middlewared/plugins/pool_/dataset.py index f88800d365ad9..6049f2e029a3d 100644 --- a/src/middlewared/middlewared/plugins/pool_/dataset.py +++ b/src/middlewared/middlewared/plugins/pool_/dataset.py @@ -13,6 +13,7 @@ CallError, CRUDService, filterable, InstanceNotFound, item_method, job, pass_app, private, ValidationErrors ) from middlewared.utils import filter_list +from middlewared.utils.origin import TCPIPOrigin from middlewared.validators import Exact, Match, Or, Range from .utils import ( @@ -699,14 +700,12 @@ async def do_create(self, app, data): if app: uri = None - if app.rest and app.host: - uri = app.host - elif app.websocket and app.request.headers.get('X-Real-Remote-Addr'): - uri = app.request.headers.get('X-Real-Remote-Addr') + if isinstance(app.origin, TCPIPOrigin): + uri = app.origin.addr if uri and uri not in [ '::1', '127.0.0.1', *[d['address'] for d in await self.middleware.call('interface.ip_in_use')] ]: - data['managedby'] = uri if not data['managedby'] != 'INHERIT' else f'{data["managedby"]}@{uri}' + data['managedby'] = uri if data['managedby'] == 'INHERIT' else f'{data["managedby"]}@{uri}' props = {} for i, real_name, transform, inheritable in ( diff --git a/src/middlewared/middlewared/plugins/rate_limit/__init__.py b/src/middlewared/middlewared/plugins/rate_limit/__init__.py index ec3814e4a1fa5..2ec3259ce7af8 100644 --- a/src/middlewared/middlewared/plugins/rate_limit/__init__.py +++ b/src/middlewared/middlewared/plugins/rate_limit/__init__.py @@ -1,5 +1,4 @@ from middlewared.service import periodic, Service - from middlewared.utils.rate_limit.cache import RateLimitCache CLEAR_CACHE_INTERVAL = 600 diff --git a/src/middlewared/middlewared/plugins/smb_/sharesec.py b/src/middlewared/middlewared/plugins/smb_/sharesec.py index 53312924a5b7a..193176d28739c 100644 --- a/src/middlewared/middlewared/plugins/smb_/sharesec.py +++ b/src/middlewared/middlewared/plugins/smb_/sharesec.py @@ -1,5 +1,5 @@ from middlewared.plugins.sysdataset import SYSDATASET_PATH -from middlewared.service import filterable, periodic, CRUDService +from middlewared.service import filterable, periodic, private, CRUDService from middlewared.service_exception import CallError, MatchNotFound from middlewared.utils import run, filter_list from middlewared.utils.tdb import ( @@ -184,7 +184,8 @@ async def setacl(self, data, db_commit=True): await self.middleware.call('datastore.update', 'sharing.cifs_share', config_share['id'], {'cifs_share_acl': new_acl_blob}) - async def _flush_share_info(self): + @private + async def flush_share_info(self): """ Write stored share acls to share_info.tdb. This should only be called if share_info.tdb contains default entries. @@ -207,7 +208,7 @@ def check_share_info_tdb(self): if not self.middleware.call_sync('service.started', 'cifs'): return else: - return self.middleware.call_sync('smb.sharesec._flush_share_info') + return self.middleware.call_sync('smb.sharesec.flush_share_info') self.middleware.call_sync('smb.sharesec.synchronize_acls') diff --git a/src/middlewared/middlewared/restful.py b/src/middlewared/middlewared/restful.py index bf0d364c991d2..4b8bd5f80f190 100644 --- a/src/middlewared/middlewared/restful.py +++ b/src/middlewared/middlewared/restful.py @@ -12,6 +12,7 @@ from truenas_api_client import json +from .api.base.server.app import App from .auth import ApiKeySessionManagerCredentials, LoginPasswordSessionManagerCredentials from .job import Job from .pipe import Pipes @@ -101,11 +102,12 @@ async def authenticate(middleware, request, credentials, method, resource): def create_application(request, credentials=None): - return Application( - request.headers.get('X-Real-Remote-Addr'), - request.headers.get('X-Real-Remote-Port'), - credentials, - ) + try: + origin = TCPIPOrigin(request.headers['X-Real-Remote-Addr'], int(request.headers['X-Real-Remote-Port'])) + except (KeyError, ValueError): + origin = TCPIPOrigin(*request.transport.get_extra_info('peername')) + + return Application(origin, credentials) def normalize_query_parameter(value): @@ -115,16 +117,13 @@ def normalize_query_parameter(value): return value -class Application: - def __init__(self, host, remote_port, authenticated_credentials): - self.host = host - self.remote_port = remote_port - self.origin = TCPIPOrigin(self.host, self.remote_port) - self.websocket = False - self.rest = True +class Application(App): + def __init__(self, origin, authenticated_credentials): + super().__init__(origin) + self.session_id = None self.authenticated = authenticated_credentials is not None self.authenticated_credentials = authenticated_credentials - self.session_id = None + self.rest = True class RESTfulAPI(object): @@ -850,7 +849,7 @@ async def do(self, http_method, req, resp, app, authorized, **kwargs): method_args.insert(0, id_) try: - serviceobj, methodobj = self.middleware._method_lookup(methodname) + serviceobj, methodobj = self.middleware.get_method(methodname) if authorized: result = await self.middleware.call_with_audit(methodname, serviceobj, methodobj, method_args, **method_kwargs) diff --git a/src/middlewared/middlewared/service/core_service.py b/src/middlewared/middlewared/service/core_service.py index 979bd9441147b..5a91a17c26208 100644 --- a/src/middlewared/middlewared/service/core_service.py +++ b/src/middlewared/middlewared/service/core_service.py @@ -8,6 +8,7 @@ import threading import time import traceback +import uuid from collections import defaultdict from remote_pdb import RemotePdb @@ -15,7 +16,12 @@ import middlewared.main +from middlewared.api import api_method from middlewared.api.base.jsonschema import get_json_schema +from middlewared.api.current import ( + CoreSetOptionsArgs, CoreSetOptionsResult, CoreSubscribeArgs, CoreSubscribeResult, CoreUnsubscribeArgs, + CoreUnsubscribeResult, +) from middlewared.common.environ import environ_update from middlewared.job import Job from middlewared.pipe import Pipes @@ -60,52 +66,6 @@ async def resize_shell(self, id_, cols, rows): shell.resize(cols, rows) - @filterable - @filterable_returns(Dict( - 'session', - Str('id'), - Str('socket_family'), - Str('address'), - Bool('authenticated'), - Int('call_count'), - )) - def sessions(self, filters, options): - """ - Get currently open websocket sessions. - """ - sessions = [] - for i in self.middleware.get_wsclients().values(): - try: - session_id = i.session_id - authenticated = i.authenticated - call_count = i._softhardsemaphore.counter - socket_family = socket.AddressFamily(i.request.transport.get_extra_info('socket').family).name - address = '' - if addr := i.request.headers.get('X-Real-Remote-Addr'): - port = i.request.headers.get('X-Real-Remote-Port') - address = f'{addr}:{port}' if all((addr, port)) else address - else: - if (info := i.request.transport.get_extra_info('peername')): - if isinstance(info, list) and len(info) == 2: - address = f'{info[0]}:{info[1]}' - except AttributeError: - # underlying websocket connection can be ripped down in process - # of enumerating this information. This is non-fatal, so ignore it. - pass - except Exception: - self.logger.warning('Failed enumerating websocket session.', exc_info=True) - break - else: - sessions.append({ - 'id': session_id, - 'socket_family': socket_family, - 'address': address, - 'authenticated': authenticated, - 'call_count': call_count, - }) - - return filter_list(sessions, filters, options) - @accepts(Bool('debug_mode')) async def set_debug_mode(self, debug_mode): """ @@ -706,7 +666,7 @@ async def download(self, app, method, args, filename, buffered): return await self._download(app, method, args, filename, buffered) async def _download(self, app, method, args, filename, buffered): - serviceobj, methodobj = self.middleware._method_lookup(method) + serviceobj, methodobj = self.middleware.get_method(method) job = await self.middleware.call_with_audit( method, serviceobj, methodobj, args, app=app, pipes=Pipes(output=self.middleware.pipe(buffered)) @@ -815,7 +775,7 @@ async def bulk(self, app, job, method, params, description): `description` contains format string for job progress (e.g. "Deleting snapshot {0[dataset]}@{0[name]}") """ - serviceobj, methodobj = self.middleware._method_lookup(method) + serviceobj, methodobj = self.middleware.get_method(method) if params: if mock := self.middleware._mock_method(method, params[0]): @@ -912,3 +872,27 @@ def _cli_args_descriptions(self, doc, names): k: '\n'.join(v) for k, v in descriptions.items() } + + @no_auth_required + @api_method(CoreSetOptionsArgs, CoreSetOptionsResult, rate_limit=False) + @pass_app() + async def set_options(self, app, options): + if "py_exceptions" in options: + app.py_exceptions = options["py_exceptions"] + + @no_auth_required + @api_method(CoreSubscribeArgs, CoreSubscribeResult) + @pass_app() + async def subscribe(self, app, event): + if not self.middleware.can_subscribe(app, event): + raise CallError('Not authorized', errno.EACCES) + + ident = str(uuid.uuid4()) + await app.subscribe(ident, event) + return ident + + @no_auth_required + @api_method(CoreUnsubscribeArgs, CoreUnsubscribeResult) + @pass_app() + async def unsubscribe(self, app, ident): + await app.unsubscribe(ident) diff --git a/src/middlewared/middlewared/test/integration/assets/roles.py b/src/middlewared/middlewared/test/integration/assets/roles.py index 6fa0232a0c8b9..74501e51a2753 100644 --- a/src/middlewared/middlewared/test/integration/assets/roles.py +++ b/src/middlewared/middlewared/test/integration/assets/roles.py @@ -5,8 +5,7 @@ import pytest import string -from truenas_api_client import ClientException - +from middlewared.service_exception import CallError from middlewared.test.integration.assets.account import unprivileged_user from middlewared.test.integration.utils import call, client @@ -53,8 +52,10 @@ def common_checks( with pytest.raises(Exception) as exc_info: client.call(method, *method_args, **method_kwargs) - assert isinstance(exc_info.value, ClientException) is False or ( - exc_info.value.errno != errno.EACCES and exc_info.value.error != 'Not authorized' + assert not ( + isinstance(exc_info.value, CallError) and + exc_info.value.errno == errno.EACCES and + exc_info.value.errmsg == 'Not authorized' ) elif is_return_type_none: @@ -62,7 +63,7 @@ def common_checks( else: assert client.call(method, *method_args, **method_kwargs) is not None else: - with pytest.raises(ClientException) as ve: + with pytest.raises(CallError) as ve: client.call(method, *method_args, **method_kwargs) assert ve.value.errno == errno.EACCES - assert ve.value.error == 'Not authorized' + assert ve.value.errmsg == 'Not authorized' diff --git a/src/middlewared/middlewared/test/integration/utils/client.py b/src/middlewared/middlewared/test/integration/utils/client.py index f4c43035d9ec9..4793bc6a3ba11 100644 --- a/src/middlewared/middlewared/test/integration/utils/client.py +++ b/src/middlewared/middlewared/test/integration/utils/client.py @@ -1,25 +1,27 @@ import contextlib import errno import os +import logging import socket import requests -from truenas_api_client import Client, ClientException +from middlewared.service_exception import CallError +from truenas_api_client import Client from truenas_api_client.utils import undefined from .pytest import fail __all__ = ["client", "host", "host_websocket_uri", "password", "session", "url", "websocket_url"] + +logger = logging.getLogger(__name__) + """ truenas_server object is used by both websocket client and REST client for determining which server to access for API calls. For HA, the `ip` attribute should be set to the virtual IP of the truenas server. """ - - class TrueNAS_Server: - __slots__ = ( '_ip', '_nodea_ip', @@ -96,7 +98,8 @@ def client(self) -> Client: try: self._client.ping() return self._client - except Exception: + except Exception as e: + logger.warning('Re-connecting test client due to %r', e) # failed liveness check, perhaps server rebooted # if target is truly broken we'll pick up error # when trying to establish a new client connection @@ -160,8 +163,8 @@ def client(*, auth=undefined, auth_required=True, py_exceptions=True, log_py_exc if auth is not None: try: logged_in = c.call("auth.login", *auth) - except ClientException as e: - if e.errno == errno.EBUSY and e.error == 'Rate Limit Exceeded': + except CallError as e: + if e.errno == errno.EBUSY and e.errmsg == 'Rate Limit Exceeded': # our "roles" tests (specifically common_checks() function) # isn't designed very well since it's generating random users # for every unique test_* function in every test file.... @@ -170,6 +173,8 @@ def client(*, auth=undefined, auth_required=True, py_exceptions=True, log_py_exc # related tests to trip on our rate limiting functionality truenas_server.client.call("rate.limit.cache_clear") logged_in = c.call("auth.login", *auth) + else: + raise if auth_required: assert logged_in yield c @@ -199,7 +204,7 @@ def host(): def host_websocket_uri(host_ip=None): - return f"ws://{host_ip or host().ip}/websocket" + return f"ws://{host_ip or host().ip}/api/current" def password(): diff --git a/src/middlewared/middlewared/utils/nginx.py b/src/middlewared/middlewared/utils/nginx.py index 7c81036e89fa1..bf4427f6481e1 100644 --- a/src/middlewared/middlewared/utils/nginx.py +++ b/src/middlewared/middlewared/utils/nginx.py @@ -17,7 +17,7 @@ def get_remote_addr_port(request): remote_addr, remote_port = request.transport.get_extra_info("peername") except Exception: # request can be NoneType or request.transport could be NoneType as well - return '', '' + return "", "" if remote_addr in ["127.0.0.1", "::1"]: try: diff --git a/src/middlewared/middlewared/utils/rate_limit/cache.py b/src/middlewared/middlewared/utils/rate_limit/cache.py index 00b24a77a170d..0593af1a9fb8c 100644 --- a/src/middlewared/middlewared/utils/rate_limit/cache.py +++ b/src/middlewared/middlewared/utils/rate_limit/cache.py @@ -2,12 +2,12 @@ from dataclasses import dataclass from random import uniform from time import monotonic -from typing import Self, TypedDict +from typing import TypedDict from middlewared.auth import is_ha_connection from middlewared.utils.origin import TCPIPOrigin -__all__ = ('RateLimitCache') +__all__ = ['RateLimitCache'] @dataclass(frozen=True) @@ -41,7 +41,6 @@ class RateLimitObject(TypedDict): RL_CACHE: dict[str, RateLimitObject] = dict() - class RateLimit: def cache_key(self, method_name: str, ip: str) -> str: """Generate a unique key per endpoint/consumer""" @@ -108,4 +107,5 @@ def max_entries_reached(self) -> bool: in the global cache has reached `self.max_cache_entries`.""" return len(RL_CACHE) == RateLimitConfig.max_cache_entries + RateLimitCache = RateLimit() diff --git a/src/middlewared/middlewared/utils/service/call.py b/src/middlewared/middlewared/utils/service/call.py index 8f6bf42a048f3..48746a1d00d03 100644 --- a/src/middlewared/middlewared/utils/service/call.py +++ b/src/middlewared/middlewared/utils/service/call.py @@ -15,7 +15,7 @@ def __init__(self, method_name, service): class ServiceCallMixin: - def _method_lookup(self, name): + def get_method(self, name): if '.' not in name: raise CallError('Invalid method name', errno.EBADMSG) diff --git a/src/middlewared/middlewared/worker.py b/src/middlewared/middlewared/worker.py index 1e93877d07e75..43d14411d6ae2 100755 --- a/src/middlewared/middlewared/worker.py +++ b/src/middlewared/middlewared/worker.py @@ -42,14 +42,14 @@ def _call(self, name, serviceobj, methodobj, params=None, app=None, pipes=None, self.client = None def _run(self, name, args, job): - serviceobj, methodobj = self._method_lookup(name) + serviceobj, methodobj = self.get_method(name) return self._call(name, serviceobj, methodobj, args, job=job) def call_sync(self, method, *params, timeout=None, **kwargs): """ Calls a method using middleware client """ - serviceobj, methodobj = self._method_lookup(method) + serviceobj, methodobj = self.get_method(method) if serviceobj._config.process_pool and not hasattr(method, '_job'): if asyncio.iscoroutinefunction(methodobj): @@ -61,7 +61,7 @@ def call_sync(self, method, *params, timeout=None, **kwargs): # process pool. If the process pool is already exhausted, it will lead to a deadlock. # By executing a synchronous implementation of the same method in the same process pool we # eliminate `Hold and wait` condition and prevent deadlock situation from arising. - _, sync_methodobj = self._method_lookup(f'{method}__sync') + _, sync_methodobj = self.get_method(f'{method}__sync') except MethodNotFoundError: # FIXME: Make this an exception in 22.MM self.logger.warning('Service uses a process pool but has an asynchronous method: %r', method) diff --git a/tests/api2/test_430_smb_sharesec.py b/tests/api2/test_430_smb_sharesec.py index e8db0a8bde19d..9aed1af62d564 100644 --- a/tests/api2/test_430_smb_sharesec.py +++ b/tests/api2/test_430_smb_sharesec.py @@ -152,7 +152,7 @@ def test_25_verify_share_info_tdb_is_deleted(request): def test_27_restore_sharesec_with_flush_share_info(request, sharesec_user): depends(request, ["sharesec_acl_set"], scope="session") with client() as c: - c.call('smb.sharesec._flush_share_info') + c.call('smb.sharesec.flush_share_info') results = POST("/sharing/smb/getacl", {'share_name': share_info['name']}) assert results.status_code == 200, results.text diff --git a/tests/api2/test_account_privilege_authentication.py b/tests/api2/test_account_privilege_authentication.py index 90b74a049a4fa..70516a257aa58 100644 --- a/tests/api2/test_account_privilege_authentication.py +++ b/tests/api2/test_account_privilege_authentication.py @@ -5,7 +5,7 @@ import pytest import websocket -from truenas_api_client import ClientException +from middlewared.service_exception import CallError from middlewared.test.integration.assets.account import user, unprivileged_user as unprivileged_user_template from middlewared.test.integration.assets.pool import dataset from middlewared.test.integration.utils import call, client, ssh, websocket_url @@ -104,7 +104,7 @@ def test_websocket_auth_calls_allowed_method(unprivileged_user): def test_websocket_auth_fails_to_call_forbidden_method(unprivileged_user): with client(auth=(unprivileged_user.username, unprivileged_user.password)) as c: - with pytest.raises(ClientException) as ve: + with pytest.raises(CallError) as ve: c.call("pool.create") assert ve.value.errno == errno.EACCES @@ -172,7 +172,7 @@ def test_token_auth_fails_to_call_forbidden_method(unprivileged_user_token): with client(auth=None) as c: assert c.call("auth.login_with_token", unprivileged_user_token) - with pytest.raises(ClientException) as ve: + with pytest.raises(CallError) as ve: c.call("pool.create") assert ve.value.errno == errno.EACCES @@ -183,7 +183,7 @@ def test_drop_privileges(unprivileged_user_token): # This should drop privileges for the current root session assert c.call("auth.login_with_token", unprivileged_user_token) - with pytest.raises(ClientException) as ve: + with pytest.raises(CallError) as ve: c.call("pool.create") assert ve.value.errno == errno.EACCES diff --git a/tests/api2/test_account_privilege_role.py b/tests/api2/test_account_privilege_role.py index 3f5e46fd6dd6c..88d572b74ce29 100644 --- a/tests/api2/test_account_privilege_role.py +++ b/tests/api2/test_account_privilege_role.py @@ -1,13 +1,13 @@ import errno import logging +from time import sleep import pytest -from truenas_api_client import ClientException +from middlewared.service_exception import CallError from middlewared.test.integration.assets.account import unprivileged_user_client from middlewared.test.integration.assets.pool import dataset, snapshot from middlewared.test.integration.utils import client -from time import sleep logger = logging.getLogger(__name__) @@ -23,7 +23,7 @@ def test_can_read_with_read_or_write_role(role): def test_can_not_write_with_read_role(): with dataset("test_snapshot_write1") as ds: with unprivileged_user_client(["SNAPSHOT_READ"]) as c: - with pytest.raises(ClientException) as ve: + with pytest.raises(CallError) as ve: c.call("zfs.snapshot.create", { "dataset": ds, "name": "test", @@ -52,7 +52,7 @@ def test_can_not_delete_with_write_role_with_separate_delete(): with dataset("test_snapshot_delete2") as ds: with snapshot(ds, "test") as id: with unprivileged_user_client(["SNAPSHOT_WRITE"]) as c: - with pytest.raises(ClientException) as ve: + with pytest.raises(CallError) as ve: c.call("zfs.snapshot.delete", id) assert ve.value.errno == errno.EACCES @@ -106,12 +106,12 @@ def test_readonly_can_call_method(method, params): def test_readonly_can_not_call_method(): with unprivileged_user_client(["READONLY_ADMIN"]) as c: - with pytest.raises(ClientException) as ve: + with pytest.raises(CallError) as ve: c.call("user.create") assert ve.value.errno == errno.EACCES - with pytest.raises(ClientException) as ve: + with pytest.raises(CallError) as ve: # fails with EPERM if API access granted c.call("filesystem.mkdir", "/foo") @@ -176,10 +176,10 @@ def test_foreign_job_access(): def test_can_not_subscribe_to_event(): with unprivileged_user_client() as unprivileged: - with pytest.raises(ValueError) as ve: + with pytest.raises(CallError) as ve: unprivileged.subscribe("alert.list", lambda *args, **kwargs: None) - assert ve.value.args[0]["errname"] == "EACCES" + assert ve.value.errno == errno.EACCES def test_can_subscribe_to_event(): diff --git a/tests/api2/test_account_root_password.py b/tests/api2/test_account_root_password.py index de7b4c7178dd2..8811a8c760971 100644 --- a/tests/api2/test_account_root_password.py +++ b/tests/api2/test_account_root_password.py @@ -1,6 +1,6 @@ import pytest -from truenas_api_client import ClientException +from middlewared.service_exception import CallError from middlewared.test.integration.utils import call, client from middlewared.test.integration.assets.account import user from middlewared.test.integration.assets.pool import dataset @@ -48,7 +48,7 @@ def callback(type, **message): assert not any(alert["klass"] == "WebUiRootLogin" for alert in alerts), alerts # Root should not be able to log in with password anymore - with pytest.raises(ClientException): + with pytest.raises(CallError): call("system.info", client_kwargs=dict(auth_required=False)) assert events[0][1]["fields"]["usernames"] == ["admin"] diff --git a/tests/api2/test_api_key.py b/tests/api2/test_api_key.py index 4ed0b32e8819d..602a02b1db765 100644 --- a/tests/api2/test_api_key.py +++ b/tests/api2/test_api_key.py @@ -36,7 +36,7 @@ def test_root_api_key_websocket(request): ip = truenas_server.ip with api_key([{"method": "*", "resource": "*"}]) as key: with user(): - cmd = f"sudo -u testuser midclt -u ws://{ip}/websocket --api-key {key} call system.info" + cmd = f"sudo -u testuser midclt -u ws://{ip}/api/current --api-key {key} call system.info" results = SSH_TEST(cmd, user_, password) assert results['result'] is True, f'out: {results["output"]}, err: {results["stderr"]}' assert 'uptime' in str(results['stdout']) @@ -58,7 +58,7 @@ def test_allowed_api_key_websocket(request): ip = truenas_server.ip with api_key([{"method": "CALL", "resource": "system.info"}]) as key: with user(): - cmd = f"sudo -u testuser midclt -u ws://{ip}/websocket --api-key {key} call system.info" + cmd = f"sudo -u testuser midclt -u ws://{ip}/api/current --api-key {key} call system.info" results = SSH_TEST(cmd, user_, password) assert results['result'] is True, f'out: {results["output"]}, err: {results["stderr"]}' assert 'uptime' in str(results['stdout']) @@ -69,7 +69,7 @@ def test_denied_api_key_websocket(request): ip = truenas_server.ip with api_key([{"method": "CALL", "resource": "system.info_"}]) as key: with user(): - cmd = f"sudo -u testuser midclt -u ws://{ip}/websocket --api-key {key} call system.info" + cmd = f"sudo -u testuser midclt -u ws://{ip}/api/current --api-key {key} call system.info" results = SSH_TEST(cmd, user_, password) assert results['result'] is False diff --git a/tests/api2/test_audit_websocket.py b/tests/api2/test_audit_websocket.py index 93ef6ec9ac8af..8f94634682293 100644 --- a/tests/api2/test_audit_websocket.py +++ b/tests/api2/test_audit_websocket.py @@ -3,8 +3,7 @@ import pytest -from truenas_api_client import ClientException -from middlewared.service_exception import ValidationErrors +from middlewared.service_exception import CallError, ValidationErrors from middlewared.test.integration.assets.account import unprivileged_user_client, user from middlewared.test.integration.assets.api_key import api_key from middlewared.test.integration.utils import call, client, ssh @@ -35,7 +34,7 @@ def test_unauthenticated_call(): "success": False, } ]): - with pytest.raises(ClientException): + with pytest.raises(CallError): c.call("user.create", {"username": "sergey", "full_name": "Sergey"}) @@ -66,7 +65,7 @@ def test_unauthorized_call(): "success": False, } ]): - with pytest.raises(ClientException): + with pytest.raises(CallError): c.call("user.create", {"username": "sergey", "full_name": "Sergey"}) diff --git a/tests/api2/test_auth_token.py b/tests/api2/test_auth_token.py index 2fec805bb223c..ba57a1dcfbe14 100644 --- a/tests/api2/test_auth_token.py +++ b/tests/api2/test_auth_token.py @@ -74,7 +74,7 @@ def unprivileged_user(): def test_login_with_token_match_origin(unprivileged_user): token = ssh( - "sudo -u test midclt -u ws://localhost/websocket -U test -P test1234 call auth.generate_token 300 '{}' true" + "sudo -u test midclt -u ws://localhost/api/current -U test -P test1234 call auth.generate_token 300 '{}' true" ).strip() with client(auth=None) as c: @@ -83,7 +83,7 @@ def test_login_with_token_match_origin(unprivileged_user): def test_login_with_token_no_match_origin(unprivileged_user): token = ssh( - "sudo -u test midclt -u ws://localhost/websocket -U test -P test1234 call auth.generate_token 300" + "sudo -u test midclt -u ws://localhost/api/current -U test -P test1234 call auth.generate_token 300" ).strip() with client(auth=None) as c: diff --git a/tests/api2/test_core_bulk.py b/tests/api2/test_core_bulk.py index bed017ffd96f0..c519d523a3522 100644 --- a/tests/api2/test_core_bulk.py +++ b/tests/api2/test_core_bulk.py @@ -2,10 +2,10 @@ import pytest -from truenas_api_client import ClientException from middlewared.test.integration.assets.account import unprivileged_user_client from middlewared.test.integration.utils import call, mock from middlewared.test.integration.utils.audit import expect_audit_log +from truenas_api_client import ClientException def test_core_bulk_reports_job_id(): diff --git a/tests/api2/test_events.py b/tests/api2/test_events.py index 3f6a0e586077e..23b6ab81723f4 100644 --- a/tests/api2/test_events.py +++ b/tests/api2/test_events.py @@ -1,5 +1,8 @@ +import errno + import pytest +from middlewared.service_exception import CallError from middlewared.test.integration.utils import client @@ -10,5 +13,7 @@ def test_can_subscribe_to_failover_status_event_without_authorization(): def test_can_not_subscribe_to_an_event_without_authorization(): with client(auth=None) as c: - with pytest.raises(ValueError): + with pytest.raises(CallError) as ve: c.subscribe("core.get_jobs", lambda *args, **kwargs: None) + + assert ve.value.errno == errno.EACCES diff --git a/tests/api2/test_ip_auth.py b/tests/api2/test_ip_auth.py index a249f634aec74..49d07b37234a1 100644 --- a/tests/api2/test_ip_auth.py +++ b/tests/api2/test_ip_auth.py @@ -1,5 +1,4 @@ import json -import shlex import pytest @@ -10,7 +9,7 @@ @pytest.mark.parametrize("url", ["127.0.0.1", "127.0.0.1:6000"]) @pytest.mark.parametrize("root", [True, False]) def test_tcp_connection_from_localhost(url, root): - cmd = f"midclt -u ws://{url}/websocket call auth.sessions '[[\"current\", \"=\", true]]' '{{\"get\": true}}'" + cmd = f"midclt -u ws://{url}/api/current call auth.sessions '[[\"current\", \"=\", true]]' '{{\"get\": true}}'" if root: assert json.loads(ssh(cmd))["credentials"] == "ROOT_TCP_SOCKET" else: diff --git a/tests/api2/test_lock.py b/tests/api2/test_lock.py index 7cd43eb5e6859..d5d621d786950 100644 --- a/tests/api2/test_lock.py +++ b/tests/api2/test_lock.py @@ -17,9 +17,9 @@ async def mock(self, *args): start = time.monotonic() with client() as c: - c1 = c.call("test.test1", background=True) - c2 = c.call("test.test1", background=True) - c.wait(c1) + c1 = c.call("test.test1", background=True, register_call=True) + c2 = c.call("test.test1", background=True, register_call=True) + c.wait(c1, timeout=10) c.wait(c2) assert time.monotonic() - start < 6 @@ -38,8 +38,8 @@ async def mock(self, *args): start = time.monotonic() with client() as c: - c1 = c.call("test.test1", background=True) - c2 = c.call("test.test1", background=True) + c1 = c.call("test.test1", background=True, register_call=True) + c2 = c.call("test.test1", background=True, register_call=True) c.wait(c1) c.wait(c2) @@ -59,8 +59,8 @@ def mock(self, *args): start = time.monotonic() with client() as c: - c1 = c.call("test.test1", background=True) - c2 = c.call("test.test1", background=True) + c1 = c.call("test.test1", background=True, register_call=True) + c2 = c.call("test.test1", background=True, register_call=True) c.wait(c1) c.wait(c2) diff --git a/tests/api2/test_rate_limit.py b/tests/api2/test_rate_limit.py index c3419dcd9897a..15eafc7d76f5b 100644 --- a/tests/api2/test_rate_limit.py +++ b/tests/api2/test_rate_limit.py @@ -1,7 +1,6 @@ import errno import pytest -from pytest_dependency import depends from middlewared.test.integration.utils import call, client @@ -9,7 +8,6 @@ SEP = '_##_' -@pytest.mark.dependency(name='rate_limited') def test_unauth_requests_are_rate_limited(): """Test that the truenas server rate limits a caller that is hammering an endpoint that requires no authentication.""" @@ -24,11 +22,8 @@ def test_unauth_requests_are_rate_limited(): c.call(NOAUTH_METHOD) assert ve.value.errno == errno.EBUSY - -def test_rate_limit_global_cache_entries(request): """Test that middleware's rate limit plugin for interacting with the global cache behaves as intended.""" - depends(request, ['rate_limited']) cache = call('rate.limit.cache_get') # the mechanism by which the rate limit chooses a unique key # for inserting into the dictionary is by using the api endpoint @@ -48,9 +43,9 @@ def test_rate_limit_global_cache_entries(request): assert len(new_new_cache) == 0, new_new_cache -def test_auth_requests_are_not_rate_limited(): +@pytest.mark.parametrize('method_name', [NOAUTH_METHOD, 'system.host_id']) +def test_authorized_requests_are_not_rate_limited(method_name): """Test that the truenas server does NOT rate limit a caller - that hammers an endpoint when said caller has been authenticated - and that method requires authentication.""" + that hammers an endpoint when said caller has been authenticated""" for i in range(1, 22): - assert call('system.host_id') + assert call(method_name) diff --git a/tests/api2/test_webui_crypto_service.py b/tests/api2/test_webui_crypto_service.py index 830000d4c2f86..1b2b49f9e805a 100644 --- a/tests/api2/test_webui_crypto_service.py +++ b/tests/api2/test_webui_crypto_service.py @@ -1,7 +1,7 @@ import errno import pytest -from truenas_api_client import ClientException +from middlewared.service_exception import CallError from middlewared.test.integration.assets.account import unprivileged_user_client from middlewared.test.integration.utils import call @@ -17,11 +17,11 @@ def test_ui_crypto_profiles_readonly_role(role, endpoint, valid_role): if valid_role: c.call(endpoint) else: - with pytest.raises(ClientException) as ve: + with pytest.raises(CallError) as ve: c.call(endpoint) assert ve.value.errno == errno.EACCES - assert ve.value.error == 'Not authorized' + assert ve.value.errmsg == 'Not authorized' @pytest.mark.parametrize('role,valid_role', ( @@ -39,8 +39,8 @@ def test_ui_crypto_domain_names_readonly_role(role, valid_role): if valid_role: c.call('webui.crypto.get_certificate_domain_names', default_certificate['id']) else: - with pytest.raises(ClientException) as ve: + with pytest.raises(CallError) as ve: c.call('webui.crypto.get_certificate_domain_names', default_certificate['id']) assert ve.value.errno == errno.EACCES - assert ve.value.error == 'Not authorized' + assert ve.value.errmsg == 'Not authorized' diff --git a/tests/conftest.py b/tests/conftest.py index 016206a6e53dd..20223e2b58e84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ import pytest from middlewared.test.integration.assets.roles import unprivileged_user_fixture # noqa -from middlewared.test.integration.utils.client import client, truenas_server +from middlewared.test.integration.utils.client import truenas_server from middlewared.test.integration.utils.pytest import failed pytest.register_assert_rewrite("middlewared.test")