diff --git a/docs/hypha_data_store.py b/docs/hypha_data_store.py index 73e391ab..52bc7f1c 100644 --- a/docs/hypha_data_store.py +++ b/docs/hypha_data_store.py @@ -4,6 +4,7 @@ import os from urllib.parse import parse_qs + class HyphaDataStore: def __init__(self): self.storage = {} @@ -12,15 +13,15 @@ def __init__(self): async def setup(self, server, service_id="data-store", visibility="public"): self._server = server - self._svc = await server.register_service({ - "id": service_id, - "type": "functions", - "config": { - "visibility": visibility, - "require_context": False + self._svc = await server.register_service( + { + "id": service_id, + "type": "functions", + "config": {"visibility": visibility, "require_context": False}, + "get": self.http_get, }, - "get": self.http_get, - }, overwrite=True) + overwrite=True, + ) def get_url(self, obj_id: str): assert self._svc, "Service not initialized, call `setup()`" @@ -30,30 +31,30 @@ def get_url(self, obj_id: str): def put(self, obj_type: str, value: any, name: str, comment: str = ""): assert self._svc, "Please call `setup()` before using the store" obj_id = str(uuid.uuid4()) - if obj_type == 'file': + if obj_type == "file": data = value assert isinstance(data, (str, bytes)), "Value must be a string or bytes" if isinstance(data, str) and data.startswith("file://"): # File URL examples: # Absolute URL: `file:///home/data/myfile.png` # Relative URL: `file://./myimage.png`, or `file://myimage.png` - with open(data.replace("file://", ""), 'rb') as fil: + with open(data.replace("file://", ""), "rb") as fil: data = fil.read() mime_type, _ = mimetypes.guess_type(name) self.storage[obj_id] = { - 'type': obj_type, - 'name': name, - 'value': data, - 'mime_type': mime_type or 'application/octet-stream', - 'comment': comment + "type": obj_type, + "name": name, + "value": data, + "mime_type": mime_type or "application/octet-stream", + "comment": comment, } else: self.storage[obj_id] = { - 'type': obj_type, - 'name': name, - 'value': value, - 'mime_type': 'application/json', - 'comment': comment + "type": obj_type, + "name": name, + "value": value, + "mime_type": "application/json", + "comment": comment, } return obj_id @@ -63,47 +64,51 @@ def get(self, id: str): return obj def http_get(self, scope, context=None): - query_string = scope['query_string'] - id = parse_qs(query_string).get('id', [])[0] + query_string = scope["query_string"] + id = parse_qs(query_string).get("id", [])[0] obj = self.storage.get(id) if obj is None: - return {'status': 404, 'headers': {}, 'body': "Not found: " + id} + return {"status": 404, "headers": {}, "body": "Not found: " + id} - if obj['type'] == 'file': - data = obj['value'] + if obj["type"] == "file": + data = obj["value"] if isinstance(data, str): if not os.path.isfile(data): return { "status": 404, - 'headers': {'Content-Type': 'text/plain'}, - "body": "File not found: " + data + "headers": {"Content-Type": "text/plain"}, + "body": "File not found: " + data, } - with open(data, 'rb') as fil: + with open(data, "rb") as fil: data = fil.read() headers = { - 'Content-Type': obj['mime_type'], - 'Content-Length': str(len(obj['value'])), - 'Content-Disposition': f'inline; filename="{obj["name"].split("/")[-1]}"' - } - - return { - 'status': 200, - 'headers': headers, - 'body': obj['value'] + "Content-Type": obj["mime_type"], + "Content-Length": str(len(obj["value"])), + "Content-Disposition": f'inline; filename="{obj["name"].split("/")[-1]}"', } + + return {"status": 200, "headers": headers, "body": obj["value"]} else: return { - 'status': 200, - 'headers': {'Content-Type': 'application/json'}, - 'body': json.dumps(obj['value']) + "status": 200, + "headers": {"Content-Type": "application/json"}, + "body": json.dumps(obj["value"]), } def http_list(self, scope, context=None): - query_string = scope.get('query_string', b'') - kws = parse_qs(query_string).get('keyword', []) + query_string = scope.get("query_string", b"") + kws = parse_qs(query_string).get("keyword", []) keyword = kws[0] if kws else None - result = [value for key, value in self.storage.items() if not keyword or keyword in value['name']] - return {'status': 200, 'headers': {'Content-Type': 'application/json'}, 'body': json.dumps(result)} + result = [ + value + for key, value in self.storage.items() + if not keyword or keyword in value["name"] + ] + return { + "status": 200, + "headers": {"Content-Type": "application/json"}, + "body": json.dumps(result), + } def remove(self, obj_id: str): assert self._svc, "Please call `setup()` before using the store" @@ -112,8 +117,10 @@ def remove(self, obj_id: str): return True raise IndexError("Not found: " + obj_id) + async def test_data_store(server_url="https://ai.imjoy.io"): from imjoy_rpc.hypha import connect_to_server, login + token = await login({"server_url": server_url}) server = await connect_to_server({"server_url": server_url, "token": token}) @@ -122,20 +129,22 @@ async def test_data_store(server_url="https://ai.imjoy.io"): await ds.setup(server) # Test PUT operation - file_id = ds.put('file', 'file:///home/data.txt', 'data.txt') - binary_id = ds.put('file', b'Some binary content', 'example.bin') - json_id = ds.put('json', {'hello': 'world'}, 'example.json') + file_id = ds.put("file", "file:///home/data.txt", "data.txt") + binary_id = ds.put("file", b"Some binary content", "example.bin") + json_id = ds.put("json", {"hello": "world"}, "example.json") # Test GET operation - assert ds.get(file_id)['type'] == 'file' - assert ds.get(binary_id)['type'] == 'file' - assert ds.get(json_id)['type'] == 'json' + assert ds.get(file_id)["type"] == "file" + assert ds.get(binary_id)["type"] == "file" + assert ds.get(json_id)["type"] == "json" # Test GET URL generation print("URL for getting file", ds.get_url(file_id)) print("URL for getting binary object", ds.get_url(binary_id)) print("URL for getting json object", ds.get_url(json_id)) + if __name__ == "__main__": import asyncio - asyncio.run(test_data_store()) \ No newline at end of file + + asyncio.run(test_data_store()) diff --git a/hypha/apps.py b/hypha/apps.py index 01af9859..4d0b7c72 100644 --- a/hypha/apps.py +++ b/hypha/apps.py @@ -272,6 +272,10 @@ async def save_file(key, content): "name" in att and "source" in att ), "Attachment should contain `name` and `source`" if att["source"].startswith("http") and "\n" not in att["source"]: + if not att["source"].startswith("https://"): + raise Exception( + "Only https sources are allowed: " + att["source"] + ) with urlopen(att["source"]) as stream: output = stream.read() att["source"] = output @@ -395,6 +399,8 @@ async def install( ) if source.startswith("http"): + if not source.startswith("https://"): + raise Exception("Only https sources are allowed: " + source) with urlopen(source) as stream: output = stream.read() source = output.decode("utf-8") diff --git a/hypha/asgi.py b/hypha/asgi.py index c6397bf4..72068552 100644 --- a/hypha/asgi.py +++ b/hypha/asgi.py @@ -64,7 +64,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: user_info = user_info.model_dump() else: user_info = None - scope['context'] = {"user": user_info, "_rkwargs": True} + scope["context"] = {"user": user_info, "_rkwargs": True} result = await func(scope) headers = Headers(headers=result.get("headers")) body = result.get("body") diff --git a/hypha/core/auth.py b/hypha/core/auth.py index 60529ef0..c284c71e 100644 --- a/hypha/core/auth.py +++ b/hypha/core/auth.py @@ -38,7 +38,7 @@ if not JWT_SECRET: logger.warning("JWT_SECRET is not defined") - JWT_SECRET = str(shortuuid.uuid()) + JWT_SECRET = shortuuid.ShortUUID().random(length=22) class AuthError(Exception): @@ -135,8 +135,7 @@ def get_rsa_key(kid, refresh=False): if JWKS is None or refresh: with urlopen( f"https://{AUTH0_DOMAIN}/.well-known/jwks.json", - # pylint: disable=protected-access - context=ssl._create_unverified_context(), + context=ssl._create_default_https_context(), ) as jsonurl: JWKS = json.loads(jsonurl.read()) rsa_key = {} diff --git a/hypha/websocket.py b/hypha/websocket.py index 25d6339c..35144af6 100644 --- a/hypha/websocket.py +++ b/hypha/websocket.py @@ -52,7 +52,9 @@ async def disconnect(code): try: user_info, ws, cid = parse_reconnection_token(reconnection_token) except jwt.JWTError as err: - logger.error("Invalid reconnection token: %s", {reconnection_token[:5]}) + logger.error( + "Invalid reconnection token: %s", {reconnection_token[:5]} + ) await disconnect(code=status.WS_1003_UNSUPPORTED_DATA) return if not await store.disconnected_client_exists(f"{ws}/{cid}"):