diff --git a/note_scraper.py b/note_scraper.py index dc0c96a..fb1fff4 100644 --- a/note_scraper.py +++ b/note_scraper.py @@ -11,7 +11,8 @@ from bs4 import BeautifulSoup -from util import ConnectionFile, is_dns_working, to_bytes, to_native_str +from util import (ConnectionFile, URLLIB3_FROM_PIP, is_dns_working, make_requests_session, setup_urllib3_ssl, to_bytes, + to_native_str) try: from typing import TYPE_CHECKING @@ -22,9 +23,19 @@ from typing import List, Text try: - from http.cookiejar import MozillaCookieJar + from urllib.parse import quote, urlparse, urlsplit, urlunsplit except ImportError: - from cookielib import MozillaCookieJar # type: ignore[no-redef] + from urllib import quote # type: ignore[attr-defined,no-redef] + from urlparse import urlparse, urlsplit, urlunsplit # type: ignore[no-redef] + +if URLLIB3_FROM_PIP: + from pip._vendor.urllib3 import Retry, Timeout + from pip._vendor.urllib3.exceptions import HTTPError, InsecureRequestWarning +else: + from urllib3 import Retry, Timeout + from urllib3.exceptions import HTTPError, InsecureRequestWarning + +setup_urllib3_ssl() try: import requests @@ -35,27 +46,11 @@ raise RuntimeError('The requests module is required for note scraping. ' 'Please install it with pip or your package manager.') -try: - from urllib.parse import quote, urlparse, urlsplit, urlunsplit -except ImportError: - from urllib import quote # type: ignore[attr-defined,no-redef] - from urlparse import urlparse, urlsplit, urlunsplit # type: ignore[no-redef] - -try: - from urllib3 import Retry - from urllib3.exceptions import HTTPError, InsecureRequestWarning -except ImportError: - try: - # pip includes urllib3 - from pip._vendor.urllib3 import Retry - from pip._vendor.urllib3.exceptions import HTTPError, InsecureRequestWarning - except ImportError: - raise RuntimeError('The urllib3 module is required. Please install it with pip or your package manager.') - EXIT_SUCCESS = 0 EXIT_SAFE_MODE = 2 EXIT_NO_INTERNET = 3 +HTTP_TIMEOUT = Timeout(90) HTTP_RETRY = Retry(3, connect=False) # Globals @@ -82,27 +77,12 @@ class WebCrawler(object): TRY_LIMIT = 2 # For code 429, only give it one extra try - def __init__(self, noverify, cookiefile, notes_limit): + def __init__(self, noverify, user_agent, cookiefile, notes_limit): self.notes_limit = notes_limit self.lasturl = None - - self.session = requests.Session() - self.session.verify = not noverify - for adapter in self.session.adapters.values(): - adapter.max_retries = HTTP_RETRY - - if cookiefile is not None: - cookies = MozillaCookieJar(cookiefile) - cookies.load() - - # Session cookies are denoted by either `expires` field set to an empty string or 0. MozillaCookieJar only - # recognizes the former (see https://bugs.python.org/issue17164). - for cookie in cookies: - if cookie.expires == 0: - cookie.expires = None - cookie.discard = True - - self.session.cookies = cookies # type: ignore[assignment] + self.session = make_requests_session( + requests.Session, HTTP_RETRY, HTTP_TIMEOUT, not noverify, user_agent, cookiefile, + ) @classmethod def quote_unsafe(cls, string): @@ -232,7 +212,7 @@ def get_notes(self, post_url): return u''.join(notes_list) -def main(stdout_conn, msg_conn, post_url_, ident_, noverify, notes_limit, cookiefile): +def main(stdout_conn, msg_conn, post_url_, ident_, noverify, user_agent, cookiefile, notes_limit): global post_url, ident, msg_pipe post_url, ident = post_url_, ident_ @@ -241,7 +221,7 @@ def main(stdout_conn, msg_conn, post_url_, ident_, noverify, notes_limit, cookie warnings.filterwarnings('ignore', category=InsecureRequestWarning) with ConnectionFile(msg_conn, 'w') as msg_pipe: - crawler = WebCrawler(noverify, cookiefile, notes_limit) + crawler = WebCrawler(noverify, user_agent, cookiefile, notes_limit) try: notes = crawler.get_notes(post_url) diff --git a/tumblr_backup.py b/tumblr_backup.py index 739d4bc..d7446a0 100755 --- a/tumblr_backup.py +++ b/tumblr_backup.py @@ -24,8 +24,9 @@ from posixpath import basename as urlbasename, join as urlpathjoin, splitext as urlsplitext from xml.sax.saxutils import escape -from util import ConnectionFile, LockedQueue, PY3, no_internet, nullcontext, path_is_on_vfat, to_bytes, to_unicode -from wget import HTTPError, WGError, WgetRetrieveWrapper, set_ssl_verify, urlopen +from util import (ConnectionFile, LockedQueue, PY3, is_dns_working, make_requests_session, no_internet, nullcontext, + path_is_on_vfat, to_bytes, to_unicode) +from wget import HTTPError, HTTP_RETRY, HTTP_TIMEOUT, WGError, WgetRetrieveWrapper, setup_wget, urlopen try: from typing import TYPE_CHECKING @@ -89,6 +90,17 @@ except ImportError: scandir = None # type: ignore[assignment,no-redef] +# NB: setup_urllib3_ssl has already been called by wget + +try: + import requests +except ImportError: + if not TYPE_CHECKING: + try: + from pip._vendor import requests # type: ignore[no-redef] + except ImportError: + raise RuntimeError('The requests module is required. Please install it with pip or your package manager.') + # These builtins have new names in Python 3 try: long, xrange # type: ignore[has-type] @@ -268,9 +280,23 @@ def mktime(tml): options.p_stop = int(mktime(tm)) -def initial_apiparse(base, prev_archive): - prev_resps = None - if prev_archive: +class ApiParser(object): + session = None # type: Optional[requests.Session] + + def __init__(self, base, account): + self.base = base + self.account = account + self.prev_resps = None # type: Optional[Tuple[str, ...]] + self.dashboard_only_blog = None # type: Optional[bool] + + @classmethod + def setup(cls): + cls.session = make_requests_session( + requests.Session, HTTP_RETRY, HTTP_TIMEOUT, + not options.no_ssl_verify, options.user_agent, options.cookiefile, + ) + + def read_archive(self, prev_archive): def read_resp(path): with io.open(path, encoding=FILE_ENCODING) as jf: return json.load(jf) @@ -278,7 +304,7 @@ def read_resp(path): if options.likes: log('Reading liked timestamps from saved responses (may take a while)\n', account=True) - prev_resps = tuple( + self.prev_resps = tuple( e.path for e in sorted( (e for e in scandir(join(prev_archive, 'json')) if (e.name.endswith('.json') and e.is_file())), key=lambda e: read_resp(e)['liked_timestamp'] if options.likes else long(e.name[:-5]), @@ -286,78 +312,108 @@ def read_resp(path): ) ) - return prev_resps, apiparse(base, prev_resps, 1) - - -def apiparse(base, prev_resps, count, start=0, before=None): - # type: (...) -> Optional[JSONDict] - if prev_resps is not None: - # Reconstruct the API response - def read_post(prf): - with io.open(prf, encoding=FILE_ENCODING) as f: - try: - post = json.load(f) - except ValueError as e: - f.seek(0) - log('{}: {}\n{!r}\n'.format(e.__class__.__name__, e, f.read())) - return None - return prf, post - posts = map(read_post, prev_resps) # type: Iterable[Tuple[DirEntry[str], JSONDict]] - if before is not None: - posts = itertools.dropwhile( - lambda pp: pp[1]['liked_timestamp' if options.likes else 'timestamp'] >= before, - posts, - ) - posts = list(itertools.islice(posts, start, start + count)) - return {'posts': [post for prf, post in posts], - 'post_respfiles': [prf for prf, post in posts], - 'blog': dict(posts[0][1]['blog'] if posts else {}, posts=len(prev_resps))} - - params = {'api_key': API_KEY, 'limit': count, 'reblog_info': 'true'} - if before: - params['before'] = before - if start > 0 and not options.likes: - params['offset'] = start - - def get_resp(): - try: - resp = urlopen(base, fields=params) - except (EnvironmentError, HTTPError) as e: - log('URL is {}?{}\nError retrieving API repsonse: {}\n'.format(base, urlencode(params), e)) - return None - if not (200 <= resp.status < 300 or 400 <= resp.status < 500): - log('URL is {}?{}\nError retrieving API repsonse: HTTP {} {}\n'.format( - base, urlencode(params), resp.status, resp.reason, - )) - return None - ctype = resp.headers.get('Content-Type') - if ctype and ctype.split(';', 1)[0].strip() != 'application/json': - log("Unexpected Content-Type: '{}'\n".format(ctype)) - return None - data = resp.data.decode('utf-8') - try: - doc = json.loads(data) - except ValueError as e: - log('{}: {}\n{} {} {}\n{!r}\n'.format( - e.__class__.__name__, e, resp.status, resp.reason, ctype, data, - )) - return None - return doc - - sleep_dur = 30 # in seconds - while True: - doc = get_resp() - if doc is None: - return None - status = doc['meta']['status'] - if status == 429: + def apiparse(self, count, start=0, before=None): + # type: (...) -> Optional[JSONDict] + assert self.session is not None + if self.prev_resps is not None: + # Reconstruct the API response + def read_post(prf): + with io.open(prf, encoding=FILE_ENCODING) as f: + try: + post = json.load(f) + except ValueError as e: + f.seek(0) + log('{}: {}\n{!r}\n'.format(e.__class__.__name__, e, f.read())) + return None + return prf, post + posts = map(read_post, self.prev_resps) # type: Iterable[Tuple[DirEntry[str], JSONDict]] + if before is not None: + posts = itertools.dropwhile( + lambda pp: pp[1]['liked_timestamp' if options.likes else 'timestamp'] >= before, + posts, + ) + posts = list(itertools.islice(posts, start, start + count)) + return {'posts': [post for prf, post in posts], + 'post_respfiles': [prf for prf, post in posts], + 'blog': dict(posts[0][1]['blog'] if posts else {}, posts=len(self.prev_resps))} + + if self.dashboard_only_blog: + base = 'https://www.tumblr.com/svc/indash_blog' + params = {'tumblelog_name_or_id': self.account, 'post_id': '', 'limit': count, + 'should_bypass_safemode': 'true', 'should_bypass_tagfiltering': 'true'} + headers = { + 'Referer': 'https://www.tumblr.com/dashboard/blog/' + self.account, + 'X-Requested-With': 'XMLHttpRequest', + } # type: Optional[Dict[str, str]] + else: + base = self.base + params = {'api_key': API_KEY, 'limit': count, 'reblog_info': 'true'} + headers = None + if before: + params['before'] = before + if start > 0 and not options.likes: + params['offset'] = start + + sleep_dur = 30 # in seconds + while True: + doc = self._get_resp(base, params, headers) + if doc is None: + return None + status = doc['meta']['status'] + if status != 429: + break time.sleep(sleep_dur) sleep_dur *= 2 - continue if status != 200: + # Detect dashboard-only blogs by the error codes + if self.dashboard_only_blog is None and status == 404: + errors = doc.get('errors', ()) + if len(errors) == 1 and errors[0].get('code') == 4012: + self.dashboard_only_blog = True + log('Found dashboard-only blog, trying svc API\n', account=True) + return self.apiparse(count, start) # Recurse once log('API response has non-200 status:\n{}\n'.format(doc)) + if status == 401 and self.dashboard_only_blog: + log("This is a dashboard-only blog, so you probably don't have the right cookies.{}\n".format( + '' if options.cookiefile else ' Try --cookiefile.', + )) return None - return doc.get('response') + # If the first API request succeeds, it's a public blog + if self.dashboard_only_blog is None: + self.dashboard_only_blog = False + resp = doc.get('response') + if resp is not None and self.dashboard_only_blog: + # svc API doesn't return blog info, steal it from the first post + resp['blog'] = resp['posts'][0]['blog'] if resp['posts'] else {} + return resp + + def _get_resp(self, base, params, headers): + assert self.session is not None + while True: + try: + with self.session.get(base, params=params, headers=headers) as resp: + if not (200 <= resp.status_code < 300 or 400 <= resp.status_code < 500): + log('URL is {}?{}\nError retrieving API repsonse: HTTP {} {}\n'.format( + base, urlencode(params), resp.status_code, resp.reason, + )) + return None + ctype = resp.headers.get('Content-Type') + if ctype and ctype.split(';', 1)[0].strip() != 'application/json': + log("Unexpected Content-Type: '{}'\n".format(ctype)) + return None + try: + return resp.json() + except ValueError as e: + log('{}: {}\n{} {} {}\n{!r}\n'.format( + e.__class__.__name__, e, resp.status_code, resp.reason, ctype, resp.content.decode('utf-8'), + )) + return None + except (EnvironmentError, HTTPError) as e: + if isinstance(e, HTTPError) and not is_dns_working(timeout=5): + no_internet.signal() + continue + log('URL is {}?{}\nError retrieving API repsonse: {}\n'.format(base, urlencode(params), e)) + return None def add_exif(image_name, tags): @@ -766,7 +822,10 @@ def backup(self, account, prev_archive): log.status('Getting basic information\r') - prev_resps, resp = initial_apiparse(base, prev_archive) + api_parser = ApiParser(base, account) + if prev_archive: + api_parser.read_archive(prev_archive) + resp = api_parser.apiparse(1) if not resp: self.errors = True return @@ -782,9 +841,8 @@ def backup(self, account, prev_archive): count_estimate = resp['liked_count'] else: posts_key = 'posts' - blog = resp['blog'] - count_estimate = blog['posts'] - assert isinstance(count_estimate, int) + blog = resp.get('blog', {}) + count_estimate = blog.get('posts') self.title = escape(blog.get('title', account)) self.subtitle = blog.get('description', '') @@ -852,11 +910,12 @@ def _backup(posts, post_respfiles): before = options.p_stop if options.period else None while True: # find the upper bound - log.status('Getting {}posts {} to {} (of {} expected)\r'.format( - 'liked ' if options.likes else '', i, i + MAX_POSTS - 1, count_estimate, + log.status('Getting {}posts {} to {}{}\r'.format( + 'liked ' if options.likes else '', i, i + MAX_POSTS - 1, + '' if count_estimate is None else ' (of {} expected)'.format(count_estimate), )) - resp = apiparse(base, prev_resps, MAX_POSTS, i, before) + resp = api_parser.apiparse(MAX_POSTS, i, before) if resp is None: self.errors = True break @@ -918,7 +977,7 @@ def __init__(self, post, backup_account, respfile, prev_archive): self.backup_account = backup_account self.respfile = respfile self.prev_archive = prev_archive - self.creator = post['blog_name'] + self.creator = post.get('blog_name') or post['tumblelog'] self.ident = str(post['id']) self.url = post['post_url'] self.shorturl = post['short_url'] @@ -928,7 +987,11 @@ def __init__(self, post, backup_account, respfile, prev_archive): self.tm = time.localtime(self.date) self.title = u'' self.tags = post['tags'] - self.note_count = post.get('note_count', 0) + self.note_count = post.get('note_count') + if self.note_count is None: + self.note_count = post.get('notes', {}).get('count') + if self.note_count is None: + self.note_count = 0 self.reblogged_from = post.get('reblogged_from_url') self.reblogged_root = post.get('reblogged_root_url') self.source_title = post.get('source_title', '') @@ -1260,8 +1323,7 @@ def get_post(self): ns_msg_rd, ns_msg_wr = multiprocessing.Pipe(duplex=False) try: args = (ns_stdout_wr, ns_msg_wr, self.url, self.ident, - options.no_ssl_verify, options.notes_limit, - options.cookiefile) + options.no_ssl_verify, options.user_agent, options.cookiefile, options.notes_limit) process = multiprocessing.Process(target=note_scraper.main, args=args) process.start() except: @@ -1515,7 +1577,7 @@ def __call__(self, parser, namespace, values, option_string=None): parser.add_argument('--save-notes', action='store_true', help='save a list of notes for each post') parser.add_argument('--copy-notes', action='store_true', help='copy the notes list from a previous archive') parser.add_argument('--notes-limit', type=int, metavar='COUNT', help='limit requested notes to COUNT, per-post') - parser.add_argument('--cookiefile', help='cookie file for youtube-dl and --save-notes') + parser.add_argument('--cookiefile', help='cookie file for youtube-dl, --save-notes, and svc API') parser.add_argument('-j', '--json', action='store_true', help='save the original JSON source') parser.add_argument('-b', '--blosxom', action='store_true', help='save the posts in blosxom format') parser.add_argument('-r', '--reverse-month', action='store_false', @@ -1559,6 +1621,7 @@ def __call__(self, parser, namespace, values, option_string=None): parser.add_argument('--mtime-postfix', action='store_true', help="timestamping: work around low-precision mtime on FAT filesystems") parser.add_argument('--hostdirs', action='store_true', help='Generate host-prefixed directories for media') + parser.add_argument('--user-agent', help='User agent string to use with HTTP requests') parser.add_argument('blogs', nargs='*') options = parser.parse_args() @@ -1575,7 +1638,7 @@ def __call__(self, parser, namespace, values, option_string=None): set_period() wget_retrieve = WgetRetrieveWrapper(options, log) - set_ssl_verify(not options.no_ssl_verify) + setup_wget(not options.no_ssl_verify, options.user_agent) blogs = options.blogs or DEFAULT_BLOGS if not blogs: @@ -1637,6 +1700,8 @@ def __call__(self, parser, namespace, values, option_string=None): print('Warning: FAT filesystem detected, enabling --mtime-postfix', file=sys.stderr) options.mtime_postfix = True + ApiParser.setup() + global backup_account tb = TumblrBackup() try: diff --git a/tumblr_login.py b/tumblr_login.py new file mode 100755 index 0000000..4f7e700 --- /dev/null +++ b/tumblr_login.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Credit to johanneszab for the C# implementation in TumblThree. +# Credit to MrEldritch for the initial Python port. +# Cleaned up and split off by Cebtenzzre. + +""" +This script works in both Python 2 & 3. +It uses Tumblr's internal SVC API to access a hidden or explicit blog, +and retrieves a JSON of very similar (but not quite identical) format to the +normal API. +""" + +import sys +from getpass import getpass + +from bs4 import BeautifulSoup + +try: + from http.cookiejar import MozillaCookieJar +except ImportError: + from cookielib import MozillaCookieJar # type: ignore[no-redef] + +try: + import requests +except ImportError: + try: + from pip._vendor import requests # type: ignore[no-redef] + except ImportError: + raise RuntimeError('The requests module is required. Please install it with pip or your package manager.') + +# This builtin has a new name in Python 3 +try: + raw_input # type: ignore[has-type] +except NameError: + raw_input = input + + +def get_tumblr_key(): + r = session.get('https://www.tumblr.com/login') + if r.status_code != 200: + raise ValueError('Response has non-200 status: HTTP {} {}'.format(r.status_code, r.reason)) + soup = BeautifulSoup(r.text, 'lxml') + head, = soup.find_all('head') + key_meta, = soup.find_all('meta', attrs={'name': 'tumblr-form-key'}) + return key_meta['content'] + + +def tumblr_login(session, login, password): + tumblr_key = get_tumblr_key() + + # You need to make these two requests in order to pick up the proper cookies + # in order to access login-required blogs (both dash-only & explicit) + + common_headers = { + 'Authority': 'www.tumblr.com', + 'Referer': 'https://www.tumblr.com/login', + 'Origin': 'https://www.tumblr.com', + } + common_params = { + 'tumblelog[name]': '', + 'user[age]': '', + 'context': 'no_referer', + 'version': 'STANDARD', + 'follow': '', + 'form_key': tumblr_key, + 'seen_suggestion': '0', + 'used_suggestion': '0', + 'used_auto_suggestion': '0', + 'about_tumblr_slide': '', + 'random_username_suggestions': '["KawaiiBouquetStranger","KeenTravelerFury","RainyMakerTastemaker"' + ',"SuperbEnthusiastCollective","TeenageYouthFestival"]', + 'action': 'signup_determine', + } + + # Register + headers = common_headers.copy() + headers.update({ + 'Accept': 'application/json, text/javascript, */*; q=0.01', + 'Content-Type': 'application/x-www-form-urlencoded; charset=UTF-8', + 'X-Requested-With': 'XMLHttpRequest', + }) + parameters = common_params.copy() + parameters.update({ + 'determine_email': login, + 'user[email]': '', + 'user[password]': '', + 'tracking_url': '/login', + 'tracking_version': 'modal', + }) + r = session.post('https://www.tumblr.com/svc/account/register', data=parameters, headers=headers) + if r.status_code != 200: + raise ValueError('Response has non-200 status: HTTP {} {}'.format(r.status_code, r.reason)) + + # Authenticate + headers = common_headers.copy() + headers.update({ + 'Content-Type': 'application/x-www-form-urlencoded', + }) + parameters = common_params.copy() + parameters.update({ + 'determine_email': login, + 'user[email]': login, + 'user[password]': password, + }) + r = session.post('https://www.tumblr.com/login', data=parameters, headers=headers) + if r.status_code != 200: + raise ValueError('Response has non-200 status: HTTP {} {}'.format(r.status_code, r.reason)) + + # We now have the necessary cookies loaded into our session. + + +if __name__ == '__main__': + cookiefile, = sys.argv[1:] + + print('Enter the credentials for your Tumblr account.') + login = raw_input('Email: ') + password = getpass() + + # Create a requests session with cookies + session = requests.Session() + session.cookies = MozillaCookieJar(cookiefile) # type: ignore[assignment] + session.headers['User-Agent'] = ('Mozilla/5.0 (Windows NT 10.0; Win64; x64) ' + 'AppleWebKit/537.36 (KHTML, like Gecko) ' + 'Chrome/85.0.4183.121 ' + 'Safari/537.36') + + # Log into Tumblr + tumblr_login(session, login, password) + + # Save the cookies + session.cookies.save(ignore_discard=True) # type: ignore[attr-defined] diff --git a/util.py b/util.py index 3786038..faf95a2 100644 --- a/util.py +++ b/util.py @@ -8,6 +8,7 @@ import sys import threading import time +import warnings try: from typing import TYPE_CHECKING @@ -22,6 +23,11 @@ except ImportError: import Queue as queue # type: ignore[no-redef] +try: + from http.cookiejar import MozillaCookieJar +except ImportError: + from cookielib import MozillaCookieJar # type: ignore[no-redef] + _PATH_IS_ON_VFAT_WORKS = True try: @@ -37,6 +43,17 @@ _getvolumepathname = None # type: ignore[no-redef] _PATH_IS_ON_VFAT_WORKS = False +try: + from urllib3.exceptions import DependencyWarning + URLLIB3_FROM_PIP = False +except ImportError: + try: + # pip includes urllib3 + from pip._vendor.urllib3.exceptions import DependencyWarning + URLLIB3_FROM_PIP = True + except ImportError: + raise RuntimeError('The urllib3 module is required. Please install it with pip or your package manager.') + # This builtin has a new name in Python 3 try: raw_input # type: ignore[has-type] @@ -45,6 +62,11 @@ PY3 = sys.version_info[0] >= 3 +try: + from ssl import HAS_SNI as SSL_HAS_SNI +except ImportError: + SSL_HAS_SNI = False + def to_unicode(string, encoding='utf-8', errors='strict'): if isinstance(string, bytes): @@ -245,3 +267,77 @@ def _wait(): no_internet = NoInternet() + + +# Set up ssl for urllib3. This should be called before using urllib3 or importing requests. +def setup_urllib3_ssl(): + # Don't complain about missing SOCKS dependencies + warnings.filterwarnings('ignore', category=DependencyWarning) + + try: + import ssl + except ImportError: + return + + # Inject SecureTransport on macOS if the linked OpenSSL is too old to handle TLSv1.2 + if sys.platform == 'darwin' and ssl.OPENSSL_VERSION_NUMBER < 0x1000100F: + try: + if URLLIB3_FROM_PIP: + from pip._vendor.urllib3.contrib import securetransport + else: + from urllib3.contrib import securetransport + except (ImportError, EnvironmentError): + pass + else: + securetransport.inject_into_urllib3() + + # Inject PyOpenSSL if the linked OpenSSL has no SNI + if not SSL_HAS_SNI: + try: + if URLLIB3_FROM_PIP: + from pip._vendor.urllib3.contrib import pyopenssl + else: + from urllib3.contrib import pyopenssl + except ImportError: + pass + else: + pyopenssl.inject_into_urllib3() + + +def get_supported_encodings(): + encodings = ['deflate', 'gzip'] + try: + from brotli import brotli + except ImportError: + pass + else: + encodings.insert(0, 'br') # brotli takes priority if available + return encodings + + +def make_requests_session(session_type, retry, timeout, verify, user_agent, cookiefile): + class SessionWithTimeout(session_type): # type: ignore[misc,valid-type] + def request(self, method, url, **kwargs): + kwargs.setdefault('timeout', timeout) + return super(SessionWithTimeout, self).request(method, url, **kwargs) + + session = SessionWithTimeout() + session.verify = verify + session.headers['Accept-Encoding'] = ', '.join(get_supported_encodings()) + if user_agent is not None: + session.headers['User-Agent'] = user_agent + for adapter in session.adapters.values(): + adapter.max_retries = retry + if cookiefile is not None: + cookies = MozillaCookieJar(cookiefile) + cookies.load() + + # Session cookies are denoted by either `expires` field set to an empty string or 0. MozillaCookieJar only + # recognizes the former (see https://bugs.python.org/issue17164). + for cookie in cookies: + if cookie.expires == 0: + cookie.expires = None + cookie.discard = True + + session.cookies = cookies # type: ignore[assignment] + return session diff --git a/wget.py b/wget.py index f7a9b2c..4537f42 100644 --- a/wget.py +++ b/wget.py @@ -7,69 +7,31 @@ import io import itertools import os -import sys import time import warnings from email.utils import mktime_tz, parsedate_tz from tempfile import NamedTemporaryFile from wsgiref.handlers import format_date_time -from util import PY3, is_dns_working, no_internet +from util import PY3, URLLIB3_FROM_PIP, get_supported_encodings, is_dns_working, no_internet, setup_urllib3_ssl try: from urllib.parse import urljoin, urlsplit except ImportError: from urlparse import urljoin, urlsplit # type: ignore[no-redef] -try: +if URLLIB3_FROM_PIP: + from pip._vendor.urllib3 import HTTPConnectionPool, HTTPResponse, HTTPSConnectionPool, PoolManager, Retry, Timeout + from pip._vendor.urllib3.exceptions import ConnectTimeoutError, InsecureRequestWarning, MaxRetryError + from pip._vendor.urllib3.exceptions import HTTPError as HTTPError + from pip._vendor.urllib3.util import make_headers +else: from urllib3 import HTTPConnectionPool, HTTPResponse, HTTPSConnectionPool, PoolManager, Retry, Timeout - from urllib3.exceptions import ConnectTimeoutError, DependencyWarning, InsecureRequestWarning, MaxRetryError + from urllib3.exceptions import ConnectTimeoutError, InsecureRequestWarning, MaxRetryError from urllib3.exceptions import HTTPError as HTTPError from urllib3.util import make_headers - URLLIB3_FROM_PIP = False -except ImportError: - try: - # pip includes urllib3 - from pip._vendor.urllib3 import (HTTPConnectionPool, HTTPResponse, HTTPSConnectionPool, PoolManager, Retry, - Timeout) - from pip._vendor.urllib3.exceptions import (ConnectTimeoutError, DependencyWarning, InsecureRequestWarning, - MaxRetryError) - from pip._vendor.urllib3.exceptions import HTTPError as HTTPError - from pip._vendor.urllib3.util import make_headers - URLLIB3_FROM_PIP = True - except ImportError: - raise RuntimeError('The urllib3 module is required. Please install it with pip or your package manager.') - -# Don't complain about missing socks -warnings.filterwarnings('ignore', category=DependencyWarning) -try: - import ssl as ssl -except ImportError: - ssl = None # type: ignore[assignment,no-redef] - -# Inject SecureTransport on macOS if the linked OpenSSL is too old to handle TLSv1.2 -if ssl is not None and sys.platform == 'darwin' and ssl.OPENSSL_VERSION_NUMBER < 0x1000100F: - try: - if URLLIB3_FROM_PIP: - from pip._vendor.urllib3.contrib import securetransport - else: - from urllib3.contrib import securetransport - except (ImportError, EnvironmentError): - pass - else: - securetransport.inject_into_urllib3() - -if ssl is not None and not getattr(ssl, 'HAS_SNI', False): - try: - if URLLIB3_FROM_PIP: - from pip._vendor.urllib3.contrib import pyopenssl - else: - from urllib3.contrib import pyopenssl - except ImportError: - pass - else: - pyopenssl.inject_into_urllib3() +setup_urllib3_ssl() # long is int in Python 3 try: @@ -81,15 +43,7 @@ HTTP_RETRY = Retry(3, connect=False) HTTP_CHUNK_SIZE = 1024 * 1024 -try: - from brotli import brotli - have_brotlipy = True -except ImportError: - have_brotlipy = False - -supported_encodings = (('br',) if have_brotlipy else ()) + ('deflate', 'gzip') - -base_headers = make_headers(keep_alive=True, accept_encoding=list(supported_encodings)) +base_headers = make_headers(keep_alive=True, accept_encoding=list(get_supported_encodings())) # Document type flags @@ -787,11 +741,13 @@ def try_unlink(path): raise -def set_ssl_verify(verify): - if not verify: +def setup_wget(ssl_verify, user_agent): + if not ssl_verify: # Hide the InsecureRequestWarning from urllib3 warnings.filterwarnings('ignore', category=InsecureRequestWarning) - poolman.connection_pool_kw['cert_reqs'] = 'CERT_REQUIRED' if verify else 'CERT_NONE' + poolman.connection_pool_kw['cert_reqs'] = 'CERT_REQUIRED' if ssl_verify else 'CERT_NONE' + if user_agent is not None: + base_headers['User-Agent'] = user_agent # This is a simple urllib3-based urlopen function.