diff --git a/requests_html.py b/requests_html.py index cd341de..8be6a6e 100644 --- a/requests_html.py +++ b/requests_html.py @@ -1,19 +1,20 @@ import sys import asyncio +import pyppeteer +import requests +import http.cookiejar +import lxml + +from typing import Set, Union, List, MutableMapping, Optional + from urllib.parse import urlparse, urlunparse, urljoin from concurrent.futures import ThreadPoolExecutor from concurrent.futures._base import TimeoutError from functools import partial -from typing import Set, Union, List, MutableMapping, Optional - -import pyppeteer -import requests -import http.cookiejar from pyquery import PyQuery from fake_useragent import UserAgent from lxml.html.clean import Cleaner -import lxml from lxml import etree from lxml.html import HtmlElement from lxml.html import tostring as lxml_html_tostring @@ -771,7 +772,6 @@ def __init__(self, mock_browser : bool = True, verify : bool = True, self.__browser_args = browser_args - def response_hook(self, response, **kwargs) -> HTMLResponse: """ Change response encoding and replace it by a HTMLResponse. """ if not response.encoding: @@ -823,6 +823,12 @@ def __init__(self, loop=None, workers=None, self.loop = loop or asyncio.get_event_loop() self.thread_pool = ThreadPoolExecutor(max_workers=workers) + async def __aenter__(self) -> 'AsyncHTMLSession': + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() + def request(self, *args, **kwargs): """ Partial original request func and run it in a thread. """ func = partial(super().request, *args, **kwargs) diff --git a/tests/test_requests_html.py b/tests/test_requests_html.py index 5237a82..843c2bd 100644 --- a/tests/test_requests_html.py +++ b/tests/test_requests_html.py @@ -1,9 +1,9 @@ import os +import pytest + from functools import partial -import pytest from pyppeteer.browser import Browser -from pyppeteer.page import Page from requests_html import HTMLSession, AsyncHTMLSession, HTML from requests_file import FileAdapter @@ -322,3 +322,25 @@ async def test_async_browser_session(): browser = await session.browser assert isinstance(browser, Browser) await session.close() + + +@pytest.mark.asyncio +async def test_async_context_manager(): + """ + Test the behavior of the async context manager for AsyncHTMLSession. + + This test case validates that the AsyncHTMLSession instance can be used + as an asynchronous context manager, and the session can successfully make + an HTTP GET request within the context. + + Note: If the user has no connection, a ConnectionError may occur, and the + test will be skipped. + + """ + async with AsyncHTMLSession() as s: + try: + results = await s.get('https://www.google.com') + assert results.status_code == 200 + except ConnectionError: + # if the user has no connection skip this test + pass