Skip to content
This repository has been archived by the owner on Jan 3, 2024. It is now read-only.

Introduce mypy #388

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/cibuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ jobs:
- run: pip install tox
- run: tox -e flake8

mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v2
- run: python -m pip install --upgrade pip
- run: pip install tox
- run: tox -e mypy

unit_tests:
strategy:
fail-fast: false
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,9 @@ repos:
args: [--config=.flake8]
language: system
files: \.py$
- id: mypy
name: mypy
entry: mypy
stages: [commit]
language: system
files: \.py$
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ line_length = 119
multi_line_output = 3
use_parentheses = true
include_trailing_comma = true


[tool.mypy]
exclude = "thirdparty"

12 changes: 8 additions & 4 deletions seleniumwire/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import signal
from argparse import RawDescriptionHelpFormatter
from typing import Callable, Dict

from seleniumwire import backend, utils

Expand All @@ -23,10 +24,13 @@ def standalone_proxy(port=0, addr='127.0.0.1'):
signal.signal(signal.SIGINT, lambda *_: b.shutdown())


# Mapping of command names to the command callables
COMMANDS: Dict[str, Callable] = {'extractcert': utils.extract_cert, 'standaloneproxy': standalone_proxy}


if __name__ == '__main__':
commands = {'extractcert': utils.extract_cert, 'standaloneproxy': standalone_proxy}
parser = argparse.ArgumentParser(
description='\n\nsupported commands: \n %s' % '\n '.join(sorted(commands)),
description='\n\nsupported commands: \n %s' % '\n '.join(sorted(COMMANDS)),
formatter_class=RawDescriptionHelpFormatter,
usage='python -m seleniumwire <command>',
)
Expand All @@ -40,10 +44,10 @@ def standalone_proxy(port=0, addr='127.0.0.1'):

args = parser.parse_args()
pargs = [arg for arg in args.args if '=' not in arg and arg is not args.command]
kwargs = dict([tuple(arg.split('=')) for arg in args.args if '=' in arg])
kwargs: Dict[str, str] = dict([arg.split('=') for arg in args.args if '=' in arg])

try:
commands[args.command](*pargs, **kwargs)
COMMANDS[args.command](*pargs, **kwargs)
except KeyError:
print("Unsupported command '{}' (use --help for list of commands)".format(args.command))
except TypeError as e:
Expand Down
74 changes: 37 additions & 37 deletions seleniumwire/inspect.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import inspect
import time
from typing import Iterator, List, Optional, Union
from typing import Callable, Iterator, List, Optional, Union

from selenium.common.exceptions import TimeoutException
from selenium.common.exceptions import TimeoutException # type: ignore

from seleniumwire import har
from seleniumwire.request import Request
Expand All @@ -23,18 +23,18 @@ def requests(self) -> List[Request]:
A list of Request instances representing the requests made
between the browser and server.
"""
return self.backend.storage.load_requests()
return self.backend.storage.load_requests() # type: ignore

@requests.deleter
def requests(self):
self.backend.storage.clear_requests()
self.backend.storage.clear_requests() # type: ignore

def iter_requests(self) -> Iterator[Request]:
"""Return an iterator of requests.

Returns: An iterator.
"""
yield from self.backend.storage.iter_requests()
yield from self.backend.storage.iter_requests() # type: ignore

@property
def last_request(self) -> Optional[Request]:
Expand All @@ -46,7 +46,7 @@ def last_request(self) -> Optional[Request]:
A Request instance representing the last request made, or
None if no requests have been made.
"""
return self.backend.storage.load_last_request()
return self.backend.storage.load_last_request() # type: ignore

def wait_for_request(self, pat: str, timeout: Union[int, float] = 10) -> Request:
"""Wait up to the timeout period for a request matching the specified
Expand All @@ -73,7 +73,7 @@ def wait_for_request(self, pat: str, timeout: Union[int, float] = 10) -> Request
start = time.time()

while time.time() - start < timeout:
request = self.backend.storage.find(pat)
request = self.backend.storage.find(pat) # type: ignore

if request is None:
time.sleep(1 / 5)
Expand All @@ -91,7 +91,7 @@ def har(self) -> str:

Returns: A JSON string of HAR data.
"""
return har.generate_har(self.backend.storage.load_har_entries())
return har.generate_har(self.backend.storage.load_har_entries()) # type: ignore

@property
def header_overrides(self):
Expand Down Expand Up @@ -119,7 +119,7 @@ def header_overrides(self):
('*.somewhere-else.com.*', {'User-Agent': 'Chrome'})
]
"""
return self.backend.modifier.headers
return self.backend.modifier.headers # type: ignore

@header_overrides.setter
def header_overrides(self, headers):
Expand All @@ -129,16 +129,16 @@ def header_overrides(self, headers):
else:
self._validate_headers(headers)

self.backend.modifier.headers = headers
self.backend.modifier.headers = headers # type: ignore

def _validate_headers(self, headers):
for v in headers.values():
if v is not None:
assert isinstance(v, str), 'Header values must be strings'

@header_overrides.deleter
@header_overrides.deleter # type: ignore
def header_overrides(self):
del self.backend.modifier.headers
del self.backend.modifier.headers # type: ignore

@property
def param_overrides(self):
Expand All @@ -164,15 +164,15 @@ def param_overrides(self):
('*.somewhere-else.com.*', {'x': 'y'}),
]
"""
return self.backend.modifier.params
return self.backend.modifier.params # type: ignore

@param_overrides.setter
def param_overrides(self, params):
self.backend.modifier.params = params
self.backend.modifier.params = params # type: ignore

