Skip to content

Commit

Permalink
FCM v1: use async version of google-auth and add HTTP proxy support
Browse files Browse the repository at this point in the history
  • Loading branch information
MatMaul committed May 15, 2024
1 parent fdae42a commit fe0770f
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 35 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ dev = [
"mypy-zope==1.0.1",
"towncrier",
"tox",
"google-auth-stubs==0.2.0",
"types-opentracing>=2.4.2",
"types-pyOpenSSL",
"types-PyYAML",
Expand Down
84 changes: 52 additions & 32 deletions sygnal/gcmpushkin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import logging
import os
import time
from enum import Enum
from io import BytesIO
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple

import google.auth.transport.requests
from google.oauth2 import service_account
import aiohttp
import google.auth.transport._aiohttp_requests
from google.auth._default_async import load_credentials_from_file
from google.oauth2._credentials_async import Credentials
from opentracing import Span, logs, tags
from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.defer import DeferredSemaphore
from twisted.internet.defer import Deferred, DeferredSemaphore
from twisted.web.client import FileBodyProducer, HTTPConnectionPool, readBody
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
Expand Down Expand Up @@ -153,6 +157,15 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None:
proxy_url_str=proxy_url,
)

# Use the fcm_options config dictionary as a foundation for the body;
# this lets the Sygnal admin choose custom FCM options
# (e.g. content_available).
self.base_request_body = self.get_config("fcm_options", dict, {})
if not isinstance(self.base_request_body, dict):
raise PushkinSetupException(
"Config field fcm_options, if set, must be a dictionary of options"
)

self.api_version = APIVersion.Legacy
version_str = self.get_config("api_version", str)
if not version_str:
Expand Down Expand Up @@ -180,19 +193,31 @@ def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]) -> None:
"Must configure `project_id` when using FCM api v1",
)

self.service_account_file = self.get_config("service_account_file", str)
if self.api_version is APIVersion.V1 and not self.service_account_file:
raise PushkinSetupException(
"Must configure `service_account_file` when using FCM api v1",
)
self.credentials: Credentials = None # type: ignore

# Use the fcm_options config dictionary as a foundation for the body;
# this lets the Sygnal admin choose custom FCM options
# (e.g. content_available).
self.base_request_body = self.get_config("fcm_options", dict, {})
if not isinstance(self.base_request_body, dict):
raise PushkinSetupException(
"Config field fcm_options, if set, must be a dictionary of options"
if self.api_version is APIVersion.V1:
self.service_account_file = self.get_config("service_account_file", str)
if self.service_account_file:
try:
self.credentials, _ = load_credentials_from_file(
str(self.service_account_file),
scopes=AUTH_SCOPES,
)
except google.auth.exceptions.DefaultCredentialsError:
pass

if not self.credentials:
raise PushkinSetupException(
"Must configure valid `service_account_file` when using FCM api v1",
)

session = None
if proxy_url:
os.environ["HTTPS_PROXY"] = proxy_url
session = aiohttp.ClientSession(trust_env=True, auto_decompress=False)

self.request = google.auth.transport._aiohttp_requests.Request(
session=session
)

@classmethod
Expand Down Expand Up @@ -464,21 +489,19 @@ def _handle_v1_response(
f"Unknown GCM response code {response.code}"
)

def _get_access_token(self) -> str:
"""Retrieve a valid access token that can be used to authorize requests.
async def _get_auth_header(self) -> str:
"""Retrieve the auth header that can be used to authorize requests.
:return: Access token.
:return: Needed content of the `Authorization` header
"""
# TODO: Should we use the environment variable approach instead?
# export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json
# credentials, project = google.auth.default(scopes=AUTH_SCOPES)
credentials = service_account.Credentials.from_service_account_file(
str(self.service_account_file),
scopes=AUTH_SCOPES,
)
request = google.auth.transport.requests.Request()
credentials.refresh(request)
return credentials.token
if self.api_version is APIVersion.Legacy:
return "key=%s" % (self.api_key,)
else:
if not self.credentials.valid:
await Deferred.fromFuture(
asyncio.ensure_future(self.credentials.refresh(self.request))
)
return "Bearer %s" % self.credentials.token

async def _dispatch_notification_unlimited(
self, n: Notification, device: Device, context: NotificationContext
Expand Down Expand Up @@ -532,10 +555,7 @@ async def _dispatch_notification_unlimited(
"Content-Type": ["application/json"],
}

if self.api_version == APIVersion.Legacy:
headers["Authorization"] = ["key=%s" % (self.api_key,)]
elif self.api_version is APIVersion.V1:
headers["Authorization"] = ["Bearer %s" % (self._get_access_token(),)]
headers["Authorization"] = [await self._get_auth_header()]

body = self.base_request_body.copy()
body["data"] = data
Expand Down
9 changes: 6 additions & 3 deletions tests/test_gcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Tuple
from unittest.mock import MagicMock

from sygnal.gcmpushkin import GcmPushkin
from sygnal.gcmpushkin import GcmPushkin, PushkinSetupException

from tests import testutils
from tests.testutils import DummyResponse
Expand Down Expand Up @@ -86,12 +86,15 @@ class TestGcmPushkin(GcmPushkin):
"""

def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]):
super().__init__(name, sygnal, config)
self.preloaded_response = DummyResponse(0)
self.preloaded_response_payload: Dict[str, Any] = {}
self.last_request_body: Dict[str, Any] = {}
self.last_request_headers: Dict[AnyStr, List[AnyStr]] = {} # type: ignore[valid-type]
self.num_requests = 0
try:
super().__init__(name, sygnal, config)
except PushkinSetupException:
pass

def preload_with_response(
self, code: int, response_payload: Dict[str, Any]
Expand All @@ -110,7 +113,7 @@ async def _perform_http_request( # type: ignore[override]
self.num_requests += 1
return self.preloaded_response, json.dumps(self.preloaded_response_payload)

def _get_access_token(self) -> str:
async def _get_access_token(self) -> str:
return "token"


Expand Down

0 comments on commit fe0770f

Please sign in to comment.