Skip to content

Commit

Permalink
OIDC: Refactor to expose OIDC client object (#971)
Browse files Browse the repository at this point in the history
  • Loading branch information
psrok1 authored Aug 19, 2024
1 parent 92adddc commit 2723ff9
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 69 deletions.
113 changes: 67 additions & 46 deletions mwdb/core/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,74 +7,91 @@

from authlib.common.security import generate_token
from authlib.integrations.requests_client import OAuth2Session
from authlib.jose import JsonWebKey, JsonWebToken, jwt
from authlib.jose import JsonWebKey, JsonWebToken
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo


class OpenIDSession(OAuth2Session):
def fetch_jwk_set(self, force=True):
jwk_set = self.metadata.get("jwks")
if jwk_set and not force:
return jwk_set
class OpenIDClient:
supported_algorithms = ["HS256", "HS384", "HS512", "RS256", "RS384", "RS512"]

def __init__(
self,
client_id,
client_secret,
authorization_endpoint,
token_endpoint,
userinfo_endpoint,
jwks_uri,
**kwargs,
):
self.client_id = client_id
self.client_secret = client_secret
self.authorization_endpoint = authorization_endpoint
self.token_endpoint = token_endpoint
self.userinfo_endpoint = userinfo_endpoint
self.jwks_uri = jwks_uri

self.session = OAuth2Session(
client_id=client_id,
client_secret=client_secret,
authorization_endpoint=authorization_endpoint,
token_endpoint=token_endpoint,
userinfo_endpoint=userinfo_endpoint,
jwks_uri=jwks_uri,
**kwargs,
)

def create_authorization_url(self, redirect_uri):
nonce = generate_token()
return (
*self.session.create_authorization_url(
self.authorization_endpoint, nonce=nonce, redirect_uri=redirect_uri
),
nonce,
)

uri = self.metadata.get("jwks_uri")
if not uri:
def load_key(self, header, _):
alg = header.get("alg")
if alg in ["HS256", "HS384", "HS512"]:
# For HS256: client secret is used for id_token signing
return self.client_secret
elif alg in ["RS256", "RS384", "RS512"]:
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set())
return jwk_set.find_by_kid(header.get("kid"))
else:
raise RuntimeError(f"Unsupported id_token algorithm: '{alg}'")

def fetch_jwk_set(self):
if not self.jwks_uri:
raise RuntimeError('Missing "jwks_uri" in metadata')

resp = self.request("GET", uri, withhold_token=True)
resp = self.session.request("GET", self.jwks_uri, withhold_token=True)
resp.raise_for_status()
jwk_set = resp.json()

self.metadata["jwks"] = jwk_set
return jwk_set

def userinfo(self, **kwargs):
"""Fetch user info from ``userinfo_endpoint``."""
resp = self.get(self.metadata["userinfo_endpoint"], **kwargs)
resp.raise_for_status()
data = resp.json()
return UserInfo(data)
def fetch_id_token(self, code, state, nonce, redirect_uri):
token = self.session.fetch_token(
code=code, state=state, redirect_uri=redirect_uri
)
return self.parse_id_token(token, nonce)

def parse_id_token(self, token, nonce, claims_options=None, leeway=120):
"""Return an instance of UserInfo from token's ``id_token``."""
if "id_token" not in token:
return None

def load_key(header, _):
alg = header.get("alg")
if alg in ["HS256", "HS384", "HS512"]:
# For HS256: client secret is used for id_token signing
return self.client_secret
elif alg in ["RS256", "RS384", "RS512"]:
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set())
try:
return jwk_set.find_by_kid(header.get("kid"))
except ValueError:
# re-try with new jwk set
jwk_set = JsonWebKey.import_key_set(self.fetch_jwk_set(force=True))
return jwk_set.find_by_kid(header.get("kid"))
else:
raise RuntimeError(f"Unsupported id_token algorithm: '{alg}'")

claims_params = dict(nonce=nonce, client_id=self.client_id)
if "access_token" in token:
claims_params["access_token"] = token["access_token"]
claims_cls = CodeIDToken
else:
claims_cls = ImplicitIDToken

if claims_options is None and "issuer" in self.metadata:
claims_options = {"iss": {"values": [self.metadata["issuer"]]}}

alg_values = self.metadata.get("id_token_signing_alg_values_supported")
if alg_values:
_jwt = JsonWebToken(alg_values)
else:
_jwt = jwt

