Skip to content

Commit

Permalink
Multiple providers and unit tests (GH-13)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArtyomVancyan authored Aug 11, 2023
2 parents c6b4c73 + ff8385f commit 17eb243
Show file tree
Hide file tree
Showing 18 changed files with 250 additions and 149 deletions.
6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@
FastAPI OAuth2 is a middleware-based social authentication mechanism supporting several auth providers. It depends on
the [social-core](https://github.com/python-social-auth/social-core) authentication backends.

## Features to be implemented

- Use multiple OAuth2 providers at the same time
* There need to be provided a way to configure the OAuth2 for multiple providers
- Customizable OAuth2 routes

## Installation

```shell
Expand Down
9 changes: 7 additions & 2 deletions examples/demonstration/.env
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
OAUTH2_CLIENT_ID=eccd08d6736b7999a32a
OAUTH2_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31
# These id and secret are generated especially for testing purposes,
# if you have your own, please use them, otherwise you can use these.
OAUTH2_GITHUB_CLIENT_ID=eccd08d6736b7999a32a
OAUTH2_GITHUB_CLIENT_SECRET=642999c1c5f2b3df8b877afdc78252ef5b594d31

OAUTH2_GOOGLE_CLIENT_ID=105851609656-uueuan570963mnnf4288nv40eieh9f5l.apps.googleusercontent.com
OAUTH2_GOOGLE_CLIENT_SECRET=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

JWT_SECRET=secret
JWT_ALGORITHM=HS256
Expand Down
15 changes: 12 additions & 3 deletions examples/demonstration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from dotenv import load_dotenv
from social_core.backends.github import GithubOAuth2
from social_core.backends.google import GoogleOAuth2

from fastapi_oauth2.claims import Claims
from fastapi_oauth2.client import OAuth2Client
Expand All @@ -17,14 +18,22 @@
clients=[
OAuth2Client(
backend=GithubOAuth2,
client_id=os.getenv("OAUTH2_CLIENT_ID"),
client_secret=os.getenv("OAUTH2_CLIENT_SECRET"),
# redirect_uri="http://127.0.0.1:8000/",
client_id=os.getenv("OAUTH2_GITHUB_CLIENT_ID"),
client_secret=os.getenv("OAUTH2_GITHUB_CLIENT_SECRET"),
scope=["user:email"],
claims=Claims(
picture="avatar_url",
identity=lambda user: "%s:%s" % (user.get("provider"), user.get("id")),
),
),
OAuth2Client(
backend=GoogleOAuth2,
client_id=os.getenv("OAUTH2_GOOGLE_CLIENT_ID"),
client_secret=os.getenv("OAUTH2_GOOGLE_CLIENT_SECRET"),
scope=["openid", "profile", "email"],
claims=Claims(
identity=lambda user: "%s:%s" % (user.get("provider"), user.get("sub")),
),
),
]
)
13 changes: 8 additions & 5 deletions examples/demonstration/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import APIRouter
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from sqlalchemy.orm import Session

from config import oauth2_config
Expand All @@ -24,16 +25,18 @@ async def on_auth(auth: Auth, user: User):
db: Session = next(get_db())
query = db.query(UserModel)
if user.identity and not query.filter_by(identity=user.identity).first():
# create a local user by OAuth2 user's data if it does not exist yet
UserModel(**{
"identity": user.get("identity"),
"username": user.get("username"),
"image": user.get("image"),
"email": user.get("email"),
"name": user.get("name"),
"identity": user.identity, # User property
"username": user.get("username"), # custom attribute
"name": user.display_name, # User property
"image": user.picture, # User property
"email": user.email, # User property
}).save(db)


app = FastAPI()
app.include_router(app_router)
app.include_router(oauth2_router)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(OAuth2Middleware, config=oauth2_config, callback=on_auth)
5 changes: 5 additions & 0 deletions examples/demonstration/static/github.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions examples/demonstration/static/google-oauth2.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 17 additions & 5 deletions examples/demonstration/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,30 @@
<a href="/auth" style="display: flex; align-items: center; color: #dfdfd6; margin-right: 1rem; text-decoration: none;">
Simulate Login
</a>
<a href="/oauth2/github/auth" style="display: flex; align-items: center;">
<svg style="height: 50px; width: 50px;" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16">
<path fill="#dfdfd6" d="M7.499,1C3.91,1,1,3.906,1,7.49c0,2.867,1.862,5.299,4.445,6.158C5.77,13.707,6,13.375,6,13.125 c0-0.154,0.003-0.334,0-0.875c-1.808,0.392-2.375-0.875-2.375-0.875c-0.296-0.75-0.656-0.963-0.656-0.963 c-0.59-0.403,0.044-0.394,0.044-0.394C3.666,10.064,4,10.625,4,10.625c0.5,0.875,1.63,0.791,2,0.625 c0-0.397,0.044-0.688,0.154-0.873C4.111,10.02,2.997,8.84,3,7.208c0.002-0.964,0.335-1.715,0.876-2.269 C3.639,4.641,3.479,3.625,3.961,3c1.206,0,1.927,0.873,1.927,0.873s0.565-0.248,1.61-0.248c1.045,0,1.608,0.234,1.608,0.234 S9.829,3,11.035,3c0.482,0.625,0.322,1.641,0.132,1.918C11.684,5.461,12,6.21,12,7.208c0,1.631-1.11,2.81-3.148,3.168 C8.982,10.572,9,10.842,9,11.25c0,0.867,0,1.662,0,1.875c0,0.25,0.228,0.585,0.558,0.522C12.139,12.787,14,10.356,14,7.49 C14,3.906,11.09,1,7.499,1z"></path>
</svg>
</a>
{% for provider in request.auth.clients %}
<a href="/oauth2/{{ provider }}/auth" style="display: flex; align-items: center;">
<img
alt="{{ provider }} icon"
src="/static/{{ provider }}.svg"
style="width: 50px; height: 50px; margin-right: 1rem;"
>
</a>
{% endfor %}
{% endif %}
</div>
</header>
<section
style="display: flex; flex-direction: column; align-items: center; justify-content: center; height: calc(100vh - 70px);">
{% if request.user.is_authenticated %}
<h1>Hi, {{ request.user.display_name }}</h1>
<h3>
You're signed in using
{% if request.auth.provider %}
external {{ request.auth.provider.provider }} OAuth2 provider.
{% else %}
local authentication system.
{% endif %}
</h3>
<p>This is what your JWT contains currently</p>
<pre>{{ json.dumps(request.user, indent=4) }}</pre>
{% else %}
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ license_files = LICENSE
platforms = unix, linux, osx, win32
classifiers =
Operating System :: OS Independent
Development Status :: 2 - Pre-Alpha
Development Status :: 3 - Alpha
Framework :: FastAPI
Programming Language :: Python
Programming Language :: Python :: 3
Expand Down
2 changes: 1 addition & 1 deletion src/fastapi_oauth2/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.0-alpha.1"
__version__ = "1.0.0-alpha.2"
13 changes: 9 additions & 4 deletions src/fastapi_oauth2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import httpx
from oauthlib.oauth2 import WebApplicationClient
from oauthlib.oauth2.rfc6749.errors import CustomOAuth2Error
from social_core.backends.oauth import BaseOAuth2
from social_core.strategy import BaseStrategy
from starlette.exceptions import HTTPException
Expand Down Expand Up @@ -46,9 +47,10 @@ class OAuth2Core:

client_id: str = None
client_secret: str = None
callback_url: Optional[str] = None
scope: Optional[List[str]] = None
claims: Optional[Claims] = None
provider: str = None
redirect_uri: str = None
backend: BaseOAuth2 = None
_oauth_client: Optional[WebApplicationClient] = None

Expand Down Expand Up @@ -108,9 +110,12 @@ async def token_redirect(self, request: Request) -> RedirectResponse:
auth = httpx.BasicAuth(self.client_id, self.client_secret)
async with httpx.AsyncClient() as session:
response = await session.post(token_url, headers=headers, content=content, auth=auth)
token = self.oauth_client.parse_request_body_response(json.dumps(response.json()))
token_data = self.standardize(self.backend.user_data(token.get("access_token")))
access_token = request.auth.jwt_create(token_data)
try:
token = self.oauth_client.parse_request_body_response(json.dumps(response.json()))
token_data = self.standardize(self.backend.user_data(token.get("access_token")))
access_token = request.auth.jwt_create(token_data)
except (CustomOAuth2Error, Exception) as e:
raise OAuth2LoginError(400, str(e))

response = RedirectResponse(self.redirect_uri or request.base_url)
response.set_cookie(
Expand Down
30 changes: 13 additions & 17 deletions src/fastapi_oauth2/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union

Expand Down Expand Up @@ -39,16 +38,15 @@ class Auth(AuthCredentials):
scopes: List[str]
clients: Dict[str, OAuth2Core] = {}

provider: str
default_provider: str = "local"
_provider: OAuth2Core = None

def __init__(
self,
scopes: Optional[Sequence[str]] = None,
provider: str = default_provider,
) -> None:
super().__init__(scopes)
self.provider = provider
@property
def provider(self) -> Union[OAuth2Core, None]:
return self._provider

@provider.setter
def provider(self, identifier) -> None:
self._provider = self.clients.get(identifier)

@classmethod
def set_http(cls, http: bool) -> None:
Expand Down Expand Up @@ -145,18 +143,16 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
return Auth(), User()

user = User(Auth.jwt_decode(param))
user.update(provider=user.get("provider", Auth.default_provider))
auth = Auth(user.pop("scope", []), user.get("provider"))
client = Auth.clients.get(auth.provider)
claims = client.claims if client else Claims()
user = user.use_claims(claims)
auth = Auth(user.pop("scope", []))
auth.provider = user.get("provider")
claims = auth.provider.claims if auth.provider else {}

# Call the callback function on authentication
if callable(self.callback):
coroutine = self.callback(auth, user)
coroutine = self.callback(auth, user.use_claims(claims))
if issubclass(type(coroutine), Awaitable):
await coroutine
return auth, user
return auth, user.use_claims(claims)


class OAuth2Middleware:
Expand Down
29 changes: 11 additions & 18 deletions src/fastapi_oauth2/security.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Type

from fastapi.security import OAuth2 as FastAPIOAuth2
Expand All @@ -12,32 +8,29 @@
from starlette.requests import Request


def use_cookies(cls: Type[FastAPIOAuth2]) -> Callable[[Tuple[Any], Dict[str, Any]], FastAPIOAuth2]:
"""OAuth2 classes wrapped with this decorator will use cookies for the Authorization header."""
class OAuth2Cookie(type):
"""OAuth2 classes using this metaclass will use cookies for the Authorization header."""

def __new__(metacls, name, bases, attrs) -> Type:
instance = super().__new__(metacls, name, bases, attrs)

def _use_cookies(*args, **kwargs) -> FastAPIOAuth2:
async def __call__(self: FastAPIOAuth2, request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization", request.cookies.get("Authorization"))
if authorization:
request._headers = Headers({**request.headers, "Authorization": authorization})
return await super(cls, self).__call__(request)

cls.__call__ = __call__
return cls(*args, **kwargs)
return await instance.__base__.__call__(self, request)

return _use_cookies
instance.__call__ = __call__
return instance


@use_cookies
class OAuth2(FastAPIOAuth2):
class OAuth2(FastAPIOAuth2, metaclass=OAuth2Cookie):
"""Wrapper class of the `fastapi.security.OAuth2` class."""


@use_cookies
class OAuth2PasswordBearer(FastAPIPasswordBearer):
class OAuth2PasswordBearer(FastAPIPasswordBearer, metaclass=OAuth2Cookie):
"""Wrapper class of the `fastapi.security.OAuth2PasswordBearer` class."""


@use_cookies
class OAuth2AuthorizationCodeBearer(FastAPICodeBearer):
class OAuth2AuthorizationCodeBearer(FastAPICodeBearer, metaclass=OAuth2Cookie):
"""Wrapper class of the `fastapi.security.OAuth2AuthorizationCodeBearer` class."""
60 changes: 60 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@

import pytest
import social_core.backends as backends
from fastapi import APIRouter
from fastapi import Depends
from fastapi import FastAPI
from fastapi import Request
from social_core.backends.github import GithubOAuth2
from social_core.backends.oauth import BaseOAuth2
from starlette.responses import Response

from fastapi_oauth2.client import OAuth2Client
from fastapi_oauth2.middleware import OAuth2Middleware
from fastapi_oauth2.router import router as oauth2_router
from fastapi_oauth2.security import OAuth2

package_path = backends.__path__[0]

Expand All @@ -24,3 +35,52 @@ def backends():
except ImportError:
continue
return backend_instances


@pytest.fixture
def get_app():
def fixture_wrapper(authentication: OAuth2 = None):
if not authentication:
authentication = OAuth2()

oauth2 = authentication
application = FastAPI()
app_router = APIRouter()

@app_router.get("/user")
def user(request: Request, _: str = Depends(oauth2)):
return request.user

@app_router.get("/auth")
def auth(request: Request):
access_token = request.auth.jwt_create({
"name": "test",
"sub": "test",
"id": "test",
})
response = Response()
response.set_cookie(
"Authorization",
value=f"Bearer {access_token}",
max_age=request.auth.expires,
expires=request.auth.expires,
httponly=request.auth.http,
)
return response

application.include_router(app_router)
application.include_router(oauth2_router)
application.add_middleware(OAuth2Middleware, config={
"allow_http": True,
"clients": [
OAuth2Client(
backend=GithubOAuth2,
client_id="test_id",
client_secret="test_secret",
),
],
})

return application

return fixture_wrapper
17 changes: 17 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

from fastapi_oauth2.client import OAuth2Client
from fastapi_oauth2.core import OAuth2Core


@pytest.mark.anyio
async def test_core_init_with_all_backends(backends):
for backend in backends:
try:
OAuth2Core(OAuth2Client(
backend=backend,
client_id="test_client_id",
client_secret="test_client_secret",
))
except (NotImplementedError, Exception):
assert False
Loading

0 comments on commit 17eb243

Please sign in to comment.