From b00944367c8f42366031d6b75041b02a6cc58b4c Mon Sep 17 00:00:00 2001 From: Trish Gillett-Kawamoto Date: Wed, 31 Jul 2024 09:16:55 -0600 Subject: [PATCH] refactor: Authenticator refactoring (preparation for app token refreshing) (#281) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Note: In this PR I am not changing functionality but just refactoring to make functionality changes possible in a future PR.** Essentially: [make the change easy, then make the easy change.](https://x.com/KentBeck/status/250733358307500032?lang=en) My goal is to add github app token features, for example refreshing app tokens before they expire (they're only good for an hour and currently don't refresh), and allowing multiple app tokens Currently, personal and app tokens are stored in the same list and treated interchangeably, so some structural changes are needed in order to handle them differently. My plan is to develop PersonalTokenManager and AppTokenManager classes so the code for working with each token type can be built into its manager class. In this PR I start by proposing to convert the TokenRateLimit class to a more general TokenManager, and I move functionality common to both token types there. I've added unit tests for the methods in TokenManager. Follow up PRs will develop PersonalTokenManager and AppTokenManager and implement the new app token features. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Edgar Ramírez Mondragón <16805946+edgarrmondragon@users.noreply.github.com> --- tap_github/authenticator.py | 127 ++++++++++++++----------- tap_github/tests/test_authenticator.py | 114 ++++++++++++++++++++++ 2 files changed, 188 insertions(+), 53 deletions(-) create mode 100644 tap_github/tests/test_authenticator.py diff --git a/tap_github/authenticator.py b/tap_github/authenticator.py index 7c9528d6..13b20bde 100644 --- a/tap_github/authenticator.py +++ b/tap_github/authenticator.py @@ -2,10 +2,11 @@ import logging import time +from copy import deepcopy from datetime import datetime from os import environ from random import choice, shuffle -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import jwt import requests @@ -13,8 +14,8 @@ from singer_sdk.streams import RESTStream -class TokenRateLimit: - """A class to store token rate limiting information.""" +class TokenManager: + """A class to store a token's attributes and state.""" DEFAULT_RATE_LIMIT = 5000 # The DEFAULT_RATE_LIMIT_BUFFER buffer serves two purposes: @@ -22,9 +23,15 @@ class TokenRateLimit: # - not consume all available calls when we rare using an org or user token. DEFAULT_RATE_LIMIT_BUFFER = 1000 - def __init__(self, token: str, rate_limit_buffer: Optional[int] = None): - """Init TokenRateLimit info.""" + def __init__( + self, + token: str, + rate_limit_buffer: Optional[int] = None, + logger: Optional[Any] = None, + ): + """Init TokenManager info.""" self.token = token + self.logger = logger self.rate_limit = self.DEFAULT_RATE_LIMIT self.rate_limit_remaining = self.DEFAULT_RATE_LIMIT self.rate_limit_reset: Optional[int] = None @@ -41,7 +48,28 @@ def update_rate_limit(self, response_headers: Any) -> None: self.rate_limit_reset = int(response_headers["X-RateLimit-Reset"]) self.rate_limit_used = int(response_headers["X-RateLimit-Used"]) - def is_valid(self) -> bool: + def is_valid_token(self) -> bool: + """Try making a request with the current token. If the request succeeds return True, else False.""" + try: + response = requests.get( + url="https://api.github.com/rate_limit", + headers={ + "Authorization": f"token {self.token}", + }, + ) + response.raise_for_status() + return True + except requests.exceptions.HTTPError: + msg = ( + f"A token was dismissed. " + f"{response.status_code} Client Error: " + f"{str(response.content)} (Reason: {response.reason})" + ) + if self.logger is not None: + self.logger.warning(msg) + return False + + def has_calls_remaining(self) -> bool: """Check if token is valid. Returns: @@ -113,25 +141,37 @@ def generate_app_access_token( class GitHubTokenAuthenticator(APIAuthenticatorBase): """Base class for offloading API auth.""" - def prepare_tokens(self) -> Dict[str, TokenRateLimit]: + def prepare_tokens(self) -> List[TokenManager]: # Save GitHub tokens - available_tokens: List[str] = [] + rate_limit_buffer = self._config.get("rate_limit_buffer", None) + + personal_tokens: Set[str] = set() if "auth_token" in self._config: - available_tokens = available_tokens + [self._config["auth_token"]] + personal_tokens.add(self._config["auth_token"]) if "additional_auth_tokens" in self._config: - available_tokens = available_tokens + self._config["additional_auth_tokens"] + personal_tokens = personal_tokens.union( + self._config["additional_auth_tokens"] + ) else: # Accept multiple tokens using environment variables GITHUB_TOKEN* - env_tokens = [ + env_tokens = { value for key, value in environ.items() if key.startswith("GITHUB_TOKEN") - ] + } if len(env_tokens) > 0: self.logger.info( f"Found {len(env_tokens)} 'GITHUB_TOKEN' environment variables for authentication." ) - available_tokens = env_tokens + personal_tokens = env_tokens + + token_managers: List[TokenManager] = [] + for token in personal_tokens: + token_manager = TokenManager( + token, rate_limit_buffer=rate_limit_buffer, logger=self.logger + ) + if token_manager.is_valid_token(): + token_managers.append(token_manager) # Parse App level private key and generate a token if "GITHUB_APP_PRIVATE_KEY" in environ.keys(): @@ -152,39 +192,17 @@ def prepare_tokens(self) -> Dict[str, TokenRateLimit]: app_token = generate_app_access_token( github_app_id, github_private_key, github_installation_id or None ) - available_tokens = available_tokens + [app_token] - - # Get rate_limit_buffer - rate_limit_buffer = self._config.get("rate_limit_buffer", None) - - # Dedup tokens and test them - filtered_tokens = [] - for token in list(set(available_tokens)): - try: - response = requests.get( - url="https://api.github.com/rate_limit", - headers={ - "Authorization": f"token {token}", - }, - ) - response.raise_for_status() - filtered_tokens.append(token) - except requests.exceptions.HTTPError: - msg = ( - f"A token was dismissed. " - f"{response.status_code} Client Error: " - f"{str(response.content)} (Reason: {response.reason})" + token_manager = TokenManager( + app_token, rate_limit_buffer=rate_limit_buffer, logger=self.logger ) - self.logger.warning(msg) + if token_manager.is_valid_token(): + token_managers.append(token_manager) - self.logger.info(f"Tap will run with {len(filtered_tokens)} auth tokens") + self.logger.info(f"Tap will run with {len(token_managers)} auth tokens") - # Create a dict of TokenRateLimit - # TODO - separate app_token and add logic to refresh the token - # using generate_app_access_token. - return { - token: TokenRateLimit(token, rate_limit_buffer) for token in filtered_tokens - } + # Create a dict of TokenManager + # TODO - separate app_token and add logic to refresh the token using generate_app_access_token. + return token_managers def __init__(self, stream: RESTStream) -> None: """Init authenticator. @@ -196,18 +214,21 @@ def __init__(self, stream: RESTStream) -> None: self.logger: logging.Logger = stream.logger self.tap_name: str = stream.tap_name self._config: Dict[str, Any] = dict(stream.config) - self.tokens_map = self.prepare_tokens() - self.active_token: Optional[TokenRateLimit] = ( - choice(list(self.tokens_map.values())) if len(self.tokens_map) else None + self.token_managers = self.prepare_tokens() + self.active_token: Optional[TokenManager] = ( + choice(self.token_managers) if self.token_managers else None ) def get_next_auth_token(self) -> None: - tokens_list = list(self.tokens_map.items()) current_token = self.active_token.token if self.active_token else "" - shuffle(tokens_list) - for _, token_rate_limit in tokens_list: - if token_rate_limit.is_valid() and current_token != token_rate_limit.token: - self.active_token = token_rate_limit + token_managers = deepcopy(self.token_managers) + shuffle(token_managers) + for token_manager in token_managers: + if ( + token_manager.has_calls_remaining() + and current_token != token_manager.token + ): + self.active_token = token_manager self.logger.info(f"Switching to fresh auth token") return @@ -219,7 +240,7 @@ def update_rate_limit( self, response_headers: requests.models.CaseInsensitiveDict ) -> None: # If no token or only one token is available, return early. - if len(self.tokens_map) <= 1 or self.active_token is None: + if len(self.token_managers) <= 1 or self.active_token is None: return self.active_token.update_rate_limit(response_headers) @@ -236,7 +257,7 @@ def auth_headers(self) -> Dict[str, str]: result = super().auth_headers if self.active_token: # Make sure that our token is still valid or update it. - if not self.active_token.is_valid(): + if not self.active_token.has_calls_remaining(): self.get_next_auth_token() result["Authorization"] = f"token {self.active_token.token}" else: diff --git a/tap_github/tests/test_authenticator.py b/tap_github/tests/test_authenticator.py new file mode 100644 index 00000000..c63853ca --- /dev/null +++ b/tap_github/tests/test_authenticator.py @@ -0,0 +1,114 @@ +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from tap_github.authenticator import TokenManager + + +class TestTokenManager: + + def test_default_rate_limits(self): + token_manager = TokenManager("mytoken", rate_limit_buffer=700) + + assert token_manager.rate_limit == 5000 + assert token_manager.rate_limit_remaining == 5000 + assert token_manager.rate_limit_reset is None + assert token_manager.rate_limit_used == 0 + assert token_manager.rate_limit_buffer == 700 + + token_manager_2 = TokenManager("mytoken") + assert token_manager_2.rate_limit_buffer == 1000 + + def test_update_rate_limit(self): + mock_response_headers = { + "X-RateLimit-Limit": "5000", + "X-RateLimit-Remaining": "4999", + "X-RateLimit-Reset": "1372700873", + "X-RateLimit-Used": "1", + } + + token_manager = TokenManager("mytoken") + token_manager.update_rate_limit(mock_response_headers) + + assert token_manager.rate_limit == 5000 + assert token_manager.rate_limit_remaining == 4999 + assert token_manager.rate_limit_reset == 1372700873 + assert token_manager.rate_limit_used == 1 + + def test_is_valid_token_successful(self): + with patch("requests.get") as mock_get: + mock_response = mock_get.return_value + mock_response.raise_for_status.return_value = None + + token_manager = TokenManager("validtoken") + + assert token_manager.is_valid_token() + mock_get.assert_called_once_with( + url="https://api.github.com/rate_limit", + headers={"Authorization": "token validtoken"}, + ) + + def test_is_valid_token_failure(self): + with patch("requests.get") as mock_get: + # Setup for a failed request + mock_response = mock_get.return_value + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError() + mock_response.status_code = 401 + mock_response.content = b"Unauthorized Access" + mock_response.reason = "Unauthorized" + + token_manager = TokenManager("invalidtoken") + token_manager.logger = MagicMock() + + assert not token_manager.is_valid_token() + token_manager.logger.warning.assert_called_once() + assert "401" in token_manager.logger.warning.call_args[0][0] + + def test_has_calls_remaining_succeeds_if_token_never_used(self): + token_manager = TokenManager("mytoken") + assert token_manager.has_calls_remaining() + + def test_has_calls_remaining_succeeds_if_lots_remaining(self): + mock_response_headers = { + "X-RateLimit-Limit": "5000", + "X-RateLimit-Remaining": "4999", + "X-RateLimit-Reset": "1372700873", + "X-RateLimit-Used": "1", + } + + token_manager = TokenManager("mytoken") + token_manager.update_rate_limit(mock_response_headers) + + assert token_manager.has_calls_remaining() + + def test_has_calls_remaining_succeeds_if_reset_time_reached(self): + mock_response_headers = { + "X-RateLimit-Limit": "5000", + "X-RateLimit-Remaining": "1", + "X-RateLimit-Reset": "1372700873", + "X-RateLimit-Used": "4999", + } + + token_manager = TokenManager("mytoken", rate_limit_buffer=1000) + token_manager.update_rate_limit(mock_response_headers) + + assert token_manager.has_calls_remaining() + + def test_has_calls_remaining_fails_if_few_calls_remaining_and_reset_time_not_reached( + self, + ): + mock_response_headers = { + "X-RateLimit-Limit": "5000", + "X-RateLimit-Remaining": "1", + "X-RateLimit-Reset": str( + int((datetime.now() + timedelta(days=100)).timestamp()) + ), + "X-RateLimit-Used": "4999", + } + + token_manager = TokenManager("mytoken", rate_limit_buffer=1000) + token_manager.update_rate_limit(mock_response_headers) + + assert not token_manager.has_calls_remaining()