diff --git a/src/middlewared/middlewared/api/base/model.py b/src/middlewared/middlewared/api/base/model.py index d68142240de60..0b51b443c6f14 100644 --- a/src/middlewared/middlewared/api/base/model.py +++ b/src/middlewared/middlewared/api/base/model.py @@ -1,4 +1,6 @@ import copy +import inspect +from types import NoneType import typing from pydantic import BaseModel as PydanticBaseModel, ConfigDict, create_model, Field, model_serializer @@ -86,10 +88,19 @@ def wrapper(klass): return wrapper -def single_argument_result(klass): +def single_argument_result(klass, klass_name=None): + if klass is None: + klass = NoneType + + if klass.__module__ == "builtins": + if klass_name is None: + raise TypeError("You must specify class name when using `single_argument_result` for built-in types") + else: + klass_name = klass_name or klass.__name__ + return create_model( - klass.__name__, + klass_name, __base__=(BaseModel,), - __module__=klass.__module__, + __module__=inspect.getmodule(inspect.stack()[1][0]), **{"result": Annotated[klass, Field()]}, ) diff --git a/src/middlewared/middlewared/api/base/server/__init__.py b/src/middlewared/middlewared/api/base/server/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/middlewared/middlewared/api/base/server/app.py b/src/middlewared/middlewared/api/base/server/app.py new file mode 100644 index 0000000000000..16472cd188270 --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/app.py @@ -0,0 +1,18 @@ +import logging +import uuid + +from middlewared.auth import SessionManagerCredentials +from middlewared.utils.origin import Origin + +logger = logging.getLogger(__name__) + + +class App: + def __init__(self, origin: Origin): + self.origin = origin + self.session_id = str(uuid.uuid4()) + self.authenticated = False + self.authenticated_credentials: SessionManagerCredentials | None = None + self.py_exceptions = False + self.websocket = False + self.rest = False diff --git a/src/middlewared/middlewared/api/base/server/method.py b/src/middlewared/middlewared/api/base/server/method.py new file mode 100644 index 0000000000000..0037ff0a0234f --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/method.py @@ -0,0 +1,30 @@ +import types + +from middlewared.job import Job + + +class Method: + def __init__(self, middleware: "Middleware", name: str): + self.middleware = middleware + self.name = name + + async def call(self, app: "RpcWebSocketApp", params): + serviceobj, methodobj = self.middleware.get_method(self.name) + + await self.middleware.authorize_method_call(app, self.name, methodobj, params) + + if mock := self.middleware._mock_method(self.name, params): + methodobj = mock + + result = await self.middleware.call_with_audit(self.name, serviceobj, methodobj, params, app) + if isinstance(result, Job): + result = result.id + elif isinstance(result, types.GeneratorType): + result = list(result) + elif isinstance(result, types.AsyncGeneratorType): + result = [i async for i in result] + + return result + + def dump_args(self, params): + self.middleware.dump_args(params, method_name=self.name) diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/__init__.py b/src/middlewared/middlewared/api/base/server/ws_handler/__init__.py new file mode 100644 index 0000000000000..d3b4c09c780ef --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/ws_handler/__init__.py @@ -0,0 +1,6 @@ +# -*- coding=utf-8 -*- +import logging + +logger = logging.getLogger(__name__) + +__all__ = [] diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/base.py b/src/middlewared/middlewared/api/base/server/ws_handler/base.py new file mode 100644 index 0000000000000..975369abe0b7f --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/ws_handler/base.py @@ -0,0 +1,72 @@ +import socket +import struct + +from aiohttp.http_websocket import WSCloseCode +from aiohttp.web import Request, WebSocketResponse + +from middlewared.auth import is_ha_connection +from middlewared.utils.nginx import get_remote_addr_port +from middlewared.utils.origin import Origin, UnixSocketOrigin, TCPIPOrigin +from middlewared.webui_auth import addr_in_allowlist + + +class BaseWebSocketHandler: + def __init__(self, middleware: "Middleware"): + self.middleware = middleware + + async def __call__(self, request: Request): + ws = WebSocketResponse() + try: + await ws.prepare(request) + except ConnectionResetError: + # Happens when we're preparing a new session, and during the time we prepare, the server is + # stopped/killed/restarted etc. Ignore these to prevent log spam. + return ws + + origin = await self.get_origin(request) + if origin is None: + await ws.close() + return ws + if not await self.can_access(origin): + await ws.close( + code=WSCloseCode.POLICY_VIOLATION, + message=b"You are not allowed to access this resource", + ) + return ws + + await self.process(origin, ws) + return ws + + async def get_origin(self, request: Request) -> Origin | None: + try: + sock = request.transport.get_extra_info("socket") + except AttributeError: + # request.transport can be None by the time this is called on HA systems because remote node could have been + # rebooted + return + + if sock.family == socket.AF_UNIX: + peercred = sock.getsockopt(socket.SOL_SOCKET, socket.SO_PEERCRED, struct.calcsize("3i")) + pid, uid, gid = struct.unpack("3i", peercred) + return UnixSocketOrigin(pid, uid, gid) + + remote_addr, remote_port = await self.middleware.run_in_thread(get_remote_addr_port, request) + return TCPIPOrigin(remote_addr, remote_port) + + async def can_access(self, origin: Origin | None) -> bool: + if not isinstance(origin, TCPIPOrigin): + return True + + if not (ui_allowlist := await self.middleware.call("system.general.get_ui_allowlist")): + return True + + if is_ha_connection(origin.addr, origin.port): + return True + + if addr_in_allowlist(origin.addr, ui_allowlist): + return True + + return False + + async def process(self, origin: Origin, ws: WebSocketResponse): + raise NotImplementedError diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/rpc.py b/src/middlewared/middlewared/api/base/server/ws_handler/rpc.py new file mode 100644 index 0000000000000..27187083b166a --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/ws_handler/rpc.py @@ -0,0 +1,301 @@ +import asyncio +import binascii +from collections import defaultdict +import enum +import errno +import pickle +import sys +import traceback +from typing import Any, Callable + +from aiohttp.http_websocket import WSCloseCode, WSMessage +from aiohttp.web import WebSocketResponse, WSMsgType +import jsonschema + +from truenas_api_client import json +from truenas_api_client.jsonrpc import JSONRPCError + +from middlewared.schema import Error +from middlewared.service_exception import (CallException, CallError, ValidationError, ValidationErrors, adapt_exception, + get_errname) +from middlewared.utils.debug import get_frame_details +from middlewared.utils.lock import SoftHardSemaphore, SoftHardSemaphoreLimit +from middlewared.utils.origin import Origin +from .base import BaseWebSocketHandler +from ..app import App +from ..method import Method + +REQUEST_SCHEMA = { + "type": "object", + "additionalProperties": False, + "required": ["jsonrpc", "method"], + "properties": { + "jsonrpc": {"enum": ["2.0"]}, + "method": {"type": "string"}, + "params": {"type": "array"}, + "id": {"type": ["null", "number", "string"]}, + } +} + + +class RpcWebSocketAppEvent(enum.Enum): + MESSAGE = enum.auto() + CLOSE = enum.auto() + + +class RpcWebSocketApp(App): + def __init__(self, middleware: "Middleware", origin: Origin, ws: WebSocketResponse): + super().__init__(origin) + + self.websocket = True + + self.middleware = middleware + self.ws = ws + self.softhardsemaphore = SoftHardSemaphore(10, 20) + self.callbacks = defaultdict(list) + self.subscriptions = {} + + def send(self, data): + fut = asyncio.run_coroutine_threadsafe(self.ws.send_str(json.dumps(data)), self.middleware.loop) + + def send_error(self, id_: Any, code: int, message: str, data: Any = None): + error = { + "jsonrpc": "2.0", + "error": { + "code": code, + "message": message, + }, + "id": id_, + } + if data is not None: + error["error"]["data"] = data + + self.send(error) + + def send_truenas_error(self, id_: Any, code: int, message: str, errno_: int, reason: str, + exc_info=None, extra: list | None = None): + self.send_error(id_, code, message, self.format_truenas_error(errno_, reason, exc_info, extra)) + + def format_truenas_error(self, errno_: int, reason: str, exc_info=None, extra: list | None = None): + return { + "error": errno_, + "errname": get_errname(errno_), + "reason": reason, + "trace": self.truenas_error_traceback(exc_info) if exc_info else None, + "extra": extra, + **({"py_exception": binascii.b2a_base64(pickle.dumps(exc_info[1])).decode()} + if self.py_exceptions and exc_info else {}), + } + + def truenas_error_traceback(self, exc_info): + etype, value, tb = exc_info + + frames = [] + cur_tb = tb + while cur_tb: + tb_frame = cur_tb.tb_frame + cur_tb = cur_tb.tb_next + + cur_frame = get_frame_details(tb_frame, self.middleware.logger) + if cur_frame: + frames.append(cur_frame) + + return { + "class": etype.__name__, + "frames": frames, + "formatted": "".join(traceback.format_exception(*exc_info)), + "repr": repr(value), + } + + def send_truenas_validation_error(self, id_: Any, exc_info, errors: list): + self.send_error(id_, JSONRPCError.INVALID_PARAMS.value, "Invalid params", + self.format_truenas_validation_error(exc_info[1], exc_info, errors)) + + def format_truenas_validation_error(self, exception, exc_info=None, errors: list | None = None): + return self.format_truenas_error(errno.EINVAL, str(exception), exc_info, errors) + + def register_callback(self, event: RpcWebSocketAppEvent, callback: Callable): + self.callbacks[event.value].append(callback) + + def run_callback(self, event, *args, **kwargs): + for callback in self.callbacks[event.value]: + try: + callback(self, *args, **kwargs) + except Exception: + logger.error(f"Failed to run {event} callback", exc_info=True) + + async def subscribe(self, ident: str, name: str): + shortname, arg = self.middleware.event_source_manager.short_name_arg(name) + if shortname in self.middleware.event_source_manager.event_sources: + await self.middleware.event_source_manager.subscribe_app(self, self.__esm_ident(ident), shortname, arg) + else: + self.subscriptions[ident] = name + + async def unsubscribe(self, ident: str): + if ident in self.subscriptions: + self.subscriptions.pop(ident) + elif self.__esm_ident(ident) in self.middleware.event_source_manager.idents: + await self.middleware.event_source_manager.unsubscribe(self.__esm_ident(ident)) + + def __esm_ident(self, ident): + return self.session_id + ident + + def send_event(self, name: str, event_type: str, **kwargs): + if ( + not any(i in [name, "*"] for i in self.subscriptions.values()) and + ( + self.middleware.event_source_manager.short_name_arg(name)[0] not in + self.middleware.event_source_manager.event_sources + ) + ): + return + + event = { + "msg": event_type.lower(), + "collection": name, + } + kwargs = kwargs.copy() + if "id" in kwargs: + event["id"] = kwargs.pop("id") + if event_type in ("ADDED", "CHANGED"): + if "fields" in kwargs: + event["fields"] = kwargs.pop("fields") + if kwargs: + event["extra"] = kwargs + + self.send_notification("collection_update", event) + + def notify_unsubscribed(self, collection: str, error: Exception | None): + params = {"collection": collection, "error": None} + if error: + if isinstance(error, ValidationErrors): + params["error"] = self.format_truenas_validation_error(error, extra=list(error)) + elif isinstance(error, CallError): + params["error"] = self.format_truenas_error(error.errno, str(error), extra=error.extra) + else: + params["error"] = self.format_truenas_error(errno.EINVAL, str(error)) + + self.send_notification("notify_unsubscribed", params) + + def send_notification(self, method, params): + self.send({ + "jsonrpc": "2.0", + "method": method, + "params": params, + }) + + +class RpcWebSocketHandler(BaseWebSocketHandler): + def __init__(self, middleware: "Middleware", methods: {str: Method}): + super().__init__(middleware) + self.methods = methods + + async def process(self, origin: Origin, ws: WebSocketResponse): + app = RpcWebSocketApp(self.middleware, origin, ws) + + self.middleware.register_wsclient(app) + try: + # aiohttp can cancel tasks if a request take too long to finish. + # It is desired to prevent that in this stage in case we are debugging middlewared via gdb (which makes the + # program execution a lot slower) + await asyncio.shield(self.middleware.call_hook("core.on_connect", app)) + + msg: WSMessage + async for msg in ws: + if msg.type == WSMsgType.ERROR: + self.middleware.logger.error("Websocket error: %r", msg.data) + break + + if msg.type != WSMsgType.TEXT: + await ws.close( + code=WSCloseCode.UNSUPPORTED_DATA, + message=f"Invalid websocket message type: {msg.type!r}".encode("utf-8"), + ) + break + + if not app.authenticated and len(msg.data) > 8192: + await ws.close( + code=WSCloseCode.INVALID_TEXT, + message=b"Anonymous connection max message length is 8 kB", + ) + break + + try: + message = json.loads(msg.data) + except ValueError as e: + app.send_error(None, JSONRPCError.INVALID_JSON.value, str(e)) + continue + + app.run_callback(RpcWebSocketAppEvent.MESSAGE, message) + + try: + await self.process_message(app, message) + except Exception as e: + self.middleware.logger.error("Unhandled exception in JSON-RPC message handler", exc_info=True) + await ws.close( + code=WSCloseCode.INTERNAL_ERROR, + message=str(e).encode("utf-8"), + ) + break + finally: + app.run_callback(RpcWebSocketAppEvent.CLOSE) + + await self.middleware.event_source_manager.unsubscribe_app(app) + + self.middleware.unregister_wsclient(app) + + async def process_message(self, app: RpcWebSocketApp, message: Any): + try: + jsonschema.validate(message, REQUEST_SCHEMA) + except jsonschema.ValidationError as e: + app.send_error(app, None, JSONRPCError.INVALID_REQUEST.value, str(e)) + return + + id_ = message.get("id") + method = self.methods.get(message["method"]) + if method is None: + app.send_error(id, JSONRPCError.METHOD_NOT_FOUND.value, "Method does not exist") + return + + asyncio.ensure_future(self.process_method_call(app, id_, method, message["params"])) + + async def process_method_call(self, app: RpcWebSocketApp, id_: Any, method: Method, params: {str: Any}): + try: + async with app.softhardsemaphore: + result = await method.call(app, params) + except SoftHardSemaphoreLimit as e: + app.send_error(id_, JSONRPCError.TRUENAS_TOO_MANY_CONCURRENT_CALLS.value, + f"Maximum number of concurrent calls ({e.args[0]}) has exceeded") + except ValidationError as e: + app.send_truenas_validation_error(id_, sys.exc_info(), [ + (e.attribute, e.errmsg, e.errno), + ]) + except ValidationErrors as e: + app.send_truenas_validation_error(id_, sys.exc_info(), list(e)) + except (CallException, Error) as e: + # CallException and subclasses are the way to gracefully send errors to the client + app.send_truenas_error(id_, JSONRPCError.TRUENAS_CALL_ERROR.value, "Method call error", e.errno, str(e), + sys.exc_info(), e.extra) + except Exception as e: + adapted = adapt_exception(e) + if adapted: + errno_ = adapted.errno + error = adapted + extra = adapted.extra + else: + errno_ = errno.EINVAL + error = e + extra = None + + app.send_truenas_error(id_, JSONRPCError.TRUENAS_CALL_ERROR.value, "Method call error", errno_, + str(error) or repr(error), sys.exc_info(), extra) + + if not adapted and not app.py_exceptions: + self.middleware.logger.warning(f"Exception while calling {method.name}({method.dump_args(params)!r})", + exc_info=True) + else: + app.send({ + "jsonrpc": "2.0", + "result": result, + "id": id_, + }) diff --git a/src/middlewared/middlewared/api/base/server/ws_handler/rpc_factory.py b/src/middlewared/middlewared/api/base/server/ws_handler/rpc_factory.py new file mode 100644 index 0000000000000..c0390794b20fb --- /dev/null +++ b/src/middlewared/middlewared/api/base/server/ws_handler/rpc_factory.py @@ -0,0 +1,19 @@ +from ..method import Method +from .rpc import RpcWebSocketHandler + + +def create_rpc_ws_handler(middleware: "Middleware"): + methods = {} + for service_name, service in middleware.get_services().items(): + for attribute in dir(service): + if attribute.startswith("_"): + continue + + if not callable(getattr(service, attribute)): + continue + + method_name = f"{service_name}.{attribute}" + + methods[method_name] = Method(middleware, method_name) + + return RpcWebSocketHandler(middleware, methods) diff --git a/src/middlewared/middlewared/api/v25_04_0/__init__.py b/src/middlewared/middlewared/api/v25_04_0/__init__.py index ca1f985d122bc..51ab996663221 100644 --- a/src/middlewared/middlewared/api/v25_04_0/__init__.py +++ b/src/middlewared/middlewared/api/v25_04_0/__init__.py @@ -1,3 +1,4 @@ from .cloud_sync import * # noqa from .common import * # noqa +from .core import * # noqa from .user import * # noqa diff --git a/src/middlewared/middlewared/api/v25_04_0/core.py b/src/middlewared/middlewared/api/v25_04_0/core.py new file mode 100644 index 0000000000000..202f3f292e814 --- /dev/null +++ b/src/middlewared/middlewared/api/v25_04_0/core.py @@ -0,0 +1,29 @@ +from middlewared.api.base import BaseModel, ForUpdateMetaclass, single_argument_result + +__all__ = ["CoreSetOptionsArgs", "CoreSetOptionsResult", "CoreSubscribeArgs", "CoreSubscribeResult", + "CoreUnsubscribeArgs", "CoreUnsubscribeResult"] + + +class CoreSetOptionsOptions(BaseModel, metaclass=ForUpdateMetaclass): + py_exceptions: bool + + +class CoreSetOptionsArgs(BaseModel): + options: CoreSetOptionsOptions + + +CoreSetOptionsResult = single_argument_result(None, "CoreSetOptionsResult") + + +class CoreSubscribeArgs(BaseModel): + event: str + + +CoreSubscribeResult = single_argument_result(str, "CoreSubscribeResult") + + +class CoreUnsubscribeArgs(BaseModel): + id_: str + + +CoreUnsubscribeResult = single_argument_result(None, "CoreUnsubscribeResult") diff --git a/src/middlewared/middlewared/api/v25_04_0/user.py b/src/middlewared/middlewared/api/v25_04_0/user.py index b45168c617d6d..f05df38ebcb66 100644 --- a/src/middlewared/middlewared/api/v25_04_0/user.py +++ b/src/middlewared/middlewared/api/v25_04_0/user.py @@ -103,4 +103,4 @@ class UserRenew2faSecretArgs(BaseModel): twofactor_options: TwofactorOptions -UserRenew2faSecretResult = single_argument_result(UserEntry) +UserRenew2faSecretResult = single_argument_result(UserEntry, "UserRenew2faSecretResult") diff --git a/src/middlewared/middlewared/common/event_source/manager.py b/src/middlewared/middlewared/common/event_source/manager.py index 56df54efee0b6..bc5e4efbc6c60 100644 --- a/src/middlewared/middlewared/common/event_source/manager.py +++ b/src/middlewared/middlewared/common/event_source/manager.py @@ -1,12 +1,10 @@ import asyncio from collections import defaultdict, namedtuple -import errno import functools from uuid import uuid4 from middlewared.event import EventSource from middlewared.schema import ValidationErrors -from middlewared.service_exception import CallError IdentData = namedtuple("IdentData", ["subscriber", "name", "arg"]) @@ -28,20 +26,7 @@ def send_event(self, event_type, **kwargs): self.app.send_event(self.collection, event_type, **kwargs) def terminate(self, error): - error_dict = {} - if error: - if isinstance(error, ValidationErrors): - error_dict['error'] = self.app.get_error_dict( - errno.EAGAIN, str(error), etype='VALIDATION', extra=list(error) - ) - elif isinstance(error, CallError): - error_dict['error'] = self.app.get_error_dict( - error.errno, str(error), extra=error.extra - ) - else: - error_dict['error'] = self.app.get_error_dict(errno.EINVAL, str(error)) - - self.app._send({'msg': 'nosub', 'collection': self.collection, **error_dict}) + self.app.notify_unsubscribed(self.collection, error) class InternalSubscriber(Subscriber): diff --git a/src/middlewared/middlewared/etc_files/local/nginx/nginx.conf.mako b/src/middlewared/middlewared/etc_files/local/nginx/nginx.conf.mako index 99ab6584e3935..a0d002a9130ba 100644 --- a/src/middlewared/middlewared/etc_files/local/nginx/nginx.conf.mako +++ b/src/middlewared/middlewared/etc_files/local/nginx/nginx.conf.mako @@ -218,8 +218,15 @@ http { report_uploads proxied; } - location /api/docs { - proxy_pass http://127.0.0.1:6000/api/docs; + location /api { + allow all; # This is handled by `Middleware.ws_can_access` because if we return HTTP 403, browser security + # won't allow us to understand that connection error was due to client IP not being allowlisted. + proxy_pass http://127.0.0.1:6000/api; + proxy_http_version 1.1; + proxy_set_header X-Real-Remote-Addr $remote_addr; + proxy_set_header X-Real-Remote-Port $remote_port; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; } location /api/docs/restful/static { diff --git a/src/middlewared/middlewared/main.py b/src/middlewared/middlewared/main.py index 1d1bf2abe2ebb..d2dd9a888be33 100644 --- a/src/middlewared/middlewared/main.py +++ b/src/middlewared/middlewared/main.py @@ -1,7 +1,9 @@ from .api.base.handler.dump_params import dump_params from .api.base.handler.result import serialize_result +from .api.base.server.ws_handler.base import BaseWebSocketHandler +from .api.base.server.ws_handler.rpc import RpcWebSocketApp, RpcWebSocketAppEvent +from .api.base.server.ws_handler.rpc_factory import create_rpc_ws_handler from .apidocs import routes as apidocs_routes -from .auth import is_ha_connection from .common.event_source.manager import EventSourceManager from .event import Events from .job import Job, JobsQueue @@ -18,8 +20,7 @@ from .utils import MIDDLEWARE_RUN_DIR, sw_version from .utils.debug import get_frame_details, get_threads_stacks from .utils.lock import SoftHardSemaphore, SoftHardSemaphoreLimit -from .utils.nginx import get_remote_addr_port -from .utils.origin import UnixSocketOrigin, TCPIPOrigin +from .utils.origin import Origin from .utils.os import close_fds from .utils.plugins import LoadPluginsMixin from .utils.privilege import credential_has_full_admin @@ -28,7 +29,7 @@ from .utils.syslog import syslog_message from .utils.threading import set_thread_name, IoThreadPoolExecutor, io_thread_pool_executor from .utils.type import copy_function_metadata -from .webui_auth import addr_in_allowlist, WebUIAuth +from .webui_auth import WebUIAuth from .worker import main_worker, worker_init from aiohttp import web from aiohttp.http_websocket import WSCloseCode @@ -58,7 +59,6 @@ import queue import setproctitle import signal -import socket import struct import sys import termios @@ -85,6 +85,7 @@ # Type of the output of sys.exc_info() ExcInfoType = typing.Union[tuple[typing.Type[BaseException], BaseException, types.TracebackType], tuple[None, None, None]] + @dataclass class LoopMonitorIgnoreFrame: regex: typing.Pattern @@ -99,56 +100,23 @@ def real_crud_method(method): return child_method -class Application: +class Application(RpcWebSocketApp): + def __init__(self, middleware: 'Middleware', origin: Origin, loop: asyncio.AbstractEventLoop, request, response): + super().__init__(middleware, origin, response) + self.websocket = True - def __init__(self, middleware: 'Middleware', loop: asyncio.AbstractEventLoop, request, response): - self.middleware = middleware self.loop = loop self.request = request self.response = response - self.authenticated = False - self.authenticated_credentials = None self.handshake = False self.logger = logger.Logger('application').getLogger() - self.session_id = str(uuid.uuid4()) - self.rest = False - self.websocket = True # Allow at most 10 concurrent calls and only queue up until 20 self._softhardsemaphore = SoftHardSemaphore(10, 20) self._py_exceptions = False - """ - Callback index registered by services. They are blocking. - - Currently the following events are allowed: - on_message(app, message) - on_close(app) - """ - self.__callbacks = defaultdict(list) self.__subscribed = {} - @functools.cached_property - def origin(self) -> typing.Union[UnixSocketOrigin, TCPIPOrigin, None]: - try: - sock = self.request.transport.get_extra_info("socket") - except AttributeError: - # self.request.transport can be None by the time this is called - # on HA systems because remote node could have been rebooted - return - - if sock.family == socket.AF_UNIX: - peercred = sock.getsockopt(socket.SOL_SOCKET, socket.SO_PEERCRED, struct.calcsize('3i')) - pid, uid, gid = struct.unpack('3i', peercred) - return UnixSocketOrigin(pid, uid, gid) - - remote_addr, remote_port = get_remote_addr_port(self.request) - return TCPIPOrigin(remote_addr, remote_port) - - def register_callback(self, name: str, method): - assert name in ('on_message', 'on_close') - self.__callbacks[name].append(method) - def _send(self, data: typing.Dict[str, typing.Any]): serialized = json.dumps(data) asyncio.run_coroutine_threadsafe(self.response.send_str(serialized), loop=self.loop) @@ -242,20 +210,6 @@ async def call_method(self, message, serviceobj, methodobj): self.middleware.dump_args(message.get('params', []), method_name=message['method']) ), exc_info=True) - def can_subscribe(self, name): - if event := self.middleware.events.get_event(name): - if event['no_auth_required']: - return True - - if not self.authenticated: - return False - - if event: - if event['no_authz_required']: - return True - - return self.authenticated_credentials.authorize('SUBSCRIBE', name) - async def subscribe(self, ident, name): shortname, arg = self.middleware.event_source_manager.short_name_arg(name) if shortname in self.middleware.event_source_manager.event_sources: @@ -299,8 +253,21 @@ def send_event(self, name, event_type, **kwargs): event['extra'] = kwargs self._send(event) - async def log_audit_message(self, event, event_data, success): - return await self.middleware.log_audit_message(self, event, event_data, success) + def notify_unsubscribed(self, collection, error): + error_dict = {} + if error: + if isinstance(error, ValidationErrors): + error_dict['error'] = self.get_error_dict( + errno.EAGAIN, str(error), etype='VALIDATION', extra=list(error) + ) + elif isinstance(error, CallError): + error_dict['error'] = self.get_error_dict( + error.errno, str(error), extra=error.extra + ) + else: + error_dict['error'] = self.get_error_dict(errno.EINVAL, str(error)) + + self._send({'msg': 'nosub', 'collection': collection, **error_dict}) async def __log_audit_message_for_method(self, message, methodobj, authenticated, authorized, success): return await self.middleware.log_audit_message_for_method( @@ -310,25 +277,15 @@ async def __log_audit_message_for_method(self, message, methodobj, authenticated def on_open(self): self.middleware.register_wsclient(self) - async def on_close(self, *args, **kwargs): - # Run callbacks registered in plugins for on_close - for method in self.__callbacks['on_close']: - try: - method(self) - except Exception: - self.logger.error('Failed to run on_close callback.', exc_info=True) + async def on_close(self): + self.run_callback(RpcWebSocketAppEvent.CLOSE) await self.middleware.event_source_manager.unsubscribe_app(self) self.middleware.unregister_wsclient(self) async def on_message(self, message: typing.Dict[str, typing.Any]): - # Run callbacks registered in plugins for on_message - for method in self.__callbacks['on_message']: - try: - method(self, message) - except Exception: - self.logger.error('Failed to run on_message callback.', exc_info=True) + self.run_callback(RpcWebSocketAppEvent.MESSAGE, message) if message['msg'] == 'connect': if message.get('version') != '1': @@ -346,42 +303,26 @@ async def on_message(self, message: typing.Dict[str, typing.Any]): elif not self.handshake: self._send({'msg': 'failed', 'version': '1'}) elif message['msg'] == 'method': - error = False if 'method' not in message: self.send_error(message, errno.EINVAL, "Message is malformed: 'method' is absent.") - error = True else: try: - serviceobj, methodobj = self.middleware._method_lookup(message['method']) + serviceobj, methodobj = self.middleware.get_method(message['method']) + + await self.middleware.authorize_method_call( + self, message['method'], methodobj, message.get('params') or [], + ) except CallError as e: self.send_error(message, e.errno, str(e), sys.exc_info(), extra=e.extra) - error = True - if not error and not hasattr(methodobj, '_no_auth_required'): - if not self.authenticated: - await self.__log_audit_message_for_method(message, methodobj, False, False, False) - self.send_error(message, ErrnoMixin.ENOTAUTHENTICATED, 'Not authenticated') - error = True - - # Some methods require authentication to the NAS (a valid account) - # but not explicit authorization. In this case the authorization - # check is bypassed as long as it is a user session. API keys - # explicitly whitelist particular methods and are used for targeted - # purposes, and so authorization is _always_ enforced. - elif self.authenticated_credentials.is_user_session and hasattr(methodobj, '_no_authz_required'): - pass - elif not self.authenticated_credentials.authorize('CALL', message['method']): - await self.__log_audit_message_for_method(message, methodobj, True, False, False) - self.send_error(message, errno.EACCES, 'Not authorized') - error = True - if not error: - self.middleware.create_task(self.call_method(message, serviceobj, methodobj)) + else: + self.middleware.create_task(self.call_method(message, serviceobj, methodobj)) elif message['msg'] == 'ping': pong = {'msg': 'pong'} if 'id' in message: pong['id'] = message['id'] self._send(pong) elif message['msg'] == 'sub': - if not self.can_subscribe(message['name'].split(':', 1)[0]): + if not self.middleware.can_subscribe(message['name'].split(':', 1)[0]): self.send_error(message, errno.EACCES, 'Not authorized') else: await self.subscribe(message['id'], message['name']) @@ -542,7 +483,7 @@ async def upload(self, request): return resp try: - serviceobj, methodobj = self.middleware._method_lookup(data['method']) + serviceobj, methodobj = self.middleware.get_method(data['method']) if authenticated_credentials.authorize('CALL', data['method']): job = await self.middleware.call_with_audit(data['method'], serviceobj, methodobj, data.get('params') or [], app, @@ -736,14 +677,16 @@ async def ws_handler(self, request): if not prepared: return ws - if not await self.middleware.ws_can_access(request, ws): + handler = BaseWebSocketHandler(self) + origin = await handler.get_origin(request) + if not await self.middleware.ws_can_access(ws, origin): return ws conndata = ShellConnectionData() conndata.id = str(uuid.uuid4()) try: - await self.run(ws, request, conndata) + await self.run(ws, origin, conndata) except Exception: if conndata.t_worker: await self.worker_kill(conndata.t_worker) @@ -751,7 +694,7 @@ async def ws_handler(self, request): self.shells.pop(conndata.id, None) return ws - async def run(self, ws, request, conndata): + async def run(self, ws, origin, conndata): # Each connection will have its own input queue input_queue = queue.Queue() @@ -771,7 +714,6 @@ async def run(self, ws, request, conndata): if not token: continue - origin = TCPIPOrigin(*await self.middleware.run_in_thread(get_remote_addr_port, request)) token = await self.middleware.call('auth.get_token_for_shell_application', token, origin) if not token: await ws.send_json({ @@ -1166,9 +1108,6 @@ def __notify_startup_complete(self): def plugin_route_add(self, plugin_name, route, method): self.app.router.add_route('*', f'/_plugins/{plugin_name}/{route}', method) - def get_wsclients(self): - return self.__wsclients - def register_wsclient(self, client): self.__wsclients[client.session_id] = client @@ -1417,7 +1356,7 @@ def dump_args(self, args, method=None, method_name=None): if method is None: if method_name is not None: try: - method = self._method_lookup(method_name)[1] + method = self.get_method(method_name)[1] except Exception: return args @@ -1450,6 +1389,40 @@ def dump_result(self, method, result, expose_secrets): return result + async def authorize_method_call(self, app, method_name, methodobj, params): + if hasattr(methodobj, '_no_auth_required'): + return + + if not app.authenticated: + await self.log_audit_message_for_method(method_name, methodobj, params, app, False, False, False) + raise CallError('Not authenticated', ErrnoMixin.ENOTAUTHENTICATED) + + # Some methods require authentication to the NAS (a valid account) + # but not explicit authorization. In this case the authorization + # check is bypassed as long as it is a user session. API keys + # explicitly whitelist particular methods and are used for targeted + # purposes, and so authorization is _always_ enforced. + if app.authenticated_credentials.is_user_session and hasattr(methodobj, '_no_authz_required'): + return + + if not app.authenticated_credentials.authorize('CALL', method_name): + await self.log_audit_message_for_method(method_name, methodobj, params, app, True, False, False) + raise CallError('Not authorized', errno.EACCES) + + def can_subscribe(self, app, name): + if event := self.events.get_event(name): + if event['no_auth_required']: + return True + + if not app.authenticated: + return False + + if event: + if event['no_authz_required']: + return True + + return app.authenticated_credentials.authorize('SUBSCRIBE', name) + async def call_with_audit(self, method, serviceobj, methodobj, params, app, **kwargs): audit_callback_messages = [] success = False @@ -1542,7 +1515,7 @@ async def log_audit_message(self, app, event, event_data, success): async def call(self, name, *params, app=None, audit_callback=None, job_on_progress_cb=None, pipes=None, profile=False): - serviceobj, methodobj = self._method_lookup(name) + serviceobj, methodobj = self.get_method(name) if mock := self._mock_method(name, params): methodobj = mock @@ -1562,7 +1535,7 @@ def call_sync(self, name, *params, job_on_progress_cb=None, app=None, audit_call if background: return self.loop.call_soon_threadsafe(lambda: self.create_task(self.call(name, *params, app=app))) - serviceobj, methodobj = self._method_lookup(name) + serviceobj, methodobj = self.get_method(name) if mock := self._mock_method(name, params): methodobj = mock @@ -1758,7 +1731,7 @@ def set_mock(self, name, args, mock): if args == _args: raise ValueError(f'{name!r} is already mocked with {args!r}') - serviceobj, methodobj = self._method_lookup(name) + serviceobj, methodobj = self.get_method(name) if inspect.iscoroutinefunction(mock): async def f(*args, **kwargs): @@ -1808,10 +1781,12 @@ async def ws_handler(self, request): if not prepared: return ws - if not await self.ws_can_access(request, ws): + handler = BaseWebSocketHandler(self) + origin = await handler.get_origin(request) + if not await self.ws_can_access(ws, origin): return ws - connection = Application(self, self.loop, request, ws) + connection = Application(self, origin, self.loop, request, ws) connection.on_open() try: @@ -1857,26 +1832,14 @@ async def ws_handler(self, request): return ws - async def ws_can_access(self, request, ws): - if not (ui_allowlist := await self.call('system.general.get_ui_allowlist')): - return True - - sock = request.transport.get_extra_info('socket') - if sock.family == socket.AF_UNIX: - return True - - remote_addr, remote_port = await self.run_in_thread(get_remote_addr_port, request) - if is_ha_connection(remote_addr, remote_port): - return True - - if addr_in_allowlist(remote_addr, ui_allowlist): - return True - - await ws.close( - code=WSCloseCode.POLICY_VIOLATION, - message='You are not allowed to access this resource'.encode('utf-8'), - ) - return False + async def ws_can_access(self, ws, origin): + if not await BaseWebSocketHandler(self).can_access(origin): + await ws.close( + code=WSCloseCode.POLICY_VIOLATION, + message='You are not allowed to access this resource'.encode('utf-8'), + ) + return False + return True _loop_monitor_ignore_frames = ( LoopMonitorIgnoreFrame( @@ -1981,6 +1944,10 @@ async def __initialize(self): self.loop.add_signal_handler(signal.SIGUSR1, self.pdb) self.loop.add_signal_handler(signal.SIGUSR2, self.log_threads_stacks) + rpc_ws_handler = create_rpc_ws_handler(self) + app.router.add_route('GET', '/api/current', rpc_ws_handler) + app.router.add_route('GET', '/api/v25.04.0', rpc_ws_handler) + app.router.add_route('GET', '/websocket', self.ws_handler) app.router.add_routes(apidocs_routes) diff --git a/src/middlewared/middlewared/plugins/auth.py b/src/middlewared/middlewared/plugins/auth.py index 9202be3433466..558c8329e3a44 100644 --- a/src/middlewared/middlewared/plugins/auth.py +++ b/src/middlewared/middlewared/plugins/auth.py @@ -7,6 +7,7 @@ import psutil +from middlewared.api.base.server.ws_handler.rpc import RpcWebSocketAppEvent from middlewared.auth import (SessionManagerCredentials, UserSessionManagerCredentials, UnixSocketSessionManagerCredentials, RootTcpSocketSessionManagerCredentials, LoginPasswordSessionManagerCredentials, ApiKeySessionManagerCredentials, @@ -108,12 +109,12 @@ async def login(self, app, credentials): app.authenticated = True app.authenticated_credentials = credentials - app.register_callback("on_message", self._app_on_message) - app.register_callback("on_close", self._app_on_close) + app.register_callback(RpcWebSocketAppEvent.MESSAGE, self._app_on_message) + app.register_callback(RpcWebSocketAppEvent.CLOSE, self._app_on_close) if not is_internal_session(session): self.middleware.send_event("auth.sessions", "ADDED", fields=dict(id=app.session_id, **session.dump())) - await app.log_audit_message("AUTHENTICATION", { + await self.middleware.log_audit_message(app, "AUTHENTICATION", { "credentials": dump_credentials(credentials), "error": None, }, True) @@ -205,7 +206,6 @@ def dump(self): return data - def is_internal_session(session): if isinstance(session.app.origin, UnixSocketOrigin) and session.app.origin.uid == 0: return True @@ -454,7 +454,7 @@ async def login(self, app, username, password, otp_token): """ user = await self.get_login_user(username, password, otp_token) if user is None: - await app.log_audit_message("AUTHENTICATION", { + await self.middleware.log_audit_message(app, "AUTHENTICATION", { "credentials": { "credentials": "LOGIN_PASSWORD", "credentials_data": {"username": username}, @@ -494,7 +494,7 @@ async def login_with_api_key(self, app, api_key): await self.session_manager.login(app, ApiKeySessionManagerCredentials(api_key_object)) return True - await app.log_audit_message("AUTHENTICATION", { + await self.middleware.log_audit_message(app, "AUTHENTICATION", { "credentials": { "credentials": "API_KEY", "credentials_data": { @@ -516,7 +516,7 @@ async def login_with_token(self, app, token_str): """ token = self.token_manager.get(token_str, app.origin) if token is None: - await app.log_audit_message("AUTHENTICATION", { + await self.middleware.log_audit_message(app, "AUTHENTICATION", { "credentials": { "credentials": "TOKEN", "credentials_data": { @@ -528,7 +528,7 @@ async def login_with_token(self, app, token_str): return False if token.attributes: - await app.log_audit_message("AUTHENTICATION", { + await self.middleware.log_audit_message(app, "AUTHENTICATION", { "credentials": { "credentials": "TOKEN", "credentials_data": { diff --git a/src/middlewared/middlewared/restful.py b/src/middlewared/middlewared/restful.py index bf0d364c991d2..a2c655574e603 100644 --- a/src/middlewared/middlewared/restful.py +++ b/src/middlewared/middlewared/restful.py @@ -12,6 +12,7 @@ from truenas_api_client import json +from .api.base.server.app import App from .auth import ApiKeySessionManagerCredentials, LoginPasswordSessionManagerCredentials from .job import Job from .pipe import Pipes @@ -101,11 +102,9 @@ async def authenticate(middleware, request, credentials, method, resource): def create_application(request, credentials=None): - return Application( - request.headers.get('X-Real-Remote-Addr'), - request.headers.get('X-Real-Remote-Port'), - credentials, - ) + return Application(TCPIPOrigin(request.headers.get('X-Real-Remote-Addr'), + int(request.headers.get('X-Real-Remote-Port'))), + credentials) def normalize_query_parameter(value): @@ -115,16 +114,13 @@ def normalize_query_parameter(value): return value -class Application: - def __init__(self, host, remote_port, authenticated_credentials): - self.host = host - self.remote_port = remote_port - self.origin = TCPIPOrigin(self.host, self.remote_port) - self.websocket = False - self.rest = True +class Application(App): + def __init__(self, origin, authenticated_credentials): + super().__init__(origin) + self.session_id = None self.authenticated = authenticated_credentials is not None self.authenticated_credentials = authenticated_credentials - self.session_id = None + self.rest = True class RESTfulAPI(object): @@ -850,7 +846,7 @@ async def do(self, http_method, req, resp, app, authorized, **kwargs): method_args.insert(0, id_) try: - serviceobj, methodobj = self.middleware._method_lookup(methodname) + serviceobj, methodobj = self.middleware.get_method(methodname) if authorized: result = await self.middleware.call_with_audit(methodname, serviceobj, methodobj, method_args, **method_kwargs) diff --git a/src/middlewared/middlewared/service/core_service.py b/src/middlewared/middlewared/service/core_service.py index 368e8d7b3cf51..20b94c60e3a23 100644 --- a/src/middlewared/middlewared/service/core_service.py +++ b/src/middlewared/middlewared/service/core_service.py @@ -8,6 +8,7 @@ import threading import time import traceback +import uuid from collections import defaultdict from remote_pdb import RemotePdb @@ -15,7 +16,12 @@ import middlewared.main +from middlewared.api import api_method from middlewared.api.base.jsonschema import get_json_schema +from middlewared.api.current import ( + CoreSetOptionsArgs, CoreSetOptionsResult, CoreSubscribeArgs, CoreSubscribeResult, CoreUnsubscribeArgs, + CoreUnsubscribeResult, +) from middlewared.common.environ import environ_update from middlewared.job import Job from middlewared.pipe import Pipes @@ -60,52 +66,6 @@ async def resize_shell(self, id_, cols, rows): shell.resize(cols, rows) - @filterable - @filterable_returns(Dict( - 'session', - Str('id'), - Str('socket_family'), - Str('address'), - Bool('authenticated'), - Int('call_count'), - )) - def sessions(self, filters, options): - """ - Get currently open websocket sessions. - """ - sessions = [] - for i in self.middleware.get_wsclients().values(): - try: - session_id = i.session_id - authenticated = i.authenticated - call_count = i._softhardsemaphore.counter - socket_family = socket.AddressFamily(i.request.transport.get_extra_info('socket').family).name - address = '' - if addr := i.request.headers.get('X-Real-Remote-Addr'): - port = i.request.headers.get('X-Real-Remote-Port') - address = f'{addr}:{port}' if all((addr, port)) else address - else: - if (info := i.request.transport.get_extra_info('peername')): - if isinstance(info, list) and len(info) == 2: - address = f'{info[0]}:{info[1]}' - except AttributeError: - # underlying websocket connection can be ripped down in process - # of enumerating this information. This is non-fatal, so ignore it. - pass - except Exception: - self.logger.warning('Failed enumerating websocket session.', exc_info=True) - break - else: - sessions.append({ - 'id': session_id, - 'socket_family': socket_family, - 'address': address, - 'authenticated': authenticated, - 'call_count': call_count, - }) - - return filter_list(sessions, filters, options) - @accepts(Bool('debug_mode')) async def set_debug_mode(self, debug_mode): """ @@ -706,7 +666,7 @@ async def download(self, app, method, args, filename, buffered): return await self._download(app, method, args, filename, buffered) async def _download(self, app, method, args, filename, buffered): - serviceobj, methodobj = self.middleware._method_lookup(method) + serviceobj, methodobj = self.middleware.get_method(method) job = await self.middleware.call_with_audit( method, serviceobj, methodobj, *args, app=app, pipes=Pipes(output=self.middleware.pipe(buffered)) @@ -815,7 +775,7 @@ async def bulk(self, app, job, method, params, description): `description` contains format string for job progress (e.g. "Deleting snapshot {0[dataset]}@{0[name]}") """ - serviceobj, methodobj = self.middleware._method_lookup(method) + serviceobj, methodobj = self.middleware.get_method(method) if params: if mock := self.middleware._mock_method(method, params[0]): @@ -912,3 +872,25 @@ def _cli_args_descriptions(self, doc, names): k: '\n'.join(v) for k, v in descriptions.items() } + + @no_auth_required + @api_method(CoreSetOptionsArgs, CoreSetOptionsResult) + @pass_app() + async def set_options(self, app, options): + if "py_exceptions" in options: + app.py_exceptions = options["py_exceptions"] + + @api_method(CoreSubscribeArgs, CoreSubscribeResult) + @pass_app() + async def subscribe(self, app, event): + if not self.middleware.can_subscribe(app, event): + raise CallError('Not authorized', errno.EACCES) + + ident = str(uuid.uuid4()) + await app.subscribe(ident, event) + return ident + + @api_method(CoreUnsubscribeArgs, CoreUnsubscribeResult) + @pass_app() + async def unsubscribe(self, app, ident): + await app.unsubscribe(ident) diff --git a/src/middlewared/middlewared/utils/nginx.py b/src/middlewared/middlewared/utils/nginx.py index 7c81036e89fa1..bf4427f6481e1 100644 --- a/src/middlewared/middlewared/utils/nginx.py +++ b/src/middlewared/middlewared/utils/nginx.py @@ -17,7 +17,7 @@ def get_remote_addr_port(request): remote_addr, remote_port = request.transport.get_extra_info("peername") except Exception: # request can be NoneType or request.transport could be NoneType as well - return '', '' + return "", "" if remote_addr in ["127.0.0.1", "::1"]: try: diff --git a/src/middlewared/middlewared/utils/service/call.py b/src/middlewared/middlewared/utils/service/call.py index 8f6bf42a048f3..48746a1d00d03 100644 --- a/src/middlewared/middlewared/utils/service/call.py +++ b/src/middlewared/middlewared/utils/service/call.py @@ -15,7 +15,7 @@ def __init__(self, method_name, service): class ServiceCallMixin: - def _method_lookup(self, name): + def get_method(self, name): if '.' not in name: raise CallError('Invalid method name', errno.EBADMSG) diff --git a/src/middlewared/middlewared/worker.py b/src/middlewared/middlewared/worker.py index 1e93877d07e75..43d14411d6ae2 100755 --- a/src/middlewared/middlewared/worker.py +++ b/src/middlewared/middlewared/worker.py @@ -42,14 +42,14 @@ def _call(self, name, serviceobj, methodobj, params=None, app=None, pipes=None, self.client = None def _run(self, name, args, job): - serviceobj, methodobj = self._method_lookup(name) + serviceobj, methodobj = self.get_method(name) return self._call(name, serviceobj, methodobj, args, job=job) def call_sync(self, method, *params, timeout=None, **kwargs): """ Calls a method using middleware client """ - serviceobj, methodobj = self._method_lookup(method) + serviceobj, methodobj = self.get_method(method) if serviceobj._config.process_pool and not hasattr(method, '_job'): if asyncio.iscoroutinefunction(methodobj): @@ -61,7 +61,7 @@ def call_sync(self, method, *params, timeout=None, **kwargs): # process pool. If the process pool is already exhausted, it will lead to a deadlock. # By executing a synchronous implementation of the same method in the same process pool we # eliminate `Hold and wait` condition and prevent deadlock situation from arising. - _, sync_methodobj = self._method_lookup(f'{method}__sync') + _, sync_methodobj = self.get_method(f'{method}__sync') except MethodNotFoundError: # FIXME: Make this an exception in 22.MM self.logger.warning('Service uses a process pool but has an asynchronous method: %r', method)