@param_overrides.deleter
def param_overrides(self):
del self.backend.modifier.params
del self.backend.modifier.params # type: ignore

@property
def body_overrides(self):
Expand All @@ -194,15 +194,15 @@ def body_overrides(self):
('*.somewhere-else.com.*', '{"x":"y"}'),
]
"""
return self.backend.modifier.bodies
return self.backend.modifier.bodies # type: ignore

@body_overrides.setter
def body_overrides(self, bodies):
self.backend.modifier.bodies = bodies
self.backend.modifier.bodies = bodies # type: ignore

@body_overrides.deleter
def body_overrides(self):
del self.backend.modifier.bodies
del self.backend.modifier.bodies # type: ignore

@property
def querystring_overrides(self):
Expand All @@ -223,15 +223,15 @@ def querystring_overrides(self):
('*.somewhere-else.com.*', 'a=b&c=d'),
]
"""
return self.backend.modifier.querystring
return self.backend.modifier.querystring # type: ignore

@querystring_overrides.setter
def querystring_overrides(self, querystrings):
self.backend.modifier.querystring = querystrings
self.backend.modifier.querystring = querystrings # type: ignore

@querystring_overrides.deleter
def querystring_overrides(self):
del self.backend.modifier.querystring
del self.backend.modifier.querystring # type: ignore

@property
def rewrite_rules(self):
Expand All @@ -248,15 +248,15 @@ def rewrite_rules(self):
(r'https://docs.python.org/2/', r'https://docs.python.org/3/'),
]
"""
return self.backend.modifier.rewrite_rules
return self.backend.modifier.rewrite_rules # type: ignore

@rewrite_rules.setter
def rewrite_rules(self, rewrite_rules):
self.backend.modifier.rewrite_rules = rewrite_rules
self.backend.modifier.rewrite_rules = rewrite_rules # type: ignore

@rewrite_rules.deleter
def rewrite_rules(self):
del self.backend.modifier.rewrite_rules
del self.backend.modifier.rewrite_rules # type: ignore

@property
def scopes(self) -> List[str]:
Expand All @@ -271,48 +271,48 @@ def scopes(self) -> List[str]:
'.*github.*'
]
"""
return self.backend.scopes
return self.backend.scopes # type: ignore

@scopes.setter
def scopes(self, scopes: List[str]):
self.backend.scopes = scopes
self.backend.scopes = scopes # type: ignore

@scopes.deleter
def scopes(self):
self.backend.scopes = []
self.backend.scopes = [] # type: ignore

@property
def request_interceptor(self) -> callable:
def request_interceptor(self) -> Callable:
"""A callable that will be used to intercept/modify requests.

The callable must accept a single argument for the request
being intercepted.
"""
return self.backend.request_interceptor
return self.backend.request_interceptor # type: ignore

@request_interceptor.setter
def request_interceptor(self, interceptor: callable):
self.backend.request_interceptor = interceptor
def request_interceptor(self, interceptor: Callable):
self.backend.request_interceptor = interceptor # type: ignore

@request_interceptor.deleter
def request_interceptor(self):
self.backend.request_interceptor = None
self.backend.request_interceptor = None # type: ignore

@property
def response_interceptor(self) -> callable:
def response_interceptor(self) -> Callable:
"""A callable that will be used to intercept/modify responses.

The callable must accept two arguments: the response being
intercepted and the originating request.
"""
return self.backend.response_interceptor
return self.backend.response_interceptor # type: ignore

@response_interceptor.setter
def response_interceptor(self, interceptor: callable):
def response_interceptor(self, interceptor: Callable):
if len(inspect.signature(interceptor).parameters) != 2:
raise RuntimeError('A response interceptor takes two parameters: the request and response')
self.backend.response_interceptor = interceptor
self.backend.response_interceptor = interceptor # type: ignore

@response_interceptor.deleter
def response_interceptor(self):
self.backend.response_interceptor = None
self.backend.response_interceptor = None # type: ignore
4 changes: 3 additions & 1 deletion seleniumwire/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def __repr__(self):
class Request:
"""Represents an HTTP request."""

_body: bytes

def __init__(self, *, method: str, url: str, headers: Iterable[Tuple[str, str]], body: bytes = b''):
"""Initialise a new Request object.

Expand Down Expand Up @@ -119,7 +121,7 @@ def host(self) -> str:
"""
return urlsplit(self.url).netloc

@path.setter
@path.setter # type: ignore
def path(self, p: str):
parts = list(urlsplit(self.url))
parts[2] = p
Expand Down
4 changes: 2 additions & 2 deletions seleniumwire/undetected_chromedriver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
try:
import undetected_chromedriver as uc
import undetected_chromedriver as uc # type: ignore
except ImportError as e:
raise ImportError(
'undetected_chromedriver not found. ' 'Install it with `pip install undetected_chromedriver`.'
Expand All @@ -8,5 +8,5 @@
from seleniumwire.webdriver import Chrome

uc._Chrome = Chrome
Chrome = uc.Chrome
Chrome = uc.Chrome # type: ignore
ChromeOptions = uc.ChromeOptions # noqa: F811
4 changes: 2 additions & 2 deletions seleniumwire/undetected_chromedriver/v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

import undetected_chromedriver.v2 as uc
from selenium.webdriver import DesiredCapabilities
import undetected_chromedriver.v2 as uc # type: ignore
from selenium.webdriver import DesiredCapabilities # type: ignore

from seleniumwire import backend
from seleniumwire.inspect import InspectRequestsMixin
Expand Down
Loading