Skip to content

Commit

Permalink
Merge pull request #270 from Haidra-Org/main
Browse files Browse the repository at this point in the history
fix; use a default certifi based ssl context
  • Loading branch information
tazlin authored Oct 3, 2024
2 parents 441430f + 63a3917 commit f9df079
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 9 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/maintests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,5 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install --upgrade -r requirements.dev.txt
- name: Run pre-commit
uses: pre-commit/[email protected]
with:
extra_args: --all-files
- name: Run unit tests
run: tox -e tests-no-api-calls
4 changes: 0 additions & 4 deletions .github/workflows/prtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,5 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install --upgrade -r requirements.dev.txt
- name: Run pre-commit
uses: pre-commit/[email protected]
with:
extra_args: --all-files
- name: Run unit tests
run: tox -e tests-no-api-calls
12 changes: 11 additions & 1 deletion horde_sdk/ai_horde_api/ai_horde_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import urllib.parse
from abc import ABC, abstractmethod
from collections.abc import Callable, Coroutine
from ssl import SSLContext
from typing import cast

import aiohttp
Expand Down Expand Up @@ -78,6 +79,7 @@
GenericAsyncHordeAPISession,
GenericHordeAPIManualClient,
GenericHordeAPISession,
_default_sslcontext,
)


Expand Down Expand Up @@ -291,12 +293,18 @@ def delete_pending_image(
class AIHordeAPIAsyncManualClient(GenericAsyncHordeAPIManualClient, BaseAIHordeClient):
"""An asyncio based API client specifically configured for the AI-Horde API."""

def __init__(self, aiohttp_session: aiohttp.ClientSession) -> None:
def __init__(
self,
aiohttp_session: aiohttp.ClientSession,
*,
ssl_context: SSLContext = _default_sslcontext,
) -> None:
"""Create a new instance of the RatingsAPIClient."""
super().__init__(
aiohttp_session=aiohttp_session,
path_fields=AIHordePathData,
query_fields=AIHordeQueryData,
ssl_context=ssl_context,
)

async def get_generate_check(
Expand Down Expand Up @@ -395,12 +403,14 @@ class AIHordeAPIAsyncClientSession(GenericAsyncHordeAPISession):
def __init__(
self,
aiohttp_session: aiohttp.ClientSession,
ssl_context: SSLContext = _default_sslcontext,
) -> None:
"""Create a new instance of the RatingsAPIClient."""
super().__init__(
aiohttp_session=aiohttp_session,
path_fields=AIHordePathData,
query_fields=AIHordeQueryData,
ssl_context=ssl_context,
)


Expand Down
19 changes: 19 additions & 0 deletions horde_sdk/generic_api/generic_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

import asyncio
import os
import ssl
from abc import ABC
from ssl import SSLContext
from typing import Any, TypeVar

import aiohttp
import certifi
import requests
from loguru import logger
from pydantic import BaseModel, ValidationError
Expand All @@ -32,6 +35,9 @@
GenericQueryFields,
)

_default_sslcontext = ssl.create_default_context(cafile=certifi.where())
"""The default SSL context to use for aiohttp requests."""


class ParsedRawRequest(BaseModel):
"""A helper class for passing around the data needed to make an actual web request."""
Expand Down Expand Up @@ -59,6 +65,7 @@ class BaseHordeAPIClient(ABC):

# region Private Fields
_aiohttp_session: aiohttp.ClientSession
_ssl_context: SSLContext

_apikey: str | None

Expand All @@ -82,6 +89,7 @@ def __init__(
path_fields: type[GenericPathFields] = GenericPathFields,
query_fields: type[GenericQueryFields] = GenericQueryFields,
accept_types: type[GenericAcceptTypes] = GenericAcceptTypes,
ssl_context: SSLContext = _default_sslcontext,
**kwargs: Any, # noqa: ANN401
) -> None:
"""Initialize a new `GenericHordeAPIClient` instance.
Expand All @@ -97,13 +105,20 @@ def __init__(
Defaults to GenericQueryFields.
accept_types (type[GenericAcceptTypes], optional): Pass this to define the API's accept types.
Defaults to GenericAcceptTypes.
ssl_context (SSLContext, optional): The SSL context to use for aiohttp requests.
Defaults to using `certifi`.
kwargs: Any additional keyword arguments are ignored.
Raises:
TypeError: If any of the passed types are not subclasses of their respective `Generic*` class.
"""
self._apikey = apikey

if not isinstance(ssl_context, SSLContext):
raise TypeError("`ssl_context` must be of type `SSLContext`!")

self._ssl_context = ssl_context

if not self._apikey:
self._apikey = ANON_API_KEY

Expand Down Expand Up @@ -394,6 +409,7 @@ def __init__( # noqa: D107
path_fields: type[GenericPathFields] = GenericPathFields,
query_fields: type[GenericQueryFields] = GenericQueryFields,
accept_types: type[GenericAcceptTypes] = GenericAcceptTypes,
ssl_context: SSLContext = _default_sslcontext,
**kwargs: Any, # noqa: ANN401
) -> None:
super().__init__(
Expand Down Expand Up @@ -445,6 +461,7 @@ async def submit_request(
params=parsed_request.request_queries,
json=parsed_request.request_body,
allow_redirects=True,
ssl=self._ssl_context,
) as response,
):
raw_response_json = await response.json()
Expand Down Expand Up @@ -657,13 +674,15 @@ def __init__( # noqa: D107
path_fields: type[GenericPathFields] = GenericPathFields,
query_fields: type[GenericQueryFields] = GenericQueryFields,
accept_types: type[GenericAcceptTypes] = GenericAcceptTypes,
ssl_context: SSLContext = _default_sslcontext,
) -> None:
super().__init__(
aiohttp_session=aiohttp_session,
header_fields=header_fields,
path_fields=path_fields,
query_fields=query_fields,
accept_types=accept_types,
ssl_context=ssl_context,
)
self._pending_follow_ups = []
self._awaiting_requests = []
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pydantic==2.9.2
requests
StrEnum
loguru
certifi
aiohttp
aiofiles
aiodns
Expand Down

0 comments on commit f9df079

Please sign in to comment.