From 3e03eeab0025f712f0923e3d88743457b5d4266c Mon Sep 17 00:00:00 2001 From: favilo Date: Tue, 2 Jul 2024 09:15:02 -0700 Subject: [PATCH] I went a little buckwild while fixing type annotations. I turned on a stricter mode than what we have configured in pyproject.toml --- esrally/mechanic/launcher.py | 2 +- esrally/mechanic/team.py | 159 +++++++++++++++++---------- esrally/utils/io.py | 204 +++++++++++++++++++++-------------- esrally/utils/modules.py | 6 +- esrally/utils/process.py | 19 ++-- pyproject.toml | 20 ++-- 6 files changed, 249 insertions(+), 161 deletions(-) diff --git a/esrally/mechanic/launcher.py b/esrally/mechanic/launcher.py index 18d21f90b..b23a0682f 100644 --- a/esrally/mechanic/launcher.py +++ b/esrally/mechanic/launcher.py @@ -242,7 +242,7 @@ def stop(self, nodes, metrics_store): stop_watch.start() try: es.terminate() - es.wait(10.0) + es.wait(10) stopped_nodes.append(node) except psutil.NoSuchProcess: self.logger.warning("No process found with PID [%s] for node [%s].", es.pid, node_name) diff --git a/esrally/mechanic/team.py b/esrally/mechanic/team.py index 6023e4698..9b1fa678c 100644 --- a/esrally/mechanic/team.py +++ b/esrally/mechanic/team.py @@ -20,7 +20,18 @@ import os from enum import Enum from types import ModuleType -from typing import Any, Collection, Mapping, Optional, Union +from typing import ( + Any, + Callable, + Collection, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, +) import tabulate @@ -30,14 +41,14 @@ TEAM_FORMAT_VERSION = 1 -def _path_for(team_root_path, team_member_type): +def _path_for(team_root_path: str, team_member_type: str) -> str: root_path = os.path.join(team_root_path, team_member_type, f"v{TEAM_FORMAT_VERSION}") if not os.path.exists(root_path): raise exceptions.SystemSetupError(f"Path {root_path} for {team_member_type} does not exist.") return root_path -def list_cars(cfg: types.Config): +def list_cars(cfg: types.Config) -> None: loader = CarLoader(team_path(cfg)) cars = [] for name in loader.car_names(): @@ -51,15 +62,15 @@ def list_cars(cfg: types.Config): def load_car(repo: str, name: Collection[str], car_params: Optional[Mapping] = None) -> "Car": class Component: - def __init__(self, root_path, entry_point): + def __init__(self, root_path: str, entry_point: str): self.root_path = root_path self.entry_point = entry_point root_paths = [] # preserve order as we append to existing config files later during provisioning. all_config_paths = [] - all_config_base_vars = {} - all_car_vars = {} + all_config_base_vars: MutableMapping[str, str] = {} + all_car_vars: MutableMapping[str, str] = {} for n in name: descriptor = CarLoader(repo).load_car(n, car_params) @@ -76,7 +87,7 @@ def __init__(self, root_path, entry_point): if len(all_config_paths) == 0: raise exceptions.SystemSetupError(f"At least one config base is required for car {name}") - variables = {} + variables: MutableMapping[str, str] = {} # car variables *always* take precedence over config base variables variables.update(all_config_base_vars) variables.update(all_car_vars) @@ -84,7 +95,7 @@ def __init__(self, root_path, entry_point): return Car(name, root_paths, all_config_paths, variables) -def list_plugins(cfg: types.Config): +def list_plugins(cfg: types.Config) -> None: plugins = PluginLoader(team_path(cfg)).plugins() if plugins: console.println("Available Elasticsearch plugins:\n") @@ -93,12 +104,16 @@ def list_plugins(cfg: types.Config): console.println("No Elasticsearch plugins are available.\n") -def load_plugin(repo, name, config_names, plugin_params=None): +def load_plugin( + repo: str, name: str, config_names: Optional[Collection[str]], plugin_params: Optional[Mapping[str, str]] = None +) -> "PluginDescriptor": return PluginLoader(repo).load_plugin(name, config_names, plugin_params) -def load_plugins(repo, plugin_names, plugin_params=None): - def name_and_config(p): +def load_plugins( + repo: str, plugin_names: Collection[str], plugin_params: Optional[Mapping[str, str]] = None +) -> Collection["PluginDescriptor"]: + def name_and_config(p: str) -> Tuple[str, Optional[Collection[str]]]: plugin_spec = p.split(":") if len(plugin_spec) == 1: return plugin_spec[0], None @@ -115,7 +130,7 @@ def name_and_config(p): return plugins -def team_path(cfg: types.Config): +def team_path(cfg: types.Config) -> str: root_path = cfg.opts("mechanic", "team.path", mandatory=False) if root_path: return root_path @@ -140,35 +155,38 @@ def team_path(cfg: types.Config): class CarLoader: - def __init__(self, team_root_path): + def __init__(self, team_root_path: str): self.cars_dir = _path_for(team_root_path, "cars") self.logger = logging.getLogger(__name__) - def car_names(self): - def __car_name(path): + def car_names(self) -> Iterator[str]: + def __car_name(path: str) -> str: p, _ = io.splitext(path) return io.basename(p) - def __is_car(path): + def __is_car(path: str) -> bool: _, extension = io.splitext(path) return extension == ".ini" return map(__car_name, filter(__is_car, os.listdir(self.cars_dir))) - def _car_file(self, name): + def _car_file(self, name: str) -> str: return os.path.join(self.cars_dir, f"{name}.ini") - def load_car(self, name, car_params=None): + def load_car(self, name: str, car_params: Optional[Mapping[str, Any]] = None) -> "CarDescriptor": car_config_file = self._car_file(name) if not io.exists(car_config_file): raise exceptions.SystemSetupError(f"Unknown car [{name}]. List the available cars with {PROGRAM_NAME} list cars.") config = self._config_loader(car_config_file) - root_paths = [] - config_paths = [] + root_paths: List[str] = [] + config_paths: List[str] = [] config_base_vars: Mapping[str, Any] = {} description = self._value(config, ["meta", "description"], default="") car_type = self._value(config, ["meta", "type"], default="car") - config_bases = self._value(config, ["config", "base"], default="").split(",") + config_base = self._value(config, ["config", "base"], default="") + assert config_base is not None, f"Car [{name}] does not define a config base." + assert isinstance(config_base, str), f"Car [{name}] defines an invalid config base [{config_base}]." + config_bases = config_base.split(",") for base in config_bases: if base: root_path = os.path.join(self.cars_dir, base) @@ -189,24 +207,27 @@ def load_car(self, name, car_params=None): return CarDescriptor(name, description, car_type, root_paths, config_paths, config_base_vars, variables) - def _config_loader(self, file_name): + def _config_loader(self, file_name: str) -> "configparser.ConfigParser": config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation()) # Do not modify the case of option keys but read them as is config.optionxform = lambda optionstr: optionstr # type: ignore[method-assign] config.read(file_name) return config - def _value(self, cfg, section_path, default=None): - path = [section_path] if (isinstance(section_path, str)) else section_path + def _value( + self, cfg: "configparser.ConfigParser", section_path: Union[str, Collection[str]], default: Optional[str] = None + ) -> Optional[Mapping[str, Any]]: + path: Collection[str] = [section_path] if (isinstance(section_path, str)) else section_path current_cfg = cfg for k in path: + assert isinstance(current_cfg, dict), f"Expected a dict but got [{current_cfg}] instead." if k in current_cfg: current_cfg = current_cfg[k] else: return default return current_cfg - def _copy_section(self, cfg, section, target): + def _copy_section(self, cfg: "configparser.ConfigParser", section: str, target: MutableMapping[str, Any]) -> MutableMapping[str, Any]: if section in cfg.sections(): for k, v in cfg[section].items(): target[k] = v @@ -214,7 +235,16 @@ def _copy_section(self, cfg, section, target): class CarDescriptor: - def __init__(self, name, description, type, root_paths, config_paths, config_base_variables, variables): + def __init__( + self, + name: str, + description: str, + type: str, + root_paths: Collection[str], + config_paths: Collection[str], + config_base_variables: Mapping[str, str], + variables: Mapping[str, str], + ): self.name = name self.description = description self.type = type @@ -223,10 +253,10 @@ def __init__(self, name, description, type, root_paths, config_paths, config_bas self.config_base_variables = config_base_variables self.variables = variables - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) and self.name == other.name @@ -265,40 +295,40 @@ def __init__( self.config_paths = config_paths self.variables = variables - def mandatory_var(self, name): + def mandatory_var(self, name: str) -> str: try: return self.variables[name] except KeyError: raise exceptions.SystemSetupError(f'Car "{self.name}" requires config key "{name}"') @property - def name(self): + def name(self) -> str: return "+".join(self.names) # Adapter method for BootstrapHookHandler @property - def config(self): + def config(self) -> str: return self.name @property - def safe_name(self): + def safe_name(self) -> str: return "_".join(self.names) - def __str__(self): + def __str__(self) -> str: return self.name class PluginLoader: - def __init__(self, team_root_path): + def __init__(self, team_root_path: str): self.plugins_root_path = _path_for(team_root_path, "plugins") self.logger = logging.getLogger(__name__) - def plugins(self, variables=None): + def plugins(self, variables: Optional[Mapping[str, str]] = None) -> List["PluginDescriptor"]: known_plugins = self._core_plugins(variables) + self._configured_plugins(variables) sorted(known_plugins, key=lambda p: p.name) return known_plugins - def _core_plugins(self, variables=None): + def _core_plugins(self, variables: Optional[Mapping[str, str]] = None) -> List["PluginDescriptor"]: core_plugins = [] core_plugins_path = os.path.join(self.plugins_root_path, "core-plugins.txt") if os.path.exists(core_plugins_path): @@ -310,7 +340,7 @@ def _core_plugins(self, variables=None): core_plugins.append(PluginDescriptor(name=values[0], core_plugin=True, variables=variables)) return core_plugins - def _configured_plugins(self, variables=None): + def _configured_plugins(self, variables: Optional[Mapping[str, str]] = None) -> List["PluginDescriptor"]: configured_plugins = [] # each directory is a plugin, each .ini is a config (just go one level deep) for entry in os.listdir(self.plugins_root_path): @@ -324,10 +354,10 @@ def _configured_plugins(self, variables=None): configured_plugins.append(PluginDescriptor(name=plugin_name, config=config, variables=variables)) return configured_plugins - def _plugin_file(self, name, config): + def _plugin_file(self, name: str, config: str) -> str: return os.path.join(self._plugin_root_path(name), "%s.ini" % config) - def _plugin_root_path(self, name): + def _plugin_root_path(self, name: str) -> str: return os.path.join(self.plugins_root_path, self._plugin_name_to_file(name)) # As we allow to store Python files in the plugin directory and the plugin directory also serves as the root path of the corresponding @@ -335,16 +365,18 @@ def _plugin_root_path(self, name): # need to switch from underscores to hyphens and vice versa. # # We are implicitly assuming that plugin names stick to the convention of hyphen separation to simplify implementation and usage a bit. - def _file_to_plugin_name(self, file_name): + def _file_to_plugin_name(self, file_name: str) -> str: return file_name.replace("_", "-") - def _plugin_name_to_file(self, plugin_name): + def _plugin_name_to_file(self, plugin_name: str) -> str: return plugin_name.replace("-", "_") - def _core_plugin(self, name, variables=None): + def _core_plugin(self, name: str, variables: Optional[Mapping[str, str]] = None) -> Optional["PluginDescriptor"]: return next((p for p in self._core_plugins(variables) if p.name == name and p.config is None), None) - def load_plugin(self, name, config_names, plugin_params=None): + def load_plugin( + self, name: str, config_names: Optional[Collection[str]], plugin_params: Optional[Mapping[str, str]] = None + ) -> "PluginDescriptor": if config_names is not None: self.logger.info("Loading plugin [%s] with configuration(s) [%s].", name, config_names) else: @@ -426,7 +458,15 @@ class PluginDescriptor: # name of the initial Python file to load for plugins. entry_point = "plugin" - def __init__(self, name, core_plugin=False, config=None, root_path=None, config_paths=None, variables=None): + def __init__( + self, + name: str, + core_plugin: bool = False, + config: Optional[Collection[str]] = None, + root_path: Optional[str] = None, + config_paths: Optional[Collection[str]] = None, + variables: Optional[Mapping[str, Any]] = None, + ): if config_paths is None: config_paths = [] if variables is None: @@ -438,27 +478,27 @@ def __init__(self, name, core_plugin=False, config=None, root_path=None, config_ self.config_paths = config_paths self.variables = variables - def __str__(self): - return "Plugin descriptor for [%s]" % self.name + def __str__(self) -> str: + return f"Plugin descriptor for [{self.name}]" - def __repr__(self): + def __repr__(self) -> str: r = [] for prop, value in vars(self).items(): r.append("%s = [%s]" % (prop, repr(value))) return ", ".join(r) @property - def moved_to_module(self): + def moved_to_module(self) -> bool: # For a BWC escape hatch we first check if the plugin is listed in rally-teams' "core-plugin.txt", # thus allowing users to override the teams path or revision to include the repository-s3/azure/gcs plugins in # "core-plugin.txt" # TODO: https://github.com/elastic/rally/issues/1622 return self.name in ["repository-s3", "repository-gcs", "repository-azure"] and not self.core_plugin - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) ^ hash(self.config) ^ hash(self.core_plugin) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) and (self.name, self.config, self.core_plugin) == (other.name, other.config, other.core_plugin) @@ -466,14 +506,14 @@ class BootstrapPhase(Enum): post_install = 10 @classmethod - def valid(cls, name): + def valid(cls, name: str) -> bool: for n in BootstrapPhase.names(): if n == name: return True return False @classmethod - def names(cls): + def names(cls) -> Collection[str]: return [p.name for p in list(BootstrapPhase)] @@ -482,7 +522,7 @@ class BootstrapHookHandler: Responsible for loading and executing component-specific intitialization code. """ - def __init__(self, component, loader_class=modules.ComponentLoader): + def __init__(self, component: Any, loader_class: Callable = modules.ComponentLoader): """ Creates a new BootstrapHookHandler. @@ -497,13 +537,13 @@ def __init__(self, component, loader_class=modules.ComponentLoader): else: root_path = [self.component.root_path] self.loader = loader_class(root_path=root_path, component_entry_point=self.component.entry_point, recurse=False) - self.hooks = {} + self.hooks: MutableMapping[str, List[Callable]] = {} self.logger = logging.getLogger(__name__) - def can_load(self): + def can_load(self) -> bool: return self.loader.can_load() - def load(self): + def load(self) -> None: root_modules: Collection[ModuleType] = self.loader.load() try: # every module needs to have a register() method @@ -517,15 +557,16 @@ def load(self): self.logger.exception(msg) raise exceptions.SystemSetupError(msg) - def register(self, phase, hook): + def register(self, phase: str, hook: Callable) -> None: self.logger.info("Registering bootstrap hook [%s] for phase [%s] in component [%s]", hook.__name__, phase, self.component.name) if not BootstrapPhase.valid(phase): raise exceptions.SystemSetupError(f"Unknown bootstrap phase [{phase}]. Valid phases are: {BootstrapPhase.names()}.") if phase not in self.hooks: - self.hooks[phase] = [] + empty: List[Callable] = [] + self.hooks[phase] = empty self.hooks[phase].append(hook) - def invoke(self, phase, **kwargs): + def invoke(self, phase: str, **kwargs: Mapping[str, Any]) -> None: if phase in self.hooks: self.logger.info("Invoking phase [%s] for component [%s] in config [%s]", phase, self.component.name, self.component.config) for hook in self.hooks[phase]: diff --git a/esrally/utils/io.py b/esrally/utils/io.py index 6e4ff4f74..47bcafdda 100644 --- a/esrally/utils/io.py +++ b/esrally/utils/io.py @@ -26,7 +26,22 @@ import subprocess import tarfile import zipfile -from typing import AnyStr +from types import TracebackType +from typing import ( + IO, + Any, + AnyStr, + Callable, + Collection, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) import zstandard @@ -40,27 +55,31 @@ class FileSource: FileSource is a wrapper around a plain file which simplifies testing of file I/O calls. """ - def __init__(self, file_name, mode, encoding="utf-8"): + def __init__(self, file_name: str, mode: str, encoding: str = "utf-8"): self.file_name = file_name self.mode = mode self.encoding = encoding - self.f = None + self.f: Optional[IO[Any]] = None - def open(self): + def open(self) -> "FileSource": self.f = open(self.file_name, mode=self.mode, encoding=self.encoding) # allow for chaining return self - def seek(self, offset): + def seek(self, offset: int) -> None: + assert self.f is not None, "File is not open" self.f.seek(offset) - def read(self): + def read(self) -> bytes: + assert self.f is not None, "File is not open" return self.f.read() - def readline(self): + def readline(self) -> bytes: + assert self.f is not None, "File is not open" return self.f.readline() - def readlines(self, num_lines): + def readlines(self, num_lines: int) -> Sequence[bytes]: + assert self.f is not None, "File is not open" lines = [] f = self.f for _ in range(num_lines): @@ -70,19 +89,22 @@ def readlines(self, num_lines): lines.append(line) return lines - def close(self): + def close(self) -> None: + assert self.f is not None, "File is not open" self.f.close() self.f = None - def __enter__(self): + def __enter__(self) -> "FileSource": self.open() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: self.close() return False - def __str__(self, *args, **kwargs): + def __str__(self, *args: Collection[Any], **kwargs: Mapping[str, Any]) -> str: return self.file_name @@ -91,14 +113,14 @@ class MmapSource: MmapSource is a wrapper around a memory-mapped file which simplifies testing of file I/O calls. """ - def __init__(self, file_name, mode, encoding="utf-8"): + def __init__(self, file_name: str, mode: str, encoding: str = "utf-8"): self.file_name = file_name self.mode = mode self.encoding = encoding - self.f = None - self.mm = None + self.f: Optional[IO[Any]] = None + self.mm: Optional[mmap.mmap] = None - def open(self): + def open(self) -> "MmapSource": self.f = open(self.file_name, mode="r+b") self.mm = mmap.mmap(self.f.fileno(), 0, access=mmap.ACCESS_READ) self.mm.madvise(mmap.MADV_SEQUENTIAL) @@ -106,16 +128,20 @@ def open(self): # allow for chaining return self - def seek(self, offset): + def seek(self, offset: int) -> None: + assert self.mm is not None, "Source is not open" self.mm.seek(offset) - def read(self): + def read(self) -> bytes: + assert self.mm is not None, "Source is not open" return self.mm.read() - def readline(self): + def readline(self) -> bytes: + assert self.mm is not None, "Source is not open" return self.mm.readline() - def readlines(self, num_lines): + def readlines(self, num_lines: int) -> Sequence[bytes]: + assert self.mm is not None, "Source is not open" lines = [] mm = self.mm for _ in range(num_lines): @@ -125,21 +151,25 @@ def readlines(self, num_lines): lines.append(line) return lines - def close(self): + def close(self) -> None: + assert self.mm is not None, "Source is not open" self.mm.close() self.mm = None + assert self.f is not None, "File is not open" self.f.close() self.f = None - def __enter__(self): + def __enter__(self) -> "MmapSource": self.open() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: self.close() return False - def __str__(self, *args, **kwargs): + def __str__(self, *args: Collection[Any], **kwargs: Mapping[str, Any]) -> str: return self.file_name @@ -150,10 +180,10 @@ class DictStringFileSourceFactory: It is intended for scenarios where multiple files may be read by client code. """ - def __init__(self, name_to_contents): + def __init__(self, name_to_contents: Mapping[str, Sequence[str]]): self.name_to_contents = name_to_contents - def __call__(self, name, mode, encoding="utf-8"): + def __call__(self, name: str, mode: str, encoding: str = "utf-8") -> "StringAsFileSource": return StringAsFileSource(self.name_to_contents[name], mode, encoding) @@ -163,7 +193,7 @@ class StringAsFileSource: be used in production code. """ - def __init__(self, contents, mode, encoding="utf-8"): + def __init__(self, contents: Sequence[str], mode: str, encoding: str = "utf-8"): """ :param contents: The file contents as an array of strings. Each item in the array should correspond to one line. :param mode: The file mode. It is ignored in this implementation but kept to implement the same interface as ``FileSource``. @@ -173,20 +203,20 @@ def __init__(self, contents, mode, encoding="utf-8"): self.current_index = 0 self.opened = False - def open(self): + def open(self) -> "StringAsFileSource": self.opened = True return self - def seek(self, offset): + def seek(self, offset: int) -> None: self._assert_opened() if offset != 0: raise AssertionError("StringAsFileSource does not support random seeks") - def read(self): + def read(self) -> str: self._assert_opened() return "\n".join(self.contents) - def readline(self): + def readline(self) -> str: self._assert_opened() if self.current_index >= len(self.contents): return "" @@ -194,7 +224,7 @@ def readline(self): self.current_index += 1 return line - def readlines(self, num_lines): + def readlines(self, num_lines: int) -> Sequence[str]: lines = [] for _ in range(num_lines): line = self.readline() @@ -203,23 +233,25 @@ def readlines(self, num_lines): lines.append(line) return lines - def close(self): + def close(self) -> None: self._assert_opened() - self.contents = None + self.contents = [] self.opened = False - def _assert_opened(self): + def _assert_opened(self) -> None: assert self.opened - def __enter__(self): + def __enter__(self) -> "StringAsFileSource": self.open() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: self.close() return False - def __str__(self, *args, **kwargs): + def __str__(self, *args: Collection[Any], **kwargs: Mapping[str, Any]) -> str: return "StringAsFileSource" @@ -228,20 +260,20 @@ class ZstAdapter: Adapter class to make the zstandard API work with Rally's decompression abstractions """ - def __init__(self, path): + def __init__(self, path: str): self.fh = open(path, "rb") self.dctx = zstandard.ZstdDecompressor() self.reader = self.dctx.stream_reader(self.fh) - def read(self, size): + def read(self, size: int) -> bytes: return self.reader.read(size) - def close(self): + def close(self) -> None: self.reader.close() self.fh.close() -def ensure_dir(directory, mode=0o777): +def ensure_dir(directory: str, mode: int = 0o777) -> None: """ Ensure that the provided directory and all of its parent directories exist. This function is safe to execute on existing directories (no op). @@ -253,7 +285,7 @@ def ensure_dir(directory, mode=0o777): os.makedirs(directory, mode, exist_ok=True) -def _zipdir(source_directory, archive): +def _zipdir(source_directory: str, archive: zipfile.ZipFile) -> None: for root, _, files in os.walk(source_directory): for file in files: archive.write( @@ -262,7 +294,7 @@ def _zipdir(source_directory, archive): ) -def is_archive(name): +def is_archive(name: str) -> bool: """ :param name: File name to check. Can be either just the file name or optionally also an absolute path. :return: True iff the given file name is an archive that is also recognized for decompression by Rally. @@ -271,7 +303,7 @@ def is_archive(name): return ext in SUPPORTED_ARCHIVE_FORMATS -def is_executable(name): +def is_executable(name: str) -> bool: """ :param name: File name to check. :return: True iff given file name is executable and in PATH, all other cases False. @@ -280,7 +312,7 @@ def is_executable(name): return shutil.which(name) is not None -def compress(source_directory, archive_name): +def compress(source_directory: str, archive_name: str) -> None: """ Compress a directory tree. @@ -291,7 +323,7 @@ def compress(source_directory, archive_name): _zipdir(source_directory, archive) -def decompress(zip_name, target_directory): +def decompress(zip_name: str, target_directory: str) -> None: """ Decompresses the provided archive to the target directory. The following file extensions are supported: @@ -315,23 +347,23 @@ def decompress(zip_name, target_directory): _do_decompress(target_directory, zipfile.ZipFile(zip_name)) elif extension == ".bz2": decompressor_args = ["pbzip2", "-d", "-k", "-m10000", "-c"] - decompressor_lib = bz2.open - _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) + decompressor_lib_bz2 = bz2.open + _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib_bz2) elif extension == ".zst": decompressor_args = ["pzstd", "-f", "-d", "-c"] - decompressor_lib = ZstAdapter - _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) + decompressor_lib_zst = ZstAdapter + _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib_zst) elif extension == ".gz": decompressor_args = ["pigz", "-d", "-k", "-c"] - decompressor_lib = gzip.open - _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) + decompressor_lib_gzip = gzip.open + _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib_gzip) elif extension in [".tar", ".tar.gz", ".tgz", ".tar.bz2"]: _do_decompress(target_directory, tarfile.open(zip_name)) else: raise RuntimeError("Unsupported file extension [%s]. Cannot decompress [%s]" % (extension, zip_name)) -def _do_decompress_manually(target_directory, filename, decompressor_args, decompressor_lib): +def _do_decompress_manually(target_directory: str, filename: str, decompressor_args: List[str], decompressor_lib: Callable) -> None: decompressor_bin = decompressor_args[0] base_path_without_extension = basename(splitext(filename)[0]) @@ -346,7 +378,9 @@ def _do_decompress_manually(target_directory, filename, decompressor_args, decom _do_decompress_manually_with_lib(target_directory, filename, decompressor_lib(filename)) -def _do_decompress_manually_external(target_directory, filename, base_path_without_extension, decompressor_args): +def _do_decompress_manually_external( + target_directory: str, filename: str, base_path_without_extension: str, decompressor_args: List[str] +) -> bool: with open(os.path.join(target_directory, base_path_without_extension), "wb") as new_file: try: subprocess.run(decompressor_args + [filename], stdout=new_file, stderr=subprocess.PIPE, check=True) @@ -358,7 +392,7 @@ def _do_decompress_manually_external(target_directory, filename, base_path_witho return True -def _do_decompress_manually_with_lib(target_directory, filename, compressed_file): +def _do_decompress_manually_with_lib(target_directory: str, filename: str, compressed_file: IO[Any]) -> None: path_without_extension = basename(splitext(filename)[0]) ensure_dir(target_directory) @@ -370,29 +404,34 @@ def _do_decompress_manually_with_lib(target_directory, filename, compressed_file compressed_file.close() -def _do_decompress(target_directory, compressed_file): +def _do_decompress(target_directory: str, compressed_file: Union[zipfile.ZipFile, tarfile.TarFile]) -> None: try: compressed_file.extractall(path=target_directory) except BaseException: - raise RuntimeError("Could not decompress provided archive [%s]" % compressed_file.filename) + if isinstance(compressed_file, zipfile.ZipFile): + raise RuntimeError( + f"Could not decompress provided archive [{compressed_file.filename}]. Please check if it is a valid zip file." + ) + if isinstance(compressed_file, tarfile.TarFile): + raise RuntimeError(f"Could not decompress provided archive [{compressed_file.name!r}]. Please check if it is a valid tar file.") finally: compressed_file.close() # just in a dedicated method to ease mocking -def dirname(path: AnyStr): +def dirname(path: AnyStr) -> AnyStr: return os.path.dirname(path) -def basename(path: AnyStr): +def basename(path: AnyStr) -> AnyStr: return os.path.basename(path) -def exists(path: AnyStr): +def exists(path: AnyStr) -> bool: return os.path.exists(path) -def normalize_path(path, cwd="."): +def normalize_path(path: AnyStr, cwd: Any = ".") -> AnyStr: """ Normalizes a path by removing redundant "../" and also expanding the "~" character to the user home directory. :param path: A possibly non-normalized path. @@ -407,7 +446,7 @@ def normalize_path(path, cwd="."): return normalized -def escape_path(path): +def escape_path(path: str) -> str: """ Escapes any characters that might be problematic in shell interactions. @@ -417,7 +456,7 @@ def escape_path(path): return path.replace("\\", "\\\\") -def splitext(file_name): +def splitext(file_name: str) -> Tuple[str, str]: if file_name.endswith(".tar.gz"): return file_name[0:-7], file_name[-7:] elif file_name.endswith(".tar.bz2"): @@ -426,7 +465,7 @@ def splitext(file_name): return os.path.splitext(file_name) -def has_extension(file_name, extension): +def has_extension(file_name: str, extension: str) -> bool: """ Checks whether the given file name has the given extension. @@ -444,7 +483,7 @@ class FileOffsetTable: data file. This helps bulk-indexing clients to advance quickly to a certain position in a large data file. """ - def __init__(self, data_file_path, offset_table_path, mode): + def __init__(self, data_file_path: str, offset_table_path: str, mode: str): """ Creates a new FileOffsetTable instance. The constructor should not be called directly but instead the respective factory methods should be used. @@ -457,34 +496,35 @@ def __init__(self, data_file_path, offset_table_path, mode): self.data_file_path = data_file_path self.offset_table_path = offset_table_path self.mode = mode - self.offset_file = None + self.offset_file: Optional[IO[Any]] = None - def exists(self): + def exists(self) -> bool: """ :return: True iff the file offset table already exists. """ return os.path.exists(self.offset_table_path) - def is_valid(self): + def is_valid(self) -> bool: """ :return: True iff the file offset table exists and it is up-to-date. """ return self.exists() and os.path.getmtime(self.offset_table_path) >= os.path.getmtime(self.data_file_path) - def __enter__(self): + def __enter__(self) -> "FileOffsetTable": self.offset_file = open(self.offset_table_path, self.mode) return self - def add_offset(self, line_number, offset): + def add_offset(self, line_number: int, offset: int) -> None: """ Adds a new offset mapping to the file offset table. This method has to be called inside a context-manager block. :param line_number: A line number to add. :param offset: The corresponding offset in bytes. """ + assert self.offset_file is not None, "File offset table must be opened in a context manager block." print(f"{line_number};{offset}", file=self.offset_file) - def find_closest_offset(self, target_line_number): + def find_closest_offset(self, target_line_number: int) -> Tuple[int, int]: """ Determines the offset in bytes for the line L in the corresponding data file with the following properties: @@ -498,6 +538,7 @@ def find_closest_offset(self, target_line_number): prior_offset = 0 prior_remaining_lines = target_line_number + assert self.offset_file is not None, "File offset table must be opened in a context manager block." for line in self.offset_file: line_number, offset_in_bytes = (int(i) for i in line.strip().split(";")) if line_number <= target_line_number: @@ -508,13 +549,16 @@ def find_closest_offset(self, target_line_number): return prior_offset, prior_remaining_lines - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: + assert self.offset_file is not None, "File offset table must be opened in a context manager block." self.offset_file.close() self.offset_file = None return False @classmethod - def create_for_data_file(cls, data_file_path): + def create_for_data_file(cls, data_file_path: str) -> "FileOffsetTable": """ Factory method to create a new file offset table. @@ -523,7 +567,7 @@ def create_for_data_file(cls, data_file_path): return cls(data_file_path, f"{data_file_path}.offset", "wt") @classmethod - def read_for_data_file(cls, data_file_path): + def read_for_data_file(cls, data_file_path: str) -> "FileOffsetTable": """ Factory method to read from an existing file offset table. @@ -533,7 +577,7 @@ def read_for_data_file(cls, data_file_path): return cls(data_file_path, f"{data_file_path}.offset", "rt") @staticmethod - def remove(data_file_path): + def remove(data_file_path: str) -> None: """ Removes a file offset table for the provided data path. @@ -542,7 +586,7 @@ def remove(data_file_path): os.remove(f"{data_file_path}.offset") -def prepare_file_offset_table(data_file_path): +def prepare_file_offset_table(data_file_path: str) -> Optional[int]: """ Creates a file that contains a mapping from line numbers to file offsets for the provided path. This file is used internally by #skip_lines(data_file_path, data_file) to speed up line skipping. @@ -569,7 +613,7 @@ def prepare_file_offset_table(data_file_path): return None -def remove_file_offset_table(data_file_path): +def remove_file_offset_table(data_file_path: str) -> None: """ Attempts to remove the file offset table for the provided data path. @@ -579,7 +623,7 @@ def remove_file_offset_table(data_file_path): FileOffsetTable.remove(data_file_path) -def skip_lines(data_file_path, data_file, number_of_lines_to_skip): +def skip_lines(data_file_path: str, data_file: IO[Any], number_of_lines_to_skip: int) -> None: """ Skips the first `number_of_lines_to_skip` lines in `data_file` as a side effect. @@ -607,7 +651,7 @@ def skip_lines(data_file_path, data_file, number_of_lines_to_skip): data_file.readline() -def get_size(start_path="."): +def get_size(start_path: str = ".") -> int: total_size = 0 for dirpath, _, filenames in os.walk(start_path): for f in filenames: diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index a7b0bc83a..757445767 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -20,7 +20,7 @@ import os import sys from types import ModuleType -from typing import Collection, Union +from typing import Collection, Generator, Tuple, Union from esrally import exceptions from esrally.utils import io @@ -50,7 +50,7 @@ def __init__(self, root_path: Union[str, Collection[str]], component_entry_point self.recurse = recurse self.logger = logging.getLogger(__name__) - def _modules(self, module_paths: Collection[str], component_name: str, root_path: str): + def _modules(self, module_paths: Collection[str], component_name: str, root_path: str) -> Generator[Tuple[str, str], None, None]: for path in module_paths: for filename in os.listdir(path): name, ext = os.path.splitext(filename) @@ -61,7 +61,7 @@ def _modules(self, module_paths: Collection[str], component_name: str, root_path module_name = "%s.%s" % (component_name, root_relative_path.replace(os.path.sep, ".")) yield module_name, file_absolute_path - def _load_component(self, component_name: str, module_dirs: Collection[str], root_path: str): + def _load_component(self, component_name: str, module_dirs: Collection[str], root_path: str) -> ModuleType: # precondition: A module with this name has to exist provided that the caller has called #can_load() before. root_module_name = "%s.%s" % (component_name, self.component_entry_point) for name, p in self._modules(module_dirs, component_name, root_path): diff --git a/esrally/utils/process.py b/esrally/utils/process.py index c26c4c0f2..76dfb5d42 100644 --- a/esrally/utils/process.py +++ b/esrally/utils/process.py @@ -20,7 +20,7 @@ import shlex import subprocess import time -from typing import IO, Callable, Dict, List, Optional, Union +from typing import IO, Callable, List, Mapping, Optional, Union import psutil @@ -38,7 +38,7 @@ def run_subprocess(command_line: str) -> int: return subprocess.call(command_line, shell=True) -def run_subprocess_with_output(command_line: str, env: Dict[str, str] = None) -> List[str]: +def run_subprocess_with_output(command_line: str, env: Optional[Mapping[str, str]] = None) -> List[str]: logger = logging.getLogger(__name__) logger.debug("Running subprocess [%s] with output.", command_line) command_line_args = shlex.split(command_line) @@ -46,6 +46,7 @@ def run_subprocess_with_output(command_line: str, env: Dict[str, str] = None) -> has_output = True lines = [] while has_output: + assert command_line_process.stdout is not None, "stdout is None" line = command_line_process.stdout.readline() if line: lines.append(line.decode("UTF-8").strip()) @@ -72,10 +73,10 @@ def exit_status_as_bool(runnable: Callable[[], int], quiet: bool = False) -> boo def run_subprocess_with_logging( command_line: str, - header: str = None, + header: Optional[str] = None, level: LogLevel = logging.INFO, stdin: Optional[Union[FileId, IO[bytes]]] = None, - env: Dict[str, str] = None, + env: Optional[Mapping[str, str]] = None, detach: bool = False, ) -> int: """ @@ -117,10 +118,10 @@ def run_subprocess_with_logging( def run_subprocess_with_logging_and_output( command_line: str, - header: str = None, + header: Optional[str] = None, level: LogLevel = logging.INFO, - stdin: FileId = None, - env: Dict[str, str] = None, + stdin: Optional[FileId] = None, + env: Optional[Mapping[str, str]] = None, detach: bool = False, ) -> subprocess.CompletedProcess: """ @@ -173,7 +174,7 @@ def is_rally_process(p: psutil.Process) -> bool: def find_all_other_rally_processes() -> List[psutil.Process]: - others = [] + others: List[psutil.Process] = [] for_all_other_processes(is_rally_process, others.append) return others @@ -187,7 +188,7 @@ def redact_cmdline(cmdline: list) -> List[str]: def kill_all(predicate: Callable[[psutil.Process], bool]) -> None: - def kill(p: psutil.Process): + def kill(p: psutil.Process) -> None: logging.getLogger(__name__).info( "Killing lingering process with PID [%s] and command line [%s].", p.pid, redact_cmdline(p.cmdline()) ) diff --git a/pyproject.toml b/pyproject.toml index 627e9ecd3..0cf898796 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,13 +10,13 @@ path = "esrally/_version.py" name = "esrally" dynamic = ["version"] authors = [ - {name="Daniel Mitterdorfer", email="daniel.mitterdorfer@gmail.com"}, + { name = "Daniel Mitterdorfer", email = "daniel.mitterdorfer@gmail.com" }, ] description = "Macrobenchmarking framework for Elasticsearch" readme = "README.md" -license = {text = "Apache License 2.0"} +license = { text = "Apache License 2.0" } requires-python = ">=3.8" -classifiers=[ +classifiers = [ "Topic :: System :: Benchmark", "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: Apache Software License", @@ -81,7 +81,7 @@ dependencies = [ # License: Apache 2.0 "google-auth==1.22.1", # License: BSD - "zstandard==0.21.0" + "zstandard==0.21.0", ] [project.optional-dependencies] @@ -112,6 +112,9 @@ develop = [ "pylint==3.1.0", "trustme==0.9.0", "GitPython==3.1.30", + # mypy + "types-psutil==5.9.4", + "types-tabulate==0.8.9", ] [project.scripts] @@ -181,17 +184,16 @@ disable_error_code = [ "union-attr", "var-annotated", ] -files = [ - "esrally/", - "it/", - "tests/", -] +files = ["esrally/", "it/", "tests/"] [[tool.mypy.overrides]] module = [ "esrally.mechanic.team", "esrally.utils.modules", + "esrally.utils.io", + "esrally.utils.process", ] +disallow_incomplete_defs = true # this should be a copy of disabled_error_code from above enable_error_code = [ "assignment",