diff --git a/mopidy_iris/handlers.py b/mopidy_iris/handlers.py index 3667082b6..648387914 100755 --- a/mopidy_iris/handlers.py +++ b/mopidy_iris/handlers.py @@ -4,10 +4,12 @@ import tornado.web import tornado.websocket import tornado.template +import urllib.parse import logging import json import time import asyncio +from typing import cast from .mem import iris @@ -20,11 +22,15 @@ class WebsocketHandler(tornado.websocket.WebSocketHandler): def initialize(self, core, config): self.core = core self.config = config + self.allowed_origins = config["http"]["allowed_origins"] + self.csrf_protection = config["http"]["csrf_protection"] self.ioloop = tornado.ioloop.IOLoop.current() iris.ioloop = self.ioloop # Make available elsewhere in the Frontend - def check_origin(self, origin): - return True + def check_origin(self, origin: str) -> bool: + if not self.csrf_protection: + return True + return check_origin(origin, self.request.headers, self.allowed_origins) def open(self): @@ -173,7 +179,6 @@ def handle_result(self, *args, **kwargs): data["recipient"] = self.connection_id iris.send_message(data=data) - class HttpHandler(tornado.web.RequestHandler): def set_default_headers(self): self.set_header("Access-Control-Allow-Origin", "*") @@ -189,6 +194,19 @@ def initialize(self, core, config): self.core = core self.config = config self.ioloop = tornado.ioloop.IOLoop.current() + self.allowed_origins = config["http"]["allowed_origins"] + self.csrf_protection = config["http"]["csrf_protection"] + + def check_origin(self, origin: str) -> bool: + if self.csrf_protection: + origin = cast(str | None, self.request.headers.get("Origin")) + if not check_origin(origin, self.request.headers, self.allowed_origins): + self.set_status(403, f"Access denied for origin {origin}") + return + + assert origin + self.set_cors_headers(origin) + return check_origin(origin, self.request.headers, self.allowed_origins) # Options request # This is a preflight request for CORS requests @@ -343,3 +361,23 @@ def initialize(self, path): def get(self, path=None, include_body=True): return super().get(self.path, include_body) + + +def check_origin( + origin: str, + request_headers: tornado.httputil.HTTPHeaders, + allowed_origins: set[str], +) -> bool: + if origin is None: + logger.warning("HTTP request denied for missing Origin header") + return False + host_header = request_headers.get("Host") + parsed_origin = urllib.parse.urlparse(origin).netloc.lower() + # Some frameworks (e.g. Apache Cordova) use local files. Requests from + # these files don't really have a sensible Origin so the browser sets the + # header to something like 'file://' or 'null'. This results here in an + # empty parsed_origin which we choose to allow. + if parsed_origin and parsed_origin not in allowed_origins: + logger.warning('HTTP request denied for Origin "%s"', origin) + return False + return True