Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

refactor: Refactors GitHub interactions & adds unit tests #27

Merged
merged 5 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 16 additions & 36 deletions src/app/api/api_v1/endpoints/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import secrets

import requests
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
Expand All @@ -16,8 +15,10 @@
from app.core.security import create_access_token, hash_password, verify_password
from app.crud import UserCRUD
from app.models import UserScope
from app.schemas.login import GHAccessToken, GHToken, GHTokenRequest, Token, TokenRequest
from app.schemas.login import GHAccessToken, Token, TokenRequest
from app.schemas.services import GHToken
from app.schemas.users import UserCreation
from app.services.github import gh_client
from app.services.telemetry import telemetry_client

router = APIRouter()
Expand All @@ -35,23 +36,9 @@

@router.post("/github", status_code=status.HTTP_200_OK, summary="Request a GitHub token from authorization code")
async def request_github_token_from_code(
payload: GHTokenRequest,
payload: TokenRequest,
) -> GHToken:
gh_payload = TokenRequest(
client_id=settings.GH_OAUTH_ID,
client_secret=settings.GH_OAUTH_SECRET,
redirect_uri=payload.redirect_uri,
code=payload.code,
)
response = requests.post(
settings.GH_TOKEN_ENDPOINT,
json=gh_payload.dict(),
headers={"Accept": "application/json"},
timeout=5,
)
if response.status_code != status.HTTP_200_OK or isinstance(response.json().get("error"), str):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authorization code.")
return GHToken(**response.json())
return gh_client.get_token_from_code(payload.code, payload.redirect_uri)


