Skip to content

Commit

Permalink
jsonrpc server
Browse files Browse the repository at this point in the history
  • Loading branch information
themylogin committed Jul 12, 2024
1 parent a815b64 commit 7316bab
Show file tree
Hide file tree
Showing 43 changed files with 777 additions and 343 deletions.
4 changes: 4 additions & 0 deletions src/middlewared/middlewared/api/base/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def api_method(
audit: str | None = None,
audit_callback: bool = False,
audit_extended: Callable[..., str] | None = None,
rate_limit=True,
roles: list[str] | None = None,
private: bool = False,
):
Expand All @@ -33,6 +34,8 @@ def api_method(
`audit_extended` is the function that takes the same arguments as the decorated function and returns the string
that will be appended to the audit message to be logged.
`rate_limit` specifies whether the method calls should be rate limited when calling without authentication.
`roles` is a list of user roles that will gain access to this method.
`private` is `True` when the method should not be exposed in the public API. By default, the method is public.
Expand Down Expand Up @@ -63,6 +66,7 @@ def wrapped(*args):
wrapped.audit = audit
wrapped.audit_callback = audit_callback
wrapped.audit_extended = audit_extended
wrapped.rate_limit = rate_limit
wrapped.roles = roles or []
wrapped._private = private

Expand Down
17 changes: 14 additions & 3 deletions src/middlewared/middlewared/api/base/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import copy
import inspect
from types import NoneType
import typing

from pydantic import BaseModel as PydanticBaseModel, ConfigDict, create_model, Field, model_serializer
Expand Down Expand Up @@ -86,10 +88,19 @@ def wrapper(klass):
return wrapper


def single_argument_result(klass):
def single_argument_result(klass, klass_name=None):
if klass is None:
klass = NoneType

if klass.__module__ == "builtins":
if klass_name is None:
raise TypeError("You must specify class name when using `single_argument_result` for built-in types")
else:
klass_name = klass_name or klass.__name__

return create_model(
klass.__name__,
klass_name,
__base__=(BaseModel,),
__module__=klass.__module__,
__module__=inspect.getmodule(inspect.stack()[1][0]),
**{"result": Annotated[klass, Field()]},
)
Empty file.
18 changes: 18 additions & 0 deletions src/middlewared/middlewared/api/base/server/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import logging
import uuid

from middlewared.auth import SessionManagerCredentials
from middlewared.utils.origin import Origin

logger = logging.getLogger(__name__)


class App:
def __init__(self, origin: Origin):
self.origin = origin
self.session_id = str(uuid.uuid4())
self.authenticated = False
self.authenticated_credentials: SessionManagerCredentials | None = None
self.py_exceptions = False
self.websocket = False
self.rest = False
30 changes: 30 additions & 0 deletions src/middlewared/middlewared/api/base/server/method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import types

from middlewared.job import Job


class Method:
def __init__(self, middleware: "Middleware", name: str):
self.middleware = middleware
self.name = name

async def call(self, app: "RpcWebSocketApp", params):
serviceobj, methodobj = self.middleware.get_method(self.name)

await self.middleware.authorize_method_call(app, self.name, methodobj, params)

if mock := self.middleware._mock_method(self.name, params):
methodobj = mock

result = await self.middleware.call_with_audit(self.name, serviceobj, methodobj, params, app)
if isinstance(result, Job):
result = result.id
elif isinstance(result, types.GeneratorType):
result = list(result)
elif isinstance(result, types.AsyncGeneratorType):
result = [i async for i in result]

return result

def dump_args(self, params):
return self.middleware.dump_args(params, method_name=self.name)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -*- coding=utf-8 -*-
import logging

logger = logging.getLogger(__name__)

__all__ = []
72 changes: 72 additions & 0 deletions src/middlewared/middlewared/api/base/server/ws_handler/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import socket
import struct

from aiohttp.http_websocket import WSCloseCode
from aiohttp.web import Request, WebSocketResponse

from middlewared.auth import is_ha_connection
from middlewared.utils.nginx import get_remote_addr_port
from middlewared.utils.origin import Origin, UnixSocketOrigin, TCPIPOrigin
from middlewared.webui_auth import addr_in_allowlist


class BaseWebSocketHandler:
def __init__(self, middleware: "Middleware"):
self.middleware = middleware

async def __call__(self, request: Request):
ws = WebSocketResponse()
try:
await ws.prepare(request)
except ConnectionResetError:
# Happens when we're preparing a new session, and during the time we prepare, the server is
# stopped/killed/restarted etc. Ignore these to prevent log spam.
return ws

origin = await self.get_origin(request)
if origin is None:
await ws.close()
return ws
if not await self.can_access(origin):
await ws.close(
code=WSCloseCode.POLICY_VIOLATION,
message=b"You are not allowed to access this resource",
)
return ws

await self.process(origin, ws)
return ws

async def get_origin(self, request: Request) -> Origin | None:
try:
sock = request.transport.get_extra_info("socket")
except AttributeError:
# request.transport can be None by the time this is called on HA systems because remote node could have been
# rebooted
return

if sock.family == socket.AF_UNIX:
peercred = sock.getsockopt(socket.SOL_SOCKET, socket.SO_PEERCRED, struct.calcsize("3i"))
pid, uid, gid = struct.unpack("3i", peercred)
return UnixSocketOrigin(pid, uid, gid)

remote_addr, remote_port = await self.middleware.run_in_thread(get_remote_addr_port, request)
return TCPIPOrigin(remote_addr, remote_port)

async def can_access(self, origin: Origin | None) -> bool:
if not isinstance(origin, TCPIPOrigin):
return True

if not (ui_allowlist := await self.middleware.call("system.general.get_ui_allowlist")):
return True

if is_ha_connection(origin.addr, origin.port):
return True

if addr_in_allowlist(origin.addr, ui_allowlist):
return True

return False

async def process(self, origin: Origin, ws: WebSocketResponse):
raise NotImplementedError
Loading

0 comments on commit 7316bab

Please sign in to comment.