From 38bdbf12b428a3919a46aa1cecf8a10833871f48 Mon Sep 17 00:00:00 2001 From: Cebtenzzre Date: Thu, 21 May 2020 23:17:14 -0400 Subject: [PATCH] tumblr_backup: Implement wget module Uses urllib3 and a temporary file to make downloading media smoother. Also supports various modes of timestamping, inspired by wget. Included revisions: - Fix urllib3 status code handling - Cleanup fsync logic - Specific type: ignore comments - Cleanup file_isfat and merge into util.py - note_scraper: Smarter handling of cookies and error messages - Remove enospc handler - Use os.supports_dir_fd - No more log_queue - Global, abstract no_internet - Support "canceled" WaitOnMainThread status - Fix "TypeError: 'bytes' object does not support item assignment" Fixes #201 --- note_scraper.py | 86 +++-- tumblr_backup.py | 308 ++++++++--------- util.py | 148 ++++++++- wget.py | 839 +++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1188 insertions(+), 193 deletions(-) create mode 100644 wget.py diff --git a/note_scraper.py b/note_scraper.py index dc656c6..dc0c96a 100644 --- a/note_scraper.py +++ b/note_scraper.py @@ -2,17 +2,16 @@ from __future__ import absolute_import, division, print_function, with_statement -import contextlib import itertools import re -import ssl import sys import time import traceback +import warnings from bs4 import BeautifulSoup -from util import ConnectionFile, HAVE_SSL_CTX, HTTP_TIMEOUT, to_bytes, to_native_str +from util import ConnectionFile, is_dns_working, to_bytes, to_native_str try: from typing import TYPE_CHECKING @@ -23,27 +22,41 @@ from typing import List, Text try: - from http.client import HTTPException + from http.cookiejar import MozillaCookieJar except ImportError: - from httplib import HTTPException # type: ignore[no-redef] + from cookielib import MozillaCookieJar # type: ignore[no-redef] try: - from http.cookiejar import MozillaCookieJar + import requests except ImportError: - from cookielib import MozillaCookieJar # type: ignore[no-redef] + try: + from pip._vendor import requests # type: ignore[no-redef] + except ImportError: + raise RuntimeError('The requests module is required for note scraping. ' + 'Please install it with pip or your package manager.') try: - from urllib.error import HTTPError, URLError from urllib.parse import quote, urlparse, urlsplit, urlunsplit - from urllib.request import BaseHandler, HTTPCookieProcessor, HTTPSHandler, build_opener except ImportError: from urllib import quote # type: ignore[attr-defined,no-redef] - from urllib2 import (BaseHandler, HTTPCookieProcessor, HTTPSHandler, HTTPError, URLError, # type: ignore[no-redef] - build_opener) from urlparse import urlparse, urlsplit, urlunsplit # type: ignore[no-redef] +try: + from urllib3 import Retry + from urllib3.exceptions import HTTPError, InsecureRequestWarning +except ImportError: + try: + # pip includes urllib3 + from pip._vendor.urllib3 import Retry + from pip._vendor.urllib3.exceptions import HTTPError, InsecureRequestWarning + except ImportError: + raise RuntimeError('The urllib3 module is required. Please install it with pip or your package manager.') + EXIT_SUCCESS = 0 EXIT_SAFE_MODE = 2 +EXIT_NO_INTERNET = 3 + +HTTP_RETRY = Retry(3, connect=False) # Globals post_url = None @@ -73,10 +86,11 @@ def __init__(self, noverify, cookiefile, notes_limit): self.notes_limit = notes_limit self.lasturl = None - handlers = [] # type: List[BaseHandler] # pytype: disable=invalid-annotation - if HAVE_SSL_CTX: - context = ssl._create_unverified_context() if noverify else ssl.create_default_context() - handlers.append(HTTPSHandler(context=context)) + self.session = requests.Session() + self.session.verify = not noverify + for adapter in self.session.adapters.values(): + adapter.max_retries = HTTP_RETRY + if cookiefile is not None: cookies = MozillaCookieJar(cookiefile) cookies.load() @@ -88,9 +102,7 @@ def __init__(self, noverify, cookiefile, notes_limit): cookie.expires = None cookie.discard = True - handlers.append(HTTPCookieProcessor(cookies)) - - self.opener = build_opener(*handlers) + self.session.cookies = cookies # type: ignore[assignment] @classmethod def quote_unsafe(cls, string): @@ -140,19 +152,23 @@ def ratelimit_sleep(self, headers): def urlopen(self, iri): self.lasturl = iri uri = self.iri_to_uri(iri) + try_count = 0 while True: - try: - with contextlib.closing(self.opener.open(uri, timeout=HTTP_TIMEOUT)) as resp: - try_count += 1 - parsed_uri = urlparse(resp.geturl()) - if re.match(r'(www\.)?tumblr\.com', parsed_uri.netloc) and parsed_uri.path == '/safe-mode': - sys.exit(EXIT_SAFE_MODE) - return resp.read().decode('utf-8', errors='ignore') - except HTTPError as e: - if e.code == 429 and try_count < self.TRY_LIMIT and self.ratelimit_sleep(e.headers): + with self.session.get(uri) as resp: + try_count += 1 + parsed_uri = urlparse(resp.url) + if re.match(r'(www\.)?tumblr\.com', parsed_uri.netloc) and parsed_uri.path == '/safe-mode': + sys.exit(EXIT_SAFE_MODE) + if resp.status_code == 429 and try_count < self.TRY_LIMIT and self.ratelimit_sleep(resp.headers): continue - raise + if 200 <= resp.status_code < 300: + return resp.content.decode('utf-8', errors='ignore') + log(iri, 'Unexpected response status: HTTP {} {}{}'.format( + resp.status_code, resp.reason, + '' if resp.status_code == 404 else '\nHeaders: {}'.format(resp.headers), + )) + return None @staticmethod def get_more_link(soup, base, notes_url): @@ -220,6 +236,10 @@ def main(stdout_conn, msg_conn, post_url_, ident_, noverify, notes_limit, cookie global post_url, ident, msg_pipe post_url, ident = post_url_, ident_ + if noverify: + # Hide the InsecureRequestWarning from urllib3 + warnings.filterwarnings('ignore', category=InsecureRequestWarning) + with ConnectionFile(msg_conn, 'w') as msg_pipe: crawler = WebCrawler(noverify, cookiefile, notes_limit) @@ -228,13 +248,9 @@ def main(stdout_conn, msg_conn, post_url_, ident_, noverify, notes_limit, cookie except KeyboardInterrupt: sys.exit() # Ignore these so they don't propogate into the parent except HTTPError as e: - log(crawler.lasturl, 'HTTP Error {} {}'.format(e.code, e.reason)) - sys.exit() - except URLError as e: - log(crawler.lasturl, 'URL Error: {}'.format(e.reason)) - sys.exit() - except HTTPException as e: - log(crawler.lasturl, 'HTTP Exception: {}'.format(e)) + if not is_dns_working(timeout=5): + sys.exit(EXIT_NO_INTERNET) + log(crawler.lasturl, e) sys.exit() except Exception: log(crawler.lasturl, 'Caught an exception') diff --git a/tumblr_backup.py b/tumblr_backup.py index 6375e2d..0179f57 100755 --- a/tumblr_backup.py +++ b/tumblr_backup.py @@ -8,13 +8,11 @@ import hashlib import imghdr import io -import itertools import locale import multiprocessing import os import re import shutil -import ssl import sys import threading import time @@ -25,7 +23,8 @@ from posixpath import basename as urlbasename, join as urlpathjoin, splitext as urlsplitext from xml.sax.saxutils import escape -from util import ConnectionFile, HAVE_SSL_CTX, HTTP_TIMEOUT, LockedQueue, PY3, nullcontext, to_bytes, to_unicode +from util import ConnectionFile, LockedQueue, PY3, no_internet, nullcontext, path_is_on_vfat, to_bytes, to_unicode +from wget import HTTPError, WGError, WgetRetrieveWrapper, set_ssl_verify, urlopen try: from typing import TYPE_CHECKING @@ -38,11 +37,6 @@ JSONDict = Dict[str, Any] -try: - from http.client import HTTPException -except ImportError: - from httplib import HTTPException # type: ignore[no-redef] - try: import json except ImportError: @@ -54,14 +48,10 @@ import Queue as queue # type: ignore[no-redef] try: - from urllib.request import urlopen from urllib.parse import urlencode, urlparse, quote - NEW_URLLIB = True except ImportError: - from urllib2 import urlopen # type: ignore[no-redef] - from urlparse import urlparse # type: ignore[no-redef] from urllib import urlencode, quote # type: ignore[attr-defined,no-redef] - NEW_URLLIB = False + from urlparse import urlparse # type: ignore[no-redef] try: from settings import DEFAULT_BLOGS @@ -156,8 +146,6 @@ def test_jpg(h, f): MAX_POSTS = 50 REM_POST_INC = 10 -HTTP_CHUNK_SIZE = 1024 * 1024 - # get your own API key at https://www.tumblr.com/oauth/apps API_KEY = '' @@ -169,15 +157,6 @@ def test_jpg(h, f): FILE_ENCODING = 'utf-8' TIME_ENCODING = locale.getlocale(locale.LC_TIME)[1] or FILE_ENCODING - -if HAVE_SSL_CTX: - ssl_ctx = ssl.create_default_context() - def tb_urlopen(url): - return urlopen(url, timeout=HTTP_TIMEOUT, context=ssl_ctx) -else: - def tb_urlopen(url): - return urlopen(url, timeout=HTTP_TIMEOUT) - disable_note_scraper = set() # type: Set[str] disablens_lock = threading.Lock() prev_resps = None # type: Optional[Tuple[str, ...]] @@ -250,10 +229,6 @@ def open_text(*parts): ) -def open_media(*parts): - return open_file(lambda f: io.open(f, 'wb'), parts) - - def strftime(fmt, t=None): if t is None: t = time.localtime() @@ -329,28 +304,28 @@ def apiparse(base, prev_resps, count, start=0, before=None): params['before'] = before if start > 0 and not options.likes: params['offset'] = start - url = base + '?' + urlencode(params) def get_resp(): - for _ in range(10): - try: - resp = tb_urlopen(url) - data = resp.read() - except (EnvironmentError, HTTPException) as e: - log('URL is {}\nError retrieving API repsonse: {}\n'.format(url, e)) - continue - info = resp.info() - if (info.get_content_type() if NEW_URLLIB else info.gettype()) == 'application/json': - break - log("Unexpected Content-Type: '{}'\n".format(resp.info().gettype())) + try: + resp = urlopen(base, fields=params) + except (EnvironmentError, HTTPError) as e: + log('URL is {}?{}\nError retrieving API repsonse: {}\n'.format(base, urlencode(params), e)) return None - else: + if not (200 <= resp.status < 300 or 400 <= resp.status < 500): + log('URL is {}?{}\nError retrieving API repsonse: HTTP {} {}\n'.format( + base, urlencode(params), resp.status, resp.reason, + )) + return None + ctype = resp.headers.get('Content-Type') + if ctype and ctype.split(';', 1)[0].strip() != 'application/json': + log("Unexpected Content-Type: '{}'\n".format(ctype)) return None + data = resp.data.decode('utf-8') try: doc = json.loads(data) except ValueError as e: log('{}: {}\n{} {} {}\n{!r}\n'.format( - e.__class__.__name__, e, resp.getcode(), resp.msg, resp.info().gettype(), data, + e.__class__.__name__, e, resp.status, resp.reason, ctype, data, )) return None return doc @@ -412,28 +387,40 @@ def save_style(): def get_avatar(prev_archive): - path_parts = (theme_dir, avatar_base) if prev_archive is not None: # Copy old avatar, if present - cpy_res = maybe_copy_media(prev_archive, path_parts, known_extension=False) - if cpy_res: - return # We got the avatar - - try: - resp = tb_urlopen('https://api.tumblr.com/v2/blog/%s/avatar' % blog_name) - avatar_data = resp.read() - except (EnvironmentError, HTTPException): - return - avatar_file = avatar_base - ext = imghdr.what(None, avatar_data[:32]) - if ext is not None: - avatar_file += '.' + ext - # Remove avatars with a different extension - for old_avatar in glob(join(theme_dir, avatar_base + '.*')): - if split(old_avatar)[-1] != avatar_file: + avatar_glob = glob(join(prev_archive, theme_dir, avatar_base + '.*')) + if avatar_glob: + src = avatar_glob[0] + path_parts = (theme_dir, split(src)[-1]) + cpy_res = maybe_copy_media(prev_archive, path_parts) + if cpy_res: + return # We got the avatar + + url = 'https://api.tumblr.com/v2/blog/%s/avatar' % blog_name + avatar_dest = avatar_fpath = open_file(lambda f: f, (theme_dir, avatar_base)) + + # Remove old avatars + old_avatars = glob(join(theme_dir, avatar_base + '.*')) + if len(old_avatars) > 1: + for old_avatar in old_avatars: os.unlink(old_avatar) - with open_media(theme_dir, avatar_file) as f: - f.write(avatar_data) + elif len(old_avatars) == 1: + # Use the old avatar for timestamping + avatar_dest, = old_avatars + + def adj_bn(old_bn, f): + # Give it an extension + image_type = imghdr.what(f) + if image_type: + return avatar_fpath + '.' + image_type + return avatar_fpath + + # Download the image + try: + wget_retrieve(url, avatar_dest, adjust_basename=adj_bn) + except WGError as e: + e.log() def get_style(prev_archive): @@ -443,14 +430,16 @@ def get_style(prev_archive): if prev_archive is not None: # Copy old style, if present path_parts = (theme_dir, 'style.css') - cpy_res = maybe_copy_media(prev_archive, path_parts, known_extension=False) + cpy_res = maybe_copy_media(prev_archive, path_parts) if cpy_res: return # We got the style + url = 'https://%s/' % blog_name try: - resp = tb_urlopen('https://%s/' % blog_name) - page_data = resp.read() - except (EnvironmentError, HTTPException): + resp = urlopen(url) + page_data = resp.data + except HTTPError as e: + log('URL is {}\nError retrieving style: {}\n'.format(url, e)) return for match in re.findall(br'(?s)', page_data): css = match.strip().decode('utf-8', errors='replace') @@ -463,18 +452,11 @@ def get_style(prev_archive): # Copy media file, if present in prev_archive -def maybe_copy_media(prev_archive, path_parts, known_extension): +def maybe_copy_media(prev_archive, path_parts): if prev_archive is None: return False # Source does not exist - if known_extension: - srcpath = join(prev_archive, *path_parts) - else: - image_glob = glob(join(*itertools.chain((prev_archive,), path_parts[:-1], ('{}.*'.format(path_parts[-1]),)))) - if not image_glob: - return False # Source does not exist - srcpath = image_glob[0] - path_parts = tuple(itertools.chain(path_parts[:-1], (split(srcpath)[-1],))) + srcpath = join(prev_archive, *path_parts) dstpath = open_file(lambda f: f, path_parts) if PY3: @@ -521,7 +503,6 @@ def dup(fd): return fd return True # Either we copied it or we didn't need to - class Index(object): def __init__(self, blog, body_class='index'): self.blog = blog @@ -804,6 +785,7 @@ def _backup(posts, post_respfiles): sorted_posts = sorted(zip(posts, post_respfiles), key=lambda x: x[0]['liked_timestamp' if options.likes else 'id'], reverse=True) for p, prf in sorted_posts: + no_internet.check() post = post_class(p, account, prf, prev_archive) if ident_max is None: pass # No limit @@ -835,7 +817,14 @@ def _backup(posts, post_respfiles): self.filter_skipped += 1 continue - backup_pool.add_work(post.save_content) + while True: + try: + backup_pool.add_work(post.save_content, timeout=0.1) + break + except queue.Full: + pass + no_internet.check() + self.post_count += 1 return True @@ -1184,44 +1173,32 @@ def get_filename(self, url, offset=''): return re.sub(r'[:<>"/\\|*?]', '', fname) if os.name == 'nt' else fname def download_media(self, url, filename): - # check if a file with this name already exists - known_extension = '.' in filename[-5:] - image_glob = glob(path_to(self.media_dir, - filename + ('' if known_extension else '.*') - )) - if image_glob: - return split(image_glob[0])[1] - - path_parts = (self.media_dir, filename) + parsed_url = urlparse(url, 'http') + if parsed_url.scheme not in ('http', 'https') or not parsed_url.hostname: + return None # This URL does not follow our basic assumptions - cpy_res = maybe_copy_media(self.prev_archive, path_parts, known_extension) + # Make a sane directory to represent the host + try: + hostdir = parsed_url.hostname.encode('idna').decode('ascii') + except UnicodeError: + hostdir = parsed_url.hostname + if hostdir in ('.', '..'): + hostdir = hostdir.replace('.', '%2E') + if parsed_url.port not in (None, (80 if parsed_url.scheme == 'http' else 443)): + hostdir += '{}{}'.format('+' if os.name == 'nt' else ':', parsed_url.port) + + path_parts = [self.media_dir, filename] + if options.hostdirs: + path_parts.insert(1, hostdir) + + cpy_res = maybe_copy_media(self.prev_archive, path_parts) if not cpy_res: - # download the media data try: - resp = tb_urlopen(url) - with open_media(*path_parts) as dest: - data = resp.read(HTTP_CHUNK_SIZE) - hdr = data[:32] # save the first few bytes - while data: - dest.write(data) - data = resp.read(HTTP_CHUNK_SIZE) - except (EnvironmentError, ValueError, HTTPException) as e: - sys.stderr.write('%s downloading %s\n' % (e, url)) - - try: - os.unlink(path_to(self.media_dir, filename)) - except EnvironmentError as ee: - if getattr(ee, 'errno', None) != errno.ENOENT: - raise - + wget_retrieve(url, open_file(lambda f: f, path_parts)) + except WGError as e: + e.log() return None - # determine the file type if it's unknown - if not known_extension: - image_type = imghdr.what(None, hdr) - if image_type: - oldname = path_to(self.media_dir, filename) - filename += '.' + image_type.replace('jpeg', 'jpg') - os.rename(oldname, path_to(self.media_dir, filename)) + return filename def get_post(self): @@ -1262,46 +1239,51 @@ def get_post(self): if options.save_notes and self.backup_account not in disable_note_scraper and not notes_html.strip(): # Scrape and save notes - ns_stdout_rd, ns_stdout_wr = multiprocessing.Pipe(duplex=False) - ns_msg_rd, ns_msg_wr = multiprocessing.Pipe(duplex=False) - try: - args = (ns_stdout_wr, ns_msg_wr, self.url, self.ident, - options.no_ssl_verify, options.notes_limit, - options.cookiefile) - process = multiprocessing.Process(target=note_scraper.main, args=args) - process.start() - except: - ns_stdout_rd.close() - ns_msg_rd.close() - raise - finally: - ns_stdout_wr.close() - ns_msg_wr.close() - - try: - with ConnectionFile(ns_msg_rd) as msg_pipe: - for line in msg_pipe: - log(line) - - with ConnectionFile(ns_stdout_rd) as stdout: - notes_html = stdout.read() - - process.join() - except: - process.terminate() - process.join() - raise + while True: + ns_stdout_rd, ns_stdout_wr = multiprocessing.Pipe(duplex=False) + ns_msg_rd, ns_msg_wr = multiprocessing.Pipe(duplex=False) + try: + args = (ns_stdout_wr, ns_msg_wr, self.url, self.ident, + options.no_ssl_verify, options.notes_limit, + options.cookiefile) + process = multiprocessing.Process(target=note_scraper.main, args=args) + process.start() + except: + ns_stdout_rd.close() + ns_msg_rd.close() + raise + finally: + ns_stdout_wr.close() + ns_msg_wr.close() - if process.exitcode == 2: # EXIT_SAFE_MODE - # Safe mode is blocking us, disable note scraping for this blog - notes_html = u'' - with disablens_lock: - # Check if another thread already set this - if self.backup_account not in disable_note_scraper: - disable_note_scraper.add(self.backup_account) - log('[Note Scraper] Blocked by safe mode - scraping disabled for {}\n'.format( - self.backup_account - )) + try: + with ConnectionFile(ns_msg_rd) as msg_pipe: + for line in msg_pipe: + log(line) + + with ConnectionFile(ns_stdout_rd) as stdout: + notes_html = stdout.read() + + process.join() + except: + process.terminate() + process.join() + raise + + if process.exitcode == 2: # EXIT_SAFE_MODE + # Safe mode is blocking us, disable note scraping for this blog + notes_html = u'' + with disablens_lock: + # Check if another thread already set this + if self.backup_account not in disable_note_scraper: + disable_note_scraper.add(self.backup_account) + log('[Note Scraper] Blocked by safe mode - scraping disabled for {}\n'.format( + self.backup_account + )) + elif process.exitcode == 3: # EXIT_NO_INTERNET + no_internet.signal() + continue + break notes_str = u'{} note{}'.format(self.note_count, 's'[self.note_count == 1:]) if notes_html.strip(): @@ -1421,7 +1403,12 @@ def add_work(self, *args, **kwargs): def wait(self): log.status('{} remaining posts to save\r'.format(self.queue.qsize())) self.quit.set() - self.queue.join() + while True: + with self.queue.all_tasks_done: + if not self.queue.unfinished_tasks: + break + self.queue.all_tasks_done.wait(timeout=0.1) + no_internet.check() def cancel(self): self.abort.set() @@ -1463,6 +1450,8 @@ def handler(self): else: multiprocessing.set_start_method('spawn') # Slow but safe + no_internet.setup() + import argparse class CSVCallback(argparse.Action): @@ -1544,6 +1533,15 @@ def __call__(self, parser, namespace, values, option_string=None): parser.add_argument('-S', '--no-ssl-verify', action='store_true', help='ignore SSL verification errors') parser.add_argument('--prev-archives', action=CSVListCallback, default=[], metavar='DIRS', help='comma-separated list of directories (one per blog) containing previous blog archives') + parser.add_argument('-M', '--timestamping', action='store_true', + help="don't re-download files if the remote timestamp and size match the local file") + parser.add_argument('--no-if-modified-since', action='store_false', dest='if_modified_since', + help="timestamping: don't send If-Modified-Since header") + parser.add_argument('--no-server-timestamps', action='store_false', dest='use_server_timestamps', + help="don't set local timestamps from HTTP headers") + parser.add_argument('--mtime-postfix', action='store_true', + help="timestamping: work around low-precision mtime on FAT filesystems") + parser.add_argument('--hostdirs', action='store_true', help='Generate host-prefixed directories for media') parser.add_argument('blogs', nargs='*') options = parser.parse_args() @@ -1558,10 +1556,9 @@ def __call__(self, parser, namespace, values, option_string=None): if not re.match(r'^\d{4}(\d\d)?(\d\d)?$', options.period): parser.error("Period must be 'y', 'm', 'd' or YYYY[MM[DD]]") set_period() - if HAVE_SSL_CTX and options.no_ssl_verify: - ssl_ctx = ssl._create_unverified_context() - # Otherwise, it's an old Python version without SSL verification, - # so this is the default. + + wget_retrieve = WgetRetrieveWrapper(options, log) + set_ssl_verify(not options.no_ssl_verify) blogs = options.blogs or DEFAULT_BLOGS if not blogs: @@ -1613,6 +1610,9 @@ def __call__(self, parser, namespace, values, option_string=None): for d in options.prev_archives: if not os.access(d, os.R_OK | os.X_OK): parser.error("--prev-archives: directory '{}' cannot be read".format(d)) + if not options.mtime_postfix and path_is_on_vfat.works and path_is_on_vfat('.'): + print('Warning: FAT filesystem detected, enabling --mtime-postfix', file=sys.stderr) + options.mtime_postfix = True global backup_account tb = TumblrBackup() diff --git a/util.py b/util.py index 2aa1c92..e5eae8d 100644 --- a/util.py +++ b/util.py @@ -3,8 +3,11 @@ from __future__ import absolute_import, division, print_function, with_statement import io +import os +import socket import sys import threading +import time try: from typing import TYPE_CHECKING @@ -12,17 +15,35 @@ TYPE_CHECKING = False if TYPE_CHECKING: - from typing import Generic, TypeVar + from typing import Generic, Optional, TypeVar try: import queue except ImportError: import Queue as queue # type: ignore[no-redef] -PY3 = sys.version_info[0] >= 3 -HAVE_SSL_CTX = sys.version_info >= (2, 7, 9) +_PATH_IS_ON_VFAT_WORKS = True -HTTP_TIMEOUT = 90 +try: + import psutil +except ImportError: + psutil = None # type: ignore[no-redef] + _PATH_IS_ON_VFAT_WORKS = False + +if os.name == 'nt': + try: + from nt import _getvolumepathname # type: ignore[no-redef] + except ImportError: + _getvolumepathname = None # type: ignore[no-redef] + _PATH_IS_ON_VFAT_WORKS = False + +# This builtin has a new name in Python 3 +try: + raw_input # type: ignore[has-type] +except NameError: + raw_input = input + +PY3 = sys.version_info[0] >= 3 def to_unicode(string, encoding='utf-8', errors='strict'): @@ -94,3 +115,122 @@ def __enter__(self): def __exit__(self, *excinfo): pass + + +KNOWN_GOOD_NAMESERVER = '8.8.8.8' +# DNS query for 'A' record of 'google.com'. +# Generated using python -c "import dnslib; print(bytes(dnslib.DNSRecord.question('google.com').pack()))" +DNS_QUERY = b'\xf1\xe1\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x06google\x03com\x00\x00\x01\x00\x01' + + +def is_dns_working(timeout=None): + sock = None + try: + # Would use a with statement, but that doesn't work on Python 2, mumble mumble + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + if timeout is not None: + sock.settimeout(timeout) + sock.sendto(DNS_QUERY, (KNOWN_GOOD_NAMESERVER, 53)) + sock.recvfrom(1) + except EnvironmentError: + return False + finally: + if sock is not None: + sock.close() + + return True + + +def rstrip_slashes(path): + return path.rstrip(b'\\/' if isinstance(path, bytes) else u'\\/') + + +class _Path_Is_On_VFat(object): + works = _PATH_IS_ON_VFAT_WORKS + + def __call__(self, path): + if not self.works: + raise RuntimeError('This function must not be called unless PATH_IS_ON_VFAT_WORKS is True') + + if os.name == 'nt': + # Compare normalized absolute path of volume + getdev = rstrip_slashes + path_dev = rstrip_slashes(_getvolumepathname(path)) + else: + # Compare device ID + def getdev(mount): return os.stat(mount).st_dev + path_dev = getdev(path) + + return any(part.fstype == 'vfat' and getdev(part.mountpoint) == path_dev + for part in psutil.disk_partitions(all=True)) + + +path_is_on_vfat = _Path_Is_On_VFat() + + +class WaitOnMainThread(object): + def __init__(self): + self.cond = None # type: Optional[threading.Condition] + self.flag = False # type: Optional[bool] + + def setup(self, lock=None): + self.cond = threading.Condition(lock) + + def signal(self): + assert self.cond is not None + if isinstance(threading.current_thread(), threading._MainThread): # type: ignore[attr-defined] + self._do_wait() + return + + with self.cond: + if self.flag is None: + sys.exit(1) + self.flag = True + self.cond.wait() + if self.flag is None: + sys.exit(1) + + # Call on main thread when signaled or idle. + def check(self): + assert self.cond is not None + if self.flag is False: + return + + self._do_wait() + + with self.cond: + self.flag = False + self.cond.notify_all() + + def _do_wait(self): + assert self.cond is not None + if self.flag is None: + raise RuntimeError('Broken WaitOnMainThread cannot be reused') + + try: + self._wait() + except: + with self.cond: + self.flag = None # Waiting never completed + raise + + @staticmethod + def _wait(): + raise NotImplementedError + + +class NoInternet(WaitOnMainThread): + @staticmethod + def _wait(): + # Having no internet is a temporary system error + # Wait 30 seconds at first, then exponential backoff up to 15 minutes + print('DNS probe finished: No internet. Waiting...', file=sys.stderr) + sleep_time = 30 + while True: + time.sleep(sleep_time) + if is_dns_working(): + break + sleep_time = min(sleep_time * 2, 900) + + +no_internet = NoInternet() diff --git a/wget.py b/wget.py new file mode 100644 index 0000000..3959075 --- /dev/null +++ b/wget.py @@ -0,0 +1,839 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, division, print_function, with_statement + +import errno +import functools +import io +import itertools +import os +import sys +import time +import warnings +from email.utils import mktime_tz, parsedate_tz +from tempfile import NamedTemporaryFile +from wsgiref.handlers import format_date_time + +from util import PY3, is_dns_working, no_internet + +try: + from urllib.parse import urlsplit, urljoin +except ImportError: + from urlparse import urlsplit, urljoin # type: ignore[no-redef] + +try: + from urllib3 import HTTPResponse, HTTPConnectionPool, HTTPSConnectionPool, PoolManager, Retry, Timeout + from urllib3.exceptions import DependencyWarning, InsecureRequestWarning, MaxRetryError, ConnectTimeoutError + from urllib3.exceptions import HTTPError as HTTPError + from urllib3.util import make_headers + URLLIB3_FROM_PIP = False +except ImportError: + try: + # pip includes urllib3 + from pip._vendor.urllib3 import (HTTPResponse, HTTPConnectionPool, HTTPSConnectionPool, PoolManager, + Retry, Timeout) + from pip._vendor.urllib3.exceptions import (DependencyWarning, InsecureRequestWarning, MaxRetryError, + ConnectTimeoutError) + from pip._vendor.urllib3.exceptions import HTTPError as HTTPError + from pip._vendor.urllib3.util import make_headers + URLLIB3_FROM_PIP = True + except ImportError: + raise RuntimeError('The urllib3 module is required. Please install it with pip or your package manager.') + +# Don't complain about missing socks +warnings.filterwarnings('ignore', category=DependencyWarning) + +try: + import ssl as ssl +except ImportError: + ssl = None # type: ignore[assignment,no-redef] + +# Inject SecureTransport on macOS if the linked OpenSSL is too old to handle TLSv1.2 +if ssl is not None and sys.platform == 'darwin' and ssl.OPENSSL_VERSION_NUMBER < 0x1000100F: + try: + if URLLIB3_FROM_PIP: + from pip._vendor.urllib3.contrib import securetransport + else: + from urllib3.contrib import securetransport + except (ImportError, EnvironmentError): + pass + else: + securetransport.inject_into_urllib3() + +if ssl is not None and not getattr(ssl, 'HAS_SNI', False): + try: + if URLLIB3_FROM_PIP: + from pip._vendor.urllib3.contrib import pyopenssl + else: + from urllib3.contrib import pyopenssl + except ImportError: + pass + else: + pyopenssl.inject_into_urllib3() + +# long is int in Python 3 +try: + long # type: ignore[has-type] +except NameError: + long = int + +HTTP_TIMEOUT = Timeout(90) +HTTP_RETRY = Retry(3, connect=False) +HTTP_CHUNK_SIZE = 1024 * 1024 + +try: + from brotli import brotli + have_brotlipy = True +except ImportError: + have_brotlipy = False + +supported_encodings = (('br',) if have_brotlipy else ()) + ('deflate', 'gzip') + +base_headers = make_headers(keep_alive=True, accept_encoding=list(supported_encodings)) + + +# Document type flags +RETROKF = 0x2 # retrieval was OK +HEAD_ONLY = 0x4 # only send the HEAD request +IF_MODIFIED_SINCE = 0x80 # use If-Modified-Since header + + +# Error statuses +class UErr(object): + RETRUNNEEDED = 0 + RETRINCOMPLETE = 1 # Not part of wget + RETRFINISHED = 2 + HEADUNSUPPORTED = 3 + + +class HttpStat(object): + def __init__(self): + self.current_url = None # the most recent redirect, otherwise the initial url + self.bytes_read = 0 # received length + self.bytes_written = 0 # written length + self.contlen = None # expected length + self.restval = 0 # the restart value + self.last_modified = None # Last-Modified header + self.remote_time = None # remote time-stamp + self.statcode = 0 # status code + self.dest_dir = None # handle to the directory containing part_file + self.part_file = None # handle to local file used to store in-progress download + self.orig_file_exists = False # if there is a local file to compare for time-stamping + self.orig_file_size = 0 # size of file to compare for time-stamping + self.orig_file_tstamp = 0 # time-stamp of file to compare for time-stamping + self.remote_encoding = None # the encoding of the remote file + self.enc_is_identity = None # whether the remote encoding is identity + self.decoder = None # saved decoder from the HTTPResponse + self._make_part_file = None # part_file supplier + + def set_part_file_supplier(self, value): + self._make_part_file = value + + def init_part_file(self): + if self._make_part_file is not None: + self.part_file = self._make_part_file() + self._make_part_file = None + + +class WGHTTPResponse(HTTPResponse): + REDIRECT_STATUSES = [300] + HTTPResponse.REDIRECT_STATUSES + + # Make decoder public for saving and restoring the decoder state + @property + def decoder(self): + return self._decoder + + @decoder.setter + def decoder(self, value): + self._decoder = value + + def __init__(self, *args, **kwargs): + self.current_url = kwargs.pop('current_url') + self.bytes_to_skip = 0 + self.last_read_length = 0 + super(WGHTTPResponse, self).__init__(*args, **kwargs) + + # Make _init_length publicly usable because its implementation is nice + def get_content_length(self, meth): + return self._init_length(meth) + + # Wrap _decode to do some extra processing of the content-encoded entity data. + def _decode(self, data, decode_content, flush_decoder): + # Skip any data we don't need + data_len = len(data) + if self.bytes_to_skip >= data_len: + data = b'' + self.bytes_to_skip -= data_len + elif self.bytes_to_skip > 0: + data = data[self.bytes_to_skip:] + self.bytes_to_skip = 0 + + self.last_read_length = len(data) # Count only non-skipped data + if not data: + return b'' + return super(WGHTTPResponse, self)._decode(data, decode_content, flush_decoder) + + +class WGHTTPConnectionPool(HTTPConnectionPool): + ResponseCls = WGHTTPResponse + + def __init__(self, host, port=None, *args, **kwargs): + norm_host = normalized_host(self.scheme, host, port) + cfh_url = kwargs.pop('cfh_url', None) + if norm_host in unreachable_hosts: + raise WGUnreachableHostError(None, cfh_url, 'Host {} is ignored.'.format(norm_host)) + super(WGHTTPConnectionPool, self).__init__(host, port, *args, **kwargs) + + def urlopen(self, method, url, *args, **kwargs): + kwargs['current_url'] = url + return super(WGHTTPConnectionPool, self).urlopen(method, url, *args, **kwargs) + + +class WGHTTPSConnectionPool(HTTPSConnectionPool): + ResponseCls = WGHTTPResponse + + def __init__(self, host, port=None, *args, **kwargs): + norm_host = normalized_host(self.scheme, host, port) + cfh_url = kwargs.pop('cfh_url', None) + if norm_host in unreachable_hosts: + raise WGUnreachableHostError(None, cfh_url, 'Host {} is ignored.'.format(norm_host)) + super(WGHTTPSConnectionPool, self).__init__(host, port, *args, **kwargs) + + def urlopen(self, method, url, *args, **kwargs): + kwargs['current_url'] = url + return super(WGHTTPSConnectionPool, self).urlopen(method, url, *args, **kwargs) + + +class WGPoolManager(PoolManager): + def __init__(self, num_pools=10, headers=None, **connection_pool_kw): + super(WGPoolManager, self).__init__(num_pools, headers, **connection_pool_kw) + self.cfh_url = None + self.pool_classes_by_scheme = {'http': WGHTTPConnectionPool, 'https': WGHTTPSConnectionPool} + + def connection_from_url(self, url, pool_kwargs=None): + try: + self.cfh_url = url + return super(WGPoolManager, self).connection_from_url(url, pool_kwargs) + finally: + self.cfh_url = None + + def urlopen(self, method, url, redirect=True, **kw): + try: + self.cfh_url = url + return super(WGPoolManager, self).urlopen(method, url, redirect, **kw) + finally: + self.cfh_url = None + + def _new_pool(self, scheme, host, port, request_context=None): + if request_context is None: + request_context = self.connection_pool_kw.copy() + request_context['cfh_url'] = self.cfh_url + return super(WGPoolManager, self)._new_pool(scheme, host, port, request_context) + + +poolman = WGPoolManager(maxsize=20, timeout=HTTP_TIMEOUT, retries=HTTP_RETRY) + + +class Logger(object): + def __init__(self, original_url, log): + self.original_url = original_url + self.log_cb = log + self.prev_log_url = None + + def log(self, url, msg): + qmsg = u'' + if self.prev_log_url is None: + qmsg += u'[wget] {}URL is {}\n'.format('' if url == self.original_url else 'Original ', self.original_url) + self.prev_log_url = self.original_url + if url != self.prev_log_url: + qmsg += u'[wget] Current redirect URL is {}\n'.format(url) + self.prev_log_url = url + qmsg += u'[wget] {}\n'.format(msg) + self.log_cb(qmsg) + + +def gethttp(url, hstat, doctype, options, logger, retry_counter): + if hstat.current_url is not None: + url = hstat.current_url # The most recent location is cached + + hstat.bytes_read = 0 + hstat.contlen = None + hstat.remote_time = None + + # Initialize the request + meth = 'GET' + if doctype & HEAD_ONLY: + meth = 'HEAD' + request_headers = {} + if doctype & IF_MODIFIED_SINCE: + request_headers['If-Modified-Since'] = format_date_time(hstat.orig_file_tstamp) + if hstat.restval: + request_headers['Range'] = 'bytes={}-'.format(hstat.restval) + + doctype &= ~RETROKF + + resp = urlopen(url, method=meth, headers=request_headers, preload_content=False, enforce_content_length=False) + url = hstat.current_url = urljoin(url, resp.current_url) + + try: + err, doctype = process_response(url, hstat, doctype, options, logger, retry_counter, meth, resp) + finally: + resp.release_conn() + + return err, doctype + + +def process_response(url, hstat, doctype, options, logger, retry_counter, meth, resp): + # RFC 7233 section 4.1 paragraph 6: + # "A server MUST NOT generate a multipart response to a request for a single range [...]" + conttype = resp.headers.get('Content-Type') + if conttype is not None and conttype.lower().split(';', 1)[0].strip() == 'multipart/byteranges': + raise WGBadResponseError(logger, url, 'Sever sent multipart response, but multiple ranges were not requested') + + contlen = resp.get_content_length(meth) + + crange_header = resp.headers.get('Content-Range') + crange_parsed = parse_content_range(crange_header) + if crange_parsed is not None: + first_bytep, last_bytep, _ = crange_parsed + contrange = first_bytep + contlen = last_bytep - first_bytep + 1 + else: + contrange = 0 + + hstat.last_modified = resp.headers.get('Last-Modified') + if hstat.last_modified is None: + hstat.last_modified = resp.headers.get('X-Archive-Orig-last-modified') + + if hstat.last_modified is None: + hstat.remote_time = None + else: + lmtuple = parsedate_tz(hstat.last_modified) + hstat.remote_time = -1 if lmtuple is None else mktime_tz(lmtuple) + + remote_encoding = resp.headers.get('Content-Encoding') + + def norm_enc(enc): + return None if enc is None else tuple(e.strip() for e in enc.split(',')) + + if hstat.restval > 0 and norm_enc(hstat.remote_encoding) != norm_enc(remote_encoding): + # Retry without restart + hstat.restval = 0 + retry_counter.increment(hstat, 'Inconsistent Content-Encoding, must start over') + return UErr.RETRINCOMPLETE, doctype + + hstat.remote_encoding = remote_encoding + hstat.enc_is_identity = remote_encoding in (None, '') or all( + enc.strip() == 'identity' for enc in remote_encoding.split(',') + ) + + # In some cases, httplib returns a status of _UNKNOWN + try: + hstat.statcode = long(resp.status) + except ValueError: + hstat.statcode = 0 + + # HTTP 500 Internal Server Error + # HTTP 501 Not Implemented + if hstat.statcode in (500, 501) and (doctype & HEAD_ONLY): + return UErr.HEADUNSUPPORTED, doctype + + # HTTP 20X + # HTTP 207 Multi-Status + if 200 <= hstat.statcode < 300 and hstat.statcode != 207: + doctype |= RETROKF + + # HTTP 204 No Content + if hstat.statcode == 204: + hstat.bytes_read = hstat.restval = 0 + return UErr.RETRFINISHED, doctype + + if doctype & IF_MODIFIED_SINCE: + # HTTP 304 Not Modified + if hstat.statcode == 304: + # File not modified on server according to If-Modified-Since, not retrieving. + doctype |= RETROKF + return UErr.RETRUNNEEDED, doctype + if (hstat.statcode == 200 + and contlen in (None, hstat.orig_file_size) + and hstat.remote_time not in (None, -1) + and hstat.remote_time <= hstat.orig_file_tstamp + ): + logger.log(url, 'If-Modified-Since was ignored (file not actually modified), not retrieving.') + return UErr.RETRUNNEEDED, doctype + logger.log(url, 'Retrieving remote file because If-Modified-Since response indicates it was modified.') + + if not (doctype & RETROKF): + raise WGWrongCodeError(logger, url, hstat.statcode, resp.reason, resp.headers) + + shrunk = False + if hstat.statcode == 416: + shrunk = True # HTTP 416 Range Not Satisfiable + elif hstat.statcode != 200 or options.timestamping or contlen == 0: + pass # Only verify contlen if 200 OK (NOT 206 Partial Contents), not timestamping, and contlen is nonzero + elif contlen is not None and contrange == 0 and hstat.restval >= contlen: + shrunk = True # Got the whole content but it is known to be shorter than the restart point + + if shrunk: + # NB: Unlike wget, we will retry because restarts are expected to succeed (we do not support '-c') + # The remote file has shrunk, retry without restart + hstat.restval = 0 + retry_counter.increment(hstat, 'Resume with Range failed, must start over') + return UErr.RETRINCOMPLETE, doctype + + # The Range request was misunderstood. Bail out. + # Unlike wget, we bail hard with no retry, because this indicates a broken or unreasonable server. + if contrange not in (0, hstat.restval): + raise WGRangeError(logger, url, 'Server provided unexpected Content-Range: Requested {}, got {}' + .format(hstat.restval, contrange)) + # HTTP 206 Partial Contents + if hstat.statcode == 206 and hstat.restval > 0 and contrange == 0: + if crange_header is None: + crange_status = 'not provided' + elif crange_parsed is None: + crange_status = 'invalid' + else: # contrange explicitly zero + crange_status = 'zero' + raise WGRangeError(logger, url, 'Requested a Range and server sent HTTP 206 Partial Contents, ' + 'but Content-Range is {}!'.format(crange_status)) + + hstat.contlen = contlen + if hstat.contlen is not None: + hstat.contlen += contrange + + if (doctype & HEAD_ONLY) or not (doctype & RETROKF): + hstat.bytes_read = hstat.restval = 0 + return UErr.RETRFINISHED, doctype + + if hstat.restval > 0 and contrange == 0: + # If the server ignored our range request, skip the first RESTVAL bytes of the body. + resp.bytes_to_skip = hstat.restval + else: + resp.bytes_to_skip = 0 + + hstat.bytes_read = hstat.restval + + assert resp.decoder is None + if hstat.restval > 0: + resp.decoder = hstat.decoder # Resume the previous decoder state -- Content-Encoding is weird + + hstat.init_part_file() # We're about to write to part_file, make sure it exists + + try: + for chunk in resp.stream(HTTP_CHUNK_SIZE): + hstat.bytes_read += resp.last_read_length + if not chunk: # May be possible if not resp.chunked due to implementation of _decode + continue + hstat.part_file.write(chunk) + except MaxRetryError: + raise + except (HTTPError, EnvironmentError) as e: + is_read_error = isinstance(e, HTTPError) + length_known = hstat.contlen is not None and (is_read_error or hstat.enc_is_identity) + logger.log(url, '{} error at byte {}{}'.format( + 'Read' if is_read_error else 'Write', + hstat.bytes_read if is_read_error else hstat.bytes_written, + '/{}'.format(hstat.contlen) if length_known else '', + )) + + if hstat.bytes_read == hstat.restval: + raise # No data read + if not retry_counter.should_retry(): + raise # This won't be retried + + # Grab the decoder state for next time + if resp.decoder is not None: + hstat.decoder = resp.decoder + + # We were able to read at least _some_ body data from the server. Keep trying. + raise # Jump to outer except block + + hstat.decoder = None + return UErr.RETRFINISHED, doctype + + +def parse_crange_num(hdrc, ci, postchar): + if not hdrc[ci].isdigit(): + raise ValueError('parse error') + num = long(0) + while hdrc[ci].isdigit(): + num = long(10) * num + long(hdrc[ci]) + ci += 1 + if hdrc[ci] != postchar: + raise ValueError('parse error') + ci += 1 + return ci, num + + +def parse_content_range(hdr): + if hdr is None: + return None + + # Ancient version of Netscape proxy server don't have the "bytes" specifier + if hdr.startswith('bytes'): + hdr = hdr[5:] + # JavaWebServer/1.1.1 sends "bytes: x-y/z" + if hdr.startswith(':'): + hdr = hdr[1:] + hdr = hdr.lstrip() + if not hdr: + return None + + ci = 0 + # Final string is a sentinel, equivalent to a null terminator + hdrc = tuple(itertools.chain((c for c in hdr), ('',))) + + try: + ci, first_bytep = parse_crange_num(hdrc, ci, '-') + ci, last_bytep = parse_crange_num(hdrc, ci, '/') + except ValueError: + return None + + if hdrc[ci] == '*': + entity_length = None + else: + num_ = long(0) + while hdrc[ci].isdigit(): + num_ = long(10) * num_ + long(hdrc[ci]) + ci += 1 + entity_length = num_ + + # A byte-content-range-spec whose last-byte-pos value is less than its first-byte-pos value, or whose entity-length + # value is less than or equal to its last-byte-pos value, is invalid. + if last_bytep < first_bytep or (entity_length is not None and entity_length <= last_bytep): + return None + + return first_bytep, last_bytep, entity_length + + +def touch(fl, mtime, dir_fd=None): + atime = time.time() + if PY3 and os.utime in os.supports_dir_fd and dir_fd is not None: + os.utime(os.path.basename(fl), (atime, mtime), dir_fd=dir_fd) + else: + os.utime(fl, (atime, mtime)) + + +class WGError(Exception): + def __init__(self, logger, url, msg, cause=None): + super(WGError, self).__init__('Error retrieving resource: {}{}'.format( + msg, '' if cause is None else '\nCaused by: {}'.format(cause), + )) + self.logger = logger + self.url = url + + def log(self): + self.logger.log(self.url, self) + + +class WGMaxRetryError(WGError): + pass + + +class WGUnreachableHostError(WGError): + pass + + +class WGBadProtocolError(WGError): + pass + + +class WGBadResponseError(WGError): + pass + + +class WGWrongCodeError(WGBadResponseError): + def __init__(self, logger, url, statcode, statmsg, headers): + msg = 'Unexpected response status: HTTP {} {}{}'.format( + statcode, statmsg, '' if statcode in (403, 404) else '\nHeaders: {}'.format(headers), + ) + super(WGWrongCodeError, self).__init__(logger, url, msg) + + +class WGRangeError(WGBadResponseError): + pass + + +unreachable_hosts = set() + + +class RetryCounter(object): + TRY_LIMIT = 20 + MAX_RETRY_WAIT = 10 + + def __init__(self, logger): + self.logger = logger + self.count = 0 + + def reset(self): + self.count = 0 + + def should_retry(self): + return self.TRY_LIMIT is None or self.count < self.TRY_LIMIT + + def increment(self, url, hstat, cause): + self.count += 1 + status = 'incomplete' if hstat.bytes_read > hstat.restval else 'failed' + msg = 'because of {} retrieval: {}'.format(status, cause) + if not self.should_retry(): + self.logger.log(url, 'Gave up {}'.format(msg)) + raise WGMaxRetryError(self.logger, url, 'Retrieval failed after {} tries.'.format(self.TRY_LIMIT), cause) + trylim = '' if self.TRY_LIMIT is None else '/{}'.format(self.TRY_LIMIT) + self.logger.log(url, 'Retrying ({}{}) {}'.format(self.count, trylim, msg)) + time.sleep(min(self.count, self.MAX_RETRY_WAIT)) + + +def normalized_host_from_url(url): + split = urlsplit(url, 'http') + hostname = split.hostname + port = split.port + if port is None: + port = 80 if split.scheme == 'http' else 443 + return '{}:{}'.format(hostname, port) + + +def normalized_host(scheme, host, port): + if port is None: + port = 80 if scheme == 'http' else 443 + return '{}:{}'.format(host, port) + + +def _retrieve_loop(hstat, url, dest_file, adjust_basename, options, log): + if PY3 and (isinstance(url, bytes) or isinstance(dest_file, bytes)): + raise ValueError('This function does not support bytes arguments on Python 3') + + logger = Logger(url, log) + + if urlsplit(url, 'http').scheme not in ('http', 'https'): + raise WGBadProtocolError(logger, url, 'Non-HTTP(S) protocols are not implemented.') + + hostname = normalized_host_from_url(url) + if hostname in unreachable_hosts: + raise WGUnreachableHostError(logger, url, 'Host {} is ignored.'.format(hostname)) + + doctype = 0 + got_head = False # used for time-stamping + dest_dirname, dest_basename = os.path.split(dest_file) + + flags = os.O_RDONLY + try: + flags |= os.O_DIRECTORY + except AttributeError: + dest_dirname += os.path.sep # Fallback, some systems don't support O_DIRECTORY + + hstat.dest_dir = os.open(dest_dirname, flags) + hstat.set_part_file_supplier(functools.partial( + lambda pfx, dir_: NamedTemporaryFile('wb', prefix=pfx, dir=dir_, delete=False), + '.{}.'.format(dest_basename), dest_dirname, + )) + send_head_first = False + + if options.timestamping: + st = None + try: + if PY3 and os.stat in os.supports_dir_fd: + st = os.stat(dest_basename, dir_fd=hstat.dest_dir) + else: + st = os.stat(dest_file) + except EnvironmentError as e: + if getattr(e, 'errno', None) != errno.ENOENT: + raise # Not unusual + + if st is not None: + # Timestamping is enabled and the local file exists + hstat.orig_file_exists = True + hstat.orig_file_size = st.st_size + hstat.orig_file_tstamp = int(st.st_mtime) + if options.mtime_postfix: + hstat.orig_file_tstamp += 1 + + if options.if_modified_since: + doctype |= IF_MODIFIED_SINCE # Send a conditional GET request + else: + send_head_first = True # Send a preliminary HEAD request + doctype |= HEAD_ONLY + + # THE loop + + retry_counter = RetryCounter(logger) + while True: + # Behave as if force_full_retrieve is always enabled + hstat.restval = hstat.bytes_read + + try: + err, doctype = gethttp(url, hstat, doctype, options, logger, retry_counter) + except MaxRetryError as e: + raise WGMaxRetryError(logger, url, 'urllib3 reached a retry limit.', e) + except HTTPError as e: + if isinstance(e, ConnectTimeoutError): + # Host is unreachable (incl ETIMEDOUT, EHOSTUNREACH, and EAI_NONAME) - condemn it and don't retry + hostname = normalized_host_from_url(url) + unreachable_hosts.add(hostname) + msg = 'Error connecting to host {}. From now on it will be ignored.'.format(hostname) + raise WGUnreachableHostError(logger, url, msg, e) + + retry_counter.increment(url, hstat, repr(e)) + continue + except WGUnreachableHostError as e: + # Set the logger for unreachable host errors thrown from WGHTTP(S)ConnectionPool + if e.logger is None: + e.logger = logger + raise + finally: + if hstat.current_url is not None: + url = hstat.current_url + + if err == UErr.RETRINCOMPLETE: + continue # Non-fatal error, try again + if err == UErr.RETRUNNEEDED: + return + if err == UErr.HEADUNSUPPORTED: + # Fall back to GET if HEAD is unsupported. + send_head_first = False + doctype &= ~HEAD_ONLY + retry_counter.reset() + continue + assert err == UErr.RETRFINISHED + + # Did we get the time-stamp? + if not got_head: + got_head = True # no more time-stamping + + if (options.timestamping or options.use_server_timestamps) and hstat.remote_time in (None, -1): + logger.log(url, 'Warning: Last-Modified header is {}' + .format('missing' if hstat.remote_time is None + else 'invalid: {}'.format(hstat.last_modified))) + + if send_head_first: + if hstat.orig_file_exists and hstat.remote_time not in (None, -1): + # Now time-stamping can be used validly. Time-stamping means that if the sizes of the local and + # remote file match, and local file is newer than the remote file, it will not be retrieved. + # Otherwise, the normal download procedure is resumed. + if hstat.remote_time > hstat.orig_file_tstamp: + logger.log(url, 'Retrieving remote file because its mtime ({}) is newer than what we have ({}).' + .format(format_date_time(hstat.remote_time), + format_date_time(hstat.orig_file_tstamp))) + elif hstat.enc_is_identity and hstat.contlen not in (None, hstat.orig_file_size): + logger.log(url, + 'Retrieving remote file because its size ({}) is does not match what we have ({}).' + .format(hstat.contlen, hstat.orig_file_size)) + else: + # Remote file is no newer and has the same size, not retrieving. + return + + doctype &= ~HEAD_ONLY + retry_counter.reset() + continue + + if hstat.contlen is not None and hstat.bytes_read < hstat.contlen: + # We lost the connection too soon + retry_counter.increment(url, hstat, 'Server closed connection before Content-Length was reached.') + continue + + # We shouldn't have read more than Content-Length bytes + assert hstat.contlen in (None, hstat.bytes_read) + + # Normal return path - we wrote a local file + pfname = hstat.part_file.name + + # NamedTemporaryFile is created 0600, set mode to the usual 0644 + os.fchmod(hstat.part_file.fileno(), 0o644) + + # Set the timestamp + if (options.use_server_timestamps + and hstat.remote_time not in (None, -1) + and hstat.contlen in (None, hstat.bytes_read) + ): + touch(pfname, hstat.remote_time, dir_fd=hstat.dest_dir) + + # Adjust the new name + if adjust_basename is None: + new_dest_basename = dest_basename + else: + # Give adjust_basename a read-only file handle + pf = io.open(hstat.part_file.fileno(), 'rb', closefd=False) + new_dest_basename = adjust_basename(dest_basename, pf) + + # Flush buffers and sync the inode + hstat.part_file.flush() + os.fsync(hstat.part_file) + try: + hstat.part_file.close() + finally: + hstat.part_file = None + + # Move to final destination + new_dest = os.path.join(dest_dirname, new_dest_basename) + if not PY3: + if os.name == 'nt': + try_unlink(new_dest) # Avoid potential FileExistsError + os.rename(pfname, new_dest) + elif os.rename not in os.supports_dir_fd: + os.replace(pfname, new_dest) + else: + os.replace(os.path.basename(pfname), new_dest_basename, + src_dir_fd=hstat.dest_dir, dst_dir_fd=hstat.dest_dir) + + # Sync the directory and return + os.fdatasync(hstat.dest_dir) + return + + +def try_unlink(path): + try: + os.unlink(path) + except EnvironmentError as e: + if getattr(e, 'errno', None) != errno.ENOENT: + raise + + +def set_ssl_verify(verify): + if not verify: + # Hide the InsecureRequestWarning from urllib3 + warnings.filterwarnings('ignore', category=InsecureRequestWarning) + poolman.connection_pool_kw['cert_reqs'] = 'CERT_REQUIRED' if verify else 'CERT_NONE' + + +# This is a simple urllib3-based urlopen function. +def urlopen(url, method='GET', headers=None, **kwargs): + req_headers = base_headers.copy() + if headers is not None: + req_headers.update(headers) + + while True: + try: + return poolman.request(method, url, headers=req_headers, **kwargs) + except HTTPError: + if is_dns_working(timeout=5): + raise + # Having no internet is a temporary system error + no_internet.signal() + + +# This functor is the primary API of this module. +class WgetRetrieveWrapper(object): + def __init__(self, options, log): + self.options = options + self.log = log + + def __call__(self, url, file, adjust_basename=None): + hstat = HttpStat() + try: + _retrieve_loop(hstat, url, file, adjust_basename, self.options, self.log) + finally: + if hstat.dest_dir is not None: + os.close(hstat.dest_dir) + hstat.dest_dir = None + # part_file may still be around if we didn't move it + if hstat.part_file is not None: + self._close_part(hstat) + + return hstat + + @staticmethod + def _close_part(hstat): + try: + hstat.part_file.close() + try_unlink(hstat.part_file.name) + finally: + hstat.part_file = None