@router.post("/creds", status_code=status.HTTP_200_OK, summary="Request an access token using credentials")
Expand Down Expand Up @@ -89,27 +76,22 @@
By default, the token expires after 1 hour.
"""
# Fetch GitHub
response = requests.get(
"https://api.github.com/user",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {payload.github_token}"},
timeout=5,
gh_user = gh_client.get_my_user(payload.github_token)
telemetry_client.capture(gh_user["id"], event="user-login", properties={"login": gh_user["login"]})
telemetry_client.identify(

Check warning on line 81 in src/app/api/api_v1/endpoints/login.py

View check run for this annotation

Codecov / codecov/patch

src/app/api/api_v1/endpoints/login.py#L80-L81

Added lines #L80 - L81 were not covered by tests
gh_user["id"],
properties={
"login": gh_user["login"],
"name": gh_user["name"],
"email": gh_user["email"],
"twitter_username": gh_user["twitter_username"],
},
)
if response.status_code != 200:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid GitHub token.")
gh_user = response.json()
# Verify credentials
user = await users.get(gh_user["id"], strict=False)
# Register if non existing
if user is None:
telemetry_client.identify(
gh_user["id"],
properties={
"login": gh_user["login"],
"name": gh_user["name"],
"email": gh_user["email"],
"twitter_username": gh_user["twitter_username"],
},
)
telemetry_client.capture(gh_user["id"], event="user-creation", properties={"login": gh_user["login"]})

Check warning on line 94 in src/app/api/api_v1/endpoints/login.py

View check run for this annotation

Codecov / codecov/patch

src/app/api/api_v1/endpoints/login.py#L94

Added line #L94 was not covered by tests
user = await users.create(
UserCreation(
id=gh_user["id"],
Expand All @@ -118,11 +100,9 @@
scope=UserScope.USER,
)
)
telemetry_client.capture(user.id, event="user-creation", properties={"login": gh_user["login"]})

# create access token using user user_id/user_scopes
token_data = {"sub": str(user.id), "scopes": user.scope.split()}
token = await create_access_token(token_data, settings.ACCESS_TOKEN_UNLIMITED_MINUTES)
telemetry_client.capture(user.id, event="user-login", properties={"login": user.login})

return Token(access_token=token, token_type="bearer") # nosec B106 # noqa S106
5 changes: 4 additions & 1 deletion src/app/api/api_v1/endpoints/repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,19 @@ async def fetch_repos(


@router.put("/{repo_id}/guidelines/order", status_code=status.HTTP_200_OK)
async def reorder_guidelines(
async def reorder_repo_guidelines(
payload: GuidelineOrder,
repo_id: int = Path(..., gt=0),
guidelines: GuidelineCRUD = Depends(get_guideline_crud),
repos: RepositoryCRUD = Depends(get_repo_crud),
user=Security(get_current_user, scopes=[UserScope.USER, UserScope.ADMIN]),
) -> List[Guideline]:
telemetry_client.capture(user.id, event="guideline-order", properties={"repo_id": repo_id})
# Ensure all IDs are unique
if len(payload.guideline_ids) != len(set(payload.guideline_ids)):
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Duplicate IDs were passed.")
# Check the repo
await repos.get(repo_id, strict=True)
# Ensure all IDs are valid
guideline_ids = [elt.id for elt in await guidelines.fetch_all(("repo_id", repo_id))]
if set(payload.guideline_ids) != set(guideline_ids):
Expand Down
1 change: 0 additions & 1 deletion src/app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class Settings(BaseSettings):
CORS_ORIGIN: str = "*"
# Ext API endpoints
GH_AUTHORIZE_ENDPOINT: str = "https://github.com/login/oauth/authorize"
GH_TOKEN_ENDPOINT: str = "https://github.com/login/oauth/access_token"
GH_OAUTH_ID: str = os.environ["GH_OAUTH_ID"]
GH_OAUTH_SECRET: str = os.environ["GH_OAUTH_SECRET"]
GH_TOKEN: Union[str, None] = os.environ.get("GH_TOKEN")
Expand Down
13 changes: 0 additions & 13 deletions src/app/schemas/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,6 @@ class TokenPayload(BaseModel):
scopes: List[UserScope] = []


class GHTokenRequest(BaseModel):
code: str
redirect_uri: HttpUrl


class TokenRequest(BaseModel):
client_id: str
client_secret: str
code: str
redirect_uri: HttpUrl


class GHToken(BaseModel):
access_token: str
token_type: str
scope: str
15 changes: 14 additions & 1 deletion src/app/schemas/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from typing import Any, Dict, List

from pydantic import BaseModel
from pydantic import BaseModel, HttpUrl

__all__ = ["ChatCompletion"]

Expand Down Expand Up @@ -59,3 +59,16 @@ class ChatCompletion(BaseModel):
function_call: Dict[str, str]
temperature: float
frequency_penalty: float


class GHTokenRequest(BaseModel):
client_id: str
client_secret: str
code: str
redirect_uri: HttpUrl


class GHToken(BaseModel):
access_token: str
token_type: str
scope: str
73 changes: 48 additions & 25 deletions src/app/services/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,73 @@
from typing import Any, Dict, Union

import requests
from fastapi import HTTPException
from fastapi import HTTPException, status
from pydantic import HttpUrl

from app.core.config import settings
from app.schemas.services import GHToken, GHTokenRequest

__all__ = ["gh_client"]


class GitHubClient:
ENDPOINT: str = "https://api.github.com"
OAUTH_ENDPOINT: str = "https://github.com/login/oauth/access_token"

def __init__(self, token: Union[str, None] = None) -> None:
self.headers = (
{"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
if token
else {"Content-Type": "application/json"}
self.headers = self._get_headers(token)

def _get_headers(self, token: Union[str, None] = None) -> Dict[str, str]:
if isinstance(token, str):
return {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
return {"Content-Type": "application/json"}

def _get(self, route: str, token: Union[str, None] = None, timeout: int = 5) -> Dict[str, Any]:
response = requests.get(
f"{self.ENDPOINT}/{route}",
headers=self._get_headers(token) if isinstance(token, str) else self.headers,
timeout=timeout,
)
json_response = response.json()
if response.status_code != status.HTTP_200_OK:
raise HTTPException(
status_code=response.status_code, detail=json_response.get("error", json_response["message"])
)
return json_response

Check warning on line 41 in src/app/services/github.py

View check run for this annotation

Codecov / codecov/patch

src/app/services/github.py#L41

Added line #L41 was not covered by tests

def _get(self, endpoint: str, entry_id: int) -> Dict[str, Any]:
response = requests.get(f"{self.ENDPOINT}/{endpoint}{entry_id}", headers=self.headers, timeout=2)
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response.json()["error"]["message"])
return response.json()
def get_repo(self, repo_id: int, **kwargs: Any) -> Dict[str, Any]:
return self._get(f"repositories/{repo_id}", **kwargs)

Check warning on line 44 in src/app/services/github.py

View check run for this annotation

Codecov / codecov/patch

src/app/services/github.py#L44

Added line #L44 was not covered by tests

def get_repo(self, repo_id: int) -> Dict[str, Any]:
return self._get("repositories/", repo_id)
def get_user(self, user_id: int, **kwargs: Any) -> Dict[str, Any]:
return self._get(f"user/{user_id}", **kwargs)

Check warning on line 47 in src/app/services/github.py

View check run for this annotation

Codecov / codecov/patch

src/app/services/github.py#L47

Added line #L47 was not covered by tests

def get_user(self, user_id: int) -> Dict[str, Any]:
return self._get("user/", user_id)
def get_my_user(self, token: str) -> Dict[str, Any]:
return self._get("user", token)

def get_permission(self, repo_name: str, user_name: str, github_token: str) -> str:
response = requests.get(
f"{self.ENDPOINT}/repos/{repo_name}/collaborators/{user_name}/permission",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {github_token}",
},
timeout=5,
)
if response.status_code != 200:
raise HTTPException(status_code=response.status_code, detail=response.json()["error"]["message"])
return response.json()["role_name"]
return self._get(f"repos/{repo_name}/collaborators/{user_name}/permission", github_token)["role_name"]

Check warning on line 53 in src/app/services/github.py

View check run for this annotation

Codecov / codecov/patch

src/app/services/github.py#L53

Added line #L53 was not covered by tests

def has_valid_permission(self, repo_name: str, user_name: str, github_token: str) -> bool:
return self.get_permission(repo_name, user_name, github_token) in ("maintain", "admin")

def get_token_from_code(self, code: str, redirect_uri: HttpUrl, timeout: int = 5) -> GHToken:
gh_payload = GHTokenRequest(
client_id=settings.GH_OAUTH_ID,
client_secret=settings.GH_OAUTH_SECRET,
redirect_uri=redirect_uri,
code=code,
)
response = requests.post(
self.OAUTH_ENDPOINT,
json=gh_payload.dict(),
headers={"Accept": "application/json"},
timeout=timeout,
)
if response.status_code != status.HTTP_200_OK:
raise HTTPException(status_code=response.status_code, detail=response.json()["error"])

Check warning on line 72 in src/app/services/github.py

View check run for this annotation

Codecov / codecov/patch

src/app/services/github.py#L72

Added line #L72 was not covered by tests
if response.status_code == status.HTTP_200_OK and isinstance(response.json().get("error_description"), str):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=response.json()["error_description"])
return GHToken(**response.json())

Check warning on line 75 in src/app/services/github.py

View check run for this annotation

Codecov / codecov/patch

src/app/services/github.py#L75

Added line #L75 was not covered by tests


gh_client = GitHubClient(settings.GH_TOKEN)
18 changes: 10 additions & 8 deletions src/tests/endpoints/test_login.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from httpx import AsyncClient
from sqlmodel.ext.asyncio.session import AsyncSession

from app.core import security
from app.api.api_v1.endpoints import login
from app.models import User

USER_TABLE = [
Expand All @@ -20,15 +20,16 @@ async def login_session(async_session: AsyncSession, monkeypatch):
for entry in USER_TABLE:
async_session.add(User(**entry))
await async_session.commit()
monkeypatch.setattr(security, "verify_password", pytest.mock_verify_password)
monkeypatch.setattr(login, "verify_password", pytest.mock_verify_password)
monkeypatch.setattr(login, "hash_password", pytest.mock_hash_password)
yield async_session


@pytest.mark.parametrize(
("payload", "status_code", "status_detail"),
[
({"username": "foo"}, 422, None),
({"github_token": "foo"}, 401, None),
({"github_token": "foo"}, 401, "Bad credentials"),
],
)
@pytest.mark.asyncio()
Expand All @@ -50,7 +51,8 @@ async def test_login_with_github_token(
[
({"username": "foo"}, 422, None),
({"username": "foo", "password": "bar"}, 401, None),
# ({"username": "first_login", "password": "first_pwd"}, 200, None),
({"username": "first_login", "password": "pwd"}, 401, None),
({"username": "first_login", "password": "first_pwd"}, 200, None),
],
)
@pytest.mark.asyncio()
Expand All @@ -67,9 +69,9 @@ async def test_login_with_creds(
assert response.json()["detail"] == status_detail
if response.status_code // 100 == 2:
response_json = response.json()
assert response_json["token_type"] == "Bearer" # noqa: S105
assert response_json["token_type"] == "bearer" # noqa: S105
assert isinstance(response_json["access_token"], str)
assert len(response_json["access_token"]) == 10
assert len(response_json["access_token"]) == 143


@pytest.mark.parametrize(
Expand All @@ -78,7 +80,7 @@ async def test_login_with_creds(
({"code": "foo", "redirect_uri": 0}, 422, None, None),
# Github 422
({"code": "foo", "redirect_uri": ""}, 422, None, None),
({"code": "foo", "redirect_uri": "https://quackai.com"}, 401, None, None),
({"code": "foo", "redirect_uri": "https://quackai.com"}, 400, None, None),
],
)
@pytest.mark.asyncio()
Expand All @@ -91,7 +93,7 @@ async def test_request_github_token_from_code(
expected_response: Union[Dict[str, Any], None],
):
response = await async_client.post("/login/github", json=payload)
assert response.status_code == status_code
assert response.status_code == status_code, print(response.json(), isinstance(response.json()["detail"], str))
if isinstance(status_detail, str):
assert response.json()["detail"] == status_detail
if isinstance(expected_response, dict):
Expand Down
39 changes: 39 additions & 0 deletions src/tests/endpoints/test_repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,42 @@ async def test_fetch_guidelines_from_repo(
assert response.json()["detail"] == status_detail
if response.status_code // 100 == 2:
assert response.json() == expected_response


@pytest.mark.parametrize(
("user_idx", "repo_id", "payload", "status_code", "status_detail"),
[
(None, 12345, {"guideline_ids": [1]}, 401, "Not authenticated"),
(0, 100, {"guideline_ids": [1]}, 404, "Table Repository has no corresponding entry."),
(0, 12345, {"guideline_ids": [1, 2]}, 422, None),
(0, 12345, {"guideline_ids": [1, 1]}, 422, None),
(0, 12345, {"guideline_ids": [1]}, 200, None),
(0, 123456, {"guideline_ids": [1]}, 422, None),
(1, 12345, {"guideline_ids": [1]}, 200, None),
],
)
@pytest.mark.asyncio()
async def test_reorder_repo_guidelines(
async_client: AsyncClient,
guideline_session: AsyncSession,
user_idx: Union[int, None],
repo_id: int,
payload: Dict[str, Any],
status_code: int,
status_detail: Union[str, None],
):
auth = None
if isinstance(user_idx, int):
auth = await pytest.get_token(USER_TABLE[user_idx]["id"], USER_TABLE[user_idx]["scope"].split())

response = await async_client.put(f"/repos/{repo_id}/guidelines/order", json=payload, headers=auth)
assert response.status_code == status_code, print(response.json())
if isinstance(status_detail, str):
assert response.json()["detail"] == status_detail
if response.status_code // 100 == 2:
assert [{k: v for k, v in entry.items() if k not in {"updated_at", "order"}} for entry in response.json()] == [
{k: v for k, v in entry.items() if k not in {"updated_at", "order"}} for entry in GUIDELINE_TABLE
]
assert [entry["order"] for entry in response.json()] == [
payload["guideline_ids"].index(entry["id"]) for entry in GUIDELINE_TABLE
]
Loading