diff --git a/arango/client.py b/arango/client.py index a6bcab6d..13d53cb1 100644 --- a/arango/client.py +++ b/arango/client.py @@ -171,6 +171,7 @@ def db( password: str = "", verify: bool = False, auth_method: str = "basic", + user_token: Optional[str] = None, superuser_token: Optional[str] = None, verify_certificate: bool = True, ) -> StandardDatabase: @@ -189,9 +190,17 @@ def db( refreshed automatically using ArangoDB username and password. This assumes that the clocks of the server and client are synchronized. :type auth_method: str + :param user_token: User generated token for user access. + If set, parameters **username**, **password** and **auth_method** + are ignored. This token is not refreshed automatically. If automatic + token refresh is required, consider setting **auth_method** to "jwt" + and using the **username** and **password** parameters instead. Token + expiry will be checked. + :type user_token: str :param superuser_token: User generated token for superuser access. If set, parameters **username**, **password** and **auth_method** - are ignored. This token is not refreshed automatically. + are ignored. This token is not refreshed automatically. Token + expiry will not be checked. :type superuser_token: str :param verify_certificate: Verify TLS certificates. :type verify_certificate: bool @@ -213,6 +222,17 @@ def db( deserializer=self._deserializer, superuser_token=superuser_token, ) + elif user_token is not None: + connection = JwtConnection( + hosts=self._hosts, + host_resolver=self._host_resolver, + sessions=self._sessions, + db_name=name, + http_client=self._http, + serializer=self._serializer, + deserializer=self._deserializer, + user_token=user_token, + ) elif auth_method.lower() == "basic": connection = BasicConnection( hosts=self._hosts, diff --git a/arango/connection.py b/arango/connection.py index 49aa7b67..3daa4585 100644 --- a/arango/connection.py +++ b/arango/connection.py @@ -13,10 +13,16 @@ from typing import Any, Callable, Optional, Sequence, Set, Tuple, Union import jwt +from jwt.exceptions import ExpiredSignatureError from requests import ConnectionError, Session from requests_toolbelt import MultipartEncoder -from arango.exceptions import JWTAuthError, ServerConnectionError +from arango.exceptions import ( + JWTAuthError, + JWTExpiredError, + JWTRefreshError, + ServerConnectionError, +) from arango.http import HTTPClient from arango.request import Request from arango.resolver import HostResolver @@ -203,7 +209,7 @@ def ping(self) -> int: request = Request(method="get", endpoint="/_api/collection") resp = self.send_request(request) if resp.status_code in {401, 403}: - raise ServerConnectionError("bad username and/or password") + raise ServerConnectionError("bad username/password or token is expired") if not resp.is_success: # pragma: no cover raise ServerConnectionError(resp.error_message or "bad server response") return resp.status_code @@ -300,11 +306,12 @@ def __init__( host_resolver: HostResolver, sessions: Sequence[Session], db_name: str, - username: str, - password: str, http_client: HTTPClient, serializer: Callable[..., str], deserializer: Callable[[str], Any], + username: Optional[str] = None, + password: Optional[str] = None, + user_token: Optional[str] = None, ) -> None: super().__init__( hosts, @@ -323,7 +330,13 @@ def __init__( self._token: Optional[str] = None self._token_exp: int = sys.maxsize - self.refresh_token() + if user_token is not None: + self.set_token(user_token) + elif username is not None and password is not None: + self.refresh_token() + else: + m = "Either **user_token** or **username** & **password** must be set" + raise ValueError(m) def send_request(self, request: Request) -> Response: """Send an HTTP request to ArangoDB server. @@ -360,7 +373,12 @@ def refresh_token(self) -> None: :return: JWT token. :rtype: str + :raise arango.exceptions.JWTRefreshError: If missing username & password. + :raise arango.exceptions.JWTAuthError: If token retrieval fails. """ + if self._username is None or self._password is None: + raise JWTRefreshError("username and password must be set") + request = Request( method="post", endpoint="/_open/auth", @@ -374,21 +392,34 @@ def refresh_token(self) -> None: if not resp.is_success: raise JWTAuthError(resp, request) - self._token = resp.body["jwt"] - assert self._token is not None - - jwt_payload = jwt.decode( - self._token, - issuer="arangodb", - algorithms=["HS256"], - options={ - "require_exp": True, - "require_iat": True, - "verify_iat": True, - "verify_exp": True, - "verify_signature": False, - }, - ) + self.set_token(resp.body["jwt"]) + + def set_token(self, token: str) -> None: + """Set the JWT token. + + :param token: JWT token. + :type token: str + :raise arango.exceptions.JWTExpiredError: If the token is expired. + """ + assert token is not None + + try: + jwt_payload = jwt.decode( + token, + issuer="arangodb", + algorithms=["HS256"], + options={ + "require_exp": True, + "require_iat": True, + "verify_iat": True, + "verify_exp": True, + "verify_signature": False, + }, + ) + except ExpiredSignatureError: + raise JWTExpiredError("JWT token is expired") + + self._token = token self._token_exp = jwt_payload["exp"] self._auth_header = f"bearer {self._token}" @@ -444,3 +475,30 @@ def send_request(self, request: Request) -> Response: request.headers["Authorization"] = self._auth_header return self.process_request(host_index, request) + + def set_token(self, token: str) -> None: + """Set the JWT token. + + :param token: JWT token. + :type token: str + :raise arango.exceptions.JWTExpiredError: If the token is expired. + """ + assert token is not None + + try: + jwt.decode( + token, + issuer="arangodb", + algorithms=["HS256"], + options={ + "require_exp": True, + "require_iat": True, + "verify_iat": True, + "verify_exp": True, + "verify_signature": False, + }, + ) + except ExpiredSignatureError: + raise JWTExpiredError("JWT token is expired") + + self._auth_header = f"bearer {token}" diff --git a/arango/exceptions.py b/arango/exceptions.py index 8998f6c5..fb11f8d5 100644 --- a/arango/exceptions.py +++ b/arango/exceptions.py @@ -1014,3 +1014,11 @@ class JWTSecretListError(ArangoServerError): class JWTSecretReloadError(ArangoServerError): """Failed to reload JWT secrets.""" + + +class JWTRefreshError(ArangoClientError): + """Failed to refresh JWT token.""" + + +class JWTExpiredError(ArangoClientError): + """JWT token has expired.""" diff --git a/docs/auth.rst b/docs/auth.rst index a0cd9ac6..11f62985 100644 --- a/docs/auth.rst +++ b/docs/auth.rst @@ -59,7 +59,7 @@ to work correctly. # compensate for out-of-sync clocks between the client and server. db.conn.ext_leeway = 2 -User generated JWT token can be used for superuser access. +User generated JWT token can be used for user and superuser access. **Example:** @@ -89,3 +89,29 @@ User generated JWT token can be used for superuser access. # Connect to "test" database as superuser using the token. db = client.db('test', superuser_token=token) + + # Connect to "test" database as user using the token. + db = client.db('test', user_token=token) + +User and superuser tokens can be set on the connection object as well. + +**Example:** + +.. code-block:: python + + from arango import ArangoClient + + # Initialize the ArangoDB client. + client = ArangoClient() + + # Connect to "test" database as superuser using the token. + db = client.db('test', user_token='token') + + # Set the user token on the connection object. + db.conn.set_token('new token') + + # Connect to "test" database as superuser using the token. + db = client.db('test', superuser_token='superuser token') + + # Set the user token on the connection object. + db.conn.set_token('new superuser token') diff --git a/tests/test_auth.py b/tests/test_auth.py index 9688799a..0f747563 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -2,8 +2,10 @@ from arango.errno import FORBIDDEN, HTTP_UNAUTHORIZED from arango.exceptions import ( JWTAuthError, + JWTExpiredError, JWTSecretListError, JWTSecretReloadError, + ServerConnectionError, ServerEncryptionError, ServerTLSError, ServerTLSReloadError, @@ -37,7 +39,8 @@ def test_auth_basic(client, db_name, username, password): assert isinstance(db.properties(), dict) -def test_auth_jwt(client, db_name, username, password): +def test_auth_jwt(client, db_name, username, password, secret): + # Test JWT authentication with username and password. db = client.db( name=db_name, username=username, @@ -54,6 +57,13 @@ def test_auth_jwt(client, db_name, username, password): client.db(db_name, username, bad_password, auth_method="jwt") assert err.value.error_code == HTTP_UNAUTHORIZED + # Test JWT authentication with user token. + token = generate_jwt(secret) + db = client.db("_system", user_token=token) + assert isinstance(db.conn, JwtConnection) + assert isinstance(db.version(), str) + assert isinstance(db.properties(), dict) + # TODO re-examine commented out code def test_auth_superuser_token(client, db_name, root_password, secret): @@ -116,13 +126,32 @@ def test_auth_superuser_token(client, db_name, root_password, secret): def test_auth_jwt_expiry(client, db_name, root_password, secret): # Test automatic token refresh on expired token. db = client.db("_system", "root", root_password, auth_method="jwt") + valid_token = generate_jwt(secret) expired_token = generate_jwt(secret, exp=-1000) db.conn._token = expired_token db.conn._auth_header = f"bearer {expired_token}" assert isinstance(db.version(), str) - # Test correct error on token expiry. + # Test expiry error on db instantiation (superuser) + with assert_raises(ServerConnectionError) as err: + client.db("_system", superuser_token=expired_token, verify=True) + + # Test expiry error on db version (superuser) db = client.db("_system", superuser_token=expired_token) with assert_raises(ServerVersionError) as err: db.version() assert err.value.error_code == FORBIDDEN + + # Test expiry error on set_token (superuser). + db = client.db("_system", superuser_token=valid_token) + with assert_raises(JWTExpiredError) as err: + db.conn.set_token(expired_token) + + # Test expiry error on db instantiation (user) + with assert_raises(JWTExpiredError) as err: + db = client.db("_system", user_token=expired_token) + + # Test expiry error on set_token (user). + db = client.db("_system", user_token=valid_token) + with assert_raises(JWTExpiredError) as err: + db.conn.set_token(expired_token)