From 55c9d9ec431575341f1b4ab2b82394b047c41411 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Tue, 14 May 2024 16:00:51 +0200 Subject: [PATCH] toto --- pyproject.toml | 1 + sygnal/gcmpushkin.py | 78 ++++++++++++++++++++++++++++---------------- tests/test_gcm.py | 9 +++-- 3 files changed, 57 insertions(+), 31 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a4351ffd..01292f55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,6 +104,7 @@ dev = [ "mypy-zope==1.0.1", "towncrier", "tox", + "google-auth-stubs==0.2.0", "types-opentracing>=2.4.2", "types-pyOpenSSL", "types-PyYAML", diff --git a/sygnal/gcmpushkin.py b/sygnal/gcmpushkin.py index 621c2fd7..88ae12e4 100644 --- a/sygnal/gcmpushkin.py +++ b/sygnal/gcmpushkin.py @@ -14,18 +14,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import json import logging +import os import time from enum import Enum from io import BytesIO from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple -import google.auth.transport.requests -from google.oauth2 import service_account +import aiohttp +import google.auth.transport._aiohttp_requests +from google.auth._default_async import load_credentials_from_file +from google.oauth2._credentials_async import Credentials from opentracing import Span, logs, tags from prometheus_client import Counter, Gauge, Histogram -from twisted.internet.defer import DeferredSemaphore +from twisted.internet.defer import Deferred, DeferredSemaphore from twisted.web.client import FileBodyProducer, HTTPConnectionPool, readBody from twisted.web.http_headers import Headers from twisted.web.iweb import IResponse @@ -153,6 +157,15 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None: proxy_url_str=proxy_url, ) + # Use the fcm_options config dictionary as a foundation for the body; + # this lets the Sygnal admin choose custom FCM options + # (e.g. content_available). + self.base_request_body = self.get_config("fcm_options", dict, {}) + if not isinstance(self.base_request_body, dict): + raise PushkinSetupException( + "Config field fcm_options, if set, must be a dictionary of options" + ) + self.api_version = APIVersion.Legacy version_str = self.get_config("api_version", str) if not version_str: @@ -180,19 +193,31 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None: "Must configure `project_id` when using FCM api v1", ) - self.service_account_file = self.get_config("service_account_file", str) - if self.api_version is APIVersion.V1 and not self.service_account_file: - raise PushkinSetupException( - "Must configure `service_account_file` when using FCM api v1", - ) + self.credentials = None - # Use the fcm_options config dictionary as a foundation for the body; - # this lets the Sygnal admin choose custom FCM options - # (e.g. content_available). - self.base_request_body = self.get_config("fcm_options", dict, {}) - if not isinstance(self.base_request_body, dict): - raise PushkinSetupException( - "Config field fcm_options, if set, must be a dictionary of options" + if self.api_version is APIVersion.V1: + self.service_account_file = self.get_config("service_account_file", str) + if self.service_account_file: + try: + self.credentials, _ = load_credentials_from_file( + str(self.service_account_file), + scopes=AUTH_SCOPES, + ) + except google.auth.exceptions.DefaultCredentialsError: + pass + + if not self.credentials: + raise PushkinSetupException( + "Must configure valid `service_account_file` when using FCM api v1", + ) + + session = None + if proxy_url: + os.environ["HTTPS_PROXY"] = proxy_url + session = aiohttp.ClientSession(trust_env=True, auto_decompress=False) + + self.request = google.auth.transport._aiohttp_requests.Request( + session=session ) @classmethod @@ -464,21 +489,16 @@ def _handle_v1_response( f"Unknown GCM response code {response.code}" ) - def _get_access_token(self) -> str: + async def _get_access_token(self) -> str: """Retrieve a valid access token that can be used to authorize requests. :return: Access token. """ - # TODO: Should we use the environment variable approach instead? - # export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json - # credentials, project = google.auth.default(scopes=AUTH_SCOPES) - credentials = service_account.Credentials.from_service_account_file( - str(self.service_account_file), - scopes=AUTH_SCOPES, - ) - request = google.auth.transport.requests.Request() - credentials.refresh(request) - return credentials.token + if not self.credentials.valid: + await Deferred.fromFuture( + asyncio.ensure_future(self.credentials.refresh(self.request)) + ) + return self.credentials.token async def _dispatch_notification_unlimited( self, n: Notification, device: Device, context: NotificationContext @@ -532,10 +552,12 @@ async def _dispatch_notification_unlimited( "Content-Type": ["application/json"], } - if self.api_version == APIVersion.Legacy: + if self.api_version is APIVersion.Legacy: headers["Authorization"] = ["key=%s" % (self.api_key,)] elif self.api_version is APIVersion.V1: - headers["Authorization"] = ["Bearer %s" % (self._get_access_token(),)] + headers["Authorization"] = [ + "Bearer %s" % (await self._get_access_token(),) + ] body = self.base_request_body.copy() body["data"] = data diff --git a/tests/test_gcm.py b/tests/test_gcm.py index a5454937..13ec1905 100644 --- a/tests/test_gcm.py +++ b/tests/test_gcm.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Tuple from unittest.mock import MagicMock -from sygnal.gcmpushkin import GcmPushkin +from sygnal.gcmpushkin import GcmPushkin, PushkinSetupException from tests import testutils from tests.testutils import DummyResponse @@ -86,12 +86,15 @@ class TestGcmPushkin(GcmPushkin): """ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]): - super().__init__(name, sygnal, config) self.preloaded_response = DummyResponse(0) self.preloaded_response_payload: Dict[str, Any] = {} self.last_request_body: Dict[str, Any] = {} self.last_request_headers: Dict[AnyStr, List[AnyStr]] = {} # type: ignore[valid-type] self.num_requests = 0 + try: + super().__init__(name, sygnal, config) + except PushkinSetupException: + pass def preload_with_response( self, code: int, response_payload: Dict[str, Any] @@ -110,7 +113,7 @@ async def _perform_http_request( # type: ignore[override] self.num_requests += 1 return self.preloaded_response, json.dumps(self.preloaded_response_payload) - def _get_access_token(self) -> str: + async def _get_access_token(self) -> str: return "token"