-
Notifications
You must be signed in to change notification settings - Fork 490
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9dc05ec
commit 40fce86
Showing
43 changed files
with
781 additions
and
348 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import logging | ||
import uuid | ||
|
||
from middlewared.auth import SessionManagerCredentials | ||
from middlewared.utils.origin import Origin | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class App: | ||
def __init__(self, origin: Origin): | ||
self.origin = origin | ||
self.session_id = str(uuid.uuid4()) | ||
self.authenticated = False | ||
self.authenticated_credentials: SessionManagerCredentials | None = None | ||
self.py_exceptions = False | ||
self.websocket = False | ||
self.rest = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import types | ||
|
||
from middlewared.job import Job | ||
|
||
|
||
class Method: | ||
def __init__(self, middleware: "Middleware", name: str): | ||
self.middleware = middleware | ||
self.name = name | ||
|
||
async def call(self, app: "RpcWebSocketApp", params): | ||
serviceobj, methodobj = self.middleware.get_method(self.name) | ||
|
||
await self.middleware.authorize_method_call(app, self.name, methodobj, params) | ||
|
||
if mock := self.middleware._mock_method(self.name, params): | ||
methodobj = mock | ||
|
||
result = await self.middleware.call_with_audit(self.name, serviceobj, methodobj, params, app) | ||
if isinstance(result, Job): | ||
result = result.id | ||
elif isinstance(result, types.GeneratorType): | ||
result = list(result) | ||
elif isinstance(result, types.AsyncGeneratorType): | ||
result = [i async for i in result] | ||
|
||
return result | ||
|
||
def dump_args(self, params): | ||
return self.middleware.dump_args(params, method_name=self.name) |
6 changes: 6 additions & 0 deletions
6
src/middlewared/middlewared/api/base/server/ws_handler/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# -*- coding=utf-8 -*- | ||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
__all__ = [] |
72 changes: 72 additions & 0 deletions
72
src/middlewared/middlewared/api/base/server/ws_handler/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import socket | ||
import struct | ||
|
||
from aiohttp.http_websocket import WSCloseCode | ||
from aiohttp.web import Request, WebSocketResponse | ||
|
||
from middlewared.auth import is_ha_connection | ||
from middlewared.utils.nginx import get_remote_addr_port | ||
from middlewared.utils.origin import Origin, UnixSocketOrigin, TCPIPOrigin | ||
from middlewared.webui_auth import addr_in_allowlist | ||
|
||
|
||
class BaseWebSocketHandler: | ||
def __init__(self, middleware: "Middleware"): | ||
self.middleware = middleware | ||
|
||
async def __call__(self, request: Request): | ||
ws = WebSocketResponse() | ||
try: | ||
await ws.prepare(request) | ||
except ConnectionResetError: | ||
# Happens when we're preparing a new session, and during the time we prepare, the server is | ||
# stopped/killed/restarted etc. Ignore these to prevent log spam. | ||
return ws | ||
|
||
origin = await self.get_origin(request) | ||
if origin is None: | ||
await ws.close() | ||
return ws | ||
if not await self.can_access(origin): | ||
await ws.close( | ||
code=WSCloseCode.POLICY_VIOLATION, | ||
message=b"You are not allowed to access this resource", | ||
) | ||
return ws | ||
|
||
await self.process(origin, ws) | ||
return ws | ||
|
||
async def get_origin(self, request: Request) -> Origin | None: | ||
try: | ||
sock = request.transport.get_extra_info("socket") | ||
except AttributeError: | ||
# request.transport can be None by the time this is called on HA systems because remote node could have been | ||
# rebooted | ||
return | ||
|
||
if sock.family == socket.AF_UNIX: | ||
peercred = sock.getsockopt(socket.SOL_SOCKET, socket.SO_PEERCRED, struct.calcsize("3i")) | ||
pid, uid, gid = struct.unpack("3i", peercred) | ||
return UnixSocketOrigin(pid, uid, gid) | ||
|
||
remote_addr, remote_port = await self.middleware.run_in_thread(get_remote_addr_port, request) | ||
return TCPIPOrigin(remote_addr, remote_port) | ||
|
||
async def can_access(self, origin: Origin | None) -> bool: | ||
if not isinstance(origin, TCPIPOrigin): | ||
return True | ||
|
||
if not (ui_allowlist := await self.middleware.call("system.general.get_ui_allowlist")): | ||
return True | ||
|
||
if is_ha_connection(origin.addr, origin.port): | ||
return True | ||
|
||
if addr_in_allowlist(origin.addr, ui_allowlist): | ||
return True | ||
|
||
return False | ||
|
||
async def process(self, origin: Origin, ws: WebSocketResponse): | ||
raise NotImplementedError |
Oops, something went wrong.