Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sessions cookie changes to support persistent cookies and partitioned cookies #2527

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions docs/middleware.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ request through as normal, but will include appropriate CORS headers on the resp

## SessionMiddleware

Adds signed cookie-based HTTP sessions. Session information is readable but not modifiable.
Adds signed cookie-based HTTP sessions. Session cookie information is user readable but not user modifiable, the data stored is ***not*** encrypted.

Access or modify the session data using the `request.session` dictionary interface.

Expand All @@ -103,10 +103,11 @@ The following arguments are supported:
* `secret_key` - Should be a random string.
* `session_cookie` - Defaults to "session".
* `max_age` - Session expiry time in seconds. Defaults to 2 weeks. If set to `None` then the cookie will last as long as the browser session.
* `refresh_window` - Refresh window in seconds before max_age. If set the cookie will automatically refresh with in that timeframe when used to a new max_age. Defaults to `None`.
* `same_site` - SameSite flag prevents the browser from sending session cookie along with cross-site requests. Defaults to `'lax'`.
* `https_only` - Indicate that Secure flag should be set (can be used with HTTPS only). Defaults to `False`.
* `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains [refrence](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute).

* `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains [reference](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute).
* `partitioned` - A [partitioned cookie](https://developer.mozilla.org/en-US/docs/Web/Privacy/Privacy_sandbox/Partitioned_cookies) is a type of cookie that can only be accessed by the same third-party service within the context of the top-level site where it was initially set, preventing cross-site tracking and improving user privacy. Defaults to `False`.

```python
from starlette.applications import Starlette
Expand Down
125 changes: 99 additions & 26 deletions starlette/middleware/sessions.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,156 @@
import json
import typing
from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone

import itsdangerous
from itsdangerous.exc import BadSignature
from itsdangerous.exc import BadSignature, SignatureExpired

from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
from starlette.types import ASGIApp, Message, Receive, Scope, Send


# mutable mapping that keeps track of whether it has been modified
class ModifiedDict(typing.Dict[str, typing.Any]):
def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None:
super().__init__(*args, **kwargs)
self.modify = False
self.invalid = False

def __setitem__(self, key: str, value: typing.Any) -> None: # pragma: no cover
super().__setitem__(key, value)
self.modify = True

def __delitem__(self, key: str) -> None: # pragma: no cover
super().__delitem__(key)
self.modify = True

def clear(self) -> None:
super().clear()
self.invalid = True
self.modify = True

def pop(
self, key: str, default: typing.Any = None
) -> typing.Any: # pragma: no cover
value = super().pop(key, default)
self.modify = True
return value

def popitem(self) -> typing.Any: # pragma: no cover
value = super().popitem()
self.modify = True
return value

def setdefault(
self, key: str, default: typing.Any = None
) -> typing.Any: # pragma: no cover
value = super().setdefault(key, default)
self.modify = True
return value

def update(self, *args: typing.Any, **kwargs: typing.Any) -> None:
super().update(*args, **kwargs)
self.modify = True


class SessionMiddleware:
def __init__(
self,
app: ASGIApp,
secret_key: typing.Union[str, Secret],
session_cookie: str = "session",
max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds
refresh_window: typing.Optional[int] = None,
path: str = "/",
same_site: typing.Literal["lax", "strict", "none"] = "lax",
https_only: bool = False,
domain: typing.Optional[str] = None,
partitioned: typing.Optional[bool] = False,
) -> None:
self.app = app
self.signer = itsdangerous.TimestampSigner(str(secret_key))
self.session_cookie = session_cookie
self.max_age = max_age
self.refresh_window = refresh_window
self.path = path
self.security_flags = "httponly; samesite=" + same_site
if https_only: # Secure flag can be used with HTTPS only
self.security_flags += "; secure"
if domain is not None:
self.security_flags += f"; domain={domain}"
if partitioned:
self.security_flags += "; partitioned"

# Decode and validate cookie
def decode_cookie(self, cookie: bytes) -> ModifiedDict:
result: ModifiedDict = ModifiedDict()
try:
data = self.signer.unsign(
cookie, max_age=self.max_age, return_timestamp=True
)
result = ModifiedDict(json.loads(b64decode(data[0])))
except (BadSignature, SignatureExpired):
result.invalid = True
return result

# data[1] is the datetime when signed from itsdangerous
if self.refresh_window and self.max_age:
now = datetime.now(timezone.utc)
expiration = data[1] + timedelta(seconds=self.max_age)
# The cookie is with in the refresh window, trigger a refresh.
if (
now >= (expiration - timedelta(seconds=self.refresh_window))
and now <= expiration
): # noqa E501
result.modify = True
return result

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
await self.app(scope, receive, send)
return

connection = HTTPConnection(scope)
initial_session_was_empty = True

if self.session_cookie in connection.cookies:
data = connection.cookies[self.session_cookie].encode("utf-8")
try:
data = self.signer.unsign(data, max_age=self.max_age)
scope["session"] = json.loads(b64decode(data))
initial_session_was_empty = False
except BadSignature:
scope["session"] = {}
scope["session"] = self.decode_cookie(
connection.cookies[self.session_cookie].encode("utf-8")
) # noqa E501
else:
scope["session"] = {}
scope["session"] = ModifiedDict()

async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start":
if scope["session"]:
# We have session data to persist.
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
if scope["session"] and not scope["session"].invalid:
# Scope has session data and is valid.
if scope["session"].modify:
# Scope has updated data or needs refreshing.
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
# If the session cookie is invalid for any reason
elif scope["session"].invalid: # Clear the cookie.
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {max_age}{security_flags}".format( # noqa E501
session_cookie=self.session_cookie,
data=data.decode("utf-8"),
path=self.path,
max_age=f"Max-Age={self.max_age}; " if self.max_age else "",
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
elif not initial_session_was_empty:
# The session has been cleared.
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
session_cookie=self.session_cookie,
data="null",
path=self.path,
expires="expires=Thu, 01 Jan 1970 00:00:00 GMT; ",
max_age="Max-Age=-1; ",
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
# No session cookie was present, or it isn't modified,
# don't modify or delete the cookie.
await send(message)

await self.app(scope, receive, send_wrapper)
104 changes: 94 additions & 10 deletions tests/middleware/test_session.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,34 @@
import re
import time
from typing import Callable

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import Mount, Route
from starlette.testclient import TestClient

TestClientFactory = Callable[..., TestClient]

def view_session(request):

def view_session(request: Request) -> JSONResponse:
return JSONResponse({"session": request.session})


async def update_session(request):
async def update_session(request: Request) -> JSONResponse:
data = await request.json()
request.session.update(data)
return JSONResponse({"session": request.session})


async def clear_session(request):
async def clear_session(request: Request) -> JSONResponse:
request.session.clear()
return JSONResponse({"session": request.session})


def test_session(test_client_factory):
def test_session(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Expand Down Expand Up @@ -56,7 +61,7 @@ def test_session(test_client_factory):
assert response.json() == {"session": {}}


def test_session_expires(test_client_factory):
def test_session_expires(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Expand All @@ -80,7 +85,7 @@ def test_session_expires(test_client_factory):
assert response.json() == {"session": {}}


def test_secure_session(test_client_factory):
def test_secure_session(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Expand Down Expand Up @@ -119,7 +124,7 @@ def test_secure_session(test_client_factory):
assert response.json() == {"session": {}}


def test_session_cookie_subpath(test_client_factory):
def test_session_cookie_subpath(test_client_factory: TestClientFactory) -> None:
second_app = Starlette(
routes=[
Route("/update_session", endpoint=update_session, methods=["POST"]),
Expand All @@ -139,7 +144,7 @@ def test_session_cookie_subpath(test_client_factory):
assert cookie_path == "/second_app"


def test_invalid_session_cookie(test_client_factory):
def test_invalid_session_cookie(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Expand All @@ -158,7 +163,7 @@ def test_invalid_session_cookie(test_client_factory):
assert response.json() == {"session": {}}


def test_session_cookie(test_client_factory):
def test_session_cookie(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Expand All @@ -180,7 +185,7 @@ def test_session_cookie(test_client_factory):
assert response.json() == {"session": {}}


def test_domain_cookie(test_client_factory):
def test_domain_cookie(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Expand All @@ -202,3 +207,82 @@ def test_domain_cookie(test_client_factory):
client.cookies.delete("session")
response = client.get("/view_session")
assert response.json() == {"session": {}}


def test_session_refresh(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[
Middleware(
SessionMiddleware, refresh_window=100, secret_key="example", max_age=100
)
],
)

client = test_client_factory(app)
response = client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}
original_cookie_header = response.headers["set-cookie"]

# itsdangerous only signs with seconds resolution, no milliseconds.
time.sleep(1)

response = client.get("/view_session")
assert response.json() == {"session": {"some": "data"}}

second_cookie_header = response.headers["set-cookie"]
# second cookie data should match what was set and the signature is differnt.
assert original_cookie_header != second_cookie_header


def test_session_persistence(test_client_factory: TestClientFactory) -> None:
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[Middleware(SessionMiddleware, secret_key="example", max_age=100)],
)

client = test_client_factory(app)
response = client.post("/update_session", json={"some": "data"})

assert response.json() == {"session": {"some": "data"}}

response = client.get("/view_session")
# response includes the cookie data, and there's no new set-cookie
assert response.json() == {
"session": {"some": "data"}
} and not response.headers.get("set-cookie")


def test_partitioned_session(test_client_factory: TestClientFactory) -> None:
session_cookie = "__Host-session"
app = Starlette(
routes=[
Route("/view_session", endpoint=view_session),
Route("/update_session", endpoint=update_session, methods=["POST"]),
],
middleware=[
Middleware(
SessionMiddleware,
secret_key="example",
https_only=True,
partitioned=True,
session_cookie=session_cookie,
same_site="none",
)
],
)

secure_client = test_client_factory(app, base_url="https://testserver")

response = secure_client.post("/update_session", json={"some": "data"})
assert response.json() == {"session": {"some": "data"}}

cookie = response.headers["set-cookie"]
cookie_partition_match = re.search(rf"{session_cookie}.*; partitioned", cookie)
assert cookie_partition_match is not None