diff --git a/sygnal/gcmpushkin.py b/sygnal/gcmpushkin.py index 621c2fd7..4b6b7b7b 100644 --- a/sygnal/gcmpushkin.py +++ b/sygnal/gcmpushkin.py @@ -14,14 +14,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import aiohttp +import asyncio import json import logging +import os import time from enum import Enum from io import BytesIO from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple -import google.auth.transport.requests +import google.auth.transport._aiohttp_requests from google.oauth2 import service_account from opentracing import Span, logs, tags from prometheus_client import Counter, Gauge, Histogram @@ -195,6 +198,15 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None: "Config field fcm_options, if set, must be a dictionary of options" ) + self._loop = asyncio.get_event_loop() + self._refresh_lock = asyncio.Lock() + self._session = None + + if proxy_url: + os.environ["HTTP_PROXY"] = proxy_url + os.environ["HTTPS_PROXY"] = proxy_url + self._session = aiohttp.ClientSession(trust_env=True) + @classmethod async def create( cls, name: str, sygnal: "Sygnal", config: Dict[str, Any] @@ -464,7 +476,7 @@ def _handle_v1_response( f"Unknown GCM response code {response.code}" ) - def _get_access_token(self) -> str: + async def _get_access_token(self) -> str: """Retrieve a valid access token that can be used to authorize requests. :return: Access token. @@ -476,8 +488,12 @@ def _get_access_token(self) -> str: str(self.service_account_file), scopes=AUTH_SCOPES, ) - request = google.auth.transport.requests.Request() - credentials.refresh(request) + request = google.auth.transport._aiohttp_requests.Request(session=self._session) + # This code is copied from https://github.com/googleapis/google-auth-library-python/blob/8cfc91db0861bc92374d708140656fb38e003ef6/google/auth/transport/_aiohttp_requests.py#L372 + async with self._refresh_lock: + await self._loop.run_in_executor( + None, credentials.refresh, request + ) return credentials.token async def _dispatch_notification_unlimited(