Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: client driven timeout #822

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Write the date in place of the "Unreleased" in the case a new version is release
optionally _extending_ an existing array.
- Add associated Python client method `ArrayClient.patch`.
- Hook to authentication prompt to make password login available without TTY.
- Expanded auth routes to manually reduce refresh token lifetime.

### Fixed

Expand Down
15 changes: 15 additions & 0 deletions tiled/_tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,21 @@ def test_password_auth_hook(config):
assert "authenticated as 'alice'" in repr(context)


def test_refresh_expiration(config):
"""Ensure we can force an early expiration of refresh tokens"""

with Context.from_app(build_app_from_config(config)) as context:
# Log in as Alice with a 1 sec refresh token expiration
spec, username = context.authenticate(
username="alice", password="secret1", refresh_token_max_age=1
)
assert "authenticated as 'alice'" in repr(context)
time.sleep(1.5)
# Attempt to refresh the token should fail
with pytest.raises(CannotRefreshAuthentication):
context.force_auth_refresh()


def test_logout(enter_username_password, config, tmpdir):
"""
Logging out revokes the session, such that it cannot be refreshed.
Expand Down
3 changes: 3 additions & 0 deletions tiled/client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def authenticate(
set_default=True,
*,
password=UNSET,
refresh_token_max_age: Optional[int] = None,
):
"""
See login. This is for programmatic use.
Expand Down Expand Up @@ -577,6 +578,8 @@ def authenticate(
"username": username,
"password": password,
}
if refresh_token_max_age is not None:
form_data["refresh_token_max_age"] = refresh_token_max_age
token_response = self.http_client.post(
auth_endpoint, data=form_data, auth=None
)
Expand Down
38 changes: 32 additions & 6 deletions tiled/server/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ async def create_session(
return fully_loaded_session


async def create_tokens_from_session(settings, db, session, provider):
async def create_tokens_from_session(
settings, db, session, provider, refresh_token_max_age: Optional[timedelta] = None
):
# Provide enough information in the access token to reconstruct Principal
# and its Identities sufficient for access policy enforcement without a
# database hit.
Expand All @@ -483,9 +485,12 @@ async def create_tokens_from_session(settings, db, session, provider):
expires_delta=settings.access_token_max_age,
secret_key=settings.secret_keys[0], # Use the *first* secret key to encode.
)
refresh_token_max_age = settings.get_refresh_token_max_age(
refresh_token_max_age or settings.refresh_token_max_age
)
refresh_token = create_refresh_token(
session_id=session.uuid.hex,
expires_delta=settings.refresh_token_max_age,
expires_delta=refresh_token_max_age,
secret_key=settings.secret_keys[0], # Use the *first* secret key to encode.
)
# Include the identity. This is not stored as part of the session.
Expand All @@ -503,7 +508,7 @@ async def create_tokens_from_session(settings, db, session, provider):
"access_token": access_token,
"expires_in": settings.access_token_max_age / UNIT_SECOND,
"refresh_token": refresh_token,
"refresh_token_expires_in": settings.refresh_token_max_age / UNIT_SECOND,
"refresh_token_expires_in": refresh_token_max_age / UNIT_SECOND,
"token_type": "bearer",
"identity": {"id": identity.id, "provider": provider},
"principal": principal.uuid.hex,
Expand All @@ -517,6 +522,7 @@ async def route(
request: Request,
settings: BaseSettings = Depends(get_settings),
db=Depends(get_database_session),
refresh_token_max_age: Optional[int] = Form[None],
):
request.state.endpoint = "auth"
user_session_state = await authenticator.authenticate(request)
Expand All @@ -531,7 +537,13 @@ async def route(
user_session_state.user_name,
user_session_state.state,
)
tokens = await create_tokens_from_session(settings, db, session, provider)
tokens = await create_tokens_from_session(
settings,
db,
session,
provider,
timedelta(seconds=refresh_token_max_age) if refresh_token_max_age else None,
)
return tokens

return route
Expand Down Expand Up @@ -678,6 +690,7 @@ async def route(
body: schemas.DeviceCode,
settings: BaseSettings = Depends(get_settings),
db=Depends(get_database_session),
refresh_token_max_age: Optional[int] = Form[None],
):
request.state.endpoint = "auth"
device_code_hex = body.device_code
Expand All @@ -704,7 +717,13 @@ async def route(
# The pending session can only be used once.
await db.delete(pending_session)
await db.commit()
tokens = await create_tokens_from_session(settings, db, session, provider)
tokens = await create_tokens_from_session(
settings,
db,
session,
provider,
timedelta(seconds=refresh_token_max_age) if refresh_token_max_age else None,
)
return tokens

return route
Expand All @@ -720,6 +739,7 @@ async def route(
form_data: OAuth2PasswordRequestForm = Depends(),
settings: BaseSettings = Depends(get_settings),
db=Depends(get_database_session),
refresh_token_max_age: Optional[int] = Form(None),
):
request.state.endpoint = "auth"
user_session_state = await authenticator.authenticate(
Expand All @@ -738,7 +758,13 @@ async def route(
user_session_state.user_name,
state=user_session_state.state,
)
tokens = await create_tokens_from_session(settings, db, session, provider)
tokens = await create_tokens_from_session(
settings,
db,
session,
provider,
timedelta(seconds=refresh_token_max_age) if refresh_token_max_age else None,
)
return tokens

return route
Expand Down
3 changes: 3 additions & 0 deletions tiled/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def database_settings(self):
max_overflow=self.database_max_overflow,
)

def get_refresh_token_max_age(self, requested_max_age: timedelta) -> timedelta:
return min(requested_max_age, self.refresh_token_max_age)


@lru_cache()
def get_settings():
Expand Down
Loading