diff --git a/api/.env.example b/api/.env.example index 7b5e4950c82dd..245c0c376178b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -20,6 +20,9 @@ FILES_URL=http://127.0.0.1:5001 # The time in seconds after the signature is rejected FILES_ACCESS_TIMEOUT=300 +# Access token expiration time in minutes +ACCESS_TOKEN_EXPIRE_MINUTES=60 + # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1 diff --git a/api/app.py b/api/app.py index a251ef5f0f72c..52dd492225339 100644 --- a/api/app.py +++ b/api/app.py @@ -183,7 +183,7 @@ def load_user_from_request(request_from_flask_login): decoded = PassportService().verify(auth_token) user_id = decoded.get("user_id") - logged_in_account = AccountService.load_logged_in_account(account_id=user_id, token=auth_token) + logged_in_account = AccountService.load_logged_in_account(account_id=user_id) if logged_in_account: contexts.tenant_id.set(logged_in_account.current_tenant_id) return logged_in_account diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 93dbc1367f394..a3334d16345e9 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -360,9 +360,9 @@ class WorkflowConfig(BaseSettings): ) -class OAuthConfig(BaseSettings): +class AuthConfig(BaseSettings): """ - Configuration for OAuth authentication + Configuration for authentication and OAuth """ OAUTH_REDIRECT_PATH: str = Field( @@ -371,7 +371,7 @@ class OAuthConfig(BaseSettings): ) GITHUB_CLIENT_ID: Optional[str] = Field( - description="GitHub OAuth client secret", + description="GitHub OAuth client ID", default=None, ) @@ -390,6 +390,11 @@ class OAuthConfig(BaseSettings): default=None, ) + ACCESS_TOKEN_EXPIRE_MINUTES: PositiveInt = Field( + description="Expiration time for access tokens in minutes", + default=60, + ) + class ModerationConfig(BaseSettings): """ @@ -607,6 +612,7 @@ def POSITION_TOOL_EXCLUDES_SET(self) -> set[str]: class FeatureConfig( # place the configs in alphabet order AppExecutionConfig, + AuthConfig, # Changed from OAuthConfig to AuthConfig BillingConfig, CodeExecutionSandboxConfig, DataSetConfig, @@ -621,14 +627,13 @@ class FeatureConfig( MailConfig, ModelLoadBalanceConfig, ModerationConfig, - OAuthConfig, + PositionConfig, RagEtlConfig, SecurityConfig, ToolConfig, UpdateConfig, WorkflowConfig, WorkspaceConfig, - PositionConfig, # hosted services config HostedServiceConfig, CeleryBeatConfig, diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 62837af2b9b0e..18a7b2316660c 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -7,7 +7,7 @@ import services from controllers.console import api from controllers.console.setup import setup_required -from libs.helper import email, get_remote_ip +from libs.helper import email, extract_remote_ip from libs.password import valid_password from models.account import Account from services.account_service import AccountService, TenantService @@ -40,17 +40,16 @@ def post(self): "data": "workspace not found, please contact system admin to invite you to join in a workspace", } - token = AccountService.login(account, ip_address=get_remote_ip(request)) + token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) - return {"result": "success", "data": token} + return {"result": "success", "data": token_pair.model_dump()} class LogoutApi(Resource): @setup_required def get(self): account = cast(Account, flask_login.current_user) - token = request.headers.get("Authorization", "").split(" ")[1] - AccountService.logout(account=account, token=token) + AccountService.logout(account=account) flask_login.logout_user() return {"result": "success"} @@ -106,5 +105,19 @@ def get(self): return {"result": "success"} +class RefreshTokenApi(Resource): + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("refresh_token", type=str, required=True, location="json") + args = parser.parse_args() + + try: + new_token_pair = AccountService.refresh_token(args["refresh_token"]) + return {"result": "success", "data": new_token_pair.model_dump()} + except Exception as e: + return {"result": "fail", "data": str(e)}, 401 + + api.add_resource(LoginApi, "/login") api.add_resource(LogoutApi, "/logout") +api.add_resource(RefreshTokenApi, "/refresh-token") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index ad0c0580aeaba..c5909b8c1092e 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -9,7 +9,7 @@ from configs import dify_config from constants.languages import languages from extensions.ext_database import db -from libs.helper import get_remote_ip +from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models.account import Account, AccountStatus from services.account_service import AccountService, RegisterService, TenantService @@ -81,9 +81,14 @@ def get(self, provider: str): TenantService.create_owner_tenant_if_not_exist(account) - token = AccountService.login(account, ip_address=get_remote_ip(request)) + token_pair = AccountService.login( + account=account, + ip_address=extract_remote_ip(request), + ) - return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}") + return redirect( + f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}" + ) def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 46b4ef5d87a8a..15a4af118b05b 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -4,7 +4,7 @@ from flask_restful import Resource, reqparse from configs import dify_config -from libs.helper import StrLen, email, get_remote_ip +from libs.helper import StrLen, email, extract_remote_ip from libs.password import valid_password from models.model import DifySetup from services.account_service import RegisterService, TenantService @@ -46,7 +46,7 @@ def post(self): # setup RegisterService.setup( - email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request) + email=args["email"], name=args["name"], password=args["password"], ip_address=extract_remote_ip(request) ) return {"result": "success"}, 201 diff --git a/api/libs/helper.py b/api/libs/helper.py index 9c3a1ff04d1f8..d8a8e7f4118d8 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -162,7 +162,7 @@ def generate_string(n): return result -def get_remote_ip(request) -> str: +def extract_remote_ip(request) -> str: if request.headers.get("CF-Connecting-IP"): return request.headers.get("Cf-Connecting-Ip") elif request.headers.getlist("X-Forwarded-For"): diff --git a/api/services/account_service.py b/api/services/account_service.py index 05b505f8a6218..4614775f23afa 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -7,6 +7,7 @@ from hashlib import sha256 from typing import Any, Optional +from pydantic import BaseModel from sqlalchemy import func from werkzeug.exceptions import Unauthorized @@ -49,9 +50,39 @@ from tasks.mail_reset_password_task import send_reset_password_mail_task +class TokenPair(BaseModel): + access_token: str + refresh_token: str + + +REFRESH_TOKEN_PREFIX = "refresh_token:" +ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:" +REFRESH_TOKEN_EXPIRY = timedelta(days=30) + + class AccountService: reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) + @staticmethod + def _get_refresh_token_key(refresh_token: str) -> str: + return f"{REFRESH_TOKEN_PREFIX}{refresh_token}" + + @staticmethod + def _get_account_refresh_token_key(account_id: str) -> str: + return f"{ACCOUNT_REFRESH_TOKEN_PREFIX}{account_id}" + + @staticmethod + def _store_refresh_token(refresh_token: str, account_id: str) -> None: + redis_client.setex(AccountService._get_refresh_token_key(refresh_token), REFRESH_TOKEN_EXPIRY, account_id) + redis_client.setex( + AccountService._get_account_refresh_token_key(account_id), REFRESH_TOKEN_EXPIRY, refresh_token + ) + + @staticmethod + def _delete_refresh_token(refresh_token: str, account_id: str) -> None: + redis_client.delete(AccountService._get_refresh_token_key(refresh_token)) + redis_client.delete(AccountService._get_account_refresh_token_key(account_id)) + @staticmethod def load_user(user_id: str) -> None | Account: account = Account.query.filter_by(id=user_id).first() @@ -61,9 +92,7 @@ def load_user(user_id: str) -> None | Account: if account.status in {AccountStatus.BANNED.value, AccountStatus.CLOSED.value}: raise Unauthorized("Account is banned or closed.") - current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( - account_id=account.id, current=True - ).first() + current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() if current_tenant: account.current_tenant_id = current_tenant.tenant_id else: @@ -84,10 +113,13 @@ def load_user(user_id: str) -> None | Account: return account @staticmethod - def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): + def get_account_jwt_token(account: Account) -> str: payload = { "user_id": account.id, - "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, + "exp": int( + datetime.now(timezone.utc).timestamp() + + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES).total_seconds() + ), "iss": dify_config.EDITION, "sub": "Console API Passport", } @@ -213,7 +245,7 @@ def update_account(account, **kwargs): return account @staticmethod - def update_last_login(account: Account, *, ip_address: str) -> None: + def update_login_info(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) account.last_login_ip = ip_address @@ -221,22 +253,45 @@ def update_last_login(account: Account, *, ip_address: str) -> None: db.session.commit() @staticmethod - def login(account: Account, *, ip_address: Optional[str] = None): + def login(account: Account, *, ip_address: Optional[str] = None) -> TokenPair: if ip_address: - AccountService.update_last_login(account, ip_address=ip_address) - exp = timedelta(days=30) - token = AccountService.get_account_jwt_token(account, exp=exp) - redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds())) - return token + AccountService.update_login_info(account=account, ip_address=ip_address) + + access_token = AccountService.get_account_jwt_token(account=account) + refresh_token = _generate_refresh_token() + + AccountService._store_refresh_token(refresh_token, account.id) + + return TokenPair(access_token=access_token, refresh_token=refresh_token) @staticmethod - def logout(*, account: Account, token: str): - redis_client.delete(_get_login_cache_key(account_id=account.id, token=token)) + def logout(*, account: Account) -> None: + refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id)) + if refresh_token: + AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) @staticmethod - def load_logged_in_account(*, account_id: str, token: str): - if not redis_client.get(_get_login_cache_key(account_id=account_id, token=token)): - return None + def refresh_token(refresh_token: str) -> TokenPair: + # Verify the refresh token + account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) + if not account_id: + raise ValueError("Invalid refresh token") + + account = AccountService.load_user(account_id.decode("utf-8")) + if not account: + raise ValueError("Invalid account") + + # Generate new access token and refresh token + new_access_token = AccountService.get_account_jwt_token(account) + new_refresh_token = _generate_refresh_token() + + AccountService._delete_refresh_token(refresh_token, account.id) + AccountService._store_refresh_token(new_refresh_token, account.id) + + return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token) + + @staticmethod + def load_logged_in_account(*, account_id: str): return AccountService.load_user(account_id) @classmethod @@ -258,10 +313,6 @@ def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: return TokenManager.get_token_data(token, "reset_password") -def _get_login_cache_key(*, account_id: str, token: str): - return f"account_login:{account_id}:{token}" - - class TenantService: @staticmethod def create_tenant(name: str) -> Tenant: @@ -698,3 +749,8 @@ def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> invitation = json.loads(data) return invitation + + +def _generate_refresh_token(length: int = 64): + token = secrets.token_hex(length) + return token diff --git a/docker/.env.example b/docker/.env.example index 87d7709a1830a..3a75bd0d74111 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -91,6 +91,9 @@ MIGRATION_ENABLED=true # The default value is 300 seconds. FILES_ACCESS_TIMEOUT=300 +# Access token expiration time in minutes +ACCESS_TOKEN_EXPIRE_MINUTES=60 + # The maximum number of active requests for the application, where 0 means unlimited, should be a non-negative integer. APP_MAX_ACTIVE_REQUESTS=0 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 62d798a695967..43cb4a1097ee8 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -47,6 +47,7 @@ x-shared-env: &shared-api-worker-env REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} + ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false}