diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..43eb191c --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,23 @@ +# Hypha Change Log + +### 0.20.15 + + - Add `revoke_token` to the workspace api. + - Simplify http endpoints to a fixed pattern such as "{workspace}/services/*" and "{workspace}/apps/*". + - To avoid naming convension, workspace names now must contain at least one hyphens, and only lowercase letters, numbers and hyphens are allowed. + +### 0.20.14 + + - Make `get_service` more restricted to support only service id string, see [migration guide](./docs/migration-guide.md) for more details. + - Clean up http endpoints for the services. + - Remove local cache of the server apps, we now always use s3 as the primary storage. + +### 0.20.12 + + - New Feature: In order to support large language models' function calling feature, hypha support built-in type annotation. With `hypha-rpc>=0.20.12`, we also support type annotation for the service functions in JSON Schema format. In Python, you can use `Pydantic` or simple python type hint, or directly write the json schema for Javascript service functions. This allows you to specify the inputs spec for functions. + - Add type support for the `hypha` module. It allows you to register a type in the workspace using `register_service_type`, `get_service_type`, `list_service_types`. When registering a new service, you can specify the type and enable type check for the service. The type check will be performed when calling the service function. The type check is only available in Python. + - Fix reconnecton issue in the client. + - Support case conversion, which allows converting the service functions to snake_case or camelCase in `get_service` (Python) or `getService` (JavaScript). + - **Breaking Changes**: In Python, all the function names uses snake case, and in JavaScript, all the function names uses camel case. For example, you should call `server.getService` instead of `server.get_service` in JavaScript, and `server.get_service` instead of `server.getService` in Python. + - **Breaking Changes**: The new version of Hypha (0.20.0+) improves the RPC connection to make it more stable and secure, most importantly it supports automatic reconnection when the connection is lost. This also means breaking changes to the previous version. In the new version you will need a new library called `hypha-rpc` (instead of the hypha submodule in the `imjoy-rpc` module) to connect to the server. + diff --git a/README.md b/README.md index 109538e1..2bac5fee 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,51 @@ +# Hypha + ![PyPI](https://img.shields.io/pypi/v/hypha.svg?style=popout) Hypha -# Hypha +**Hypha** is a generative AI-powered application framework designed for large-scale data management, AI model serving, and real-time communication. Hypha allows the creation of computational platforms consisting of both computational and user interface components. + +## Key Features + +- **Generative AI-Powered:** Leverage the power of generative AI to build advanced data management and AI model serving solutions. +- **Hypha-RPC:** Utilize [hypha-rpc](https://github.com/oeway/hypha-rpc), a bidirectional remote procedure call system, enabling seamless communication and integration across distributed components. +- **Real-Time Communication:** Support for real-time communication within virtual workspaces, similar to platforms like Zoom. +- **Scalable and Flexible:** Connect and orchestrate various compute services, AI models, tools, and services running on distributed locations. +- **GenAI-Powered Automation:** Build GenAI-powered automation platforms and agentic workflows, enabling fully autonomous agent systems. + +## How It Works + +Hypha acts as a hub that connects different components through **Hypha-RPC**. Users and programmatic clients connect to the platform in virtual workspaces, where they can seamlessly make remote procedure calls (RPC) as if they are calling local functions. Developers can integrate all types of compute services, including AI models, tools, and services, across distributed locations. + +### Virtual Workspaces + +- Hypha's virtual workspaces are akin to Zoom rooms, where clients (users and programmatic) can connect and interact. +- Within these workspaces, all clients can perform seamless RPCs, facilitating easy and efficient collaboration and computation. + +### Integration with Compute Services + +- Hypha supports the integration of various compute services, including AI models and data analytics tools. +- These services, when connected to the platform, can be understood and orchestrated by large language models (LLMs), paving the way for creating next-generation GenAI-powered automation platforms and agentic workflows. + +## Use Cases + +- **Research Institutions:** Build integrated platforms for data management and AI-powered services. +- **AI-Powered Automation:** Create autonomous agent systems and GenAI-powered automation platforms. +- **Data Services:** Facilitate scalable and efficient data services and analytics through distributed computing. -Hypha is an application framework for large-scale data management and AI model serving, it allows creating computational platforms consists of computational and user interface components. +## Getting Started -Hypha server act as a hub for connecting different components through [hypya-rpc](https://github.com/oeway/hypha-rpc). +To get started with Hypha, follow the installation and setup instructions in the [documentation](https://docs.amun.ai). -## Change log +## License -### 0.20.14 +Hypha is released under the MIT License. See the [LICENSE](./LICENSE) file for more details. - - Make `get_service` more restricted to support only service id string, see [migration guide](./docs/migration-guide.md) for more details. - - Clean up http endpoints for the services, now we also support /services/:service_id/:function_name - - Remove local cache of the server apps, we now always use s3 as the primary storage. +## Contributing -### 0.20.12 +We welcome contributions from the community. Please see our [Contributing Guidelines](./CONTRIBUTING.md) for more information. - - New Feature: In order to support large language models' function calling feature, hypha support built-in type annotation. With `hypha-rpc>=0.20.12`, we also support type annotation for the service functions in JSON Schema format. In Python, you can use `Pydantic` or simple python type hint, or directly write the json schema for Javascript service functions. This allows you to specify the inputs spec for functions. - - Add type support for the `hypha` module. It allows you to register a type in the workspace using `register_service_type`, `get_service_type`, `list_service_types`. When registering a new service, you can specify the type and enable type check for the service. The type check will be performed when calling the service function. The type check is only available in Python. - - Fix reconnecton issue in the client. - - Support case conversion, which allows converting the service functions to snake_case or camelCase in `get_service` (Python) or `getService` (JavaScript). - - **Breaking Changes**: In Python, all the function names uses snake case, and in JavaScript, all the function names uses camel case. For example, you should call `server.getService` instead of `server.get_service` in JavaScript, and `server.get_service` instead of `server.getService` in Python. - - **Breaking Changes**: The new version of Hypha (0.20.0+) improves the RPC connection to make it more stable and secure, most importantly it supports automatic reconnection when the connection is lost. This also means breaking changes to the previous version. In the new version you will need a new library called `hypha-rpc` (instead of the hypha submodule in the `imjoy-rpc` module) to connect to the server. +## Acknowledgments +We thank the contributors and community members who have helped make Hypha what it is today. diff --git a/docs/getting-started.md b/docs/getting-started.md index f93d56a7..d7f41951 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -200,7 +200,7 @@ if __name__ == "__main__": Include the following script in your HTML file to load the `hypha-rpc` client: ```html - + ``` Use the following code in JavaScript to connect to the server and access an existing service: diff --git a/docs/hypha-quick-tour.ipynb b/docs/hypha-quick-tour.ipynb index b5935d5b..37863e4f 100644 --- a/docs/hypha-quick-tour.ipynb +++ b/docs/hypha-quick-tour.ipynb @@ -83,28 +83,28 @@ { "data": { "text/plain": [ - "[{'id': 'public/workspace-manager:triton-client',\n", + "[{'id': 'public/ws:triton-client',\n", " 'name': 'Triton Client',\n", " 'type': 'triton-client',\n", " 'config': {'visibility': 'public',\n", " 'require_context': True,\n", " 'workspace': 'public',\n", " 'flags': []}},\n", - " {'id': 'public/workspace-manager:s3-storage',\n", + " {'id': 'public/ws:s3-storage',\n", " 'name': 'S3 Storage',\n", " 'type': 's3-storage',\n", " 'config': {'visibility': 'public',\n", " 'require_context': True,\n", " 'workspace': 'public',\n", " 'flags': []}},\n", - " {'id': 'public/workspace-manager:card',\n", + " {'id': 'public/ws:card',\n", " 'name': 'Card',\n", " 'type': 'card',\n", " 'config': {'visibility': 'public',\n", " 'require_context': True,\n", " 'workspace': 'public',\n", " 'flags': []}},\n", - " {'id': 'public/workspace-manager:server-apps',\n", + " {'id': 'public/ws:server-apps',\n", " 'name': 'Server Apps',\n", " 'type': 'server-apps',\n", " 'config': {'visibility': 'public',\n", @@ -268,7 +268,7 @@ ], "source": [ "url = await s3.generate_presigned_url(\n", - " info[\"bucket\"], info[\"prefix\"] + \"hello.txt\"\n", + " \"hello.txt\"\n", ")\n", "print(url)\n", "response = await fetch(url, method=\"GET\")\n", diff --git a/docs/hypha_data_store.py b/docs/hypha_data_store.py index 477bd2d7..59fd3f0b 100644 --- a/docs/hypha_data_store.py +++ b/docs/hypha_data_store.py @@ -20,7 +20,7 @@ async def setup(self, server, service_id="data-store", visibility="public"): "config": {"visibility": visibility, "require_context": False}, "get": self.http_get, }, - overwrite=True, + {"overwrite": True}, ) def get_url(self, obj_id: str): diff --git a/docs/migration-guide.md b/docs/migration-guide.md index f27861ab..1556a1a9 100644 --- a/docs/migration-guide.md +++ b/docs/migration-guide.md @@ -15,7 +15,7 @@ To connect to the server, instead of installing the `imjoy-rpc` module, you will pip install -U hypha-rpc # new install ``` -We also changed our versioning strategy, we use the same version number for the server and client, so it's easier to match the client and server versions. For example, `hypha-rpc` version `0.20.14` is compatible with Hypha server version `0.20.14`. +We also changed our versioning strategy, we use the same version number for the server and client, so it's easier to match the client and server versions. For example, `hypha-rpc` version `0.20.19` is compatible with Hypha server version `0.20.19`. #### 2. Change the imports to use `hypha-rpc` @@ -124,10 +124,10 @@ loop.run_forever() To connect to the server, instead of using the `imjoy-rpc` module, you will need to use the `hypha-rpc` module. The `hypha-rpc` module is a standalone module that provides the RPC connection to the Hypha server. You can include it in your HTML using a script tag: ```html - + ``` -We also changed our versioning strategy, we use the same version number for the server and client, so it's easier to match the client and server versions. For example, `hypha-rpc` version `0.20.14` is compatible with Hypha server version `0.20.14`. +We also changed our versioning strategy, we use the same version number for the server and client, so it's easier to match the client and server versions. For example, `hypha-rpc` version `0.20.19` is compatible with Hypha server version `0.20.19`. #### 2. Change the connection method and use camelCase for service function names @@ -145,7 +145,7 @@ Here is a suggested list of search and replace operations to update your code: Here is an example of how the updated code might look: ```html - + + + - - - - - - \ No newline at end of file diff --git a/hypha/built-in/imjoy-plugin-parser.html b/hypha/built-in/imjoy-plugin-parser.html deleted file mode 100644 index 8b5c8612..00000000 --- a/hypha/built-in/imjoy-plugin-parser.html +++ /dev/null @@ -1,127 +0,0 @@ - - - - - - Hypha App Template - - - - - - - - - - \ No newline at end of file diff --git a/hypha/built-in/test.html b/hypha/built-in/test.html deleted file mode 100644 index b0d8d6b2..00000000 --- a/hypha/built-in/test.html +++ /dev/null @@ -1,52 +0,0 @@ - - - - - - - Hypha App Template - - - - - - - - - \ No newline at end of file diff --git a/hypha/card.py b/hypha/card.py index 0da7ca1e..b6940927 100644 --- a/hypha/card.py +++ b/hypha/card.py @@ -6,8 +6,6 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse -from hypha.core import UserInfo -from hypha.core.auth import login_optional from hypha.utils import ( safe_join, list_objects_sync, @@ -53,7 +51,7 @@ async def get_card_file( workspace: str, path: str, request: Request, - user_info: login_optional = Depends(login_optional), + user_info: store.login_optional = Depends(store.login_optional), ): try: path = safe_join(workspace, path) diff --git a/hypha/core/__init__.py b/hypha/core/__init__.py index 67e0feaf..1063800d 100644 --- a/hypha/core/__init__.py +++ b/hypha/core/__init__.py @@ -50,6 +50,7 @@ class ServiceConfig(BaseModel): require_context: Union[Tuple[str], List[str], bool] = False workspace: Optional[str] = None flags: List[str] = [] + singleton: Optional[bool] = False class ServiceInfo(BaseModel): @@ -104,6 +105,10 @@ def model_validate(cls, data): return super().model_validate(data) +class RemoteService(ServiceInfo): + pass + + class UserTokenInfo(BaseModel): """Represent user profile.""" @@ -148,6 +153,9 @@ class UserInfo(BaseModel): default_factory=lambda: {} ) # e.g. s3 credential + def get_workspace(self): + return f"ws-user-{self.id}" + def get_metadata(self, key=None) -> Dict[str, Any]: """Return the metadata.""" if key: @@ -257,7 +265,7 @@ class WorkspaceInfo(BaseModel): """Represent a workspace.""" name: str - description: str + description: Optional[str] = None persistent: Optional[bool] = False owners: Optional[List[str]] = [] read_only: Optional[bool] = False @@ -352,7 +360,7 @@ async def emit_message(self, data: Union[dict, bytes]): pos = unpacker.tell() target_id = message.get("to") if "/" not in target_id: - if "/workspace-manager-" in target_id: + if "/ws-" in target_id: raise ValueError( f"Invalid target ID: {target_id}, it appears that the target is a workspace manager (target_id should starts with */)" ) diff --git a/hypha/core/auth.py b/hypha/core/auth.py index 596167e3..27f66837 100644 --- a/hypha/core/auth.py +++ b/hypha/core/auth.py @@ -12,12 +12,12 @@ import shortuuid from dotenv import find_dotenv, load_dotenv -from fastapi import Header, HTTPException +from fastapi import HTTPException from jinja2 import Environment, PackageLoader, select_autoescape from jose import jwt from hypha.core import UserInfo, UserTokenInfo, ScopeInfo, UserPermission, WorkspaceInfo -from hypha.utils import AsyncTTLCache, random_id +from hypha.utils import random_id logging.basicConfig(stream=sys.stdout) logger = logging.getLogger("auth") @@ -34,6 +34,7 @@ AUTH0_ISSUER = env.get("AUTH0_ISSUER", "https://amun.ai/") AUTH0_NAMESPACE = env.get("AUTH0_NAMESPACE", "https://amun.ai/") JWT_SECRET = env.get("JWT_SECRET") +LOGIN_SERVICE_URL = "/public/services/hypha-login" if not JWT_SECRET: logger.warning( @@ -42,32 +43,6 @@ JWT_SECRET = shortuuid.ShortUUID().random(length=22) -def login_optional(authorization: str = Header(None)): - """Return user info or create an anonymouse user. - - If authorization code is valid the user info is returned, - If the code is invalid an an anonymouse user is created. - """ - if authorization: - return parse_token(authorization) - else: - return generate_anonymous_user() - - -def login_required(authorization: str = Header(None)): - """Return user info if authorization code is valid.""" - return parse_token(authorization) - - -def admin_required(authorization: str = Header(None)): - """Return user info if the authorization code has an admin role.""" - token = parse_token(authorization) - roles = token.credentials.get(AUTH0_NAMESPACE + "roles", []) - if "admin" not in roles: - raise HTTPException(status_code=401, detail="Admin required") - return token - - def get_user_email(token): """Return the user email from the token.""" return token.credentials.get(AUTH0_NAMESPACE + "email") @@ -314,20 +289,24 @@ def update_user_scope( """Update the user scope for a workspace.""" user_info.scope = user_info.scope or ScopeInfo() permission = user_info.get_permission(workspace_info.name) + ws_scopes = {} if not permission: # infer permission from workspace - if user_info.id == workspace_info.name: - permission = UserPermission.admin - elif "admin" in user_info.roles: - permission = UserPermission.admin - elif ( - user_info.email in workspace_info.owners + if ( + user_info.get_workspace() == workspace_info.name + or user_info.email in workspace_info.owners or user_info.id in workspace_info.owners ): permission = UserPermission.admin + if permission: + ws_scopes[workspace_info.name] = permission + + if "admin" in user_info.roles: + ws_scopes["*"] = UserPermission.admin + return create_scope( - workspaces={workspace_info.name: permission} if permission else {}, + workspaces=ws_scopes, client_id=client_id, extra_scopes=user_info.scope.extra_scopes, ) @@ -345,57 +324,19 @@ def generate_jwt_scope(scope: ScopeInfo) -> str: return ps -def parse_reconnection_token(token) -> UserInfo: - """Parse a reconnection token.""" - payload = jwt.decode( - token, - JWT_SECRET, - algorithms=["HS256"], - audience=AUTH0_AUDIENCE, - issuer=AUTH0_ISSUER, - ) - user_info = get_user_info(payload) - scope = user_info.scope - assert len(scope.workspaces) == 1, "Invalid scope, it must have only one workspace" - assert scope.client_id, "Invalid scope, client_id is required" - assert len(scope.workspaces) == 1, "Invalid scope, it must have only one workspace" - workspace = list(scope.workspaces.keys())[0] - client_id = scope.client_id - return user_info, workspace, client_id - - -def parse_user(token): - """Parse user info from a token.""" - if token: - user_info = parse_token(token) - uid = user_info.id - logger.info("User connected: %s", uid) - else: - user_info = generate_anonymous_user() - uid = user_info.id - logger.info("Anonymized User connected: %s", uid) - - if uid == "root": - logger.error("Root user is not allowed to connect remotely") - raise Exception("Root user is not allowed to connect remotely") - - return user_info - - -async def register_login_service(server): +def create_login_service(store): """Hypha startup function for registering additional services.""" - cache = AsyncTTLCache(ttl=int(MAXIMUM_LOGIN_TIME)) - server_url = server.config["public_base_url"] - login_url = f"{server_url}/{server.config['workspace']}/apps/hypha-login/" - login_service_url = ( - f"{server_url}/{server.config['workspace']}/services/hypha-login" - ) + redis = store.get_redis() + server_url = store.public_base_url + login_service_url = f"{server_url}{LOGIN_SERVICE_URL}" + generate_token_url = f"{server_url}/public/services/ws/generate_token" jinja_env = Environment( loader=PackageLoader("hypha"), autoescape=select_autoescape() ) - temp = jinja_env.get_template("login_template.html") + temp = jinja_env.get_template("apps/login_template.html") login_page = temp.render( login_service_url=login_service_url, + generate_token_url=generate_token_url, auth0_client_id=AUTH0_CLIENT_ID, auth0_domain=AUTH0_DOMAIN, auth0_audience=AUTH0_AUDIENCE, @@ -404,10 +345,11 @@ async def register_login_service(server): async def start_login(): """Start the login process.""" - key = str(random_id(readable=True)) - await cache.add(key, False) + key = "login_key:" + str(random_id(readable=True)) + # set the key and with expire time + await redis.setex(key, MAXIMUM_LOGIN_TIME, "") return { - "login_url": f"{login_url}?key={key}", + "login_url": f"{login_service_url.replace('/services/', '/apps/')}/?key={key}", "key": key, "report_url": f"{login_service_url}/report", "check_url": f"{login_service_url}/check", @@ -424,11 +366,13 @@ async def index(event): async def check_login(key, timeout=MAXIMUM_LOGIN_TIME, profile=False): """Check the status of a login session.""" - assert key in cache, "Invalid key, key does not exist" + assert await redis.exists(key), "Invalid key, key does not exist" if timeout <= 0: - user_info = await cache.get(key) + user_info = await redis.get(key) + user_info = json.loads(user_info) + user_info = UserTokenInfo.model_validate(user_info) if user_info: - del cache[key] + await redis.delete(key) return ( user_info.model_dump(mode="json") if profile @@ -436,18 +380,20 @@ async def check_login(key, timeout=MAXIMUM_LOGIN_TIME, profile=False): ) count = 0 while True: - user_info = await cache.get(key) + user_info = await redis.get(key) + user_info = json.loads(user_info) + user_info = UserTokenInfo.model_validate(user_info) if user_info is None: raise Exception( f"Login session expired, the maximum login time is {MAXIMUM_LOGIN_TIME} seconds" ) if user_info: - del cache[key] + await redis.delete(key) return user_info.model_dump(mode="json") if profile else user_info.token await asyncio.sleep(1) count += 1 if count > timeout: - raise Exception("Login timeout") + raise Exception(f"Login timeout, waited for {timeout} seconds") async def report_login( key, @@ -460,7 +406,7 @@ async def report_login( picture=None, ): """Report a token associated with a login session.""" - assert key in cache, "Invalid key, key does not exist or expired" + assert await redis.exists(key), "Invalid key, key does not exist or expired" kwargs = { "token": token, "email": email, @@ -471,31 +417,18 @@ async def report_login( "picture": picture, } user_info = UserTokenInfo.model_validate(kwargs) - await cache.update(key, user_info) - - async def generate_token(token: str, expires_in: int): - """Generate a new user token.""" - # limit the expiration time to 1 year - expires_in = int(expires_in) - if expires_in > 31536000: - raise ValueError("The maximum expiration time is 1 year (31536000 seconds)") - user_info = parse_token(token) - return generate_presigned_token(user_info, expires_in=expires_in) - - svc = await server.register_service( - { - "name": "Hypha Login", - "id": "hypha-login", - "type": "functions", - "description": "Login service for Hypha", - "config": {"visibility": "public"}, - "index": index, - "start": start_login, - "check": check_login, - "report": report_login, - "generate": generate_token, - } - ) - - logger.info("Login service is available at: %s", svc.id) - logger.info(f"To preview the login page, visit: {login_url}") + user_info = user_info.model_dump(mode="json") + await redis.setex(key, MAXIMUM_LOGIN_TIME, json.dumps(user_info)) + + logger.info(f"To preview the login page, visit: {login_service_url}") + return { + "name": "Hypha Login", + "id": "hypha-login", + "type": "functions", + "description": "Login service for Hypha", + "config": {"visibility": "public"}, + "index": index, + "start": start_login, + "check": check_login, + "report": report_login, + } diff --git a/hypha/core/store.py b/hypha/core/store.py index 181a3679..0a29a2a0 100644 --- a/hypha/core/store.py +++ b/hypha/core/store.py @@ -5,10 +5,12 @@ import sys from typing import List, Union from pydantic import BaseModel +from fastapi import Header, Cookie from hypha_rpc import RPC from starlette.routing import Mount +from hypha import __version__ from hypha.core import ( RedisEventBus, RedisRPCConnection, @@ -17,7 +19,17 @@ UserInfo, WorkspaceInfo, ) -from hypha.core.auth import create_scope +from hypha.core.auth import ( + create_scope, + parse_token, + generate_anonymous_user, + UserPermission, + AUTH0_CLIENT_ID, + AUTH0_DOMAIN, + AUTH0_AUDIENCE, + AUTH0_ISSUER, + LOGIN_SERVICE_URL, +) from hypha.core.workspace import WorkspaceManager from hypha.startup import run_startup_function from hypha.utils import random_id @@ -50,7 +62,7 @@ async def _get_workspace_manager(self): # Check if workspace exists if not await self._redis.hexists("workspaces", self._workspace): raise KeyError(f"Workspace {self._workspace} does not exist") - self._wm = await self._rpc.get_manager_service(self._timeout) + self._wm = await self._rpc.get_manager_service({"timeout": self._timeout}) self._wm.rpc = self._rpc self._wm.disconnect = self._rpc.disconnect self._wm.register_codec = self._rpc.register_codec @@ -91,8 +103,14 @@ def __init__( self.manager_id = None self.reconnection_token_life_time = reconnection_token_life_time self._server_info = { + "hypha_version": __version__, "public_base_url": self.public_base_url, "local_base_url": self.local_base_url, + "auth0_client_id": AUTH0_CLIENT_ID, + "auth0_domain": AUTH0_DOMAIN, + "auth0_audience": AUTH0_AUDIENCE, + "auth0_issuer": AUTH0_ISSUER, + "login_service_url": f"{self.public_base_url}{LOGIN_SERVICE_URL}", } if redis_uri and redis_uri.startswith("redis://"): @@ -138,6 +156,37 @@ async def setup_root_user(self) -> UserInfo: ) return self._root_user + async def load_or_create_workspace(self, user_info: UserInfo, workspace: str): + """Setup the workspace.""" + if workspace is None: + workspace = user_info.get_workspace() + + assert workspace != "*", "Dynamic workspace is not allowed for this endpoint" + # Anonymous and Temporary users are not allowed to create persistant workspaces + persistent = ( + not user_info.is_anonymous and "temporary-test-user" not in user_info.roles + ) + + # Ensure calls to store for workspace existence and permissions check + workspace_info = await self.get_workspace_info(workspace, load=True) + if not workspace_info: + if workspace != user_info.get_workspace(): + raise KeyError( + f"User can only connect to a pre-existing workspace or their own workspace: {workspace}" + ) + # Simplified logic for workspace creation, ensure this matches the actual store method signatures + workspace_info = await self.register_workspace( + { + "name": workspace, + "description": f"Default user workspace for {user_info.id}", + "persistent": persistent, + "owners": [user_info.id], + "read_only": user_info.is_anonymous, + } + ) + logger.info(f"Created workspace: {workspace}") + return workspace_info + def get_root_user(self): """Get the root user.""" return self._root_user @@ -169,7 +218,7 @@ async def init(self, reset_redis, startup_functions=None): await self.register_workspace( WorkspaceInfo.model_validate( { - "name": "root", + "name": self._root_user.get_workspace(), "description": "Root workspace", "persistent": True, "owners": ["root"], @@ -212,7 +261,7 @@ async def init(self, reset_redis, startup_functions=None): try: await api.register_service( service.model_dump(), - notify=True, + {"notify": True}, ) except Exception: # pylint: disable=broad-except logger.exception("Failed to register public service: %s", service) @@ -226,7 +275,7 @@ async def init(self, reset_redis, startup_functions=None): async def _register_root_services(self): """Register root services.""" self._root_workspace_interface = await self.get_workspace_interface( - "root", self._root_user + self._root_user.get_workspace(), self._root_user ) await self._root_workspace_interface.register_service( { @@ -237,6 +286,7 @@ async def _register_root_services(self): "require_context": False, }, "kickout_client": self.kickout_client, + "list_workspaces": self.list_all_workspaces, } ) @@ -273,6 +323,36 @@ async def get_user_workspace(self, user_id: str): ) return workspace_info + async def parse_user_token(self, token): + """Parse a client token.""" + user_info = parse_token(token) + key = "revoked_token:" + token + if await self._redis.exists(key): + raise Exception("Token has been revoked") + # automatically add user's own workspace to the scope + if not user_info.scope.workspaces: + user_info.scope.workspaces = { + user_info.get_workspace(): UserPermission.admin + } + if "admin" in user_info.roles: + user_info.scope.workspaces["*"] = UserPermission.admin + return user_info + + async def login_optional( + self, authorization: str = Header(None), access_token: str = Cookie(None) + ): + """Return user info or create an anonymouse user. + + If authorization code is valid the user info is returned, + If the code is invalid an an anonymouse user is created. + """ + token = authorization or access_token + if token: + user_info = await self.parse_user_token(token) + return user_info + else: + return generate_anonymous_user() + async def get_all_workspace(self): """Get all workspaces.""" workspaces = await self._redis.hgetall("workspaces") @@ -323,6 +403,10 @@ async def register_workspace_manager(self): await manager.setup() return manager + def get_server_info(self): + """Get the server information.""" + return self._server_info + def get_workspace_interface( self, workspace: str, @@ -346,7 +430,14 @@ def get_workspace_interface( async def list_all_workspaces(self): """List all workspaces.""" workspace_keys = await self._redis.hkeys("workspaces") - return [k.decode() for k in workspace_keys] + workspaces = [] + for k in workspace_keys: + workspace_info = await self._redis.hget("workspaces", k) + workspace_info = WorkspaceInfo.model_validate( + json.loads(workspace_info.decode()) + ) + workspaces.append(workspace_info) + return [workspace for workspace in workspaces] def create_rpc( self, @@ -380,7 +471,7 @@ def create_rpc( ) return rpc - async def get_workspace(self, workspace: str, load: bool = False): + async def get_workspace_info(self, workspace: str, load: bool = False): """Return the workspace information.""" try: return await self._workspace_manager.load_workspace_info( diff --git a/hypha/core/workspace.py b/hypha/core/workspace.py index fe7b51c2..c2d01adb 100644 --- a/hypha/core/workspace.py +++ b/hypha/core/workspace.py @@ -1,11 +1,11 @@ -import asyncio import re import json import logging -import random +import time import sys from typing import Optional, Union, List, Any from contextlib import asynccontextmanager +import random from fakeredis import aioredis from hypha_rpc import RPC @@ -18,11 +18,12 @@ UserInfo, WorkspaceInfo, ServiceInfo, + RemoteService, TokenConfig, UserPermission, ServiceTypeInfo, ) -from hypha.core.auth import generate_presigned_token, create_scope +from hypha.core.auth import generate_presigned_token, create_scope, valid_token from hypha.utils import EventBus, random_id logging.basicConfig(stream=sys.stdout) @@ -32,7 +33,7 @@ SERVICE_SUMMARY_FIELD = ["id", "name", "type", "description", "config"] # Ensure the client_id is safe -_allowed_characters = re.compile(r"^[a-zA-Z0-9-_/*]*$") +_allowed_characters = re.compile(r"^[a-zA-Z0-9-_/|*]*$") def validate_key_part(key_part: str): @@ -41,6 +42,21 @@ def validate_key_part(key_part: str): raise ValueError(f"Invalid characters in query part: {key_part}") +class GetServiceConfig(BaseModel): + mode: Optional[str] = Field( + None, + description="Mode for selecting the service, it can be 'random', 'first', 'last' or 'exact'", + ) + timeout: Optional[float] = Field( + 10.0, + description="The timeout duration in seconds for fetching the service. This determines how long the function will wait for a service to respond before considering it a timeout.", + ) + case_conversion: Optional[str] = Field( + None, + description="The case conversion for service keys, can be 'camel', 'snake' or None, default is None.", + ) + + class WorkspaceManager: def __init__( self, @@ -71,13 +87,13 @@ async def setup( """Setup the workspace manager.""" if self._initialized: return self._rpc - self._client_id = "workspace-manager-" + random_id(readable=False) + self._client_id = "ws-" + random_id(readable=False) rpc = self._create_rpc(self._client_id) self._rpc = rpc management_service = self.create_service(service_id, service_name) await rpc.register_service( management_service, - notify=False, + {"notify": False}, ) self._initialized = True return rpc @@ -110,22 +126,18 @@ def _validate_workspace_name(self, name): """Validate the workspace name.""" if not name: raise ValueError("Workspace name must not be empty.") - # only allow numbers, letters in lower case, hyphens, underscores and | + + if "-" not in name: + raise ValueError( + "Workspace name must contain at least one hyphen (e.g. my-workspace)." + ) + # only allow numbers, letters in lower case and hyphens (no underscore) # use a regex to validate the workspace name - pattern = re.compile(r"^[a-z0-9-_|]*$") + pattern = re.compile(r"^[a-z0-9-]*$") if not pattern.match(name): - raise ValueError(f"Invalid workspace name: {name}, must match {pattern}") - if name in [ - "protected", - "private", - "default", - "built-in", - "all", - "admin", - "system", - "server", - ]: - raise ValueError("Invalid workspace name: " + name) + raise ValueError( + f"Invalid workspace name: {name}, only lowercase letters, numbers and hyphens are allowed." + ) return name async def _bookmark_workspace( @@ -136,7 +148,7 @@ async def _bookmark_workspace( ): """Bookmark the workspace for the user.""" assert isinstance(workspace, WorkspaceInfo) and isinstance(user_info, UserInfo) - user_workspace = await self.load_workspace_info(user_info.id) + user_workspace = await self.load_workspace_info(user_info.get_workspace()) user_workspace.config = user_workspace.config or {} if "bookmarks" not in user_workspace.config: user_workspace.config["bookmarks"] = [] @@ -154,7 +166,7 @@ async def _get_bookmarked_workspaces( ) -> List[dict]: """Get the bookmarked workspaces for the user.""" try: - user_workspace = await self.load_workspace_info(user_info.id) + user_workspace = await self.load_workspace_info(user_info.get_workspace()) return user_workspace.config.get("bookmarks", []) except KeyError: return [] @@ -185,7 +197,8 @@ async def create_workspace( workspace = WorkspaceInfo.model_validate(config) if user_info.id not in workspace.owners: workspace.owners.append(user_info.id) - self._validate_workspace_name(workspace.name) + if user_info.id != "root": + self._validate_workspace_name(workspace.name) # make sure we add the user's email to owners _id = user_info.email or user_info.id if _id not in workspace.owners: @@ -193,7 +206,7 @@ async def create_workspace( workspace.owners = [o.strip() for o in workspace.owners if o.strip()] # user workspace, let's store all the created workspaces - if user_info.id == workspace.name: + if user_info.get_workspace() == workspace.name: workspace.config = workspace.config or {} workspace.config["bookmarks"] = [ { @@ -206,7 +219,7 @@ async def create_workspace( "workspaces", workspace.name, workspace.model_dump_json() ) await self._event_bus.emit("workspace_loaded", workspace.model_dump()) - if user_info.id != workspace.name: + if user_info.get_workspace() != workspace.name: await self._bookmark_workspace(workspace, user_info, context=context) return workspace.model_dump() @@ -270,6 +283,24 @@ async def register_service_type( await self._update_workspace(workspace_info, user_info) return type_info + @schema_method + async def revoke_token( + self, + token: str = Field(..., description="token to be revoked"), + context: dict = None, + ): + """Revoke a token by storing it in Redis with a prefix and expiration time.""" + self.validate_context(context, UserPermission.admin) + try: + payload = valid_token(token) + except Exception as e: + raise ValueError(f"Invalid token: {e}") + expiration = int(payload.get("exp") - time.time()) + if expiration > 0: + await self._redis.setex("revoked_token:" + token, expiration, "revoked") + else: + raise ValueError("Token has already expired") + @schema_method async def get_service_type( self, @@ -577,6 +608,9 @@ async def list_services( async def register_service( self, service: ServiceInfo = Field(..., description="Service info"), + config: Optional[dict] = Field( + None, description="Options for registering service" + ), context: Optional[dict] = None, ): """Register a new service.""" @@ -594,11 +628,23 @@ async def register_service( assert ":" in service.id, "Service id info must contain ':'" service.app_id = service.app_id or "*" + service_name = service.id.split(":")[1] + workspace = service.id.split("/")[0] + key = f"services:*:{workspace}/*:{service_name}@*" + peer_keys = await self._redis.keys(key) + if len(peer_keys) > 0: + for peer_key in peer_keys: + peer_service = await self._redis.hgetall(peer_key) + peer_service = ServiceInfo.from_redis_dict(peer_service) + if peer_service.config.singleton: + raise ValueError( + f"A singleton service with the same name ({service_name}) has already exists in the workspace ({workspace}), please remove it first or use a different name." + ) key = ( f"services:{service.config.visibility.value}:{service.id}@{service.app_id}" ) # Check if the clients exists if not a built-in service - if ":built-in" not in service.id and ws not in ["root", "public"]: + if ":built-in" not in service.id and ws not in ["ws-user-root", "public"]: builtins = await self._redis.keys(f"services:*:{client_id}:built-in@*") if not builtins: logger.warning( @@ -650,6 +696,10 @@ async def get_service_info( ..., description="Service id, it can be the service id or the full service id with workspace: `workspace/client_id:service_id`", ), + config: Optional[dict] = Field( + None, + description="Options for getting service, the only available config is `mode`, for selecting the service, it can be 'random', 'first', 'last' or 'exact'", + ), context: Optional[dict] = None, ): """Get the service info.""" @@ -681,7 +731,27 @@ async def get_service_info( keys = await self._redis.keys(key) if not keys: raise KeyError(f"Service not found: {service_id}@{app_id}") - key = keys[0] + config = config or {} + mode = config.get("mode") + if mode is None: + # Set random mode for public services, since there can be many hypha servers + if workspace == "public": + mode = "random" + else: + mode = "exact" + if mode == "exact": + assert len(keys) == 1, f"Multiple services found for {service_id}" + key = keys[0] + elif mode == "random": + key = random.choice(keys) + elif mode == "first": + key = keys[0] + elif mode == "last": + key = keys[-1] + else: + raise ValueError( + f"Invalid mode: {mode}, the mode must be 'random', 'first', 'last' or 'exact'" + ) # if it's a public service or the user has read permission if not key.startswith(b"services:public:") and not user_info.check_permission( workspace, UserPermission.read @@ -706,7 +776,9 @@ async def unregister_service( if not user_info.check_permission(ws, UserPermission.read_write): raise PermissionError(f"Permission denied for workspace {ws}") - service = await self.get_service_info(service_id, context=context) + service = await self.get_service_info( + service_id, {"mode": "exact"}, context=context + ) service.config.workspace = ws if "/" not in service.id: service.id = f"{ws}/{service.id}" @@ -849,7 +921,7 @@ async def load_workspace_info(self, workspace: str, load=True) -> WorkspaceInfo: return workspace_info elif load and not self._workspace_loader: raise KeyError( - "Workspace not found and the workspace loader is not configured." + f"Workspace ({workspace}) not found and the workspace loader is not configured (requires s3 enabled)." ) else: raise KeyError(f"Workspace not found: {workspace}") @@ -933,8 +1005,7 @@ async def _launch_application_for_service( ) return await self.get_service( f"{client_info['id']}:{service_id}", # should not contain @app_id - timeout=timeout, - case_conversion=case_conversion, + dict(timeout=timeout, case_conversion=case_conversion), context=context, ) @@ -968,13 +1039,8 @@ async def get_service( ..., description="Service ID. This should be a service id in the format: 'workspace/service_id', 'workspace/client_id:service_id' or 'workspace/client_id:service_id@app_id'", ), - timeout: float = Field( - 10.0, - description="The timeout duration in seconds for fetching the service. This determines how long the function will wait for a service to respond before considering it a timeout.", - ), - case_conversion: str = Field( - None, - description="The case conversion for service keys, can be 'camel', 'snake' or None, default is None.", + config: Optional[GetServiceConfig] = Field( + None, description="Get service config" ), context=None, ): @@ -985,13 +1051,21 @@ async def get_service( # no need to validate the context # self.validate_context(context, permission=UserPermission.read) try: + config = config or GetServiceConfig() # Permission check will be handled by the get_service_api function - svc_info = await self.get_service_info(service_id, context=context) + svc_info = await self.get_service_info( + service_id, {"mode": config.mode}, context=context + ) service_api = await self._rpc.get_remote_service( - svc_info.id, timeout=timeout, case_conversion=case_conversion + svc_info.id, + {"timeout": config.timeout, "case_conversion": config.case_conversion}, ) assert service_api, f"Failed to get service: {service_id}" - service_api["config"]["workspace"] = svc_info.id.split("/")[0] + workspace = service_id.split("/")[0] + service_api["config"]["workspace"] = workspace + service_api["config"][ + "url" + ] = f"{self._server_info['public_base_url']}/{workspace}/services/{service_id}" return service_api except KeyError as exp: if "@" in service_id: @@ -1000,16 +1074,14 @@ async def get_service( workspace = ( service_id.split("/")[0] if "/" in service_id else context["ws"] ) - service_api = await self._launch_application_for_service( + return await self._launch_application_for_service( app_id, service_id, workspace=workspace, - timeout=timeout, - case_conversion=case_conversion, + timeout=config.timeout, + case_conversion=config.case_conversion, context=context, ) - # No need to patch the service config because the service is already patched - return service_api else: raise exp @@ -1046,18 +1118,6 @@ async def list_workspaces( ] return workspaces - async def get_workspace(self, workspace: str = None): - """Get the service api of the workspace manager.""" - assert workspace is not None - rpc = self._rpc - wm = await rpc.get_remote_service( - f"{workspace}/{self._client_id}:default", timeout=10 - ) - wm.rpc = rpc - wm.disconnect = rpc.disconnect - wm.register_codec = rpc.register_codec - return wm - async def _update_workspace( self, workspace: WorkspaceInfo, user_info: UserInfo, overwrite=False ): @@ -1112,7 +1172,7 @@ async def delete_client( await self._redis.delete(key) if await self._redis.hexists("workspaces", cws): - if user_info.is_anonymous and cws == user_info.id: + if user_info.is_anonymous and cws == user_info.get_workspace(): logger.info(f"Unloading workspace {cws} for anonymous user.") # unload temporary workspace if the user exits await self.unload(context=context) @@ -1183,6 +1243,7 @@ def create_service(self, service_id, service_name=None): "get_service_info": self.get_service_info, "get_service": self.get_service, "generate_token": self.generate_token, + "revoke_token": self.revoke_token, "create_workspace": self.create_workspace, "get_workspace_info": self.get_workspace_info, "install_application": self.install_application, diff --git a/hypha/http.py b/hypha/http.py index b69d7e63..45901834 100644 --- a/hypha/http.py +++ b/hypha/http.py @@ -2,19 +2,35 @@ import inspect import json import traceback -from typing import Any, Optional +from typing import Any +import asyncio +from pathlib import Path +import requests import httpx import msgpack from fastapi import APIRouter, Depends, Request -from fastapi.responses import JSONResponse, Response, RedirectResponse +from jinja2 import Environment, PackageLoader, select_autoescape +from fastapi.responses import ( + JSONResponse, + Response, + RedirectResponse, + StreamingResponse, + FileResponse, +) +import jose +import os from hypha_rpc import RPC +from hypha import __version__ from hypha.core import UserPermission -from hypha.core.auth import login_optional, AUTH0_DOMAIN +from hypha.core.auth import AUTH0_DOMAIN from hypha.core.store import RedisStore -from hypha.utils import GzipRoute -from hypha import __version__ as VERSION +from hypha.utils import GzipRoute, safe_join, is_safe_path +from hypha.s3 import FSFileResponse +from starlette.datastructures import Headers +from aiobotocore.session import get_session + class MsgpackResponse(Response): """Response class for msgpack encoding.""" @@ -143,14 +159,46 @@ def detected_response_type(request: Request): class HTTPProxy: """A proxy for accessing services from HTTP.""" - def __init__(self, store: RedisStore) -> None: + def __init__( + self, + store: RedisStore, + endpoint_url=None, + access_key_id=None, + secret_access_key=None, + region_name=None, + workspace_bucket="hypha-workspaces", + workspace_etc_dir="etc", + base_path="/", + ) -> None: """Initialize the http proxy.""" # pylint: disable=broad-except router = APIRouter() router.route_class = GzipRoute self.store = store - - @router.get("/authorize") + self.endpoint_url = endpoint_url + self.access_key_id = access_key_id + self.secret_access_key = secret_access_key + self.region_name = region_name + self.s3_enabled = endpoint_url is not None + self.workspace_bucket = workspace_bucket + self.workspace_etc_dir = workspace_etc_dir + self.ws_apps_dir = Path(__file__).parent / "templates/ws" + self.ws_app_files = os.listdir(self.ws_apps_dir) + self.templates_dir = Path(__file__).parent / "templates" + self.templates_files = os.listdir(self.templates_dir) + self.jinja_env = Environment( + loader=PackageLoader("hypha"), autoescape=select_autoescape() + ) + self.server_info = self.store.get_server_info() + + def norm_url(url): + return base_path.rstrip("/") + url + + # download the hypha-rpc-websocket.js file from the CDN + self.rpc_lib_esm_content = None + self.rpc_lib_umd_content = None + + @router.get(norm_url("/authorize")) async def auth_proxy(request: Request): # Construct the full URL for the Auth0 authorize endpoint with the query parameters auth0_authorize_url = ( @@ -160,7 +208,39 @@ async def auth_proxy(request: Request): # Redirect the client to the constructed URL return RedirectResponse(url=auth0_authorize_url) - @router.post("/oauth/token") + @router.get(norm_url("/hypha-rpc-websocket.mjs")) + async def hypha_rpc_websocket_mjs(request: Request): + if not self.rpc_lib_esm_content: + _rpc_lib_script = f"https://cdn.jsdelivr.net/npm/hypha-rpc@{__version__}/dist/hypha-rpc-websocket.mjs" + response = requests.get(_rpc_lib_script) + response.raise_for_status() + self.rpc_lib_esm_content = response.content + return Response( + content=self.rpc_lib_esm_content, media_type="application/javascript" + ) + + @router.get(norm_url("/hypha-rpc-websocket.js")) + async def hypha_rpc_websocket_js(request: Request): + if not self.rpc_lib_umd_content: + _rpc_lib_script = f"https://cdn.jsdelivr.net/npm/hypha-rpc@{__version__}/dist/hypha-rpc-websocket.js" + response = requests.get(_rpc_lib_script) + response.raise_for_status() + self.rpc_lib_umd_content = response.content + return Response( + content=self.rpc_lib_umd_content, media_type="application/javascript" + ) + + @router.get(norm_url("/config.json")) + async def get_config( + request: Request, + user_info: store.login_optional = Depends(store.login_optional), + ): + return JSONResponse( + status_code=200, + content={"user": user_info.model_dump(), **self.server_info}, + ) + + @router.post(norm_url("/oauth/token")) async def token_proxy(request: Request): form_data = await request.form() async with httpx.AsyncClient() as client: @@ -172,7 +252,7 @@ async def token_proxy(request: Request): return JSONResponse(status_code=200, content=auth0_response.json()) - @router.get("/authorize") + @router.get(norm_url("/authorize")) async def auth_proxy(request: Request): # Construct the full URL for the Auth0 authorize endpoint with the query parameters auth0_authorize_url = ( @@ -182,7 +262,7 @@ async def auth_proxy(request: Request): # Redirect the client to the constructed URL return RedirectResponse(url=auth0_authorize_url) - @router.post("/oauth/token") + @router.post(norm_url("/oauth/token")) async def token_proxy(request: Request): form_data = await request.form() async with httpx.AsyncClient() as client: @@ -194,70 +274,40 @@ async def token_proxy(request: Request): return JSONResponse(status_code=200, content=auth0_response.json()) - @router.get("/workspaces") - async def list_all_workspaces( - user_info: login_optional = Depends(login_optional), - ): - """Route for listing all the workspaces.""" - try: - workspaces = await store.list_all_workspaces() - return JSONResponse( - status_code=200, - content=workspaces, - ) - except Exception as exp: - return JSONResponse( - status_code=500, - content={"success": False, "detail": str(exp)}, - ) - - @router.get("/workspaces/{workspace}") + @router.get(norm_url("/{workspace}/info")) async def get_workspace_info( workspace: str, - user_info: login_optional = Depends(login_optional), + user_info: store.login_optional = Depends(store.login_optional), ): - """Route for get detailed info of a workspace.""" + """Route for checking details of a workspace.""" try: if not user_info.check_permission(workspace, UserPermission.read): return JSONResponse( status_code=403, content={ "success": False, - "detail": f"Permission denied to workspace: {workspace}", - }, - ) - workspace_info = await store.get_workspace(workspace, load=True) - if workspace_info is None: - return JSONResponse( - status_code=404, - content={ - "success": False, - "detail": f"Workspace not found: {workspace}", + "detail": ( + f"{user_info.id} has no" + f" permission to access {workspace}" + ), }, ) + + info = await self.store.load_or_create_workspace(user_info, workspace) return JSONResponse( status_code=200, - content=workspace_info.model_dump(), + content=info.model_dump(), ) except Exception as exp: return JSONResponse( - status_code=400, + status_code=500, content={"success": False, "detail": str(exp)}, ) - @router.get("/services") - async def list_services( - workspace: str, - user_info: login_optional = Depends(login_optional), - ): - """List services under a workspace.""" - return await get_workspace_services(workspace, user_info) - - @router.get("/services//{workspace}") - @router.get("/{workspace}/services") + @router.get(norm_url("/{workspace}/services")) async def get_workspace_services( workspace: str, - user_info: login_optional = Depends(login_optional), + user_info: store.login_optional = Depends(store.login_optional), ): """Route for get services under a workspace.""" try: @@ -279,19 +329,23 @@ async def get_workspace_services( content={"success": False, "detail": str(exp)}, ) - @router.get("/services/{workspace}/{service_id}") - @router.get("/{workspace}/services/{service_id}") + @router.get(norm_url("/{workspace}/services/{service_id}")) async def get_service_info( workspace: str, service_id: str, - user_info: login_optional = Depends(login_optional), + mode: str = "exact", + user_info: store.login_optional = Depends(store.login_optional), ): """Route for checking details of a service.""" try: async with self.store.get_workspace_interface( workspace, user_info ) as api: - service_info = await api.get_service_info(service_id) + if service_id == "ws": + return serialize(api) + service_info = await api.get_service_info( + service_id, {"mode": mode} + ) return JSONResponse( status_code=200, content=serialize(service_info), @@ -313,16 +367,282 @@ async def _call_service_function(func, kwargs): results = _rpc.encode(results) return results - @router.get("/services/{workspace}/{service_id}/{function_key}") - @router.post("/services/{workspace}/{service_id}/{function_key}") - @router.get("/{workspace}/services/{service_id}/{function_key}") - @router.post("/{workspace}/services/{service_id}/{function_key}") + @router.get(norm_url("/{workspace}/apps")) + async def get_workspace_apps( + workspace: str, + user_info: store.login_optional = Depends(store.login_optional), + ): + """Route for get apps under a workspace.""" + try: + async with self.store.get_workspace_interface( + workspace, user_info + ) as manager: + try: + controller = await manager.get_service("public/server-apps") + except KeyError: + return JSONResponse( + status_code=404, + content={ + "success": False, + "detail": "Server Apps service is not enabled.", + }, + ) + apps = await controller.list_apps(workspace) + return JSONResponse( + status_code=200, + content=serialize(apps), + ) + except Exception as exp: + return JSONResponse( + status_code=404, + content={"success": False, "detail": str(exp)}, + ) + + @router.get(norm_url("/{workspace}/apps/{service_id}")) + async def get_app_info( + workspace: str, + service_id: str, + request: Request, + path: str = None, + user_info: store.login_optional = Depends(store.login_optional), + ): + """Route for checking details of an app.""" + if not path: + path = "/" + return await run_app( + workspace=workspace, + service_id=service_id, + request=request, + path=path, + user_info=user_info, + ) + + @router.get(norm_url("/{workspace}/apps/{service_id}/{path:path}")) + async def run_app( + workspace: str, + service_id: str, + request: Request, + path: str = None, + mode: str = "exact", + user_info: store.login_optional = Depends(store.login_optional), + ) -> Response: + if service_id == "ws": + if not path: + template = self.jinja_env.get_template("ws/index.html") + return FileResponse( + safe_join(str(self.ws_apps_dir), "ws/index.html") + ) + + # check if the path is inside the built-in apps dir + # get the jinja template from the built-in apps dir + dir_path = path.split("/")[0] + if dir_path in self.ws_app_files: + file_path = safe_join(str(self.ws_apps_dir), path) + if not is_safe_path(str(self.ws_apps_dir), file_path): + return JSONResponse( + status_code=403, + content={ + "success": False, + "detail": f"Unsafe path: {file_path}", + }, + ) + if not os.path.exists(file_path): + return JSONResponse( + status_code=404, + content={ + "success": False, + "detail": f"File not found: {path}", + }, + ) + return FileResponse(safe_join(str(self.ws_apps_dir), "ws", path)) + else: + key = safe_join(workspace, path) + assert self.s3_enabled, "S3 is not enabled." + s3_client = get_session().create_client( + "s3", + endpoint_url=self.endpoint_url, + aws_access_key_id=self.access_key_id, + aws_secret_access_key=self.secret_access_key, + region_name=self.region_name, + ) + return FSFileResponse(s3_client, self.workspace_bucket, key) + + # Serve dynamic apps + try: + scope = request.scope + scope = { + k: scope[k] + for k in scope + if isinstance( + scope[k], (str, int, float, bool, tuple, list, dict, bytes) + ) + } + if not path.startswith("/"): + path = "/" + path + scope["path"] = path + scope["raw_path"] = path.encode("latin-1") + send_queue = asyncio.Queue() + async with self.store.get_workspace_interface( + workspace, user_info + ) as api: + info = await api.get_service_info(service_id, {"mode": mode}) + if info.type == "ASGI": + service = await api.get_service(info.id) + # Call the ASGI app with manually provided receive and send + await service.serve( + { + "scope": scope, + "receive": request.body, + "send": send_queue.put, + } + ) + elif info.type == "functions": + func_name = path.split("/", 1)[-1] or "index" + func_name = func_name.rstrip("/") + service = await api.get_service(info.id) + if func_name in service and callable(service[func_name]): + scope["query_string"] = scope["query_string"].decode( + "utf-8" + ) + scope["raw_path"] = scope["raw_path"].decode("latin-1") + scope["headers"] = dict(Headers(scope=scope).items()) + event = await request.receive() + body = event["body"] + while event.get("more_body"): + body += await request.receive()["body"] + scope["body"] = body or None + func = service[func_name] + try: + result = await func(scope) + headers = Headers(headers=result.get("headers")) + body = result.get("body") + status = result.get("status", 200) + assert isinstance(status, int) + start = { + "type": "http.response.start", + "status": status, + "headers": headers.raw, + } + if not body: + start["more_body"] = False + await send_queue.put(start) + if body: + if not isinstance(body, bytes): + body = body.encode() + await send_queue.put( + { + "type": "http.response.body", + "body": body, + "more_body": False, + } + ) + except Exception: + await send_queue.put( + { + "type": "http.response.start", + "status": 500, + "headers": [ + [b"content-type", b"text/plain"], + ], + } + ) + + else: + return JSONResponse( + status_code=404, + content={ + "success": False, + "detail": f"Function not found: {func_name}", + }, + ) + else: + return JSONResponse( + status_code=404, + content={ + "success": False, + "detail": f"Service cannot be run as an app: {service_id}", + }, + ) + + async def send_response(send_queue): + # This function will be called to send data as it's received from the ASGI app + while True: + message = await send_queue.get() + if message["type"] == "http.response.body": + yield message["body"] + if not message.get("more_body", False): + break + + # Extract initial status and headers from the queue + initial_message = await send_queue.get() + assert initial_message["type"] == "http.response.start" + status = initial_message.get("status", 200) + raw_headers = initial_message.get("headers", []) + headers = { + k.decode("latin-1"): v.decode("latin-1") for k, v in raw_headers + } + + return StreamingResponse( + send_response(send_queue), status_code=status, headers=headers + ) + except KeyError: + return Response(status_code=404) + + @router.get(norm_url("/{workspace}/browser-apps/{app_id}/{path:path}")) + async def get_browser_app_file( + workspace: str, app_id: str, path: str, token: str = None + ) -> Response: + """Route for getting browser app files.""" + if token is None: + return JSONResponse( + status_code=403, + content={ + "success": False, + "detail": (f"Token not provided for {workspace}/{path}"), + }, + ) + try: + user_info = await store.parse_user_token(token) + except jose.exceptions.JWTError: + return JSONResponse( + status_code=403, + content={ + "success": False, + "detail": ( + f"Invalid token not provided for {workspace}/{path}" + ), + }, + ) + if not user_info.check_permission(workspace, UserPermission.read): + return JSONResponse( + status_code=403, + content={ + "success": False, + "detail": ( + f"{user_info['username']} has no" + f" permission to access {workspace}" + ), + }, + ) + key = safe_join(self.workspace_etc_dir, workspace, app_id, path) + assert self.s3_enabled, "S3 is not enabled." + s3_client = get_session().create_client( + "s3", + endpoint_url=self.endpoint_url, + aws_access_key_id=self.access_key_id, + aws_secret_access_key=self.secret_access_key, + region_name=self.region_name, + ) + return FSFileResponse(s3_client, self.workspace_bucket, key) + + @router.get(norm_url("/{workspace}/services/{service_id}/{function_key}")) + @router.post(norm_url("/{workspace}/services/{service_id}/{function_key}")) async def call_service_function( workspace: str, service_id: str, function_key: str, request: Request, - user_info: login_optional = Depends(login_optional), + user_info: store.login_optional = Depends(store.login_optional), ): """Run service function by keys.""" function_kwargs = await extracted_kwargs(request, use_function_kwargs=False) @@ -344,7 +664,7 @@ async def service_function( function_info: extracted_call_info = Depends(extracted_call_info), function_kwargs: extracted_kwargs = Depends(extracted_kwargs), response_type: detected_response_type = Depends(detected_response_type), - user_info: login_optional = Depends(login_optional), + user_info: store.login_optional = Depends(store.login_optional), ): """Run service function by keys. @@ -355,7 +675,10 @@ async def service_function( async with self.store.get_workspace_interface( workspace, user_info ) as api: - service = await api.get_service(service_id) + if service_id == "ws": + service = api + else: + service = await api.get_service(service_id) func = get_value(function_key, service) if not func: return JSONResponse( @@ -396,4 +719,87 @@ async def service_function( content={"success": False, "detail": traceback.format_exc()}, ) + @router.get(norm_url("/health/readiness")) + async def readiness(req: Request) -> JSONResponse: + """Used for readiness probe. + t determines whether the application inside the container is ready to accept traffic or requests. + """ + if store.is_ready(): + return JSONResponse({"status": "OK"}) + + return JSONResponse({"status": "DOWN"}, status_code=503) + + @router.get(norm_url("/health/liveness")) + async def liveness(req: Request) -> JSONResponse: + """Used for liveness probe. + If the liveness probe fails, it means the app is in a failed state and restarts it. + """ + return JSONResponse({"status": "OK"}) + + @router.get(norm_url("/{page:path}")) + async def get_pages( + page: str, + user_info: store.login_optional = Depends(store.login_optional), + ): + """Route for getting pages, if other matches not found, this will be the fallback.""" + if not page or page == "/": + page = "index.html" + dir_path = page.split("/")[0] + if dir_path not in self.templates_files and dir_path.startswith("ws-"): + workspace = dir_path + try: + await self.store.load_or_create_workspace(user_info, workspace) + except KeyError: + return JSONResponse( + status_code=404, + content={ + "success": False, + "detail": f"Workspace not found: {workspace}", + }, + ) + inner_path = "/".join(page.split("/")[1:]) + if not inner_path: + inner_path = "index.html" + if not os.path.exists(safe_join(str(self.ws_apps_dir), inner_path)): + return JSONResponse( + status_code=404, + content={ + "success": False, + "detail": f"File not found: {inner_path}", + }, + ) + return FileResponse(safe_join(str(self.ws_apps_dir), inner_path)) + file_path = safe_join(str(self.templates_dir), page) + if not is_safe_path(str(self.templates_dir), file_path): + return JSONResponse( + status_code=403, + content={"success": False, "detail": f"Unsafe path: {file_path}"}, + ) + if not os.path.exists(file_path): + return JSONResponse( + status_code=404, + content={ + "success": False, + "detail": f"File not found: {file_path}", + }, + ) + assert os.path.basename(file_path) not in [ + "apps", + "ws", + ], f"Invalid page name: {page}" + # compile the jinja template + file_path = safe_join(str(self.templates_dir), page) + assert is_safe_path( + str(self.templates_dir), file_path + ), f"Unsafe path: {page}" + if not os.path.exists(file_path): + return JSONResponse( + status_code=404, + content={ + "success": False, + "detail": f"File not found: {inner_path}", + }, + ) + return FileResponse(file_path) + store.register_router(router) diff --git a/hypha/runner/__init__.py b/hypha/runner/__init__.py index d08232b8..7648e02f 100644 --- a/hypha/runner/__init__.py +++ b/hypha/runner/__init__.py @@ -29,7 +29,7 @@ async def export_service(app_api, config, hypha_rpc): # Copy the app name as the default name app_api["id"] = "default" app_api["name"] = config.get("name", "default") - svc = await rpc.register_service(app_api, overwrite=True, notify=True) + svc = await rpc.register_service(app_api, {"overwrite": True, "notify": True}) svc = await rpc.get_remote_service(svc["id"]) if svc.setup: await svc.setup() @@ -101,9 +101,7 @@ def done_callback(fut): if quit_on_ready: hypha_rpc.ready.add_done_callback(lambda fut: loop.stop()) else: - raise RuntimeError( - f"Invalid script type ({found[0]}) in file {app_file}" - ) + raise RuntimeError(f"Invalid script type ({found[0]}) in file {app_file}") else: raise RuntimeError(f"Invalid script file type ({app_file})") diff --git a/hypha/s3.py b/hypha/s3.py index 48f928a7..100c70f8 100644 --- a/hypha/s3.py +++ b/hypha/s3.py @@ -18,7 +18,6 @@ from starlette.types import Receive, Scope, Send from hypha.core import UserInfo, WorkspaceInfo, UserPermission -from hypha.core.auth import login_optional from hypha.core.store import RedisStore from hypha.minio import MinioClient from hypha.utils import ( @@ -153,6 +152,7 @@ def __init__( endpoint_url=None, access_key_id=None, secret_access_key=None, + region_name=None, endpoint_url_public=None, s3_admin_type="generic", workspace_bucket="hypha-workspaces", @@ -164,6 +164,7 @@ def __init__( self.endpoint_url = endpoint_url self.access_key_id = access_key_id self.secret_access_key = secret_access_key + self.region_name = region_name self.s3_admin_type = s3_admin_type if self.s3_admin_type == "minio": self.minio_client = MinioClient( @@ -217,10 +218,10 @@ async def upload_file( workspace: str, path: str, request: Request, - user_info: login_optional = Depends(login_optional), + user_info: store.login_optional = Depends(store.login_optional), ): """Upload file.""" - ws = await store.get_workspace(workspace, load=True) + ws = await store.get_workspace_info(workspace, load=True) if not ws: return JSONResponse( status_code=404, @@ -264,10 +265,10 @@ async def get_or_delete_file( path: str, request: Request, max_length: int = 1000, - user_info: login_optional = Depends(login_optional), + user_info: store.login_optional = Depends(store.login_optional), ): """Get or delete file.""" - ws = await store.get_workspace(workspace, load=True) + ws = await store.get_workspace_info(workspace, load=True) if not ws: return JSONResponse( status_code=404, @@ -484,7 +485,7 @@ def create_client_sync(self): endpoint_url=self.endpoint_url, aws_access_key_id=self.access_key_id, aws_secret_access_key=self.secret_access_key, - region_name="EU", + region_name=self.region_name, ) def create_client_async(self, public=False): @@ -494,7 +495,7 @@ def create_client_async(self, public=False): endpoint_url=self.endpoint_url_public if public else self.endpoint_url, aws_access_key_id=self.access_key_id, aws_secret_access_key=self.secret_access_key, - region_name="EU", + region_name=self.region_name, ) async def list_users( @@ -624,7 +625,7 @@ async def generate_credential(self, context: dict = None): """Generate credential.""" assert self.minio_client, "Minio client is not available" workspace = context["ws"] - ws = await self.store.get_workspace(workspace, load=True) + ws = await self.store.get_workspace_info(workspace, load=True) assert ws, f"Workspace {workspace} not found." if ws.read_only: raise Exception("Permission denied: workspace is read-only") @@ -642,6 +643,7 @@ async def generate_credential(self, context: dict = None): "endpoint_url": self.endpoint_url_public, # Return the public endpoint "access_key_id": user_info.id, "secret_access_key": password, + "region_name": self.region_name, "bucket": self.workspace_bucket, "prefix": workspace + "/", # important to have the trailing slash } @@ -654,7 +656,15 @@ async def list_files( ) -> Dict[str, Any]: """List files in the folder.""" workspace = context["ws"] - path = safe_join(workspace, path) + if path.startswith("/"): + user_info = UserInfo.model_validate(context["user"]) + assert user_info.check_permission( + "*", UserPermission.admin + ), "Permission denied: only admin can access the root folder." + # remove the leading slash + path = path[1:] + else: + path = safe_join(workspace, path) async with self.create_client_async() as s3_client: # List files in the folder if not path.endswith("/"): @@ -669,8 +679,7 @@ async def list_files( async def generate_presigned_url( self, - bucket_name, - object_name, + path: str, client_method="get_object", expiration=3600, context: dict = None, @@ -678,21 +687,23 @@ async def generate_presigned_url( """Generate presigned url.""" try: workspace = context["ws"] - ws = await self.store.get_workspace(workspace, load=True) + ws = await self.store.get_workspace_info(workspace, load=True) assert ws, f"Workspace {workspace} not found." if ws.read_only and client_method != "get_object": raise Exception("Permission denied: workspace is read-only") - if bucket_name != self.workspace_bucket or not object_name.startswith( - workspace + "/" - ): - raise Exception( - f"Permission denied: bucket name must be {self.workspace_bucket} " - "and the object name should be prefixed with workspace name + '/'." - ) + if path.startswith("/"): + user_info = UserInfo.model_validate(context["user"]) + assert user_info.check_permission( + "*", UserPermission.admin + ), "Permission denied: only admin can access the root folder." + # remove the leading slash + path = path[1:] + else: + path = safe_join(workspace, path) async with self.create_client_async(public=True) as s3_client: url = await s3_client.generate_presigned_url( client_method, - Params={"Bucket": bucket_name, "Key": object_name}, + Params={"Bucket": self.workspace_bucket, "Key": path}, ExpiresIn=expiration, ) return url diff --git a/hypha/server.py b/hypha/server.py index b1f88c0f..08643663 100644 --- a/hypha/server.py +++ b/hypha/server.py @@ -9,11 +9,9 @@ from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles -from starlette.requests import Request -from starlette.responses import JSONResponse from hypha import __version__ as VERSION -from hypha.asgi import ASGIGateway +from hypha.core.auth import create_login_service from hypha.core.store import RedisStore from hypha.core.queue import create_queue_service from hypha.http import HTTPProxy @@ -74,37 +72,15 @@ def start_builtin_services( """Set up the builtin services.""" # pylint: disable=too-many-arguments,too-many-locals - HTTPProxy(store) if args.triton_servers: TritonProxy( store, triton_servers=args.triton_servers.split(","), allow_origins=args.allow_origins, ) - ASGIGateway( - store, - allow_origins=args.allow_origins, - allow_methods=ALLOW_METHODS, - allow_headers=ALLOW_HEADERS, - expose_headers=EXPOSE_HEADERS, - ) - - @app.get(args.base_path) - async def home(): - return { - "name": "Hypha", - "version": VERSION, - } - - @app.get(norm_url(args.base_path, "/api/stats")) - async def stats(): - all_workspaces = await store.get_all_workspace() - return { - "workspace_count": len(all_workspaces), - "workspaces": [w.model_dump() for w in all_workspaces], - } store.register_public_service(create_queue_service(store)) + store.register_public_service(create_login_service(store)) if args.enable_s3: # pylint: disable=import-outside-toplevel @@ -117,6 +93,7 @@ async def stats(): access_key_id=args.access_key_id, secret_access_key=args.secret_access_key, endpoint_url_public=args.endpoint_url_public, + region_name=args.region_name, s3_admin_type=args.s3_admin_type, workspace_bucket=args.workspace_bucket, executable_path=args.executable_path, @@ -141,12 +118,13 @@ async def stats(): workspace_bucket=args.workspace_bucket, ) - @app.get(norm_url(args.base_path, "/health/liveness")) - async def liveness(req: Request) -> JSONResponse: - if store.is_ready(): - return JSONResponse({"status": "OK"}) - - return JSONResponse({"status": "DOWN"}, status_code=503) + HTTPProxy( + store, + endpoint_url=args.endpoint_url, + access_key_id=args.access_key_id, + secret_access_key=args.secret_access_key, + base_path=args.base_path, + ) def mount_static_files(app, new_route, directory, name="static"): @@ -193,7 +171,6 @@ def create_application(args): async def lifespan(app: FastAPI): # Here we can register all the startup functions args.startup_functions = args.startup_functions or [] - args.startup_functions.append("hypha.core.auth:register_login_service") await store.init(args.reset_redis, startup_functions=args.startup_functions) yield await websocket_server.stop() @@ -242,7 +219,6 @@ async def lifespan(app: FastAPI): ) websocket_server = WebsocketServer(store, path=norm_url(args.base_path, "/ws")) - start_builtin_services(application, store, args) static_folder = str(Path(__file__).parent / "static_files") mount_static_files(application, "/static", directory=static_folder, name="static") @@ -254,6 +230,8 @@ async def lifespan(app: FastAPI): application, mountpath, localdir, name=f"static-mount-{index}" ) + start_builtin_services(application, store, args) + if args.host in ("127.0.0.1", "localhost"): logger.info( "***Note: If you want to enable access from another host, " @@ -369,6 +347,12 @@ def get_argparser(add_help=True): help="set public endpoint URL for S3" "(if different from the local endpoint one)", ) + parser.add_argument( + "--region-name", + type=str, + default="EU", + help="set region name for S3", + ) parser.add_argument( "--access-key-id", type=str, diff --git a/hypha/templates/apps/imjoy-plugin-parser.html b/hypha/templates/apps/imjoy-plugin-parser.html new file mode 100644 index 00000000..30ea3839 --- /dev/null +++ b/hypha/templates/apps/imjoy-plugin-parser.html @@ -0,0 +1,127 @@ + + + + + + Hypha App Template + + + + + + + + + \ No newline at end of file diff --git a/hypha/templates/login_template.html b/hypha/templates/apps/login_template.html similarity index 84% rename from hypha/templates/login_template.html rename to hypha/templates/apps/login_template.html index cfe0b616..e05b09c8 100644 --- a/hypha/templates/login_template.html +++ b/hypha/templates/apps/login_template.html @@ -47,6 +47,10 @@
+
+ + +
@@ -100,10 +104,20 @@ window.generateToken = async () => { const expires_in = document.getElementById("expiry-time").value || 3600; // Default to 1 hour if not specified + const workspace = document.getElementById("workspace").value || ""; const accessToken = await auth0.getTokenSilently({ audience: AUTH0_AUDIENCE }); - const generate_token_url = "{{ login_service_url }}/generate"; + const generate_token_url = "{{ generate_token_url }}"; if (generate_token_url.startsWith('http')) { - const response = await fetch(`${generate_token_url}?token=${accessToken}&expires_in=${expires_in}`); + // make a post request with {config: {expires_in, workspace}} + // add auth barier token to the request + const response = await fetch(generate_token_url, { + method: "POST", + headers: { + "Content-Type": "application/json", + "Authorization": `Bearer ${accessToken}`, + }, + body: JSON.stringify({ config: { expires_in, workspace } }), + }); if (response.ok) { const token = await response.json(); document.getElementById("access-token").innerHTML = token; @@ -123,6 +137,16 @@ } }; + function parseJwt(token) { + const base64Url = token.split('.')[1]; + const base64 = base64Url.replace(/-/g, '+').replace(/_/g, '/'); + const jsonPayload = decodeURIComponent(atob(base64).split('').map(function(c) { + return '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2); + }).join('')); + + return JSON.parse(jsonPayload); + } + window.showUserProfile = async () => { let user = await auth0.getUser(); if (user) { @@ -135,7 +159,10 @@ throw new Error("User not found"); } const accessToken = await auth0.getTokenSilently({ audience: AUTH0_AUDIENCE }); - + const decodedToken = parseJwt(accessToken); + const exp = decodedToken.exp; + const maxAge = exp - Math.floor(Date.now() / 1000); + document.cookie = `access_token=${accessToken}; path=/; max-age=${maxAge}; samesite=lax`; document.getElementById("login-btn").style.display = "none"; document.getElementById("name").innerHTML = user.name; document.getElementById("email").innerHTML = user.email; @@ -147,6 +174,7 @@ document.getElementById("user-id").innerHTML = user.sub; document.getElementById("access-token").innerHTML = accessToken; document.getElementById("info-panel").style.display = "block"; + document.getElementById("workspace").value = user.sub; const key = window.location.search.split('key=')[1]; const report_url = "{{ login_service_url }}/report"; diff --git a/hypha/templates/web-python-app.html b/hypha/templates/apps/web-python-app.html similarity index 97% rename from hypha/templates/web-python-app.html rename to hypha/templates/apps/web-python-app.html index fc4bd30f..f87755a7 100644 --- a/hypha/templates/web-python-app.html +++ b/hypha/templates/apps/web-python-app.html @@ -35,7 +35,7 @@ # Copy the plugin name as the default name plugin_api["id"] = "default" plugin_api["name"] = config.get("name", "default") - await rpc.register_service(plugin_api, overwrite=True) + await rpc.register_service(plugin_api, {"overwrite": True}) console.log("Application initialized successfully.") except Exception as exp: console.error("Failed to export service: ", traceback.format_exc()) @@ -90,7 +90,7 @@ async def run(): try: - await micropip.install(["numpy", "pydantic", "hypha-rpc==0.20.14", {% for req in requirements %}"{{req}}", {% endfor %}]) + await micropip.install(["numpy", "pydantic", "hypha-rpc=={{ hypha_version }}", {% for req in requirements %}"{{req}}", {% endfor %}]) js.__resolve() except Exception as e: console.error("Failed to install packages: ", traceback.format_exc()) diff --git a/hypha/templates/web-worker-app.html b/hypha/templates/apps/web-worker-app.html similarity index 79% rename from hypha/templates/web-worker-app.html rename to hypha/templates/apps/web-worker-app.html index bb3585ed..462be777 100644 --- a/hypha/templates/web-worker-app.html +++ b/hypha/templates/apps/web-worker-app.html @@ -10,16 +10,14 @@ - \ No newline at end of file diff --git a/hypha/templates/index.html b/hypha/templates/index.html new file mode 100644 index 00000000..93e1647e --- /dev/null +++ b/hypha/templates/index.html @@ -0,0 +1,291 @@ + + + + + + + Hypha + + + + + + + + + + + + + + + + + + + + + + +
+ + + + + + diff --git a/hypha/templates/ws/index.html b/hypha/templates/ws/index.html new file mode 100644 index 00000000..f220bb80 --- /dev/null +++ b/hypha/templates/ws/index.html @@ -0,0 +1,428 @@ + + + + + + + Hypha Workspace Management + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + + + diff --git a/hypha/triton.py b/hypha/triton.py index 6666d160..15ff1c11 100644 --- a/hypha/triton.py +++ b/hypha/triton.py @@ -6,7 +6,6 @@ from fastapi import APIRouter, Depends, Request, Response from pyotritonclient import execute, get_config -from hypha.core.auth import login_optional from hypha.core.store import RedisStore @@ -29,7 +28,7 @@ async def triton_proxy( path: str, request: Request, response: Response, - user_info: login_optional = Depends(login_optional), + user_info: store.login_optional = Depends(store.login_optional), ): """Route for listing all the models.""" headers = dict(request.headers.items()) diff --git a/hypha/utils/__init__.py b/hypha/utils/__init__.py index 55e9629f..247a92ef 100644 --- a/hypha/utils/__init__.py +++ b/hypha/utils/__init__.py @@ -35,7 +35,7 @@ def random_id(readable=True): + f"{int((time.time() % 1) * 100000000):08d}" ) else: - return shortuuid.uuid() + return shortuuid.ShortUUID().random(length=22) PLUGIN_CONFIG_FIELDS = [ @@ -148,6 +148,16 @@ def generate_password(length=20): return "".join(secrets.choice(alphabet) for i in range(length)) +def is_safe_path(basedir: str, path: str, follow_symlinks: bool = True) -> bool: + """Check if the file path is safe.""" + # resolves symbolic links + if follow_symlinks: + matchpath = os.path.realpath(path) + else: + matchpath = os.path.abspath(path) + return basedir == os.path.commonpath((basedir, matchpath)) + + def safe_join(directory: str, *pathnames: str) -> Optional[str]: """Safely join zero or more untrusted path components to a base directory. @@ -581,56 +591,3 @@ async def _example_hypha_startup(server): "test": lambda x: x + 22, } ) - - -class AsyncTTLCache: - """Asynchronous Time-to-Live (TTL) Cache.""" - - def __init__(self, ttl=10): - """Initialize the AsyncTTLCache.""" - self.ttl = ttl - self.cache = {} - - async def add(self, key, value): - """Add a key-value pair to the cache.""" - self.cache[key] = (value, time.monotonic()) - await self.cleanup() - - async def update(self, key, value): - """Update the value associated with a key in the cache.""" - if key not in self.cache: - raise Exception("Invalid key") - - self.cache[key] = (value, time.monotonic()) - await self.cleanup() - - async def get(self, key): - """Retrieve the value associated with a key from the cache.""" - if key not in self.cache: - return None - - value, timestamp = self.cache[key] - if time.monotonic() - timestamp > self.ttl: - del self.cache[key] - return None - - return value - - async def cleanup(self): - """Clean up expired items from the cache.""" - keys_to_delete = [ - key - for key, (_, timestamp) in self.cache.items() - if time.monotonic() - timestamp > self.ttl - ] - for key in keys_to_delete: - del self.cache[key] - - def __contains__(self, key): - """Check if a key exists in the cache.""" - return key in self.cache - - def __delitem__(self, key): - """Delete a key-value pair from the cache.""" - if key in self.cache: - del self.cache[key] diff --git a/hypha/websocket.py b/hypha/websocket.py index bc8ef192..066e7bb2 100644 --- a/hypha/websocket.py +++ b/hypha/websocket.py @@ -11,9 +11,7 @@ from hypha.core import UserInfo, UserPermission from hypha.core.store import RedisRPCConnection, RedisStore from hypha.core.auth import ( - parse_reconnection_token, generate_reconnection_token, - parse_token, generate_anonymous_user, create_scope, update_user_scope, @@ -79,11 +77,11 @@ async def websocket_endpoint( else: user_info = generate_anonymous_user() user_info.scope = create_scope( - workspaces={user_info.id: UserPermission.admin}, + workspaces={user_info.get_workspace(): UserPermission.admin}, client_id=client_id, ) - workspace_info, user_info = await self.setup_workspace_and_permissions( - user_info, workspace, client_id + workspace_info = await self.store.load_or_create_workspace( + user_info, workspace ) user_info.scope = update_user_scope( user_info, workspace_info, client_id @@ -184,13 +182,24 @@ async def check_client(self, client_id, workspace, user_info): # remove dead client await self.store.remove_client(client_id, workspace, user_info) - async def authenticate_user(self, token, reconnection_token, client_id, workspace): + async def authenticate_user( + self, token: str, reconnection_token: str, client_id: str, workspace: str + ): """Authenticate user and handle reconnection or token authentication.""" - # Ensure actual implementation calls for parse_reconnection_token and parse_token user_info = None try: if reconnection_token: - user_info, ws, cid = parse_reconnection_token(reconnection_token) + user_info = await self.store.parse_user_token(reconnection_token) + scope = user_info.scope + assert ( + len(scope.workspaces) == 1 + ), "Invalid scope, it must have only one workspace" + assert scope.client_id, "Invalid scope, client_id is required" + assert ( + len(scope.workspaces) == 1 + ), "Invalid scope, it must have only one workspace" + ws = list(scope.workspaces.keys())[0] + cid = scope.client_id if workspace and workspace != ws: logger.error("Workspace mismatch, disconnecting") raise RuntimeError("Workspace mismatch, disconnecting") @@ -199,7 +208,7 @@ async def authenticate_user(self, token, reconnection_token, client_id, workspac raise RuntimeError("Client id mismatch, disconnecting") logger.info(f"Client reconnected: {ws}/{cid} using reconnection token") elif token: - user_info = parse_token(token) + user_info = await self.store.parse_user_token(token) # user token doesn't have client id, so we add that user_info.scope.client_id = client_id else: @@ -212,38 +221,6 @@ async def authenticate_user(self, token, reconnection_token, client_id, workspac logger.error(f"Authentication error: {str(e)}") raise RuntimeError(f"Authentication error: {str(e)}") - async def setup_workspace_and_permissions( - self, user_info: UserInfo, workspace, client_id - ): - """Setup workspace and check permissions.""" - if workspace is None: - workspace = user_info.id - - assert workspace != "*", "Dynamic workspace is not allowed for this endpoint" - # Anonymous and Temporary users are not allowed to create persistant workspaces - persistent = ( - not user_info.is_anonymous and "temporary-test-user" not in user_info.roles - ) - - # Ensure calls to store for workspace existence and permissions check - workspace_info = await self.store.get_workspace(workspace, load=True) - if not workspace_info: - assert ( - workspace == user_info.id - ), "User can only connect to a pre-existing workspace or their own workspace" - # Simplified logic for workspace creation, ensure this matches the actual store method signatures - workspace_info = await self.store.register_workspace( - { - "name": workspace, - "description": f"Default user workspace for {user_info.id}", - "persistent": persistent, - "owners": [user_info.id], - "read_only": user_info.is_anonymous, - } - ) - logger.info(f"Created workspace: {workspace}") - return workspace_info, user_info - async def establish_websocket_communication( self, websocket, workspace, client_id, user_info ): @@ -267,6 +244,9 @@ async def force_disconnect(_): async def send_bytes(data): try: await websocket.send_bytes(data) + except RuntimeError: + logger.warning("Failed to send message, closing") + await conn.disconnect("disconnected") except ConnectionClosedOK: logger.warning("Failed to send message, closing redis connection") await conn.disconnect("disconnected") diff --git a/requirements.txt b/requirements.txt index b86ee989..048e93a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ aioboto3==11.2.0 aiofiles==23.2.1 base58==2.1.1 fastapi==0.106.0 -hypha-rpc==0.20.14 +hypha-rpc==0.20.19 jinja2==3.1.2 lxml==4.9.3 msgpack==1.0.5 diff --git a/setup.py b/setup.py index 40f1dcd4..53066063 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ REQUIREMENTS = [ "aiofiles", "fastapi>=0.70.0,<=0.106.0", - "hypha-rpc>=0.20.14", + "hypha-rpc>=0.20.19", "msgpack>=1.0.2", "numpy", "pydantic[email]>=2.6.1", diff --git a/tests/example_plugin.py b/tests/example_plugin.py index ba32f345..7516dba9 100644 --- a/tests/example_plugin.py +++ b/tests/example_plugin.py @@ -19,7 +19,7 @@ async def setup(self): "type": "echo", "echo": lambda x: print("echo: " + str(x)), }, - overwrite=True, + {"overwrite": True}, ) service = await api.get_service(service_info.id) await service.echo("a message") diff --git a/tests/testFunctionsPlugin.imjoy.html b/tests/testFunctionsPlugin.imjoy.html index f5d21b20..a9aa8a85 100644 --- a/tests/testFunctionsPlugin.imjoy.html +++ b/tests/testFunctionsPlugin.imjoy.html @@ -29,15 +29,16 @@ "type": "functions", "config": { "visibility": "public", + "require_context": true }, - "hello-world": async function(event) { + "hello-world": async function(event, context) { return { status: 200, headers: {'Content-Type': 'application/json'}, - body: JSON.stringify({message: "Hello World", context: event.context}) + body: JSON.stringify({message: "Hello World", context}) }; }, - "index": async function(event) { + "index": async function(event, context) { return { status: 200, body: `Home page` diff --git a/tests/testWebWorkerPlugin.imjoy.html b/tests/testWebWorkerPlugin.imjoy.html index 39da78cc..b1fa8caa 100644 --- a/tests/testWebWorkerPlugin.imjoy.html +++ b/tests/testWebWorkerPlugin.imjoy.html @@ -35,7 +35,7 @@ echo(data){ return data } - }) + }, true) } async add2(a){ diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 84247865..05ace41d 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -5,7 +5,7 @@ import requests from hypha_rpc.websocket_client import connect_to_server -from . import WS_SERVER_URL, SERVER_URL +from . import WS_SERVER_URL, SERVER_URL, find_item # All test coroutines will be treated as marked. pytestmark = pytest.mark.asyncio @@ -67,6 +67,16 @@ async def test_functions(fastapi_server, test_user_token): service = await api.get_service(f"{config.workspace}/hello-functions") assert "hello-world" in service + response = requests.get( + f"{SERVER_URL}/{workspace}/apps", + headers={"Authorization": f"Bearer {token}"}, + ) + assert response.ok + cards = response.json() + card = find_item(cards, "name", "FunctionsPlugin") + svc = find_item(card["services"], "name", "hello-functions") + assert svc + response = requests.get( f"{SERVER_URL}/{workspace}/apps/hello-functions/hello-world", headers={"Authorization": f"Bearer {token}"}, @@ -76,7 +86,7 @@ async def test_functions(fastapi_server, test_user_token): assert ret["message"] == "Hello World" assert "user" in ret["context"] user_info = ret["context"]["user"] - assert user_info["scope"]["workspaces"]["user-1"] == "rw" + assert user_info["scope"]["workspaces"]["ws-user-user-1"] == "rw" response = requests.get( f"{SERVER_URL}/{workspace}/apps/hello-functions/hello-world/" diff --git a/tests/test_http.py b/tests/test_http.py index 39cfae0c..3e796a74 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -82,7 +82,7 @@ async def test_http_services(minio_server, fastapi_server, test_user_token): assert data.status_code == 200 assert data.json() == "123" - data = await client.get(f"{SERVER_URL}/services?workspace={workspace}") + data = await client.get(f"{SERVER_URL}/{workspace}/services") # [{'config': {'visibility': 'public', 'require_context': False, 'workspace': 'VRRVEdTF9of2y4cLmepzBw', 'flags': []}, 'id': '5XCPAyZrW72oBzywEk2oxP:test_service', 'name': 'test_service', 'type': 'test_service', 'description': '', 'docs': {}}] assert data.status_code == 200 assert find_item(data.json(), "name", "test_service") @@ -91,7 +91,9 @@ async def test_http_services(minio_server, fastapi_server, test_user_token): # pylint: disable=too-many-statements -async def test_http_proxy(minio_server, fastapi_server, test_user_token): +async def test_http_proxy( + minio_server, fastapi_server, test_user_token, root_user_token +): """Test http proxy.""" # WS_SERVER_URL = "http://127.0.0.1:9527" api = await connect_to_server( @@ -110,7 +112,7 @@ async def test_http_proxy(minio_server, fastapi_server, test_user_token): config = await controller.launch( source=TEST_APP_CODE, config={"type": "window"}, - wait_for_service=None, + wait_for_service=True, ) app = await api.get_app(config.id) assert "setup" in app and "register_services" in app @@ -124,17 +126,23 @@ async def test_http_proxy(minio_server, fastapi_server, test_user_token): service = await api.get_service(svc2["id"]) assert await service.echo("22") == "22" - response = requests.get(f"{SERVER_URL}/workspaces") - assert response.ok, response.json()["detail"] - response = response.json() - assert workspace in response + async with connect_to_server( + { + "name": "root client", + "server_url": WS_SERVER_URL, + "method_timeout": 30, + "token": root_user_token, + } + ) as root_api: + workspaces = await root_api.list_workspaces() + assert workspace in [w.name for w in workspaces] response = requests.get( - f"{SERVER_URL}/workspaces/{workspace}", headers={"Authorization": f"Bearer {token}"} + f"{SERVER_URL}/{workspace}", + headers={"Authorization": f"Bearer {token}"}, ) assert response.ok, response.json()["detail"] - response = response.json() - assert response["name"] == workspace + assert "hypha-rpc-websocket.js" in response.text # Without the token, we can only access to the public service response = requests.get(f"{SERVER_URL}/{service_ws}/services") @@ -159,12 +167,9 @@ async def test_http_proxy(minio_server, fastapi_server, test_user_token): assert response.ok, response.json()["detail"] assert find_item(response.json(), "name", "test_service") - response = requests.get(f"{SERVER_URL}/{service_ws}/services/test_service") - assert response.ok, response.json()["detail"] - service_info = response.json() - assert service_info["name"] == "test_service" - - response = requests.get(f"{SERVER_URL}/services/{svc1.id}") + response = requests.get( + f"{SERVER_URL}/{service_ws}/services/{svc1.id.split('/')[-1]}" + ) assert response.ok, response.json()["detail"] service_info = response.json() assert service_info["name"] == "test_service" @@ -209,15 +214,6 @@ async def test_http_proxy(minio_server, fastapi_server, test_user_token): result = msgpack.loads(response.content) assert result["data"] == 123 - response = requests.post( - f"{SERVER_URL}/services/{svc1.id}/echo", - data=msgpack.dumps({"data": 123}), - headers={"Content-type": "application/msgpack"}, - ) - assert response.ok - result = msgpack.loads(response.content) - assert result["data"] == 123 - response = requests.post( f"{SERVER_URL}/{service_ws}/services/test_service/echo", data=json.dumps({"data": 123}), diff --git a/tests/test_redis.py b/tests/test_redis.py index 9770746c..4983dcd3 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -35,7 +35,8 @@ async def test_redis_store(redis_store): ), overwrite=True, ) - assert "test" in await redis_store.list_all_workspaces() + wss = await redis_store.list_all_workspaces() + assert find_item(wss, "name", "test") api = await redis_store.connect_to_workspace("test", client_id="test-app-99") clients = set(await api.list_clients()) @@ -155,7 +156,7 @@ def echo(data): "echo": echo, "square": lambda x: x**2, }, - overwrite=True, + {"overwrite": True}, ) # It should fail due to permission error diff --git a/tests/test_s3.py b/tests/test_s3.py index 4103f603..4efa5347 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -29,9 +29,7 @@ async def test_s3(minio_server, fastapi_server, test_user_token): with pytest.raises( Exception, match=r".*Permission denied: workspace is read-only.*" ): - info = await s3controller.generate_presigned_url( - "", "", client_method="put_object" - ) + info = await s3controller.generate_presigned_url("", client_method="put_object") content = os.urandom(1024) response = requests.put( f"{SERVER_URL}/{workspace}/files/my-data-small.txt", @@ -139,7 +137,7 @@ def test_file_requests(): endpoint_url=info["endpoint_url"], aws_access_key_id=info["access_key_id"], aws_secret_access_key=info["secret_access_key"], - region_name="EU", + region_name=info["region_name"], ) as s3_client: bucket = await s3_client.Bucket(info["bucket"]) @@ -163,9 +161,7 @@ def test_file_requests(): f"{workspace}/hello.txt", ) - url = await s3controller.generate_presigned_url( - info["bucket"], info["prefix"] + "hello.txt" - ) + url = await s3controller.generate_presigned_url("hello.txt") assert url.startswith("http") and "X-Amz-Algorithm" in url response = requests.get(url) diff --git a/tests/test_server.py b/tests/test_server.py index 77707979..0680a3e7 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -227,7 +227,7 @@ async def test_services(fastapi_server): "type": "#test", "idx": 2, }, - overwrite=True, + {"overwrite": True}, ) # It should be overwritten because it's from the same provider assert ( @@ -274,6 +274,7 @@ async def test_services(fastapi_server): == 2 ) + async def test_server_scalability( fastapi_server_redis_1, fastapi_server_redis_2, test_user_token ): @@ -331,7 +332,7 @@ async def test_server_scalability( "id": "test-service", "add77": lambda x: x + 77, }, - overwrite=True, + {"overwrite": True}, ) svc = await api88.get_service("my-app-77:test-service") diff --git a/tests/test_server_apps.py b/tests/test_server_apps.py index e15fd859..fdbed393 100644 --- a/tests/test_server_apps.py +++ b/tests/test_server_apps.py @@ -37,7 +37,7 @@ """ -async def test_server_apps_unauthorized(fastapi_server): +async def test_server_apps_unauthorized(fastapi_server, root_user_token): """Test the server apps.""" api = await connect_to_server( {"name": "test client", "server_url": WS_SERVER_URL, "method_timeout": 30} @@ -51,22 +51,26 @@ async def test_server_apps_unauthorized(fastapi_server): ) # the workspace should exist in the stats - response = requests.get(f"{SERVER_URL}/api/stats") - assert response.status_code == 200 - stats = response.json() - workspace_info = find_item(stats["workspaces"], "name", api.config["workspace"]) - assert workspace_info + async with connect_to_server( + {"server_url": WS_SERVER_URL, "client_id": "admin", "token": root_user_token} + ) as root: + admin = await root.get_service("admin-utils") + workspaces = await admin.list_workspaces() + workspace_info = find_item(workspaces, "name", api.config["workspace"]) + assert workspace_info # Now disconnect it await api.disconnect() await asyncio.sleep(0.1) # now it should disappear from the stats - response = requests.get(f"{SERVER_URL}/api/stats") - assert response.status_code == 200 - stats = response.json() - workspace_info = find_item(stats["workspaces"], "name", api.config["workspace"]) - assert workspace_info is None + async with connect_to_server( + {"server_url": WS_SERVER_URL, "client_id": "admin", "token": root_user_token} + ) as root: + admin = await root.get_service("admin-utils") + workspaces = await admin.list_workspaces() + workspace_info = find_item(workspaces, "name", api.config["workspace"]) + assert workspace_info is None async def test_server_apps(fastapi_server, test_user_token): @@ -186,7 +190,7 @@ async def test_web_python_apps(fastapi_server, test_user_token): assert find_item(apps, "id", config.id) -async def test_non_persistent_workspace(fastapi_server): +async def test_non_persistent_workspace(fastapi_server, root_user_token): """Test non-persistent workspace.""" api = await connect_to_server( { @@ -214,11 +218,13 @@ async def test_non_persistent_workspace(fastapi_server): assert app is not None # It should exist in the stats - response = requests.get(f"{SERVER_URL}/api/stats") - assert response.status_code == 200 - stats = response.json() - workspace_info = find_item(stats["workspaces"], "name", workspace) - assert workspace_info is not None + async with connect_to_server( + {"server_url": WS_SERVER_URL, "client_id": "admin", "token": root_user_token} + ) as root: + admin = await root.get_service("admin-utils") + workspaces = await admin.list_workspaces() + workspace_info = find_item(workspaces, "name", workspace) + assert workspace_info is not None # We don't need to stop manually, since it should be removed # when the parent client exits @@ -228,11 +234,13 @@ async def test_non_persistent_workspace(fastapi_server): await asyncio.sleep(0.1) # now it should disappear from the stats - response = requests.get(f"{SERVER_URL}/api/stats") - assert response.status_code == 200 - stats = response.json() - workspace_info = find_item(stats["workspaces"], "name", workspace) - assert workspace_info is None + async with connect_to_server( + {"server_url": WS_SERVER_URL, "client_id": "admin", "token": root_user_token} + ) as root: + admin = await root.get_service("admin-utils") + workspaces = await admin.list_workspaces() + workspace_info = find_item(workspaces, "name", workspace) + assert workspace_info is None async def test_lazy_plugin(fastapi_server, test_user_token): diff --git a/tests/test_server_disconnection.py b/tests/test_server_disconnection.py index ef5528cb..e7cea4ed 100644 --- a/tests/test_server_disconnection.py +++ b/tests/test_server_disconnection.py @@ -23,46 +23,56 @@ async def test_server_reconnection(fastapi_server, root_user_token): """Test the server reconnection.""" - root = await connect_to_server( + async with connect_to_server( {"server_url": WS_SERVER_URL, "client_id": "admin", "token": root_user_token} - ) - admin = await root.get_service("admin-utils") - - api = await connect_to_server({"server_url": WS_SERVER_URL, "client_id": "client1"}) - assert api.config["client_id"] == "client1" - await admin.kickout_client( - api.config.workspace, api.config.client_id, 1008, "simulated abnormal closure" - ) - await asyncio.sleep(1) + ) as root: + admin = await root.get_service("admin-utils") - # It should reconnect - assert await api.echo("hi") == "hi" - await api.disconnect() + api = await connect_to_server( + {"server_url": WS_SERVER_URL, "client_id": "client1"} + ) + assert api.config["client_id"] == "client1" + await admin.kickout_client( + api.config.workspace, + api.config.client_id, + 1008, + "simulated abnormal closure", + ) + await asyncio.sleep(1) - api = await connect_to_server({"server_url": WS_SERVER_URL, "client_id": "client1"}) - assert api.config["client_id"] == "client1" - await admin.kickout_client( - api.config.workspace, api.config.client_id, 1000, "normal closure" - ) - await asyncio.sleep(1) - try: + # It should reconnect assert await api.echo("hi") == "hi" - except Exception as e: - assert "Connection is closed" in str(e) await api.disconnect() + api = await connect_to_server( + {"server_url": WS_SERVER_URL, "client_id": "client1"} + ) + assert api.config["client_id"] == "client1" + await admin.kickout_client( + api.config.workspace, api.config.client_id, 1000, "normal closure" + ) + await asyncio.sleep(1) + try: + assert await api.echo("hi") == "hi" + except Exception as e: + assert "Connection is closed" in str(e) + await api.disconnect() + async def test_server_reconnection_by_workspace_unload(fastapi_server): """Test the server reconnection.""" # connect to the server with a user - api = await connect_to_server( - {"server_url": WS_SERVER_URL, "client_id": "client1"} - ) + api = await connect_to_server({"server_url": WS_SERVER_URL, "client_id": "client1"}) token = await api.generate_token() # connect to the server with the same user, to the same workspace api2 = await connect_to_server( - {"server_url": WS_SERVER_URL, "client_id": "client2", "workspace": api.config["workspace"], "token": token} + { + "server_url": WS_SERVER_URL, + "client_id": "client2", + "workspace": api.config["workspace"], + "token": token, + } ) # force a server side disconnect to the second client await api.disconnect() diff --git a/tests/test_services.py b/tests/test_services.py index 5de0d9ee..685c10e3 100644 --- a/tests/test_services.py +++ b/tests/test_services.py @@ -14,6 +14,41 @@ pytestmark = pytest.mark.asyncio +async def test_singleton_service(fastapi_server): + """Test a singleton service.""" + async with connect_to_server( + {"name": "test client", "server_url": SERVER_URL} + ) as api: + await api.register_service( + { + "id": "test-service", + "name": "Test Service", + "config": {"singleton": True}, + "description": "A test service", + "tools": { + "add": lambda a, b: a + b, + "sub": lambda a, b: a - b, + }, + } + ) + + # Registering the same service again should raise an error + with pytest.raises(Exception, match=".*Failed to notify workspace manager.*"): + await api.register_service( + { + "id": "test-service", + "name": "Test Service", + "config": {"singleton": True}, + "description": "A test service", + "tools": { + "add": lambda a, b: a + b, + "sub": lambda a, b: a - b, + }, + }, + {"overwrite": True}, + ) + + async def test_typed_service(fastapi_server): """Test a typed service.""" async with connect_to_server( @@ -47,7 +82,7 @@ async def add( assert svc_type["id"] == f"{api.config.workspace}/test-service-type" service["type"] = "test-service-type" - svc_info = await api.register_service(service, check_type=True) + svc_info = await api.register_service(service, {"check_type": True}) assert svc_info["id"].endswith(":test-service") assert svc_info["type"] == "test-service-type" @@ -62,7 +97,7 @@ def add2(a: int, b: int) -> int: "type": api.config.workspace + "/test-service-type", "tools": {"add": add2}, }, - check_type=True, + {"check_type": True}, ) assert svc_info2["id"].endswith(":test-service2") assert svc_info2["service_schema"] @@ -85,7 +120,6 @@ async def callback(context): assert resp.status_code == 200, resp.text assert "Hypha Account" in resp.text assert "{{ report_url }}" not in resp.text - assert context["report_url"] in resp.text resp = await client.get( context["report_url"] + "?key=" + context["key"] + "&token=" + TOKEN ) diff --git a/tests/test_token.py b/tests/test_token.py index df74719e..0487ff1f 100644 --- a/tests/test_token.py +++ b/tests/test_token.py @@ -34,6 +34,39 @@ async def test_generate_token(fastapi_server): await api1.disconnect() +async def test_revoke_token(fastapi_server): + """Test connecting to the server with a revoked token.""" + async with connect_to_server( + {"name": "my app", "server_url": WS_SERVER_URL, "client_id": "my-app"} + ) as api1: + token = await api1.generate_token() + + async with connect_to_server( + { + "name": "my app", + "server_url": WS_SERVER_URL, + "token": token, + "workspace": api1.config.workspace, + } + ) as api2: + assert await api2.echo("hello") == "hello" + + await api1.revoke_token(token) + + try: + async with connect_to_server( + { + "name": "my app", + "server_url": WS_SERVER_URL, + "token": token, + "workspace": api1.config.workspace, + } + ): + assert False, "Should have raised an exception" + except Exception as e: + assert "Token has been revoked" in str(e) + + async def test_generate_token_with_custom_scope(fastapi_server): """Test token with custom scope.""" async with connect_to_server(