From 79127763eeee35476ea1ec09e5a49c82e0cd4dd7 Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 4 Oct 2024 09:44:29 -0400 Subject: [PATCH] fix: use new default ssl context in all aiohttp requests (#271) Adds the missing ssl kwarg usage in other parts of the code, missed in #268. --- .../ai_horde_client/image/async_manual_client_example.py | 4 ++-- horde_sdk/__init__.py | 5 ++++- horde_sdk/ai_horde_api/ai_horde_clients.py | 7 +++---- horde_sdk/generic_api/apimodels.py | 5 +++-- horde_sdk/generic_api/generic_clients.py | 4 +--- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/examples/ai_horde_client/image/async_manual_client_example.py b/examples/ai_horde_client/image/async_manual_client_example.py index fe236e5..237771c 100644 --- a/examples/ai_horde_client/image/async_manual_client_example.py +++ b/examples/ai_horde_client/image/async_manual_client_example.py @@ -8,7 +8,7 @@ import aiohttp from loguru import logger -from horde_sdk import ANON_API_KEY +from horde_sdk import ANON_API_KEY, _default_sslcontext from horde_sdk.ai_horde_api import AIHordeAPIAsyncManualClient from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerateStatusRequest from horde_sdk.generic_api.apimodels import RequestErrorResponse @@ -90,7 +90,7 @@ async def main(apikey: str = ANON_API_KEY) -> None: image_bytes = None # image_gen.img is a url, download it using aiohttp. - async with aiohttp.ClientSession() as session, session.get(image_gen.img) as resp: + async with aiohttp.ClientSession() as session, session.get(image_gen.img, ssl=_default_sslcontext) as resp: image_bytes = await resp.read() if image_bytes is None: diff --git a/horde_sdk/__init__.py b/horde_sdk/__init__.py index afcbdbe..2c56a33 100644 --- a/horde_sdk/__init__.py +++ b/horde_sdk/__init__.py @@ -2,6 +2,8 @@ # isort: off # We import dotenv first so that we can use it to load environment variables before importing anything else. +import ssl +import certifi import dotenv # If the current working directory contains a `.env` file, import the environment variables from it. @@ -59,7 +61,7 @@ def _dev_env_var_warnings() -> None: # pragma: no cover _dev_env_var_warnings() - +_default_sslcontext = ssl.create_default_context(cafile=certifi.where()) from horde_sdk.consts import ( PAYLOAD_HTTP_METHODS, @@ -109,4 +111,5 @@ def _dev_env_var_warnings() -> None: # pragma: no cover "PROGRESS_LOGGER_LABEL", "COMPLETE_LOGGER_LABEL", "HordeException", + "_default_sslcontext", ] diff --git a/horde_sdk/ai_horde_api/ai_horde_clients.py b/horde_sdk/ai_horde_api/ai_horde_clients.py index 9e72099..8d92b64 100644 --- a/horde_sdk/ai_horde_api/ai_horde_clients.py +++ b/horde_sdk/ai_horde_api/ai_horde_clients.py @@ -18,7 +18,7 @@ import requests from loguru import logger -from horde_sdk import COMPLETE_LOGGER_LABEL, PROGRESS_LOGGER_LABEL +from horde_sdk import COMPLETE_LOGGER_LABEL, PROGRESS_LOGGER_LABEL, _default_sslcontext from horde_sdk.ai_horde_api.apimodels import ( AIHordeHeartbeatRequest, AIHordeHeartbeatResponse, @@ -79,7 +79,6 @@ GenericAsyncHordeAPISession, GenericHordeAPIManualClient, GenericHordeAPISession, - _default_sslcontext, ) @@ -1290,7 +1289,7 @@ async def download_image_from_generation(self, generation: ImageGeneration) -> t image_bytes: bytes | None = None if urllib.parse.urlparse(generation.img).scheme in ["http", "https"]: - async with self._aiohttp_session.get(generation.img) as response: + async with self._aiohttp_session.get(generation.img, ssl=_default_sslcontext) as response: if response.status != 200: # pragma: no cover logger.error(f"Error downloading image: {response.status}") response.raise_for_status() @@ -1326,7 +1325,7 @@ async def download_image_from_url(self, url: str) -> PIL.Image.Image: if self._aiohttp_session is None: raise RuntimeError("No aiohttp session provided but an async request was made.") - async with self._aiohttp_session.get(url) as response: + async with self._aiohttp_session.get(url, ssl=_default_sslcontext) as response: if response.status != 200: # pragma: no cover logger.error(f"Error downloading image: {response.status}") response.raise_for_status() diff --git a/horde_sdk/generic_api/apimodels.py b/horde_sdk/generic_api/apimodels.py index 28cc740..8d1cab7 100644 --- a/horde_sdk/generic_api/apimodels.py +++ b/horde_sdk/generic_api/apimodels.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator from typing_extensions import override +from horde_sdk import _default_sslcontext from horde_sdk.consts import HTTPMethod, HTTPStatusCode from horde_sdk.generic_api.consts import ANON_API_KEY from horde_sdk.generic_api.endpoints import GENERIC_API_ENDPOINT_SUBPATH, url_with_path @@ -256,7 +257,7 @@ class ResponseRequiringDownloadMixin(HordeAPIDataObject): async def download_file_as_base64(self, client_session: aiohttp.ClientSession, url: str) -> str: """Download a file and return the value as a base64 string.""" - async with client_session.get(url) as response: + async with client_session.get(url, ssl=_default_sslcontext) as response: response.raise_for_status() return base64.b64encode(await response.read()).decode("utf-8") @@ -273,7 +274,7 @@ async def download_file_to_field_as_base64( url (str): The URL to download the file from. field_name (str): The name of the field to save the file to. """ - async with client_session.get(url) as response: + async with client_session.get(url, ssl=_default_sslcontext) as response: response.raise_for_status() setattr(self, field_name, base64.b64encode(await response.read()).decode("utf-8")) diff --git a/horde_sdk/generic_api/generic_clients.py b/horde_sdk/generic_api/generic_clients.py index 74da092..76fa66b 100644 --- a/horde_sdk/generic_api/generic_clients.py +++ b/horde_sdk/generic_api/generic_clients.py @@ -4,19 +4,18 @@ import asyncio import os -import ssl from abc import ABC from ssl import SSLContext from typing import Any, TypeVar import aiohttp -import certifi import requests from loguru import logger from pydantic import BaseModel, ValidationError from strenum import StrEnum from typing_extensions import override +from horde_sdk import _default_sslcontext from horde_sdk.ai_horde_api.exceptions import AIHordePayloadValidationError from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api.apimodels import ( @@ -35,7 +34,6 @@ GenericQueryFields, ) -_default_sslcontext = ssl.create_default_context(cafile=certifi.where()) """The default SSL context to use for aiohttp requests."""