claims = _jwt.decode(
jwt = JsonWebToken(self.supported_algorithms)
claims = jwt.decode(
token["id_token"],
key=load_key,
key=self.load_key,
claims_cls=claims_cls,
claims_options=claims_options,
claims_params=claims_params,
Expand All @@ -86,5 +103,9 @@ def load_key(header, _):
claims.validate(leeway=leeway)
return UserInfo(claims)

def generate_nonce(self):
return generate_token()
def userinfo(self, **kwargs):
"""Fetch user info from ``userinfo_endpoint``."""
resp = self.session.get(self.userinfo_endpoint, **kwargs)
resp.raise_for_status()
data = resp.json()
return UserInfo(data)
24 changes: 5 additions & 19 deletions mwdb/model/oauth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from werkzeug.exceptions import NotFound

from mwdb.core.oauth import OpenIDSession
from mwdb.core.oauth import OpenIDClient
from mwdb.model import Group

from . import db
Expand All @@ -25,34 +25,20 @@ class OpenIDProvider(db.Model):
cascade="all, delete-orphan",
)

def _get_client(self, state=None):
return OpenIDSession(
def get_oidc_client(self):
return OpenIDClient(
client_id=self.client_id,
client_secret=self.client_secret,
scope="openid profile email",
grant_type="authorization_code",
response_type="code",
authorization_endpoint=self.authorization_endpoint,
token_endpoint=self.token_endpoint,
userinfo_endpoint=self.userinfo_endpoint,
jwks_uri=self.jwks_endpoint,
state=state,
state=None,
)

def create_authorization_url(self, redirect_uri):
client = self._get_client()
nonce = client.generate_nonce()
return (
*client.create_authorization_url(
self.authorization_endpoint, nonce=nonce, redirect_uri=redirect_uri
),
nonce,
)

def fetch_id_token(self, code, state, nonce, redirect_uri):
client = self._get_client()
token = client.fetch_token(code=code, state=state, redirect_uri=redirect_uri)
return client.parse_id_token(token, nonce)

def get_group(self):
group_name = ("OpenID_" + self.name)[:32]
group = db.session.query(Group).filter(Group.name == group_name).first()
Expand Down
12 changes: 8 additions & 4 deletions mwdb/resources/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,8 @@ def post(self, provider_name):
raise NotFound(f"Requested provider name '{provider_name}' not found")

redirect_uri = f"{app_config.mwdb.base_url}/oauth/callback"
url, state, nonce = provider.create_authorization_url(redirect_uri)
oidc_client = provider.get_oidc_client()
url, state, nonce = oidc_client.create_authorization_url(redirect_uri)

schema = OpenIDLoginResponseSchema()
return schema.dump({"authorization_url": url, "state": state, "nonce": nonce})
Expand All @@ -370,7 +371,8 @@ def post(self, provider_name):
schema = OpenIDAuthorizeRequestSchema()
obj = loads_schema(request.get_data(as_text=True), schema)
redirect_uri = f"{app_config.mwdb.base_url}/oauth/callback"
userinfo = provider.fetch_id_token(
oidc_client = provider.get_oidc_client()
userinfo = oidc_client.fetch_id_token(
obj["code"], obj["state"], obj["nonce"], redirect_uri
)
# 'sub' bind should be used instead of 'name'
Expand Down Expand Up @@ -432,7 +434,8 @@ def post(self, provider_name):
schema = OpenIDAuthorizeRequestSchema()
obj = loads_schema(request.get_data(as_text=True), schema)
redirect_uri = f"{app_config.mwdb.base_url}/oauth/callback"
userinfo = provider.fetch_id_token(
oidc_client = provider.get_oidc_client()
userinfo = oidc_client.fetch_id_token(
obj["code"], obj["state"], obj["nonce"], redirect_uri
)
# register user with information from provider
Expand Down Expand Up @@ -563,7 +566,8 @@ def post(self, provider_name):
schema = OpenIDAuthorizeRequestSchema()
obj = loads_schema(request.get_data(as_text=True), schema)
redirect_uri = f"{app_config.mwdb.base_url}/oauth/callback"
userinfo = provider.fetch_id_token(
oidc_client = provider.get_oidc_client()
userinfo = oidc_client.fetch_id_token(
obj["code"], obj["state"], obj["nonce"], redirect_uri
)
if db.session.query(
Expand Down

0 comments on commit 2723ff9

Please sign in to comment.