From c3333a6d4a9aa0e8134643c27e3ff36990e41d66 Mon Sep 17 00:00:00 2001 From: Marvin van Aalst Date: Tue, 31 May 2022 16:07:59 +0200 Subject: [PATCH 1/2] added type hints --- .gitignore | 2 +- .mypy.ini | 7 + llvmlite/_version.py | 45 +- llvmlite/binding/__init__.py | 2 +- llvmlite/binding/analysis.py | 21 +- llvmlite/binding/common.py | 10 +- llvmlite/binding/context.py | 14 +- llvmlite/binding/dylib.py | 11 +- llvmlite/binding/executionengine.py | 55 +- llvmlite/binding/ffi.py | 146 ++-- llvmlite/binding/initfini.py | 20 +- llvmlite/binding/linker.py | 8 +- llvmlite/binding/module.py | 81 +- llvmlite/binding/object_file.py | 33 +- llvmlite/binding/options.py | 4 +- llvmlite/binding/passmanagers.py | 91 ++- llvmlite/binding/targets.py | 92 ++- llvmlite/binding/transforms.py | 53 +- llvmlite/binding/value.py | 117 +-- llvmlite/ir/__init__.py | 3 + llvmlite/ir/_utils.py | 63 +- llvmlite/ir/builder.py | 1164 ++++++++++++++++++--------- llvmlite/ir/context.py | 11 +- llvmlite/ir/instructions.py | 704 ++++++++++------ llvmlite/ir/module.py | 118 +-- llvmlite/ir/transforms.py | 35 +- llvmlite/ir/types.py | 298 ++++--- llvmlite/ir/values.py | 643 ++++++++------- llvmlite/llvmpy/core.py | 124 +-- llvmlite/llvmpy/passes.py | 59 +- llvmlite/py.typed | 1 + llvmlite/utils.py | 7 +- 32 files changed, 2492 insertions(+), 1550 deletions(-) create mode 100644 .mypy.ini create mode 100644 llvmlite/py.typed diff --git a/.gitignore b/.gitignore index d0b7baacd..fe424f697 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,4 @@ coverage.xml MANIFEST docs/_build/ docs/gh-pages/ - +.vscode/ diff --git a/.mypy.ini b/.mypy.ini new file mode 100644 index 000000000..b73a2cf6d --- /dev/null +++ b/.mypy.ini @@ -0,0 +1,7 @@ +[mypy] +python_version = 3.7 +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = true +disallow_untyped_defs = true +exclude=(?x)(^llvmlite/test) diff --git a/llvmlite/_version.py b/llvmlite/_version.py index 01fc873f2..075ac4472 100644 --- a/llvmlite/_version.py +++ b/llvmlite/_version.py @@ -5,6 +5,8 @@ # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. +from __future__ import annotations + # This file is released into the public domain. Generated by # versioneer-0.12 (https://github.com/warner/python-versioneer) @@ -19,7 +21,13 @@ import os, sys, re, subprocess, errno -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False): +def run_command( + commands: list[str], + args: list[str], + cwd: str | None = None, + verbose: bool = False, + hide_stderr: bool = False, +) -> str | None: assert isinstance(commands, list) p = None for c in commands: @@ -31,7 +39,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False): break except EnvironmentError: e = sys.exc_info()[1] - if e.errno == errno.ENOENT: + if e.errno == errno.ENOENT: # type: ignore continue if verbose: print("unable to run %s" % args[0]) @@ -43,15 +51,17 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False): return None stdout = p.communicate()[0].strip() if sys.version >= '3': - stdout = stdout.decode() + stdout = stdout.decode() # type: ignore if p.returncode != 0: if verbose: print("unable to run %s (error)" % args[0]) return None - return stdout + return stdout # type: ignore -def versions_from_parentdir(parentdir_prefix, root, verbose=False): +def versions_from_parentdir( + parentdir_prefix: str, root: str, verbose: bool = False +) -> dict[str, str] | None: # Source tarballs conventionally unpack into a directory that includes # both the project name and a version string. dirname = os.path.basename(root) @@ -62,12 +72,12 @@ def versions_from_parentdir(parentdir_prefix, root, verbose=False): return None return {"version": dirname[len(parentdir_prefix):], "full": ""} -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> dict[str, str]: # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: dict[str, str] = {} try: f = open(versionfile_abs,"r") for line in f.readlines(): @@ -84,7 +94,10 @@ def git_get_keywords(versionfile_abs): pass return keywords -def git_versions_from_keywords(keywords, tag_prefix, verbose=False): + +def git_versions_from_keywords( + keywords: dict[str, str], tag_prefix: str, verbose: bool = False +) -> dict[str, str] | None: if not keywords: return {} # keyword-finding function failed to find keywords refnames = keywords["refnames"].strip() @@ -125,7 +138,9 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose=False): "full": keywords["full"].strip() } -def git_versions_from_vcs(tag_prefix, root, verbose=False): +def git_versions_from_vcs( + tag_prefix: str, root: str, verbose: bool = False +) -> dict[str, str]: # this runs 'git' from the root of the source tree. This only gets called # if the git-archive 'subst' keywords were *not* expanded, and # _version.py hasn't already been rewritten with a short version string, @@ -143,7 +158,7 @@ def git_versions_from_vcs(tag_prefix, root, verbose=False): cwd=root) if stdout is None: return {} - if not stdout.startswith(tag_prefix): + if not stdout.startswith(tag_prefix): # type: ignore if verbose: print("tag '%s' doesn't start with prefix '%s'" % (stdout, tag_prefix)) return {} @@ -152,12 +167,14 @@ def git_versions_from_vcs(tag_prefix, root, verbose=False): if stdout is None: return {} full = stdout.strip() - if tag.endswith("-dirty"): - full += "-dirty" - return {"version": tag, "full": full} + if tag.endswith("-dirty"): # type: ignore + full += "-dirty" # type: ignore + return {"version": tag, "full": full} # type: ignore -def get_versions(default={"version": "unknown", "full": ""}, verbose=False): +def get_versions( + default: dict[str, str] = {"version": "unknown", "full": ""}, verbose: bool = False +) -> dict[str, str]: # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which diff --git a/llvmlite/binding/__init__.py b/llvmlite/binding/__init__.py index 9c6eedf92..ada2dd849 100644 --- a/llvmlite/binding/__init__.py +++ b/llvmlite/binding/__init__.py @@ -13,4 +13,4 @@ from .value import * from .analysis import * from .object_file import * -from .context import * \ No newline at end of file +from .context import * diff --git a/llvmlite/binding/analysis.py b/llvmlite/binding/analysis.py index 25d0572b3..1da52dfa7 100644 --- a/llvmlite/binding/analysis.py +++ b/llvmlite/binding/analysis.py @@ -2,14 +2,18 @@ A collection of analysis utilities """ +from __future__ import annotations + from ctypes import POINTER, c_char_p, c_int +from typing import Any from llvmlite import ir from llvmlite.binding import ffi from llvmlite.binding.module import parse_assembly +from llvmlite.binding.value import ValueRef -def get_function_cfg(func, show_inst=True): +def get_function_cfg(func: ir.Function | ValueRef, show_inst: bool = True) -> str: """Return a string of the control-flow graph of the function in DOT format. If the input `func` is not a materialized function, the module containing the function is parsed to create an actual LLVM module. @@ -27,7 +31,7 @@ def get_function_cfg(func, show_inst=True): return str(dotstr) -def view_dot_graph(graph, filename=None, view=False): +def view_dot_graph(graph: Any, filename: str | None = None, view: bool = False) -> Any: """ View the given DOT source. If view is True, the image is rendered and viewed by the default application in the system. The file path of @@ -47,20 +51,23 @@ def view_dot_graph(graph, filename=None, view=False): """ # Optionally depends on graphviz package - import graphviz as gv + # could not be resolved + import graphviz as gv # type: ignore - src = gv.Source(graph) + src = gv.Source(graph) # type: ignore if view: # Returns the output file path - return src.render(filename, view=view) + return src.render(filename, view=view) # type: ignore else: # Attempts to show the graph in IPython notebook try: - __IPYTHON__ + __IPYTHON__ # type: ignore # is undefined except NameError: return src else: - import IPython.display as display + # could not be resolved + import IPython.display as display # type: ignore + format = 'svg' return display.SVG(data=src.pipe(format)) diff --git a/llvmlite/binding/common.py b/llvmlite/binding/common.py index 3f5746a5b..7344e766f 100644 --- a/llvmlite/binding/common.py +++ b/llvmlite/binding/common.py @@ -1,12 +1,14 @@ +from __future__ import annotations + import atexit -def _encode_string(s): +def _encode_string(s: str) -> bytes: encoded = s.encode('utf-8') return encoded -def _decode_string(b): +def _decode_string(b: bytes) -> str: return b.decode('utf-8') @@ -17,14 +19,14 @@ def _decode_string(b): _shutting_down = [False] -def _at_shutdown(): +def _at_shutdown() -> None: _shutting_down[0] = True atexit.register(_at_shutdown) -def _is_shutting_down(_shutting_down=_shutting_down): +def _is_shutting_down(_shutting_down: list[bool] = _shutting_down) -> bool: """ Whether the interpreter is currently shutting down. For use in finalizers, __del__ methods, and similar; it is advised diff --git a/llvmlite/binding/context.py b/llvmlite/binding/context.py index 7dffb82a7..81cea7e32 100644 --- a/llvmlite/binding/context.py +++ b/llvmlite/binding/context.py @@ -1,24 +1,28 @@ +from __future__ import annotations + +from typing import Any + from llvmlite.binding import ffi -def create_context(): +def create_context() -> ContextRef: return ContextRef(ffi.lib.LLVMPY_ContextCreate()) -def get_global_context(): +def get_global_context() -> GlobalContextRef: return GlobalContextRef(ffi.lib.LLVMPY_GetGlobalContext()) class ContextRef(ffi.ObjectRef): - def __init__(self, context_ptr): + def __init__(self, context_ptr: Any) -> None: super(ContextRef, self).__init__(context_ptr) - def _dispose(self): + def _dispose(self) -> None: ffi.lib.LLVMPY_ContextDispose(self) class GlobalContextRef(ContextRef): - def _dispose(self): + def _dispose(self) -> None: pass diff --git a/llvmlite/binding/dylib.py b/llvmlite/binding/dylib.py index e22542c49..874939408 100644 --- a/llvmlite/binding/dylib.py +++ b/llvmlite/binding/dylib.py @@ -1,10 +1,13 @@ -from ctypes import c_void_p, c_char_p, c_bool, POINTER +from __future__ import annotations + +from ctypes import POINTER, c_bool, c_char_p, c_void_p +from typing import Any from llvmlite.binding import ffi from llvmlite.binding.common import _encode_string -def address_of_symbol(name): +def address_of_symbol(name: str) -> Any: """ Get the in-process address of symbol named *name*. An integer is returned, or None if the symbol isn't found. @@ -12,7 +15,7 @@ def address_of_symbol(name): return ffi.lib.LLVMPY_SearchAddressOfSymbol(_encode_string(name)) -def add_symbol(name, address): +def add_symbol(name: str, address: int) -> None: """ Register the *address* of global symbol *name*. This will make it usable (e.g. callable) from LLVM-compiled functions. @@ -20,7 +23,7 @@ def add_symbol(name, address): ffi.lib.LLVMPY_AddSymbol(_encode_string(name), c_void_p(address)) -def load_library_permanently(filename): +def load_library_permanently(filename: str) -> None: """ Load an external library """ diff --git a/llvmlite/binding/executionengine.py b/llvmlite/binding/executionengine.py index 07cb8dab0..0d7058d32 100644 --- a/llvmlite/binding/executionengine.py +++ b/llvmlite/binding/executionengine.py @@ -1,15 +1,20 @@ +from __future__ import annotations + from ctypes import (POINTER, c_char_p, c_bool, c_void_p, c_int, c_uint64, c_size_t, CFUNCTYPE, string_at, cast, py_object, Structure) +from typing import Any, Callable from llvmlite.binding import ffi, targets, object_file - +from llvmlite.binding.module import ModuleRef # Just check these weren't optimized out of the DLL. ffi.lib.LLVMPY_LinkInMCJIT -def create_mcjit_compiler(module, target_machine): +def create_mcjit_compiler( + module: ModuleRef, target_machine: targets.TargetMachine +) -> ExecutionEngine: """ Create a MCJIT ExecutionEngine from the given *module* and *target_machine*. @@ -24,7 +29,7 @@ def create_mcjit_compiler(module, target_machine): return ExecutionEngine(engine, module=module) -def check_jit_execution(): +def check_jit_execution() -> None: """ Check the system allows execution of in-memory JITted functions. An exception is raised otherwise. @@ -45,16 +50,16 @@ class ExecutionEngine(ffi.ObjectRef): """ _object_cache = None - def __init__(self, ptr, module): + def __init__(self, ptr: Any, module: ModuleRef) -> None: """ Module ownership is transferred to the EE """ self._modules = set([module]) - self._td = None + self._td: targets.TargetData | None = None module._owned = True ffi.ObjectRef.__init__(self, ptr) - def get_function_address(self, name): + def get_function_address(self, name: str) -> Any: """ Return the address of the function named *name* as an integer. @@ -62,7 +67,7 @@ def get_function_address(self, name): """ return ffi.lib.LLVMPY_GetFunctionAddress(self, name.encode("ascii")) - def get_global_value_address(self, name): + def get_global_value_address(self, name: str) -> Any: """ Return the address of the global value named *name* as an integer. @@ -70,11 +75,11 @@ def get_global_value_address(self, name): """ return ffi.lib.LLVMPY_GetGlobalValueAddress(self, name.encode("ascii")) - def add_global_mapping(self, gv, addr): + def add_global_mapping(self, gv: Any, addr: Any) -> None: # XXX unused? ffi.lib.LLVMPY_AddGlobalMapping(self, gv, addr) - def add_module(self, module): + def add_module(self, module: ModuleRef) -> None: """ Ownership of module is transferred to the execution engine """ @@ -84,27 +89,27 @@ def add_module(self, module): module._owned = True self._modules.add(module) - def finalize_object(self): + def finalize_object(self) -> None: """ Make sure all modules owned by the execution engine are fully processed and "usable" for execution. """ ffi.lib.LLVMPY_FinalizeObject(self) - def run_static_constructors(self): + def run_static_constructors(self) -> None: """ Run static constructors which initialize module-level static objects. """ ffi.lib.LLVMPY_RunStaticConstructors(self) - def run_static_destructors(self): + def run_static_destructors(self) -> None: """ Run static destructors which perform module-level cleanup of static resources. """ ffi.lib.LLVMPY_RunStaticDestructors(self) - def remove_module(self, module): + def remove_module(self, module: ModuleRef) -> None: """ Ownership of module is returned """ @@ -115,7 +120,7 @@ def remove_module(self, module): module._owned = False @property - def target_data(self): + def target_data(self) -> targets.TargetData: """ The TargetData for this execution engine. """ @@ -126,7 +131,7 @@ def target_data(self): self._td._owned = True return self._td - def enable_jit_events(self): + def enable_jit_events(self) -> Any: """ Enable JIT events for profiling of generated code. Return value indicates whether connection to profiling tool @@ -135,7 +140,7 @@ def enable_jit_events(self): ret = ffi.lib.LLVMPY_EnableJITEvents(self) return ret - def _find_module_ptr(self, module_ptr): + def _find_module_ptr(self, module_ptr: Any) -> ModuleRef | None: """ Find the ModuleRef corresponding to the given pointer. """ @@ -145,7 +150,7 @@ def _find_module_ptr(self, module_ptr): return module return None - def add_object_file(self, obj_file): + def add_object_file(self, obj_file: str | object_file.ObjectFileRef) -> None: """ Add object file to the jit. object_file can be instance of :class:ObjectFile or a string representing file system path @@ -155,7 +160,11 @@ def add_object_file(self, obj_file): ffi.lib.LLVMPY_MCJITAddObjectFile(self, obj_file) - def set_object_cache(self, notify_func=None, getbuffer_func=None): + def set_object_cache( + self, + notify_func: Callable[[ModuleRef, bytes], Any] | None = None, + getbuffer_func: Callable[[ModuleRef], Any] | None = None, + ) -> None: """ Set the object cache "notifyObjectCompiled" and "getBuffer" callbacks to the given Python functions. @@ -168,7 +177,7 @@ def set_object_cache(self, notify_func=None, getbuffer_func=None): # cycles. ffi.lib.LLVMPY_SetObjectCache(self, self._object_cache) - def _raw_object_cache_notify(self, data): + def _raw_object_cache_notify(self, data: Any) -> None: """ Low-level notify hook. """ @@ -186,7 +195,7 @@ def _raw_object_cache_notify(self, data): "for unknown module %s" % (module_ptr,)) self._object_cache_notify(module, buf) - def _raw_object_cache_getbuffer(self, data): + def _raw_object_cache_getbuffer(self, data: Any) -> None: """ Low-level getbuffer hook. """ @@ -206,7 +215,7 @@ def _raw_object_cache_getbuffer(self, data): data[0].buf_ptr = ffi.lib.LLVMPY_CreateByteString(buf, len(buf)) data[0].buf_len = len(buf) - def _dispose(self): + def _dispose(self) -> None: # The modules will be cleaned up by the EE for mod in self._modules: mod.detach() @@ -222,13 +231,13 @@ class _ObjectCacheRef(ffi.ObjectRef): Internal: an ObjectCache instance for use within an ExecutionEngine. """ - def __init__(self, obj): + def __init__(self, obj: Any) -> None: ptr = ffi.lib.LLVMPY_CreateObjectCache(_notify_c_hook, _getbuffer_c_hook, obj) ffi.ObjectRef.__init__(self, ptr) - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeObjectCache(self) diff --git a/llvmlite/binding/ffi.py b/llvmlite/binding/ffi.py index 948769215..2cc5793ca 100644 --- a/llvmlite/binding/ffi.py +++ b/llvmlite/binding/ffi.py @@ -1,12 +1,17 @@ +from __future__ import annotations + import ctypes import threading import importlib.resources +from types import TracebackType +from typing import Any, Callable, Type, cast + from llvmlite.binding.common import _decode_string, _is_shutting_down from llvmlite.utils import get_library_name -def _make_opaque_ref(name): +def _make_opaque_ref(name: str) -> Any: newcls = type(name, (ctypes.Structure,), {}) return ctypes.POINTER(newcls) @@ -46,35 +51,42 @@ class _LLVMLock: Also, callbacks can be attached so that every time the lock is acquired and released the corresponding callbacks will be invoked. """ - def __init__(self): + + def __init__(self) -> None: # The reentrant lock is needed for callbacks that re-enter # the Python interpreter. self._lock = threading.RLock() - self._cblist = [] + self._cblist: list[tuple[Any, Any]] = [] - def register(self, acq_fn, rel_fn): + def register(self, acq_fn: Any, rel_fn: Any) -> None: """Register callbacks that are invoked immediately after the lock is acquired (``acq_fn()``) and immediately before the lock is released (``rel_fn()``). """ self._cblist.append((acq_fn, rel_fn)) - def unregister(self, acq_fn, rel_fn): + def unregister(self, acq_fn: Any, rel_fn: Any) -> None: """Remove the registered callbacks. """ self._cblist.remove((acq_fn, rel_fn)) - def __enter__(self): + def __enter__(self) -> None: self._lock.acquire() # Invoke all callbacks - for acq_fn, rel_fn in self._cblist: + for acq_fn, _ in self._cblist: acq_fn() - def __exit__(self, *exc_details): + def __exit__( + self, + exception_type: Type[BaseException] | None, + exception_instance: BaseException | None, + exc_traceback: TracebackType | None, + ) -> bool | None: # Invoke all callbacks - for acq_fn, rel_fn in self._cblist: + for _, rel_fn in self._cblist: rel_fn() self._lock.release() + return None class _lib_wrapper(object): @@ -85,12 +97,12 @@ class _lib_wrapper(object): """ __slots__ = ['_lib', '_fntab', '_lock'] - def __init__(self, lib): + def __init__(self, lib: Any) -> None: self._lib = lib - self._fntab = {} + self._fntab: dict[str, _lib_fn_wrapper] = {} self._lock = _LLVMLock() - def __getattr__(self, name): + def __getattr__(self, name: str) -> _lib_fn_wrapper: try: return self._fntab[name] except KeyError: @@ -101,15 +113,15 @@ def __getattr__(self, name): return wrapped @property - def _name(self): + def _name(self) -> str: """The name of the library passed in the CDLL constructor. For duck-typing a ctypes.CDLL """ - return self._lib._name + return cast(str, self._lib._name) @property - def _handle(self): + def _handle(self) -> Any: """The system handle used to access the library. For duck-typing a ctypes.CDLL @@ -126,27 +138,27 @@ class _lib_fn_wrapper(object): """ __slots__ = ['_lock', '_cfn'] - def __init__(self, lock, cfn): + def __init__(self, lock: _LLVMLock, cfn: Any) -> None: self._lock = lock self._cfn = cfn @property - def argtypes(self): + def argtypes(self) -> Any: return self._cfn.argtypes @argtypes.setter - def argtypes(self, argtypes): + def argtypes(self, argtypes: Any) -> None: self._cfn.argtypes = argtypes @property - def restype(self): + def restype(self) -> Any: return self._cfn.restype @restype.setter - def restype(self, restype): + def restype(self, restype: Any) -> None: self._cfn.restype = restype - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: with self._lock: return self._cfn(*args, **kwargs) @@ -157,7 +169,7 @@ def __call__(self, *args, **kwargs): pkgname = ".".join(__name__.split(".")[0:-1]) try: _lib_handle = importlib.resources.path(pkgname, _lib_name) - lib = ctypes.CDLL(str(_lib_handle.__enter__())) + _lib = ctypes.CDLL(str(_lib_handle.__enter__())) # on windows file handles to the dll file remain open after # loading, therefore we can not exit the context manager # which might delete the file @@ -167,17 +179,17 @@ def __call__(self, *args, **kwargs): raise OSError(msg) -lib = _lib_wrapper(lib) +lib = _lib_wrapper(_lib) -def register_lock_callback(acq_fn, rel_fn): +def register_lock_callback(acq_fn: Any, rel_fn: Any) -> None: """Register callback functions for lock acquire and release. *acq_fn* and *rel_fn* are callables that take no arguments. """ lib._lock.register(acq_fn, rel_fn) -def unregister_lock_callback(acq_fn, rel_fn): +def unregister_lock_callback(acq_fn: Any, rel_fn: Any) -> None: """Remove the registered callback functions for lock acquire and release. The arguments are the same as used in `register_lock_callback()`. """ @@ -197,7 +209,7 @@ class OutputString(object): _as_parameter_ = _DeadPointer() @classmethod - def from_return(cls, ptr): + def from_return(cls, ptr: Any) -> OutputString: """Constructing from a pointer returned from the C-API. The pointer must be allocated with LLVMPY_CreateString. @@ -206,64 +218,80 @@ def from_return(cls, ptr): Because ctypes auto-converts *restype* of *c_char_p* into a python string, we must use *c_void_p* to obtain the raw pointer. """ - return cls(init=ctypes.cast(ptr, ctypes.c_char_p)) - - def __init__(self, owned=True, init=None): - self._ptr = init if init is not None else ctypes.c_char_p(None) - self._as_parameter_ = ctypes.byref(self._ptr) + # c_char_p cannot be None in init + return cls(init=ctypes.cast(ptr, ctypes.c_char_p)) # type: ignore + + def __init__(self, owned: bool = True, init: None = None) -> None: + self._ptr: ctypes.c_char_p | None = ( + init if init is not None else ctypes.c_char_p(None) + ) + # _CArgObject != _DeadPointer + self._as_parameter_ = ctypes.byref(self._ptr) # type: ignore self._owned = owned - def close(self): + def close(self) -> None: if self._ptr is not None: if self._owned: lib.LLVMPY_DisposeString(self._ptr) self._ptr = None del self._as_parameter_ - def __enter__(self): + def __enter__(self) -> OutputString: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exception_type: Type[BaseException] | None, + exception_instance: BaseException | None, + exc_traceback: TracebackType | None, + ) -> bool | None: self.close() + return None - def __del__(self, _is_shutting_down=_is_shutting_down): + def __del__( + self, _is_shutting_down: Callable[[], bool] = _is_shutting_down + ) -> None: # Avoid errors trying to rely on globals and modules at interpreter # shutdown. if not _is_shutting_down(): if self.close is not None: self.close() - def __str__(self): + def __str__(self) -> str: if self._ptr is None: return "" s = self._ptr.value assert s is not None return _decode_string(s) - def __bool__(self): + def __bool__(self) -> bool: return bool(self._ptr) __nonzero__ = __bool__ @property - def bytes(self): + def bytes(self) -> bytes | None: """Get the raw bytes of content of the char pointer. """ + if self._ptr is None: + return None return self._ptr.value -def ret_string(ptr): +def ret_string(ptr: Any) -> str | None: """To wrap string return-value from C-API. """ if ptr is not None: return str(OutputString.from_return(ptr)) + return None -def ret_bytes(ptr): +def ret_bytes(ptr: Any) -> bytes | None: """To wrap bytes return-value from C-API. """ if ptr is not None: return OutputString.from_return(ptr).bytes + return None class ObjectRef(object): @@ -275,14 +303,14 @@ class ObjectRef(object): # Whether this object pointer is owned by another one. _owned = False - def __init__(self, ptr): + def __init__(self, ptr: Any) -> None: if ptr is None: raise ValueError("NULL pointer") self._ptr = ptr self._as_parameter_ = ptr self._capi = lib - def close(self): + def close(self) -> None: """ Close this object and do any required clean-up actions. """ @@ -292,7 +320,7 @@ def close(self): finally: self.detach() - def detach(self): + def detach(self) -> None: """ Detach the underlying LLVM resource without disposing of it. """ @@ -301,7 +329,7 @@ def detach(self): self._closed = True self._ptr = None - def _dispose(self): + def _dispose(self) -> None: """ Dispose of the underlying LLVM resource. Should be overriden by subclasses. Automatically called by close(), __del__() and @@ -309,38 +337,50 @@ def _dispose(self): """ @property - def closed(self): + def closed(self) -> bool: """ Whether this object has been closed. A closed object can't be used anymore. """ return self._closed - def __enter__(self): + def __enter__(self) -> ObjectRef: assert hasattr(self, "close") if self._closed: raise RuntimeError("%s instance already closed" % (self.__class__,)) return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exception_type: Type[BaseException] | None, + exception_instance: BaseException | None, + exc_traceback: TracebackType | None, + ) -> bool | None: self.close() + return None - def __del__(self, _is_shutting_down=_is_shutting_down): + def __del__( + self, _is_shutting_down: Callable[[], bool] = _is_shutting_down + ) -> None: if not _is_shutting_down(): if self.close is not None: self.close() - def __bool__(self): + def __bool__(self) -> bool: return bool(self._ptr) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if not hasattr(other, "_ptr"): return False - return ctypes.addressof(self._ptr[0]) == \ - ctypes.addressof(other._ptr[0]) + if self._ptr is None: + return False + # apparently mypy doesn't understand hasattr + if other._ptr is None: # type: ignore + return False + return ctypes.addressof(self._ptr[0]) == ctypes.addressof(other._ptr[0]) # type: ignore __nonzero__ = __bool__ # XXX useful? - def __hash__(self): + def __hash__(self) -> int: return hash(ctypes.cast(self._ptr, ctypes.c_void_p).value) diff --git a/llvmlite/binding/initfini.py b/llvmlite/binding/initfini.py index 4466d9da2..7be9fe64d 100644 --- a/llvmlite/binding/initfini.py +++ b/llvmlite/binding/initfini.py @@ -1,16 +1,18 @@ +from __future__ import annotations + from ctypes import c_uint from llvmlite.binding import ffi -def initialize(): +def initialize() -> None: """ Initialize the LLVM core. """ ffi.lib.LLVMPY_InitializeCore() -def initialize_all_targets(): +def initialize_all_targets() -> None: """ Initialize all targets. Necessary before targets can be looked up via the :class:`Target` class. @@ -20,7 +22,7 @@ def initialize_all_targets(): ffi.lib.LLVMPY_InitializeAllTargetMCs() -def initialize_all_asmprinters(): +def initialize_all_asmprinters() -> None: """ Initialize all code generators. Necessary before generating any assembly or machine code via the :meth:`TargetMachine.emit_object` @@ -29,7 +31,7 @@ def initialize_all_asmprinters(): ffi.lib.LLVMPY_InitializeAllAsmPrinters() -def initialize_native_target(): +def initialize_native_target() -> None: """ Initialize the native (host) target. Necessary before doing any code generation. @@ -37,21 +39,21 @@ def initialize_native_target(): ffi.lib.LLVMPY_InitializeNativeTarget() -def initialize_native_asmprinter(): +def initialize_native_asmprinter() -> None: """ Initialize the native ASM printer. """ ffi.lib.LLVMPY_InitializeNativeAsmPrinter() -def initialize_native_asmparser(): +def initialize_native_asmparser() -> None: """ Initialize the native ASM parser. """ ffi.lib.LLVMPY_InitializeNativeAsmParser() -def shutdown(): +def shutdown() -> None: ffi.lib.LLVMPY_Shutdown() @@ -61,8 +63,8 @@ def shutdown(): ffi.lib.LLVMPY_GetVersionInfo.restype = c_uint -def _version_info(): - v = [] +def _version_info() -> tuple[int, ...]: + v: list[int] = [] x = ffi.lib.LLVMPY_GetVersionInfo() while x: v.append(x & 0xff) diff --git a/llvmlite/binding/linker.py b/llvmlite/binding/linker.py index 31d1e26ff..cd4f34ea6 100644 --- a/llvmlite/binding/linker.py +++ b/llvmlite/binding/linker.py @@ -1,8 +1,14 @@ +from __future__ import annotations + from ctypes import c_int, c_char_p, POINTER +from typing import TYPE_CHECKING from llvmlite.binding import ffi +if TYPE_CHECKING: + from llvmlite.binding.module import ModuleRef + -def link_modules(dst, src): +def link_modules(dst: ModuleRef, src: ModuleRef) -> None: with ffi.OutputString() as outerr: err = ffi.lib.LLVMPY_LinkModules(dst, src, outerr) # The underlying module was destroyed diff --git a/llvmlite/binding/module.py b/llvmlite/binding/module.py index dcbb1faa6..9cf79e8f2 100644 --- a/llvmlite/binding/module.py +++ b/llvmlite/binding/module.py @@ -1,21 +1,23 @@ +from __future__ import annotations + from ctypes import (c_char_p, byref, POINTER, c_bool, create_string_buffer, c_size_t, string_at) +from typing import Any from llvmlite.binding import ffi from llvmlite.binding.linker import link_modules from llvmlite.binding.common import _decode_string, _encode_string from llvmlite.binding.value import ValueRef, TypeRef -from llvmlite.binding.context import get_global_context +from llvmlite.binding.context import GlobalContextRef, get_global_context -def parse_assembly(llvmir, context=None): +def parse_assembly(llvmir: str, context: GlobalContextRef | None = None) -> ModuleRef: """ Create Module from a LLVM IR string """ if context is None: context = get_global_context() - llvmir = _encode_string(llvmir) - strbuf = c_char_p(llvmir) + strbuf = c_char_p(_encode_string(llvmir)) with ffi.OutputString() as errmsg: mod = ModuleRef( ffi.lib.LLVMPY_ParseAssembly(context, strbuf, errmsg), @@ -26,7 +28,7 @@ def parse_assembly(llvmir, context=None): return mod -def parse_bitcode(bitcode, context=None): +def parse_bitcode(bitcode: bytes, context: GlobalContextRef | None = None) -> ModuleRef: """ Create Module from a LLVM *bitcode* (a bytes object). """ @@ -49,16 +51,16 @@ class ModuleRef(ffi.ObjectRef): A reference to a LLVM module. """ - def __init__(self, module_ptr, context): + def __init__(self, module_ptr: Any, context: GlobalContextRef) -> None: super(ModuleRef, self).__init__(module_ptr) self._context = context - def __str__(self): + def __str__(self) -> str: with ffi.OutputString() as outstr: ffi.lib.LLVMPY_PrintModuleToString(self, outstr) return str(outstr) - def as_bitcode(self): + def as_bitcode(self) -> bytes: """ Return the module's LLVM bitcode, as a bytes object. """ @@ -73,10 +75,10 @@ def as_bitcode(self): finally: ffi.lib.LLVMPY_DisposeString(ptr) - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeModule(self) - def get_function(self, name): + def get_function(self, name: str) -> ValueRef: """ Get a ValueRef pointing to the function named *name*. NameError is raised if the symbol isn't found. @@ -86,7 +88,7 @@ def get_function(self, name): raise NameError(name) return ValueRef(p, 'function', dict(module=self)) - def get_global_variable(self, name): + def get_global_variable(self, name: str) -> ValueRef: """ Get a ValueRef pointing to the global variable named *name*. NameError is raised if the symbol isn't found. @@ -96,7 +98,7 @@ def get_global_variable(self, name): raise NameError(name) return ValueRef(p, 'global', dict(module=self)) - def get_struct_type(self, name): + def get_struct_type(self, name: str) -> TypeRef: """ Get a TypeRef pointing to a structure type named *name*. NameError is raised if the struct type isn't found. @@ -106,7 +108,7 @@ def get_struct_type(self, name): raise NameError(name) return TypeRef(p) - def verify(self): + def verify(self) -> None: """ Verify the module IR's correctness. RuntimeError is raised on error. """ @@ -115,25 +117,25 @@ def verify(self): raise RuntimeError(str(outmsg)) @property - def name(self): + def name(self) -> str: """ The module's identifier. """ return _decode_string(ffi.lib.LLVMPY_GetModuleName(self)) @name.setter - def name(self, value): + def name(self, value: str) -> None: ffi.lib.LLVMPY_SetModuleName(self, _encode_string(value)) @property - def source_file(self): + def source_file(self) -> str: """ The module's original source file name """ return _decode_string(ffi.lib.LLVMPY_GetModuleSourceFileName(self)) @property - def data_layout(self): + def data_layout(self) -> str: """ This module's data layout specification, as a string. """ @@ -143,13 +145,13 @@ def data_layout(self): return str(outmsg) @data_layout.setter - def data_layout(self, strrep): + def data_layout(self, strrep: str) -> None: ffi.lib.LLVMPY_SetDataLayout(self, create_string_buffer( strrep.encode('utf8'))) @property - def triple(self): + def triple(self) -> str: """ This module's target "triple" specification, as a string. """ @@ -159,12 +161,12 @@ def triple(self): return str(outmsg) @triple.setter - def triple(self, strrep): + def triple(self, strrep: str) -> None: ffi.lib.LLVMPY_SetTarget(self, create_string_buffer( strrep.encode('utf8'))) - def link_in(self, other, preserve=False): + def link_in(self, other: ModuleRef, preserve: bool = False) -> None: """ Link the *other* module into this one. The *other* module will be destroyed unless *preserve* is true. @@ -174,7 +176,7 @@ def link_in(self, other, preserve=False): link_modules(self, other) @property - def global_variables(self): + def global_variables(self) -> _GlobalsIterator: """ Return an iterator over this module's global variables. The iterator will yield a ValueRef for each global variable. @@ -187,7 +189,7 @@ def global_variables(self): return _GlobalsIterator(it, dict(module=self)) @property - def functions(self): + def functions(self) -> _FunctionsIterator: """ Return an iterator over this module's functions. The iterator will yield a ValueRef for each function. @@ -196,7 +198,7 @@ def functions(self): return _FunctionsIterator(it, dict(module=self)) @property - def struct_types(self): + def struct_types(self) -> _TypesIterator: """ Return an iterator over the struct types defined in the module. The iterator will yield a TypeRef. @@ -204,21 +206,22 @@ def struct_types(self): it = ffi.lib.LLVMPY_ModuleTypesIter(self) return _TypesIterator(it, dict(module=self)) - def clone(self): + def clone(self) -> ModuleRef: return ModuleRef(ffi.lib.LLVMPY_CloneModule(self), self._context) class _Iterator(ffi.ObjectRef): - kind = None + kind: str | None = None - def __init__(self, ptr, parents): + def __init__(self, ptr: Any, parents: dict[str, ModuleRef]) -> None: ffi.ObjectRef.__init__(self, ptr) self._parents = parents assert self.kind is not None - def __next__(self): - vp = self._next() + def __next__(self) -> ValueRef: + # Cannot access member _next for _Iterator + vp = self._next() # type: ignore if vp: return ValueRef(vp, self.kind, self._parents) else: @@ -226,7 +229,7 @@ def __next__(self): next = __next__ - def __iter__(self): + def __iter__(self) -> _Iterator: return self @@ -234,10 +237,10 @@ class _GlobalsIterator(_Iterator): kind = 'global' - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeGlobalsIter(self) - def _next(self): + def _next(self) -> Any: return ffi.lib.LLVMPY_GlobalsIterNext(self) @@ -245,10 +248,10 @@ class _FunctionsIterator(_Iterator): kind = 'function' - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeFunctionsIter(self) - def _next(self): + def _next(self) -> Any: return ffi.lib.LLVMPY_FunctionsIterNext(self) @@ -256,20 +259,22 @@ class _TypesIterator(_Iterator): kind = 'type' - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeTypesIter(self) - def __next__(self): + # incompatible with supertype + def __next__(self) -> TypeRef: # type: ignore vp = self._next() if vp: return TypeRef(vp) else: raise StopIteration - def _next(self): + def _next(self) -> Any: return ffi.lib.LLVMPY_TypesIterNext(self) - next = __next__ + # incompatible with supertype + next = __next__ # type: ignore # ============================================================================= diff --git a/llvmlite/binding/object_file.py b/llvmlite/binding/object_file.py index e5961b079..4869b4734 100644 --- a/llvmlite/binding/object_file.py +++ b/llvmlite/binding/object_file.py @@ -1,52 +1,55 @@ +from __future__ import annotations + +from typing import Any, Iterator + from llvmlite.binding import ffi -from ctypes import (c_bool, c_char_p, c_char, c_size_t, string_at, c_uint64, - POINTER) +from ctypes import c_bool, c_char_p, c_size_t, string_at, c_uint64 class SectionIteratorRef(ffi.ObjectRef): - def name(self): + def name(self) -> Any: return ffi.lib.LLVMPY_GetSectionName(self) - def is_text(self): + def is_text(self) -> Any: return ffi.lib.LLVMPY_IsSectionText(self) - def size(self): + def size(self) -> Any: return ffi.lib.LLVMPY_GetSectionSize(self) - def address(self): + def address(self) -> Any: return ffi.lib.LLVMPY_GetSectionAddress(self) - def data(self): + def data(self) -> bytes: return string_at(ffi.lib.LLVMPY_GetSectionContents(self), self.size()) - def is_end(self, object_file): + def is_end(self, object_file: ObjectFileRef) -> Any: return ffi.lib.LLVMPY_IsSectionIteratorAtEnd(object_file, self) - def next(self): + def next(self) -> Any: ffi.lib.LLVMPY_MoveToNextSection(self) - def _dispose(self): + def _dispose(self) -> None: ffi.lib.LLVMPY_DisposeSectionIterator(self) class ObjectFileRef(ffi.ObjectRef): @classmethod - def from_data(cls, data): + def from_data(cls, data: bytes) -> ObjectFileRef: return cls(ffi.lib.LLVMPY_CreateObjectFile(data, len(data))) @classmethod - def from_path(cls, path): + def from_path(cls, path: str) -> ObjectFileRef: with open(path, 'rb') as f: data = f.read() return cls(ffi.lib.LLVMPY_CreateObjectFile(data, len(data))) - def sections(self): + def sections(self) -> Iterator[Any]: it = SectionIteratorRef(ffi.lib.LLVMPY_GetSections(self)) while not it.is_end(self): yield it it.next() - def _dispose(self): + def _dispose(self) -> None: ffi.lib.LLVMPY_DisposeObjectFile(self) @@ -76,7 +79,7 @@ def _dispose(self): ffi.lib.LLVMPY_GetSectionAddress.restype = c_uint64 ffi.lib.LLVMPY_GetSectionContents.argtypes = [ffi.LLVMSectionIteratorRef] -ffi.lib.LLVMPY_GetSectionContents.restype = POINTER(c_char) +ffi.lib.LLVMPY_GetSectionContents.restype = c_char_p ffi.lib.LLVMPY_IsSectionText.argtypes = [ffi.LLVMSectionIteratorRef] ffi.lib.LLVMPY_IsSectionText.restype = c_bool diff --git a/llvmlite/binding/options.py b/llvmlite/binding/options.py index 15eedfaaf..100f4bb0c 100644 --- a/llvmlite/binding/options.py +++ b/llvmlite/binding/options.py @@ -1,9 +1,11 @@ +from __future__ import annotations + from llvmlite.binding import ffi from llvmlite.binding.common import _encode_string from ctypes import c_char_p -def set_option(name, option): +def set_option(name: str, option: str) -> None: """ Set the given LLVM "command-line" option. diff --git a/llvmlite/binding/passmanagers.py b/llvmlite/binding/passmanagers.py index 6f9b7aa31..43fcfd8e7 100644 --- a/llvmlite/binding/passmanagers.py +++ b/llvmlite/binding/passmanagers.py @@ -1,17 +1,23 @@ +from __future__ import annotations + from ctypes import c_bool, c_int, c_size_t, Structure, byref -from collections import namedtuple +from dataclasses import dataclass from enum import IntFlag -from llvmlite.binding import ffi +from typing import Any -_prunestats = namedtuple('PruneStats', - ('basicblock diamond fanout fanout_raise')) +from llvmlite.binding import ffi +from llvmlite.binding.module import ModuleRef -class PruneStats(_prunestats): - """ Holds statistics from reference count pruning. - """ +@dataclass +class PruneStats: + """Holds statistics from reference count pruning.""" + basicblock: int + diamond: int + fanout: int + fanout_raise: int - def __add__(self, other): + def __add__(self, other: object) -> PruneStats: if not isinstance(other, PruneStats): msg = 'PruneStats can only be added to another PruneStats, got {}.' raise TypeError(msg.format(type(other))) @@ -20,7 +26,7 @@ def __add__(self, other): self.fanout + other.fanout, self.fanout_raise + other.fanout_raise) - def __sub__(self, other): + def __sub__(self, other: object) -> PruneStats: if not isinstance(other, PruneStats): msg = ('PruneStats can only be subtracted from another PruneStats, ' 'got {}.') @@ -39,7 +45,7 @@ class _c_PruneStats(Structure): ('fanout_raise', c_size_t)] -def dump_refprune_stats(printout=False): +def dump_refprune_stats(printout: bool = False) -> PruneStats: """ Returns a namedtuple containing the current values for the refop pruning statistics. If kwarg `printout` is True the stats are printed to stderr, default is False. @@ -53,7 +59,7 @@ def dump_refprune_stats(printout=False): stats.fanout_raise) -def set_time_passes(enable): +def set_time_passes(enable: bool) -> None: """Enable or disable the pass timers. Parameters @@ -65,7 +71,7 @@ def set_time_passes(enable): ffi.lib.LLVMPY_SetTimePasses(c_bool(enable)) -def report_and_reset_timings(): +def report_and_reset_timings() -> str: """Returns the pass timings report and resets the LLVM internal timers. Pass timers are enabled by ``set_time_passes()``. If the timers are not @@ -81,11 +87,11 @@ def report_and_reset_timings(): return str(buf) -def create_module_pass_manager(): +def create_module_pass_manager() -> ModulePassManager: return ModulePassManager() -def create_function_pass_manager(module): +def create_function_pass_manager(module: ModuleRef) -> FunctionPassManager: return FunctionPassManager(module) @@ -101,82 +107,85 @@ class PassManager(ffi.ObjectRef): """PassManager """ - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposePassManager(self) - def add_constant_merge_pass(self): + def add_constant_merge_pass(self) -> None: """See http://llvm.org/docs/Passes.html#constmerge-merge-duplicate-global-constants.""" # noqa E501 ffi.lib.LLVMPY_AddConstantMergePass(self) - def add_dead_arg_elimination_pass(self): + def add_dead_arg_elimination_pass(self) -> None: """See http://llvm.org/docs/Passes.html#deadargelim-dead-argument-elimination.""" # noqa E501 ffi.lib.LLVMPY_AddDeadArgEliminationPass(self) - def add_function_attrs_pass(self): + def add_function_attrs_pass(self) -> None: """See http://llvm.org/docs/Passes.html#functionattrs-deduce-function-attributes.""" # noqa E501 ffi.lib.LLVMPY_AddFunctionAttrsPass(self) - def add_function_inlining_pass(self, threshold): + def add_function_inlining_pass(self, threshold: int) -> None: """See http://llvm.org/docs/Passes.html#inline-function-integration-inlining.""" # noqa E501 ffi.lib.LLVMPY_AddFunctionInliningPass(self, threshold) - def add_global_dce_pass(self): + def add_global_dce_pass(self) -> None: """See http://llvm.org/docs/Passes.html#globaldce-dead-global-elimination.""" # noqa E501 ffi.lib.LLVMPY_AddGlobalDCEPass(self) - def add_global_optimizer_pass(self): + def add_global_optimizer_pass(self) -> None: """See http://llvm.org/docs/Passes.html#globalopt-global-variable-optimizer.""" # noqa E501 ffi.lib.LLVMPY_AddGlobalOptimizerPass(self) - def add_ipsccp_pass(self): + def add_ipsccp_pass(self) -> None: """See http://llvm.org/docs/Passes.html#ipsccp-interprocedural-sparse-conditional-constant-propagation.""" # noqa E501 ffi.lib.LLVMPY_AddIPSCCPPass(self) - def add_dead_code_elimination_pass(self): + def add_dead_code_elimination_pass(self) -> None: """See http://llvm.org/docs/Passes.html#dce-dead-code-elimination.""" ffi.lib.LLVMPY_AddDeadCodeEliminationPass(self) - def add_cfg_simplification_pass(self): + def add_cfg_simplification_pass(self) -> None: """See http://llvm.org/docs/Passes.html#simplifycfg-simplify-the-cfg.""" ffi.lib.LLVMPY_AddCFGSimplificationPass(self) - def add_gvn_pass(self): + def add_gvn_pass(self) -> None: """See http://llvm.org/docs/Passes.html#gvn-global-value-numbering.""" ffi.lib.LLVMPY_AddGVNPass(self) - def add_instruction_combining_pass(self): + def add_instruction_combining_pass(self) -> None: """See http://llvm.org/docs/Passes.html#passes-instcombine.""" ffi.lib.LLVMPY_AddInstructionCombiningPass(self) - def add_licm_pass(self): + def add_licm_pass(self) -> None: """See http://llvm.org/docs/Passes.html#licm-loop-invariant-code-motion.""" # noqa E501 ffi.lib.LLVMPY_AddLICMPass(self) - def add_sccp_pass(self): + def add_sccp_pass(self) -> None: """See http://llvm.org/docs/Passes.html#sccp-sparse-conditional-constant-propagation.""" # noqa E501 ffi.lib.LLVMPY_AddSCCPPass(self) - def add_sroa_pass(self): + def add_sroa_pass(self) -> None: """See http://llvm.org/docs/Passes.html#scalarrepl-scalar-replacement-of-aggregates-dt. Note that this pass corresponds to the ``opt -sroa`` command-line option, despite the link above.""" # noqa E501 ffi.lib.LLVMPY_AddSROAPass(self) - def add_type_based_alias_analysis_pass(self): + def add_type_based_alias_analysis_pass(self) -> None: ffi.lib.LLVMPY_AddTypeBasedAliasAnalysisPass(self) - def add_basic_alias_analysis_pass(self): + def add_basic_alias_analysis_pass(self) -> None: """See http://llvm.org/docs/AliasAnalysis.html#the-basicaa-pass.""" ffi.lib.LLVMPY_AddBasicAliasAnalysisPass(self) - def add_loop_rotate_pass(self): + def add_loop_rotate_pass(self) -> None: """http://llvm.org/docs/Passes.html#loop-rotate-rotate-loops.""" ffi.lib.LLVMPY_LLVMAddLoopRotatePass(self) # Non-standard LLVM passes - def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL, - subgraph_limit=1000): + def add_refprune_pass( + self, + subpasses_flags: RefPruneSubpasses = RefPruneSubpasses.ALL, + subgraph_limit: int = 1000, + ) -> None: """Add Numba specific Reference count pruning pass. Parameters @@ -194,13 +203,12 @@ def add_refprune_pass(self, subpasses_flags=RefPruneSubpasses.ALL, class ModulePassManager(PassManager): - - def __init__(self, ptr=None): + def __init__(self, ptr: Any = None) -> None: if ptr is None: ptr = ffi.lib.LLVMPY_CreatePassManager() PassManager.__init__(self, ptr) - def run(self, module): + def run(self, module: ModuleRef) -> Any: """ Run optimization passes on the given module. """ @@ -208,28 +216,27 @@ def run(self, module): class FunctionPassManager(PassManager): - - def __init__(self, module): + def __init__(self, module: ModuleRef) -> None: ptr = ffi.lib.LLVMPY_CreateFunctionPassManager(module) self._module = module module._owned = True PassManager.__init__(self, ptr) - def initialize(self): + def initialize(self) -> Any: """ Initialize the FunctionPassManager. Returns True if it produced any changes (?). """ return ffi.lib.LLVMPY_InitializeFunctionPassManager(self) - def finalize(self): + def finalize(self) -> Any: """ Finalize the FunctionPassManager. Returns True if it produced any changes (?). """ return ffi.lib.LLVMPY_FinalizeFunctionPassManager(self) - def run(self, function): + def run(self, function: Any) -> Any: """ Run optimization passes on the given function. """ diff --git a/llvmlite/binding/targets.py b/llvmlite/binding/targets.py index a7e6ffdc3..b79835349 100644 --- a/llvmlite/binding/targets.py +++ b/llvmlite/binding/targets.py @@ -1,12 +1,20 @@ +from __future__ import annotations + import os from ctypes import (POINTER, c_char_p, c_longlong, c_int, c_size_t, c_void_p, string_at) +from typing import Any, cast as _cast + +# from typing_extensions import Literal from llvmlite.binding import ffi from llvmlite.binding.common import _decode_string, _encode_string +from llvmlite.binding.module import ModuleRef +from llvmlite.binding.passmanagers import PassManager +from llvmlite.binding.value import TypeRef -def get_process_triple(): +def get_process_triple() -> str: """ Return a target triple suitable for generating code for the current process. An example when the default triple from ``get_default_triple()`` is not be @@ -18,13 +26,14 @@ def get_process_triple(): return str(out) +# subclassing built-ins isn't exactly great in < 3.9 class FeatureMap(dict): """ Maps feature name to a boolean indicating the availability of the feature. Extends ``dict`` to add `.flatten()` method. """ - def flatten(self, sort=True): + def flatten(self, sort: bool = True) -> str: """ Args ---- @@ -43,7 +52,7 @@ def flatten(self, sort=True): for k, v in iterator) -def get_host_cpu_features(): +def get_host_cpu_features() -> FeatureMap: """ Returns a dictionary-like object indicating the CPU features for current architecture and whether they are enabled for this CPU. The key-value pairs @@ -68,7 +77,7 @@ def get_host_cpu_features(): return outdict -def get_default_triple(): +def get_default_triple() -> str: """ Return the default target triple LLVM is configured to produce code for. """ @@ -77,7 +86,7 @@ def get_default_triple(): return str(out) -def get_host_cpu_name(): +def get_host_cpu_name() -> str: """ Get the name of the host's CPU, suitable for using with :meth:`Target.create_target_machine()`. @@ -94,7 +103,7 @@ def get_host_cpu_name(): } -def get_object_format(triple=None): +def get_object_format(triple: str | None = None) -> str: """ Get the object format for the given *triple* string (or the default triple if omitted). @@ -106,7 +115,7 @@ def get_object_format(triple=None): return _object_formats[res] -def create_target_data(layout): +def create_target_data(layout: str) -> TargetData: """ Create a TargetData instance for the given *layout* string. """ @@ -119,48 +128,48 @@ class TargetData(ffi.ObjectRef): Use :func:`create_target_data` to create instances. """ - def __str__(self): + def __str__(self) -> str: if self._closed: return "" with ffi.OutputString() as out: ffi.lib.LLVMPY_CopyStringRepOfTargetData(self, out) return str(out) - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeTargetData(self) - def get_abi_size(self, ty): + def get_abi_size(self, ty: Any) -> int: """ Get ABI size of LLVM type *ty*. """ - return ffi.lib.LLVMPY_ABISizeOfType(self, ty) + return _cast(int, ffi.lib.LLVMPY_ABISizeOfType(self, ty)) - def get_element_offset(self, ty, position): + def get_element_offset(self, ty: Any, position: int) -> int: """ Get byte offset of type's ty element at the given position """ - offset = ffi.lib.LLVMPY_OffsetOfElement(self, ty, position) + offset = _cast(int, ffi.lib.LLVMPY_OffsetOfElement(self, ty, position)) if offset == -1: raise ValueError("Could not determined offset of {}th " "element of the type '{}'. Is it a struct" "type?".format(position, str(ty))) return offset - def get_pointee_abi_size(self, ty): + def get_pointee_abi_size(self, ty: TypeRef) -> int: """ Get ABI size of pointee type of LLVM pointer type *ty*. """ - size = ffi.lib.LLVMPY_ABISizeOfElementType(self, ty) + size = _cast(int, ffi.lib.LLVMPY_ABISizeOfElementType(self, ty)) if size == -1: raise RuntimeError("Not a pointer type: %s" % (ty,)) return size - def get_pointee_abi_alignment(self, ty): + def get_pointee_abi_alignment(self, ty: TypeRef) -> int: """ Get minimum ABI alignment of pointee type of LLVM pointer type *ty*. """ - size = ffi.lib.LLVMPY_ABIAlignmentOfElementType(self, ty) + size = _cast(int, ffi.lib.LLVMPY_ABIAlignmentOfElementType(self, ty)) if size == -1: raise RuntimeError("Not a pointer type: %s" % (ty,)) return size @@ -178,7 +187,7 @@ class Target(ffi.ObjectRef): # persistent object. @classmethod - def from_default_triple(cls): + def from_default_triple(cls) -> Target: """ Create a Target instance for the default triple. """ @@ -186,13 +195,14 @@ def from_default_triple(cls): return cls.from_triple(triple) @classmethod - def from_triple(cls, triple): + def from_triple(cls, triple: str) -> Target: """ Create a Target instance for the given triple (a string). """ with ffi.OutputString() as outerr: - target = ffi.lib.LLVMPY_GetTargetFromTriple(triple.encode('utf8'), - outerr) + target = _cast(Target, + ffi.lib.LLVMPY_GetTargetFromTriple(triple.encode('utf8'), outerr) + ) if not target: raise RuntimeError(str(outerr)) target = cls(target) @@ -200,25 +210,33 @@ def from_triple(cls, triple): return target @property - def name(self): + def name(self) -> str: s = ffi.lib.LLVMPY_GetTargetName(self) return _decode_string(s) @property - def description(self): + def description(self) -> str: s = ffi.lib.LLVMPY_GetTargetDescription(self) return _decode_string(s) @property - def triple(self): + def triple(self) -> str: return self._triple - def __str__(self): + def __str__(self) -> str: return "".format(self.name, self.description) - def create_target_machine(self, cpu='', features='', - opt=2, reloc='default', codemodel='jitdefault', - printmc=False, jit=False, abiname=''): + def create_target_machine( + self, + cpu: str = "", + features: str = "", + opt: int = 2, # Literal[0, 1, 2, 3] + reloc: str = "default", + codemodel: str = "jitdefault", + printmc: bool = False, + jit: bool = False, + abiname: str = "", + ) -> TargetMachine: """ Create a new TargetMachine for this target and the given options. @@ -262,30 +280,30 @@ def create_target_machine(self, cpu='', features='', class TargetMachine(ffi.ObjectRef): - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeTargetMachine(self) - def add_analysis_passes(self, pm): + def add_analysis_passes(self, pm: PassManager) -> None: """ Register analysis passes for this target machine with a pass manager. """ ffi.lib.LLVMPY_AddAnalysisPasses(self, pm) - def set_asm_verbosity(self, verbose): + def set_asm_verbosity(self, verbose: bool) -> None: """ Set whether this target machine will emit assembly with human-readable comments describing control flow, debug information, and so on. """ ffi.lib.LLVMPY_SetTargetMachineAsmVerbosity(self, verbose) - def emit_object(self, module): + def emit_object(self, module: ModuleRef) -> bytes: """ Represent the module as a code object, suitable for use with the platform's linker. Returns a byte string. """ return self._emit_to_memory(module, use_object=True) - def emit_assembly(self, module): + def emit_assembly(self, module: ModuleRef) -> str: """ Return the raw assembler of the module, as a string. @@ -293,7 +311,7 @@ def emit_assembly(self, module): """ return _decode_string(self._emit_to_memory(module, use_object=False)) - def _emit_to_memory(self, module, use_object=False): + def _emit_to_memory(self, module: ModuleRef, use_object: bool = False) -> bytes: """Returns bytes of object code of the module. Args @@ -316,17 +334,17 @@ def _emit_to_memory(self, module, use_object=False): ffi.lib.LLVMPY_DisposeMemoryBuffer(mb) @property - def target_data(self): + def target_data(self) -> TargetData: return TargetData(ffi.lib.LLVMPY_CreateTargetMachineData(self)) @property - def triple(self): + def triple(self) -> str: with ffi.OutputString() as out: ffi.lib.LLVMPY_GetTargetMachineTriple(self, out) return str(out) -def has_svml(): +def has_svml() -> bool: """ Returns True if SVML was enabled at FFI support compile time. """ diff --git a/llvmlite/binding/transforms.py b/llvmlite/binding/transforms.py index 82c5dc157..24a6035e3 100644 --- a/llvmlite/binding/transforms.py +++ b/llvmlite/binding/transforms.py @@ -1,44 +1,47 @@ +from __future__ import annotations + +from typing import Any, cast as _cast from ctypes import c_uint, c_bool from llvmlite.binding import ffi from llvmlite.binding import passmanagers -def create_pass_manager_builder(): +def create_pass_manager_builder() -> PassManagerBuilder: return PassManagerBuilder() class PassManagerBuilder(ffi.ObjectRef): __slots__ = () - def __init__(self, ptr=None): + def __init__(self, ptr: Any = None) -> None: if ptr is None: ptr = ffi.lib.LLVMPY_PassManagerBuilderCreate() ffi.ObjectRef.__init__(self, ptr) @property - def opt_level(self): + def opt_level(self) -> int: """ The general optimization level as an integer between 0 and 3. """ - return ffi.lib.LLVMPY_PassManagerBuilderGetOptLevel(self) + return _cast(int, ffi.lib.LLVMPY_PassManagerBuilderGetOptLevel(self)) @opt_level.setter - def opt_level(self, level): + def opt_level(self, level: int) -> None: ffi.lib.LLVMPY_PassManagerBuilderSetOptLevel(self, level) @property - def size_level(self): + def size_level(self) -> int: """ Whether and how much to optimize for size. An integer between 0 and 2. """ - return ffi.lib.LLVMPY_PassManagerBuilderGetSizeLevel(self) + return _cast(int, ffi.lib.LLVMPY_PassManagerBuilderGetSizeLevel(self)) @size_level.setter - def size_level(self, size): + def size_level(self, size: int) -> None: ffi.lib.LLVMPY_PassManagerBuilderSetSizeLevel(self, size) @property - def inlining_threshold(self): + def inlining_threshold(self) -> int: """ The integer threshold for inlining a function into another. The higher, the more likely inlining a function is. This attribute is write-only. @@ -46,51 +49,51 @@ def inlining_threshold(self): raise NotImplementedError("inlining_threshold is write-only") @inlining_threshold.setter - def inlining_threshold(self, threshold): + def inlining_threshold(self, threshold: int) -> None: ffi.lib.LLVMPY_PassManagerBuilderUseInlinerWithThreshold( self, threshold) @property - def disable_unroll_loops(self): + def disable_unroll_loops(self) -> bool: """ If true, disable loop unrolling. """ - return ffi.lib.LLVMPY_PassManagerBuilderGetDisableUnrollLoops(self) + return _cast(bool, ffi.lib.LLVMPY_PassManagerBuilderGetDisableUnrollLoops(self)) @disable_unroll_loops.setter - def disable_unroll_loops(self, disable=True): + def disable_unroll_loops(self, disable: bool = True) -> None: ffi.lib.LLVMPY_PassManagerBuilderSetDisableUnrollLoops(self, disable) @property - def loop_vectorize(self): + def loop_vectorize(self) -> bool: """ If true, allow vectorizing loops. """ - return ffi.lib.LLVMPY_PassManagerBuilderGetLoopVectorize(self) + return _cast(bool, ffi.lib.LLVMPY_PassManagerBuilderGetLoopVectorize(self)) @loop_vectorize.setter - def loop_vectorize(self, enable=True): - return ffi.lib.LLVMPY_PassManagerBuilderSetLoopVectorize(self, enable) + def loop_vectorize(self, enable: bool = True) -> bool: + return _cast(bool, ffi.lib.LLVMPY_PassManagerBuilderSetLoopVectorize(self, enable)) @property - def slp_vectorize(self): + def slp_vectorize(self) -> bool: """ If true, enable the "SLP vectorizer", which uses a different algorithm from the loop vectorizer. Both may be enabled at the same time. """ - return ffi.lib.LLVMPY_PassManagerBuilderGetSLPVectorize(self) + return _cast(bool, ffi.lib.LLVMPY_PassManagerBuilderGetSLPVectorize(self)) @slp_vectorize.setter - def slp_vectorize(self, enable=True): - return ffi.lib.LLVMPY_PassManagerBuilderSetSLPVectorize(self, enable) + def slp_vectorize(self, enable: bool = True) -> bool: + return _cast(bool, ffi.lib.LLVMPY_PassManagerBuilderSetSLPVectorize(self, enable)) - def _populate_module_pm(self, pm): + def _populate_module_pm(self, pm: passmanagers.PassManager) -> None: ffi.lib.LLVMPY_PassManagerBuilderPopulateModulePassManager(self, pm) - def _populate_function_pm(self, pm): + def _populate_function_pm(self, pm: passmanagers.PassManager) -> None: ffi.lib.LLVMPY_PassManagerBuilderPopulateFunctionPassManager(self, pm) - def populate(self, pm): + def populate(self, pm: passmanagers.PassManager) -> None: if isinstance(pm, passmanagers.ModulePassManager): self._populate_module_pm(pm) elif isinstance(pm, passmanagers.FunctionPassManager): @@ -98,7 +101,7 @@ def populate(self, pm): else: raise TypeError(pm) - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_PassManagerBuilderDispose(self) diff --git a/llvmlite/binding/value.py b/llvmlite/binding/value.py index 4e21b3ee5..4c3a9d06c 100644 --- a/llvmlite/binding/value.py +++ b/llvmlite/binding/value.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +from typing import Any, cast as _cast from ctypes import POINTER, c_char_p, c_int, c_size_t, c_uint, c_bool, c_void_p import enum @@ -47,21 +50,21 @@ class TypeRef(ffi.ObjectRef): """A weak reference to a LLVM type """ @property - def name(self): + def name(self) -> str | None: """ Get type name """ return ffi.ret_string(ffi.lib.LLVMPY_GetTypeName(self)) @property - def is_pointer(self): + def is_pointer(self) -> bool: """ Returns true is the type is a pointer type. """ - return ffi.lib.LLVMPY_TypeIsPointer(self) + return _cast(bool, ffi.lib.LLVMPY_TypeIsPointer(self)) @property - def element_type(self): + def element_type(self) -> TypeRef: """ Returns the pointed-to type. When the type is not a pointer, raises exception. @@ -70,115 +73,115 @@ def element_type(self): raise ValueError("Type {} is not a pointer".format(self)) return TypeRef(ffi.lib.LLVMPY_GetElementType(self)) - def __str__(self): - return ffi.ret_string(ffi.lib.LLVMPY_PrintType(self)) + def __str__(self) -> str: + return _cast(str, ffi.ret_string(ffi.lib.LLVMPY_PrintType(self))) class ValueRef(ffi.ObjectRef): """A weak reference to a LLVM value. """ - def __init__(self, ptr, kind, parents): + def __init__(self, ptr: Any, kind: str | None, parents: dict[str, Any]) -> None: self._kind = kind self._parents = parents ffi.ObjectRef.__init__(self, ptr) - def __str__(self): + def __str__(self) -> str: with ffi.OutputString() as outstr: ffi.lib.LLVMPY_PrintValueToString(self, outstr) return str(outstr) @property - def module(self): + def module(self) -> Any: """ The module this function or global variable value was obtained from. """ return self._parents.get('module') @property - def function(self): + def function(self) -> Any: """ The function this argument or basic block value was obtained from. """ return self._parents.get('function') @property - def block(self): + def block(self) -> Any: """ The block this instruction value was obtained from. """ return self._parents.get('block') @property - def instruction(self): + def instruction(self) -> Any: """ The instruction this operand value was obtained from. """ return self._parents.get('instruction') @property - def is_global(self): + def is_global(self) -> bool: return self._kind == 'global' @property - def is_function(self): + def is_function(self) -> bool: return self._kind == 'function' @property - def is_block(self): + def is_block(self) -> bool: return self._kind == 'block' @property - def is_argument(self): + def is_argument(self) -> bool: return self._kind == 'argument' @property - def is_instruction(self): + def is_instruction(self) -> bool: return self._kind == 'instruction' @property - def is_operand(self): + def is_operand(self) -> bool: return self._kind == 'operand' @property - def name(self): + def name(self) -> str: return _decode_string(ffi.lib.LLVMPY_GetValueName(self)) @name.setter - def name(self, val): + def name(self, val: str) -> None: ffi.lib.LLVMPY_SetValueName(self, _encode_string(val)) @property - def linkage(self): + def linkage(self) -> Linkage: return Linkage(ffi.lib.LLVMPY_GetLinkage(self)) @linkage.setter - def linkage(self, value): + def linkage(self, value: Linkage) -> None: if not isinstance(value, Linkage): value = Linkage[value] ffi.lib.LLVMPY_SetLinkage(self, value) @property - def visibility(self): + def visibility(self) -> Visibility: return Visibility(ffi.lib.LLVMPY_GetVisibility(self)) @visibility.setter - def visibility(self, value): + def visibility(self, value: Visibility) -> None: if not isinstance(value, Visibility): value = Visibility[value] ffi.lib.LLVMPY_SetVisibility(self, value) @property - def storage_class(self): + def storage_class(self) -> StorageClass: return StorageClass(ffi.lib.LLVMPY_GetDLLStorageClass(self)) @storage_class.setter - def storage_class(self, value): + def storage_class(self, value: StorageClass) -> None: if not isinstance(value, StorageClass): value = StorageClass[value] ffi.lib.LLVMPY_SetDLLStorageClass(self, value) - def add_function_attribute(self, attr): + def add_function_attribute(self, attr: str) -> None: """Only works on function value Parameters @@ -196,7 +199,7 @@ def add_function_attribute(self, attr): ffi.lib.LLVMPY_AddFunctionAttr(self, attrval) @property - def type(self): + def type(self) -> TypeRef: """ This value's LLVM type. """ @@ -204,7 +207,7 @@ def type(self): return TypeRef(ffi.lib.LLVMPY_TypeOf(self)) @property - def is_declaration(self): + def is_declaration(self) -> bool: """ Whether this value (presumably global) is defined in the current module. @@ -212,10 +215,10 @@ def is_declaration(self): if not (self.is_global or self.is_function): raise ValueError('expected global or function value, got %s' % (self._kind,)) - return ffi.lib.LLVMPY_IsDeclaration(self) + return _cast(bool, ffi.lib.LLVMPY_IsDeclaration(self)) @property - def attributes(self): + def attributes(self) -> Any: """ Return an iterator over this value's attributes. The iterator will yield a string for each attribute. @@ -240,7 +243,7 @@ def attributes(self): return itr @property - def blocks(self): + def blocks(self) -> _BlocksIterator: """ Return an iterator over this function's blocks. The iterator will yield a ValueRef for each block. @@ -253,7 +256,7 @@ def blocks(self): return _BlocksIterator(it, parents) @property - def arguments(self): + def arguments(self) -> _ArgumentsIterator: """ Return an iterator over this function's arguments. The iterator will yield a ValueRef for each argument. @@ -266,7 +269,7 @@ def arguments(self): return _ArgumentsIterator(it, parents) @property - def instructions(self): + def instructions(self) -> _InstructionsIterator: """ Return an iterator over this block's instructions. The iterator will yield a ValueRef for each instruction. @@ -279,7 +282,7 @@ def instructions(self): return _InstructionsIterator(it, parents) @property - def operands(self): + def operands(self) -> _OperandsIterator: """ Return an iterator over this instruction's operands. The iterator will yield a ValueRef for each operand. @@ -293,7 +296,7 @@ def operands(self): return _OperandsIterator(it, parents) @property - def opcode(self): + def opcode(self) -> str | None: if not self.is_instruction: raise ValueError('expected instruction value, got %s' % (self._kind,)) @@ -302,10 +305,10 @@ def opcode(self): class _ValueIterator(ffi.ObjectRef): - kind = None # derived classes must specify the Value kind value + kind: str | None = None # derived classes must specify the Value kind value # as class attribute - def __init__(self, ptr, parents): + def __init__(self, ptr: Any, parents: dict[str, Any]) -> None: ffi.ObjectRef.__init__(self, ptr) # Keep parent objects (module, function, etc) alive self._parents = parents @@ -313,8 +316,9 @@ def __init__(self, ptr, parents): raise NotImplementedError('%s must specify kind attribute' % (type(self).__name__,)) - def __next__(self): - vp = self._next() + def __next__(self) -> ValueRef: + # parent class + vp = self._next() # type: ignore if vp: return ValueRef(vp, self.kind, self._parents) else: @@ -322,14 +326,15 @@ def __next__(self): next = __next__ - def __iter__(self): + def __iter__(self) -> _ValueIterator: return self class _AttributeIterator(ffi.ObjectRef): - def __next__(self): - vp = self._next() + def __next__(self) -> Any: + # unknwon member next + vp = self._next() # type: ignore if vp: return vp else: @@ -337,25 +342,25 @@ def __next__(self): next = __next__ - def __iter__(self): + def __iter__(self) -> _AttributeIterator: return self class _AttributeListIterator(_AttributeIterator): - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeAttributeListIter(self) - def _next(self): + def _next(self) -> bytes | None: return ffi.ret_bytes(ffi.lib.LLVMPY_AttributeListIterNext(self)) class _AttributeSetIterator(_AttributeIterator): - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeAttributeSetIter(self) - def _next(self): + def _next(self) -> bytes | None: return ffi.ret_bytes(ffi.lib.LLVMPY_AttributeSetIterNext(self)) @@ -363,10 +368,10 @@ class _BlocksIterator(_ValueIterator): kind = 'block' - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeBlocksIter(self) - def _next(self): + def _next(self) -> Any: return ffi.lib.LLVMPY_BlocksIterNext(self) @@ -374,10 +379,10 @@ class _ArgumentsIterator(_ValueIterator): kind = 'argument' - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeArgumentsIter(self) - def _next(self): + def _next(self) -> Any: return ffi.lib.LLVMPY_ArgumentsIterNext(self) @@ -385,10 +390,10 @@ class _InstructionsIterator(_ValueIterator): kind = 'instruction' - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeInstructionsIter(self) - def _next(self): + def _next(self) -> Any: return ffi.lib.LLVMPY_InstructionsIterNext(self) @@ -396,10 +401,10 @@ class _OperandsIterator(_ValueIterator): kind = 'operand' - def _dispose(self): + def _dispose(self) -> None: self._capi.LLVMPY_DisposeOperandsIter(self) - def _next(self): + def _next(self) -> Any: return ffi.lib.LLVMPY_OperandsIterNext(self) diff --git a/llvmlite/ir/__init__.py b/llvmlite/ir/__init__.py index b7a0737b2..32e417a33 100644 --- a/llvmlite/ir/__init__.py +++ b/llvmlite/ir/__init__.py @@ -2,6 +2,9 @@ This subpackage implements the LLVM IR classes in pure python """ +from __future__ import annotations + + from .types import * from .values import * from .module import * diff --git a/llvmlite/ir/_utils.py b/llvmlite/ir/_utils.py index 8287d77af..294847907 100644 --- a/llvmlite/ir/_utils.py +++ b/llvmlite/ir/_utils.py @@ -1,4 +1,7 @@ +from __future__ import annotations + from collections import defaultdict +from typing import Any class DuplicatedNameError(NameError): @@ -6,14 +9,14 @@ class DuplicatedNameError(NameError): class NameScope(object): - def __init__(self): - self._useset = set(['']) - self._basenamemap = defaultdict(int) + def __init__(self) -> None: + self._useset: set[str] = set(['']) + self._basenamemap: defaultdict[str, int] = defaultdict(int) - def is_used(self, name): + def is_used(self, name: str) -> bool: return name in self._useset - def register(self, name, deduplicate=False): + def register(self, name: str, deduplicate: bool = False) -> str: if deduplicate: name = self.deduplicate(name) elif self.is_used(name): @@ -21,7 +24,7 @@ def register(self, name, deduplicate=False): self._useset.add(name) return name - def deduplicate(self, name): + def deduplicate(self, name: str) -> str: basename = name while self.is_used(name): ident = self._basenamemap[basename] + 1 @@ -29,52 +32,58 @@ def deduplicate(self, name): name = "{0}.{1}".format(basename, ident) return name - def get_child(self): - return type(self)(parent=self) + def get_child(self) -> Any: + # FIXME: which types can this produce?? + return type(self)(parent=self) # type: ignore -class _StrCaching(object): +class _StrCaching: + # FIXME: self.__cached_str missing + # FIXME: self.to_string() missing - def _clear_string_cache(self): + def _clear_string_cache(self) -> None: try: - del self.__cached_str + del self.__cached_str # type: ignore except AttributeError: pass - def __str__(self): + def __str__(self) -> str: try: - return self.__cached_str + return self.__cached_str # type: ignore except AttributeError: - s = self.__cached_str = self._to_string() - return s + s = self.__cached_str = self._to_string() # type: ignore + return s # type: ignore -class _StringReferenceCaching(object): +class _StringReferenceCaching: + # FIXME: self.__cached_refstr missing + # FIXME: self._get_reference missing - def get_reference(self): + def get_reference(self) -> str: try: - return self.__cached_refstr + return self.__cached_refstr # type: ignore except AttributeError: - s = self.__cached_refstr = self._get_reference() - return s + s = self.__cached_refstr = self._get_reference() # type: ignore + return s # type: ignore -class _HasMetadata(object): +class _HasMetadata: + # FIXME: self.metadata missing - def set_metadata(self, name, node): + def set_metadata(self, name: str, node: Any) -> None: """ Attach unnamed metadata *node* to the metadata slot *name* of this value. """ - self.metadata[name] = node + self.metadata[name] = node # type: ignore - def _stringify_metadata(self, leading_comma=False): - if self.metadata: - buf = [] + def _stringify_metadata(self, leading_comma: bool = False) -> str: + if self.metadata: # type: ignore + buf: list[str] = [] if leading_comma: buf.append("") buf += ["!{0} {1}".format(k, v.get_reference()) - for k, v in self.metadata.items()] + for k, v in self.metadata.items()] # type: ignore return ', '.join(buf) else: return '' diff --git a/llvmlite/ir/builder.py b/llvmlite/ir/builder.py index e648869da..c173bd363 100644 --- a/llvmlite/ir/builder.py +++ b/llvmlite/ir/builder.py @@ -1,7 +1,19 @@ +from __future__ import annotations + import contextlib -import functools +from typing import Any, ContextManager, Iterator, Sequence, cast + +# from typing_extensions import Literal from llvmlite.ir import instructions, types, values +from llvmlite.ir.instructions import ( + Branch, + CallInstr, + CastInstr, + Instruction, + LandingPadInstr, +) +from llvmlite.ir.module import Module _CMP_MAP = { '>': 'gt', @@ -13,172 +25,8 @@ } -def _unop(opname, cls=instructions.Instruction): - def wrap(fn): - @functools.wraps(fn) - def wrapped(self, arg, name='', flags=()): - instr = cls(self.block, arg.type, opname, [arg], name, flags) - self._insert(instr) - return instr - - return wrapped - - return wrap - - -def _binop(opname, cls=instructions.Instruction): - def wrap(fn): - @functools.wraps(fn) - def wrapped(self, lhs, rhs, name='', flags=()): - if lhs.type != rhs.type: - raise ValueError("Operands must be the same type, got (%s, %s)" - % (lhs.type, rhs.type)) - instr = cls(self.block, lhs.type, opname, (lhs, rhs), name, flags) - self._insert(instr) - return instr - - return wrapped - - return wrap - - -def _binop_with_overflow(opname, cls=instructions.Instruction): - def wrap(fn): - @functools.wraps(fn) - def wrapped(self, lhs, rhs, name=''): - if lhs.type != rhs.type: - raise ValueError("Operands must be the same type, got (%s, %s)" - % (lhs.type, rhs.type)) - ty = lhs.type - if not isinstance(ty, types.IntType): - raise TypeError("expected an integer type, got %s" % (ty,)) - bool_ty = types.IntType(1) - - mod = self.module - fnty = types.FunctionType(types.LiteralStructType([ty, bool_ty]), - [ty, ty]) - fn = mod.declare_intrinsic("llvm.%s.with.overflow" % (opname,), - [ty], fnty) - ret = self.call(fn, [lhs, rhs], name=name) - return ret - - return wrapped - - return wrap - - -def _uniop(opname, cls=instructions.Instruction): - def wrap(fn): - @functools.wraps(fn) - def wrapped(self, operand, name=''): - instr = cls(self.block, operand.type, opname, [operand], name) - self._insert(instr) - return instr - - return wrapped - - return wrap - - -def _uniop_intrinsic_int(opname): - def wrap(fn): - @functools.wraps(fn) - def wrapped(self, operand, name=''): - if not isinstance(operand.type, types.IntType): - raise TypeError( - "expected an integer type, got %s" % - operand.type) - fn = self.module.declare_intrinsic(opname, [operand.type]) - return self.call(fn, [operand], name) - - return wrapped - - return wrap - - -def _uniop_intrinsic_float(opname): - def wrap(fn): - @functools.wraps(fn) - def wrapped(self, operand, name=''): - if not isinstance( - operand.type, (types.FloatType, types.DoubleType)): - raise TypeError("expected a float type, got %s" % operand.type) - fn = self.module.declare_intrinsic(opname, [operand.type]) - return self.call(fn, [operand], name) - - return wrapped - - return wrap - - -def _uniop_intrinsic_with_flag(opname): - def wrap(fn): - @functools.wraps(fn) - def wrapped(self, operand, flag, name=''): - if not isinstance(operand.type, types.IntType): - raise TypeError( - "expected an integer type, got %s" % - operand.type) - if not(isinstance(flag.type, types.IntType) and - flag.type.width == 1): - raise TypeError("expected an i1 type, got %s" % flag.type) - fn = self.module.declare_intrinsic( - opname, [operand.type, flag.type]) - return self.call(fn, [operand, flag], name) - - return wrapped - - return wrap - - -def _triop_intrinsic(opname): - def wrap(fn): - @functools.wraps(fn) - def wrapped(self, a, b, c, name=''): - if a.type != b.type or b.type != c.type: - raise TypeError( - "expected types to be the same, got %s, %s, %s" % ( - a.type, - b.type, - c.type)) - elif not isinstance( - a.type, - (types.HalfType, types.FloatType, types.DoubleType)): - raise TypeError( - "expected an floating point type, got %s" % - a.type) - fn = self.module.declare_intrinsic(opname, [a.type, b.type, c.type]) - return self.call(fn, [a, b, c], name) - - return wrapped - - return wrap - - -def _castop(opname, cls=instructions.CastInstr): - def wrap(fn): - @functools.wraps(fn) - def wrapped(self, val, typ, name=''): - if val.type == typ: - return val - instr = cls(self.block, opname, val, typ, name) - self._insert(instr) - return instr - - return wrapped - - return wrap - - -def _label_suffix(label, suffix): - """Returns (label + suffix) or a truncated version if it's too long. - Parameters - ---------- - label : str - Label name - suffix : str - Label suffix - """ +def _label_suffix(label: str, suffix: str) -> str: + """Returns (label + suffix) or a truncated version if it's too long.""" if len(label) > 50: nhead = 25 return ''.join([label[:nhead], '..', suffix]) @@ -186,88 +34,254 @@ def _label_suffix(label, suffix): return label + suffix -class IRBuilder(object): - def __init__(self, block=None): - self._block = block +class IRBuilder: + def __init__(self, block: values.Block | None = None) -> None: + self._block: values.Block | None = block self._anchor = len(block.instructions) if block else 0 self.debug_metadata = None + def _unop( + self, + opname: str, + arg: values.Constant, + name: str = "", + flags: tuple[Any, ...] = (), + ) -> Instruction: + instr = Instruction( + self.block, + arg.type, + opname, + [arg], + name, + flags, + ) + self._insert(instr) + return instr + + def _binop( + self, + opname: str, + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + flags: tuple[str, ...] = (), + ) -> Instruction: + if lhs.type != rhs.type: + raise ValueError( + f"Operands must be the same type, got ({lhs.type}, {rhs.type})" + ) + instr = Instruction( + self.block, + lhs.type, + opname, + [lhs, rhs], + name, + flags, + ) + self._insert(instr) + return instr + + def _binop_with_overflow( + self, + opname: str, + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + ) -> Instruction: + if lhs.type != rhs.type: + raise ValueError( + f"Operands must be the same type, got ({lhs.type}, {rhs.type})" + ) + ty = lhs.type + if not isinstance(ty, types.IntType): + raise TypeError(f"expected an integer type, got {ty}") + bool_ty = types.IntType(1) + + mod = self.module + fnty = types.FunctionType(types.LiteralStructType([ty, bool_ty]), [ty, ty]) + fn = mod.declare_intrinsic(f"llvm.{opname}.with.overflow", [ty], fnty) + # FIXME: is this cast ok? + ret = self.call(cast(values.Function, fn), [lhs, rhs], name=name) + return ret + + def _uniop( + self, + opname: str, + operand: values.Constant, + name: str = "", + ) -> Instruction: + instr = Instruction( + self.block, + operand.type, + opname, + [operand], + name, + ) + self._insert(instr) + return instr + + def _uniop_intrinsic_int( + self, + opname: str, + operand: values.Constant, + name: str = "", + ) -> Instruction: + + if not isinstance(operand.type, types.IntType): + raise TypeError(f"expected an integer type, got {operand.type}") + fn = self.module.declare_intrinsic(opname, [operand.type]) + # FIXME: is the cast here ok? + return self.call(cast(values.Function, fn), [operand], name) + + def _uniop_intrinsic_float( + self, + opname: str, + operand: values.Constant, + name: str = "", + ) -> CallInstr: + if not isinstance(operand.type, (types.FloatType, types.DoubleType)): + raise TypeError(f"expected a float type, got {operand.type}") + fn = self.module.declare_intrinsic(opname, [operand.type]) + # FIXME: is the cast here ok? + return self.call(cast(values.Function, fn), [operand], name) + + def _uniop_intrinsic_with_flag( + self, + opname: str, + operand: values.Constant, + flag: Any, + name: str = "", + ) -> CallInstr: + if not isinstance(operand.type, types.IntType): + raise TypeError(f"expected an integer type, got {operand.type}") + if not (isinstance(flag.type, types.IntType) and flag.type.width == 1): + raise TypeError(f"expected an i1 type, got {flag.type}") + fn = self.module.declare_intrinsic(opname, [operand.type, flag.type]) + # FIXME: is the cast here ok? + return self.call(cast(values.Function, fn), [operand, flag], name) + + def _triop_intrinsic( + self, + opname: str, + a: values.Constant, + b: values.Constant, + c: values.Constant, + name: str = "", + ) -> CallInstr: + if a.type != b.type or b.type != c.type: + raise TypeError( + f"expected types to be the same, got {a.type}, {b.type}, {c.type}" + ) + elif not isinstance( + a.type, (types.HalfType, types.FloatType, types.DoubleType) + ): + raise TypeError(f"expected an floating point type, got {a.type}") + fn = self.module.declare_intrinsic(opname, [a.type, b.type, c.type]) + # FIXME: is the cast here ok? + return self.call(cast(values.Function, fn), [a, b, c], name) + + def _castop( + self, + opname: str, + val: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | CastInstr: + if val.type == typ: + return val + instr = instructions.CastInstr( + self.block, + opname, + val, + typ, + name, + ) + self._insert(instr) + return instr + @property - def block(self): + def block(self) -> values.Block: """ The current basic block. """ - return self._block - - basic_block = block + block = self._block + if block is None: + raise AttributeError("Block is None") + return block @property - def function(self): + def function(self) -> values.Function: """ The current function. """ - return self.block.parent + # FIXME: is the cast here ok? + return cast(values.Function, self.block.parent) @property - def module(self): + def module(self) -> Module: """ The current module. """ - return self.block.parent.module + # FIXME: is the cast here ok? + return cast(values.Function, self.block.parent).module - def position_before(self, instr): + def position_before(self, instr: Instruction) -> None: """ Position immediately before the given instruction. The current block is also changed to the instruction's basic block. """ - self._block = instr.parent + # FIXME: is the cast here ok? + self._block = cast(values.Block, instr.parent) self._anchor = self._block.instructions.index(instr) - def position_after(self, instr): + def position_after(self, instr: Instruction) -> None: """ Position immediately after the given instruction. The current block is also changed to the instruction's basic block. """ - self._block = instr.parent + # FIXME: is the cast here ok? + self._block = cast(values.Block, instr.parent) self._anchor = self._block.instructions.index(instr) + 1 - def position_at_start(self, block): + def position_at_start(self, block: values.Block) -> None: """ Position at the start of the basic *block*. """ self._block = block self._anchor = 0 - def position_at_end(self, block): + def position_at_end(self, block: values.Block) -> None: """ Position at the end of the basic *block*. """ self._block = block self._anchor = len(block.instructions) - def append_basic_block(self, name=''): + def append_basic_block(self, name: str = "") -> values.Block: """ Append a basic block, with the given optional *name*, to the current function. The current block is not changed. The new block is returned. """ return self.function.append_basic_block(name) - def remove(self, instr): + def remove(self, instr: Instruction) -> None: """Remove the given instruction.""" - idx = self._block.instructions.index(instr) - del self._block.instructions[idx] - if self._block.terminator == instr: - self._block.terminator = None + # FIXME: is this cast ok? + block = cast(values.Block, self._block) + idx = block.instructions.index(instr) + del block.instructions[idx] + if block.terminator == instr: + block.terminator = None if self._anchor > idx: self._anchor -= 1 @contextlib.contextmanager - def goto_block(self, block): + def goto_block(self, block: values.Block) -> Iterator[None]: """ A context manager which temporarily positions the builder at the end of basic block *bb* (but before any terminator). """ - old_block = self.basic_block + old_block = self.block term = block.terminator if term is not None: self.position_before(term) @@ -279,7 +293,7 @@ def goto_block(self, block): self.position_at_end(old_block) @contextlib.contextmanager - def goto_entry_block(self): + def goto_entry_block(self) -> Iterator[None]: """ A context manager which temporarily positions the builder at the end of the function's entry block. @@ -288,14 +302,20 @@ def goto_entry_block(self): yield @contextlib.contextmanager - def _branch_helper(self, bbenter, bbexit): + def _branch_helper( + self, bbenter: values.Block, bbexit: values.Block + ) -> Iterator[values.Block]: self.position_at_end(bbenter) yield bbexit - if self.basic_block.terminator is None: + if self.block.terminator is None: self.branch(bbexit) @contextlib.contextmanager - def if_then(self, pred, likely=None): + def if_then( + self, + pred: values.Constant, + likely: bool | None = None, + ) -> Iterator[values.Block]: """ A context manager which sets up a conditional basic block based on the given predicate (a i1 value). If the conditional block @@ -305,7 +325,7 @@ def if_then(self, pred, likely=None): predicate is likely to be true or not, and metadata is issued for LLVM's optimizers to account for that. """ - bb = self.basic_block + bb = self.block bbif = self.append_basic_block(name=_label_suffix(bb.name, '.if')) bbend = self.append_basic_block(name=_label_suffix(bb.name, '.endif')) br = self.cbranch(pred, bbif, bbend) @@ -318,7 +338,9 @@ def if_then(self, pred, likely=None): self.position_at_end(bbend) @contextlib.contextmanager - def if_else(self, pred, likely=None): + def if_else( + self, pred: values.Constant, likely: bool | None = None + ) -> Iterator[tuple[ContextManager[values.Block], ContextManager[values.Block]]]: """ A context manager which sets up two conditional basic blocks based on the given predicate (a i1 value). @@ -333,7 +355,9 @@ def if_else(self, pred, likely=None): with otherwise: # emit instructions for when the predicate is false """ - bb = self.basic_block + bb = self.block + if bb is None: + raise AttributeError("No basic block set") bbif = self.append_basic_block(name=_label_suffix(bb.name, '.if')) bbelse = self.append_basic_block(name=_label_suffix(bb.name, '.else')) bbend = self.append_basic_block(name=_label_suffix(bb.name, '.endif')) @@ -348,13 +372,14 @@ def if_else(self, pred, likely=None): self.position_at_end(bbend) - def _insert(self, instr): - if self.debug_metadata is not None and 'dbg' not in instr.metadata: - instr.metadata['dbg'] = self.debug_metadata - self._block.instructions.insert(self._anchor, instr) + def _insert(self, instr: Instruction) -> None: + if self.debug_metadata is not None and "dbg" not in instr.metadata: + instr.metadata["dbg"] = self.debug_metadata + # FIXME: is this cast ok? + cast(values.Block, self._block).instructions.insert(self._anchor, instr) self._anchor += 1 - def _set_terminator(self, term): + def _set_terminator(self, term: Instruction) -> Instruction: assert not self.block.is_terminated self._insert(term) self.block.terminator = term @@ -364,179 +389,247 @@ def _set_terminator(self, term): # Arithmetic APIs # - @_binop('shl') - def shl(self, lhs, rhs, name=''): + def shl( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Left integer shift: name = lhs << rhs """ + return self._binop("shl", lhs, rhs, name) - @_binop('lshr') - def lshr(self, lhs, rhs, name=''): + def lshr( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Logical (unsigned) right integer shift: name = lhs >> rhs """ + return self._binop("lshr", lhs, rhs, name) - @_binop('ashr') - def ashr(self, lhs, rhs, name=''): + def ashr( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Arithmetic (signed) right integer shift: name = lhs >> rhs """ + return self._binop("ashr", lhs, rhs, name) - @_binop('add') - def add(self, lhs, rhs, name=''): + def add( + self, + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + flags: tuple[str, ...] = (), + ) -> instructions.Instruction: """ Integer addition: name = lhs + rhs """ + return self._binop("add", lhs, rhs, name, flags) - @_binop('fadd') - def fadd(self, lhs, rhs, name=''): + def fadd( + self, + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + flags: tuple[str, ...] = (), + ) -> instructions.Instruction: """ Floating-point addition: name = lhs + rhs """ + return self._binop("fadd", lhs, rhs, name, flags) - @_binop('sub') - def sub(self, lhs, rhs, name=''): + def sub( + self, + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + flags: tuple[str, ...] = (), + ) -> instructions.Instruction: """ Integer subtraction: name = lhs - rhs """ + return self._binop("sub", lhs, rhs, name, flags) - @_binop('fsub') - def fsub(self, lhs, rhs, name=''): + def fsub( + self, + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + flags: tuple[str, ...] = (), + ) -> instructions.Instruction: """ Floating-point subtraction: name = lhs - rhs """ + return self._binop("fsub", lhs, rhs, name, flags) - @_binop('mul') - def mul(self, lhs, rhs, name=''): + def mul( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Integer multiplication: name = lhs * rhs """ + return self._binop("mul", lhs, rhs, name) - @_binop('fmul') - def fmul(self, lhs, rhs, name=''): + def fmul( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Floating-point multiplication: name = lhs * rhs """ + return self._binop("fmul", lhs, rhs, name) - @_binop('udiv') - def udiv(self, lhs, rhs, name=''): + def udiv( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Unsigned integer division: name = lhs / rhs """ + return self._binop("udiv", lhs, rhs, name) - @_binop('sdiv') - def sdiv(self, lhs, rhs, name=''): + def sdiv( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Signed integer division: name = lhs / rhs """ + return self._binop("sdiv", lhs, rhs, name) - @_binop('fdiv') - def fdiv(self, lhs, rhs, name=''): + def fdiv( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Floating-point division: name = lhs / rhs """ + return self._binop("fdiv", lhs, rhs, name) - @_binop('urem') - def urem(self, lhs, rhs, name=''): + def urem( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Unsigned integer remainder: name = lhs % rhs """ + return self._binop("urem", lhs, rhs, name) - @_binop('srem') - def srem(self, lhs, rhs, name=''): + def srem( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Signed integer remainder: name = lhs % rhs """ + return self._binop("srem", lhs, rhs, name) - @_binop('frem') - def frem(self, lhs, rhs, name=''): + def frem( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Floating-point remainder: name = lhs % rhs """ + return self._binop("frem", lhs, rhs, name) - @_binop('or') - def or_(self, lhs, rhs, name=''): + def or_( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Bitwise integer OR: name = lhs | rhs """ + return self._binop("or", lhs, rhs, name) - @_binop('and') - def and_(self, lhs, rhs, name=''): + def and_( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Bitwise integer AND: name = lhs & rhs """ + return self._binop("and", lhs, rhs, name) - @_binop('xor') - def xor(self, lhs, rhs, name=''): + def xor( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Bitwise integer XOR: name = lhs ^ rhs """ + return self._binop("xor", lhs, rhs, name) - @_binop_with_overflow('sadd') - def sadd_with_overflow(self, lhs, rhs, name=''): + def sadd_with_overflow( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Signed integer addition with overflow: name = {result, overflow bit} = lhs + rhs """ + return self._binop_with_overflow("sadd", lhs, rhs, name) - @_binop_with_overflow('smul') - def smul_with_overflow(self, lhs, rhs, name=''): + def smul_with_overflow( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Signed integer multiplication with overflow: name = {result, overflow bit} = lhs * rhs """ + return self._binop_with_overflow("smul", lhs, rhs, name) - @_binop_with_overflow('ssub') - def ssub_with_overflow(self, lhs, rhs, name=''): + def ssub_with_overflow( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Signed integer subtraction with overflow: name = {result, overflow bit} = lhs - rhs """ + return self._binop_with_overflow("ssub", lhs, rhs, name) - @_binop_with_overflow('uadd') - def uadd_with_overflow(self, lhs, rhs, name=''): + def uadd_with_overflow( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Unsigned integer addition with overflow: name = {result, overflow bit} = lhs + rhs """ + return self._binop_with_overflow("uadd", lhs, rhs, name) - @_binop_with_overflow('umul') - def umul_with_overflow(self, lhs, rhs, name=''): + def umul_with_overflow( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Unsigned integer multiplication with overflow: name = {result, overflow bit} = lhs * rhs """ + return self._binop_with_overflow("umul", lhs, rhs, name) - @_binop_with_overflow('usub') - def usub_with_overflow(self, lhs, rhs, name=''): + def usub_with_overflow( + self, lhs: values.Constant, rhs: values.Constant, name: str = "" + ) -> instructions.Instruction: """ Unsigned integer subtraction with overflow: name = {result, overflow bit} = lhs - rhs """ + return self._binop_with_overflow("usub", lhs, rhs, name) # # Unary APIs # - def not_(self, value, name=''): + def not_( + self, + value: values.Constant, + name: str = "", + ) -> instructions.Instruction: """ Bitwise integer complement: name = ~value @@ -547,25 +640,41 @@ def not_(self, value, name=''): rhs = values.Constant(value.type, -1) return self.xor(value, rhs, name=name) - def neg(self, value, name=''): + def neg( + self, + value: values.Constant, + name: str = "", + ) -> instructions.Instruction: """ Integer negative: name = -value """ return self.sub(values.Constant(value.type, 0), value, name=name) - @_unop('fneg') - def fneg(self, arg, name='', flags=()): + def fneg( + self, + arg: values.Constant, + name: str = "", + flags: tuple[Any, ...] = (), + ) -> instructions.Instruction: """ Floating-point negative: name = -arg """ + return self._unop("fneg", arg, name, flags) # # Comparison APIs # - def _icmp(self, prefix, cmpop, lhs, rhs, name): + def _icmp( + self, + prefix: str, + cmpop: str, # Literal["==", "!=", "<", "<=", ">", ">="], + lhs: values.Constant, + rhs: values.Constant, + name: str, + ) -> instructions.ICMPInstr: try: op = _CMP_MAP[cmpop] except KeyError: @@ -576,7 +685,13 @@ def _icmp(self, prefix, cmpop, lhs, rhs, name): self._insert(instr) return instr - def icmp_signed(self, cmpop, lhs, rhs, name=''): + def icmp_signed( + self, + cmpop: str, # Literal["==", "!=", "<", "<=", ">", ">="], + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + ) -> instructions.ICMPInstr: """ Signed integer comparison: name = lhs rhs @@ -585,7 +700,13 @@ def icmp_signed(self, cmpop, lhs, rhs, name=''): """ return self._icmp('s', cmpop, lhs, rhs, name) - def icmp_unsigned(self, cmpop, lhs, rhs, name=''): + def icmp_unsigned( + self, + cmpop: str, # Literal["==", "!=", "<", "<=", ">", ">="], + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + ) -> instructions.ICMPInstr: """ Unsigned integer (or pointer) comparison: name = lhs rhs @@ -594,7 +715,14 @@ def icmp_unsigned(self, cmpop, lhs, rhs, name=''): """ return self._icmp('u', cmpop, lhs, rhs, name) - def fcmp_ordered(self, cmpop, lhs, rhs, name='', flags=()): + def fcmp_ordered( + self, + cmpop: str, # Literal["==", "!=", "<", "<=", ">", ">=", "ord", "uno"], + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + flags: tuple[Any, ...] = (), + ) -> instructions.FCMPInstr: """ Floating-point ordered comparison: name = lhs rhs @@ -606,11 +734,24 @@ def fcmp_ordered(self, cmpop, lhs, rhs, name='', flags=()): else: op = cmpop instr = instructions.FCMPInstr( - self.block, op, lhs, rhs, name=name, flags=flags) + self.block, + op, + lhs, + rhs, + name=name, + flags=list(flags), + ) self._insert(instr) return instr - def fcmp_unordered(self, cmpop, lhs, rhs, name='', flags=()): + def fcmp_unordered( + self, + cmpop: str, # Literal["==", "!=", "<", "<=", ">", ">=", "ord", "uno"], + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + flags: tuple[Any, ...] = (), + ) -> instructions.FCMPInstr: """ Floating-point unordered comparison: name = lhs rhs @@ -622,17 +763,36 @@ def fcmp_unordered(self, cmpop, lhs, rhs, name='', flags=()): else: op = cmpop instr = instructions.FCMPInstr( - self.block, op, lhs, rhs, name=name, flags=flags) + self.block, + op, + lhs, + rhs, + name=name, + flags=list(flags), + ) self._insert(instr) return instr - def select(self, cond, lhs, rhs, name='', flags=()): + def select( + self, + cond: values.Constant, + lhs: values.Constant, + rhs: values.Constant, + name: str = "", + flags: tuple[Any, ...] = (), + ) -> instructions.SelectInstr: """ Ternary select operator: name = cond ? lhs : rhs """ - instr = instructions.SelectInstr(self.block, cond, lhs, rhs, name=name, - flags=flags) + instr = instructions.SelectInstr( + self.block, + cond, + lhs, + rhs, + name=name, + flags=flags, + ) self._insert(instr) return instr @@ -640,174 +800,302 @@ def select(self, cond, lhs, rhs, name='', flags=()): # Cast APIs # - @_castop('trunc') - def trunc(self, value, typ, name=''): + def trunc( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Truncating integer downcast to a smaller type: name = (typ) value """ + return self._castop("trunc", value, typ, name) - @_castop('zext') - def zext(self, value, typ, name=''): + def zext( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Zero-extending integer upcast to a larger type: name = (typ) value """ + return self._castop("zext", value, typ, name) - @_castop('sext') - def sext(self, value, typ, name=''): + def sext( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Sign-extending integer upcast to a larger type: name = (typ) value """ + return self._castop("sext", value, typ, name) - @_castop('fptrunc') - def fptrunc(self, value, typ, name=''): + def fptrunc( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Floating-point downcast to a less precise type: name = (typ) value """ + return self._castop("fptrunc", value, typ, name) - @_castop('fpext') - def fpext(self, value, typ, name=''): + def fpext( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Floating-point upcast to a more precise type: name = (typ) value """ + return self._castop("fpext", value, typ, name) - @_castop('bitcast') - def bitcast(self, value, typ, name=''): + def bitcast( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Pointer cast to a different pointer type: name = (typ) value """ + return self._castop("bitcast", value, typ, name) - @_castop('addrspacecast') - def addrspacecast(self, value, typ, name=''): + def addrspacecast( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Pointer cast to a different address space: name = (typ) value """ + return self._castop("addrspacecast", value, typ, name) - @_castop('fptoui') - def fptoui(self, value, typ, name=''): + def fptoui( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Convert floating-point to unsigned integer: name = (typ) value """ + return self._castop("fptoui", value, typ, name) - @_castop('uitofp') - def uitofp(self, value, typ, name=''): + def uitofp( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Convert unsigned integer to floating-point: name = (typ) value """ + return self._castop("uitofp", value, typ, name) - @_castop('fptosi') - def fptosi(self, value, typ, name=''): + def fptosi( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Convert floating-point to signed integer: name = (typ) value """ + return self._castop("fptosi", value, typ, name) - @_castop('sitofp') - def sitofp(self, value, typ, name=''): + def sitofp( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Convert signed integer to floating-point: name = (typ) value """ + return self._castop("sitofp", value, typ, name) - @_castop('ptrtoint') - def ptrtoint(self, value, typ, name=''): + def ptrtoint( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Cast pointer to integer: name = (typ) value """ + return self._castop("ptrtoint", value, typ, name) - @_castop('inttoptr') - def inttoptr(self, value, typ, name=''): + def inttoptr( + self, + value: values.Constant, + typ: types.Type, + name: str = "", + ) -> values.Constant | instructions.CastInstr: """ Cast integer to pointer: name = (typ) value """ + return self._castop("inttoptr", value, typ, name) # # Memory APIs # - def alloca(self, typ, size=None, name=''): + def alloca( + self, + typ: types.Type, + size: values.Value | None = None, + name: str = "", + ) -> instructions.AllocaInstr: """ Stack-allocate a slot for *size* elements of the given type. (default one element) """ + # FIXME: values.Value has no .type if size is None: pass elif isinstance(size, (values.Value, values.Constant)): - assert isinstance(size.type, types.IntType) + assert isinstance(size.type, types.IntType) # type: ignore else: # If it is not a Value instance, # assume to be a Python integer. size = values.Constant(types.IntType(32), size) - al = instructions.AllocaInstr(self.block, typ, size, name) + al = instructions.AllocaInstr( + self.block, + typ, + size, + name, + ) self._insert(al) return al - def load(self, ptr, name='', align=None): + def load( + self, + ptr: values.Value, + name: str = "", + align: int | None = None, + ) -> instructions.LoadInstr: """ Load value from pointer, with optional guaranteed alignment: name = *ptr """ - if not isinstance(ptr.type, types.PointerType): - msg = "cannot load from value of type %s (%r): not a pointer" - raise TypeError(msg % (ptr.type, str(ptr))) - ld = instructions.LoadInstr(self.block, ptr, name) + # FIXME: values.Value has no .type + if not isinstance(ptr.type, types.PointerType): # type: ignore + raise TypeError( + f"cannot load from value of type {ptr.type} " # type: ignore + f"({str(ptr)!r}): not a pointer" + ) + ld = instructions.LoadInstr( + self.block, + ptr, + name, + ) ld.align = align self._insert(ld) return ld - def store(self, value, ptr, align=None): + def store( + self, + value: values.Value, + ptr: values.Value, + align: int | None = None, + ) -> instructions.StoreInstr: """ Store value to pointer, with optional guaranteed alignment: *ptr = name """ - if not isinstance(ptr.type, types.PointerType): - msg = "cannot store to value of type %s (%r): not a pointer" - raise TypeError(msg % (ptr.type, str(ptr))) - if ptr.type.pointee != value.type: - raise TypeError("cannot store %s to %s: mismatching types" - % (value.type, ptr.type)) - st = instructions.StoreInstr(self.block, value, ptr) + # FIXME: values.Value has no .type + if not isinstance(ptr.type, types.PointerType): # type: ignore + raise TypeError( + f"cannot store to value of type {ptr.type} " # type: ignore + f"({str(ptr)!r}): not a pointer" + ) + if ptr.type.pointee != value.type: # type: ignore + raise TypeError( + f"cannot store {value.type} to " # type: ignore + f"{ptr.type}: mismatching types" # type: ignore + ) + st = instructions.StoreInstr( + self.block, + value, + ptr, + ) st.align = align self._insert(st) return st - def load_atomic(self, ptr, ordering, align, name=''): + def load_atomic( + self, + ptr: values.Value, + ordering: str, # Literal["seq_cst"] and which others? + align: int, + name: str = "", + ) -> instructions.LoadAtomicInstr: """ Load value from pointer, with optional guaranteed alignment: name = *ptr """ - if not isinstance(ptr.type, types.PointerType): - msg = "cannot load from value of type %s (%r): not a pointer" - raise TypeError(msg % (ptr.type, str(ptr))) + # FIXME: values.Value has no .type + if not isinstance(ptr.type, types.PointerType): # type: ignore + raise TypeError( + f"cannot load from value of type {ptr.type} ({str(ptr)!r}): not a pointer" # type: ignore + ) ld = instructions.LoadAtomicInstr( - self.block, ptr, ordering, align, name) + self.block, + ptr, + ordering, + align, + name, + ) self._insert(ld) return ld - def store_atomic(self, value, ptr, ordering, align): + def store_atomic( + self, + value: values.Value, + ptr: values.Value, + ordering: str, # Literal["seq_cst"] and which others? + align: int, + ) -> instructions.StoreAtomicInstr: """ Store value to pointer, with optional guaranteed alignment: *ptr = name """ - if not isinstance(ptr.type, types.PointerType): - msg = "cannot store to value of type %s (%r): not a pointer" - raise TypeError(msg % (ptr.type, str(ptr))) - if ptr.type.pointee != value.type: - raise TypeError("cannot store %s to %s: mismatching types" - % (value.type, ptr.type)) + # FIXME: values.Value has no .type + if not isinstance(ptr.type, types.PointerType): # type: ignore + msg = "cannot store to value of type {} ({!r}): not a pointer" + raise TypeError(msg.format(ptr.type, str(ptr))) # type: ignore + if ptr.type.pointee != value.type: # type: ignore + raise TypeError( + "cannot store {} to {}: mismatching types".format(value.type, ptr.type) # type: ignore + ) st = instructions.StoreAtomicInstr( - self.block, value, ptr, ordering, align) + self.block, + value, + ptr, + ordering, + align, + ) self._insert(st) return st @@ -815,15 +1103,20 @@ def store_atomic(self, value, ptr, ordering, align): # Terminators APIs # - def switch(self, value, default): + def switch(self, value: values.Value, default: Any) -> instructions.SwitchInstr: """ Create a switch-case with a single *default* target. """ - swt = instructions.SwitchInstr(self.block, 'switch', value, default) + swt = instructions.SwitchInstr( + self.block, + "switch", + value, + default, + ) self._set_terminator(swt) return swt - def branch(self, target): + def branch(self, target: values.Block) -> Branch: """ Unconditional branch to *target*. """ @@ -831,7 +1124,12 @@ def branch(self, target): self._set_terminator(br) return br - def cbranch(self, cond, truebr, falsebr): + def cbranch( + self, + cond: values.Constant, + truebr: values.Block, + falsebr: values.Block, + ) -> instructions.ConditionalBranch: """ Conditional branch to *truebr* if *cond* is true, else to *falsebr*. """ @@ -840,7 +1138,7 @@ def cbranch(self, cond, truebr, falsebr): self._set_terminator(br) return br - def branch_indirect(self, addr): + def branch_indirect(self, addr: values.Block) -> instructions.IndirectBranch: """ Indirect branch to target *addr*. """ @@ -848,32 +1146,41 @@ def branch_indirect(self, addr): self._set_terminator(br) return br - def ret_void(self): + def ret_void(self) -> instructions.Instruction: """ Return from function without a value. """ return self._set_terminator( instructions.Ret(self.block, "ret void")) - def ret(self, value): + def ret(self, value: values.Value) -> instructions.Instruction: """ Return from function with the given *value*. """ return self._set_terminator( instructions.Ret(self.block, "ret", value)) - def resume(self, landingpad): + def resume(self, landingpad: LandingPadInstr) -> Branch: """ Resume an in-flight exception. """ - br = instructions.Branch(self.block, "resume", [landingpad]) + br = Branch(self.block, "resume", [landingpad]) self._set_terminator(br) return br # Call APIs - def call(self, fn, args, name='', cconv=None, tail=False, fastmath=(), - attrs=(), arg_attrs=None): + def call( + self, + fn: values.Function | instructions.InlineAsm, + args: Sequence[values.Constant], + name: str = "", + cconv: None = None, + tail: bool = False, + fastmath: tuple[Any, ...] = (), + attrs: tuple[Any, ...] = (), + arg_attrs: None = None, + ) -> instructions.CallInstr: """ Call function *fn* with *args*: name = fn(args...) @@ -884,14 +1191,27 @@ def call(self, fn, args, name='', cconv=None, tail=False, fastmath=(), self._insert(inst) return inst - def asm(self, ftype, asm, constraint, args, side_effect, name=''): + def asm( + self, + ftype: types.FunctionType, + asm: str, + constraint: Any, + args: Sequence[values.Constant], + side_effect: bool, + name: str = "", + ) -> CallInstr: """ Inline assembler. """ - asm = instructions.InlineAsm(ftype, asm, constraint, side_effect) - return self.call(asm, args, name) + asm_instr = instructions.InlineAsm(ftype, asm, constraint, side_effect) + return self.call(asm_instr, args, name) - def load_reg(self, reg_type, reg_name, name=''): + def load_reg( + self, + reg_type: types.Type, + reg_name: str, + name: str = "", + ) -> CallInstr: """ Load a register value into an LLVM value. Example: v = load_reg(IntType(32), "eax") @@ -899,7 +1219,13 @@ def load_reg(self, reg_type, reg_name, name=''): ftype = types.FunctionType(reg_type, []) return self.asm(ftype, "", "={%s}" % reg_name, [], False, name) - def store_reg(self, value, reg_type, reg_name, name=''): + def store_reg( + self, + value: values.Constant, + reg_type: types.Type, + reg_name: str, + name: str = "", + ) -> CallInstr: """ Store an LLVM value inside a register Example: @@ -908,18 +1234,42 @@ def store_reg(self, value, reg_type, reg_name, name=''): ftype = types.FunctionType(types.VoidType(), [reg_type]) return self.asm(ftype, "", "{%s}" % reg_name, [value], True, name) - def invoke(self, fn, args, normal_to, unwind_to, - name='', cconv=None, fastmath=(), attrs=(), arg_attrs=None): - inst = instructions.InvokeInstr(self.block, fn, args, normal_to, - unwind_to, name=name, cconv=cconv, - fastmath=fastmath, attrs=attrs, - arg_attrs=arg_attrs) + def invoke( + self, + fn: values.Function, + args: Sequence[values.Constant], + normal_to: values.Block, + unwind_to: values.Block, + name: str = "", + cconv: str | None = None, + fastmath: tuple[str, ...] = (), + attrs: tuple[str, ...] = (), + arg_attrs: dict[int, tuple[Any, ...]] | None = None, + ) -> instructions.InvokeInstr: + inst = instructions.InvokeInstr( + self.block, + fn, + args, + normal_to, + unwind_to, + name=name, + cconv=cconv, + fastmath=fastmath, + attrs=attrs, + arg_attrs=arg_attrs, + ) self._set_terminator(inst) return inst # GEP APIs - def gep(self, ptr, indices, inbounds=False, name=''): + def gep( + self, + ptr: values.Constant, + indices: list[values.Constant], + inbounds: bool = False, + name: str = "", + ) -> instructions.GEPInstr: """ Compute effective address (getelementptr): name = getelementptr ptr, @@ -931,7 +1281,12 @@ def gep(self, ptr, indices, inbounds=False, name=''): # Vector Operations APIs - def extract_element(self, vector, idx, name=''): + def extract_element( + self, + vector: values.Constant, + idx: values.Constant, + name: str = "", + ) -> instructions.ExtractElement: """ Returns the value at position idx. """ @@ -939,7 +1294,13 @@ def extract_element(self, vector, idx, name=''): self._insert(instr) return instr - def insert_element(self, vector, value, idx, name=''): + def insert_element( + self, + vector: values.Constant, + value: values.Constant, + idx: values.Constant, + name: str = "", + ) -> instructions.InsertElement: """ Returns vector with vector[idx] replaced by value. The result is undefined if the idx is larger or equal the vector length. @@ -949,7 +1310,13 @@ def insert_element(self, vector, value, idx, name=''): self._insert(instr) return instr - def shuffle_vector(self, vector1, vector2, mask, name=''): + def shuffle_vector( + self, + vector1: values.Constant, + vector2: values.Constant, + mask: values.Constant, + name: str = "", + ) -> instructions.ShuffleVector: """ Constructs a permutation of elements from *vector1* and *vector2*. Returns a new vector in the same length of *mask*. @@ -964,7 +1331,12 @@ def shuffle_vector(self, vector1, vector2, mask, name=''): # Aggregate APIs - def extract_value(self, agg, idx, name=''): + def extract_value( + self, + agg: values.Constant, + idx: values.Constant | list[values.Constant], + name: str = "", + ) -> instructions.ExtractValue: """ Extract member number *idx* from aggregate. """ @@ -974,7 +1346,13 @@ def extract_value(self, agg, idx, name=''): self._insert(instr) return instr - def insert_value(self, agg, value, idx, name=''): + def insert_value( + self, + agg: values.Constant, + value: values.Constant, + idx: values.Constant | list[values.Constant], + name: str = "", + ) -> instructions.InsertValue: """ Insert *value* into member number *idx* from aggregate. """ @@ -986,25 +1364,45 @@ def insert_value(self, agg, value, idx, name=''): # PHI APIs - def phi(self, typ, name='', flags=()): + def phi( + self, + typ: types.Type, + name: str = "", + flags: tuple[Any, ...] = (), + ) -> instructions.PhiInstr: inst = instructions.PhiInstr(self.block, typ, name=name, flags=flags) self._insert(inst) return inst # Special API - def unreachable(self): + def unreachable(self) -> instructions.Unreachable: inst = instructions.Unreachable(self.block) self._set_terminator(inst) return inst - def atomic_rmw(self, op, ptr, val, ordering, name=''): + def atomic_rmw( + self, + op: str, + ptr: values.Constant, + val: values.Constant, + ordering: str, + name: str = "", + ) -> instructions.AtomicRMW: inst = instructions.AtomicRMW( self.block, op, ptr, val, ordering, name=name) self._insert(inst) return inst - def cmpxchg(self, ptr, cmp, val, ordering, failordering=None, name=''): + def cmpxchg( + self, + ptr: values.Constant, + cmp: values.Constant, + val: values.Constant, + ordering: str, + failordering: str | None = None, + name: str = "", + ) -> instructions.CmpXchg: """ Atomic compared-and-set: atomic { @@ -1023,19 +1421,30 @@ def cmpxchg(self, ptr, cmp, val, ordering, failordering=None, name=''): self._insert(inst) return inst - def landingpad(self, typ, name='', cleanup=False): + def landingpad( + self, + typ: types.Type, + name: str = "", + cleanup: bool = False, + ) -> instructions.LandingPadInstr: inst = instructions.LandingPadInstr(self.block, typ, name, cleanup) self._insert(inst) return inst - def assume(self, cond): + def assume(self, cond: values.Constant) -> CallInstr: """ Optimizer hint: assume *cond* is always true. """ fn = self.module.declare_intrinsic("llvm.assume") - return self.call(fn, [cond]) + # FIXME: is the cast here ok? + return self.call(cast(values.Function, fn), [cond]) - def fence(self, ordering, targetscope=None, name=''): + def fence( + self, + ordering: str, # Literal["acquire", "release", "acq_rel", "seq_cst"], + targetscope: None = None, + name: str = "", + ) -> instructions.Fence: """ Add a memory barrier, preventing certain reorderings of load and/or store accesses with @@ -1045,47 +1454,55 @@ def fence(self, ordering, targetscope=None, name=''): self._insert(inst) return inst - @_uniop_intrinsic_int("llvm.bswap") - def bswap(self, cond): + def bswap(self, cond: values.Constant, name: str = "") -> Instruction: + """ Used to byte swap integer values with an even number of bytes (positive multiple of 16 bits) """ + return self._uniop_intrinsic_int("llvm.bswap", cond, name) - @_uniop_intrinsic_int("llvm.bitreverse") - def bitreverse(self, cond): + def bitreverse(self, cond: values.Constant, name: str = "") -> Instruction: """ Reverse the bitpattern of an integer value; for example 0b10110110 becomes 0b01101101. """ + return self._uniop_intrinsic_int("llvm.bitreverse", cond, name) - @_uniop_intrinsic_int("llvm.ctpop") - def ctpop(self, cond): + def ctpop(self, cond: values.Constant, name: str = "") -> Instruction: """ Counts the number of bits set in a value. """ + return self._uniop_intrinsic_int("llvm.ctpop", cond, name) - @_uniop_intrinsic_with_flag("llvm.ctlz") - def ctlz(self, cond, flag): + def ctlz(self, cond: values.Constant, flag: bool, name: str = "") -> Instruction: """ Counts leading zero bits in *value*. Boolean *flag* indicates whether the result is defined for ``0``. """ + return self._uniop_intrinsic_with_flag("llvm.ctlz", cond, flag, name) - @_uniop_intrinsic_with_flag("llvm.cttz") - def cttz(self, cond, flag): + def cttz(self, cond: values.Constant, flag: bool, name: str = "") -> Instruction: """ Counts trailing zero bits in *value*. Boolean *flag* indicates whether the result is defined for ``0``. """ + return self._uniop_intrinsic_with_flag("llvm.cttz", cond, flag, name) - @_triop_intrinsic("llvm.fma") - def fma(self, a, b, c): + def fma( + self, a: values.Constant, b: values.Constant, c: values.Constant, name: str = "" + ) -> Instruction: """ Perform the fused multiply-add operation. """ + return self._triop_intrinsic("llvm.fma", a, b, c, name) - def convert_from_fp16(self, a, to=None, name=''): + def convert_from_fp16( + self, + a: values.Constant, + to: values.Constant | None = None, + name: str = "", + ) -> CallInstr: """ Convert from an i16 to the given FP type """ @@ -1098,10 +1515,11 @@ def convert_from_fp16(self, a, to=None, name=''): opname = 'llvm.convert.from.fp16' fn = self.module.declare_intrinsic(opname, [to]) - return self.call(fn, [a], name) + # FIXME: is the cast here ok? + return self.call(cast(values.Function, fn), [a], name) - @_uniop_intrinsic_float("llvm.convert.to.fp16") - def convert_to_fp16(self, a): + def convert_to_fp16(self, a: values.Constant, name: str = "") -> Instruction: """ Convert the given FP number to an i16 """ + return self._uniop_intrinsic_float("llvm.convert.to.fp16", a, name) diff --git a/llvmlite/ir/context.py b/llvmlite/ir/context.py index 47d1ebbe2..69588bf3d 100644 --- a/llvmlite/ir/context.py +++ b/llvmlite/ir/context.py @@ -1,13 +1,14 @@ -from llvmlite.ir import _utils -from llvmlite.ir import types +from __future__ import annotations + +from llvmlite.ir import _utils, types class Context(object): - def __init__(self): + def __init__(self) -> None: self.scope = _utils.NameScope() - self.identified_types = {} + self.identified_types: dict[str, types.IdentifiedStructType] = {} - def get_identified_type(self, name): + def get_identified_type(self, name: str) -> types.IdentifiedStructType: if name not in self.identified_types: self.scope.register(name) ty = types.IdentifiedStructType(self, name) diff --git a/llvmlite/ir/instructions.py b/llvmlite/ir/instructions.py index 02711b71a..a0c399a38 100644 --- a/llvmlite/ir/instructions.py +++ b/llvmlite/ir/instructions.py @@ -2,75 +2,105 @@ Implementation of LLVM IR instructions. """ +from __future__ import annotations + +from typing import Any, Iterable, Sequence, cast +# from typing_extensions import Literal + from llvmlite.ir import types from llvmlite.ir.values import (Block, Function, Value, NamedValue, Constant, MetaDataArgument, MetaDataString, AttributeSet, Undefined, ArgumentAttributes) from llvmlite.ir._utils import _HasMetadata +from llvmlite.ir.module import Module class Instruction(NamedValue, _HasMetadata): - def __init__(self, parent, typ, opname, operands, name='', flags=()): - super(Instruction, self).__init__(parent, typ, name=name) + def __init__( + self, + parent: Module | Block, + typ: types.Type, + opname: str, + operands: list[Value], + name: str = "", + flags: Iterable[Any] = (), + ) -> None: + super().__init__(parent, typ, name=name) assert isinstance(parent, Block) assert isinstance(flags, (tuple, list)) self.opname = opname self.operands = operands self.flags = list(flags) - self.metadata = {} + self.metadata: dict[str, Any] = {} @property - def function(self): - return self.parent.function + def function(self) -> Function: + # FIXME: is the cast here ok? + return cast(Block, self.parent).function @property - def module(self): - return self.parent.function.module + def module(self) -> Module: + return self.parent.function.module # type: ignore - def descr(self, buf): + def descr(self, buf: list[str]) -> None: opname = self.opname if self.flags: opname = ' '.join([opname] + self.flags) - operands = ', '.join([op.get_reference() for op in self.operands]) + # FIXME: get_reference is unknown + operands = ', '.join([op.get_reference() for op in self.operands]) # type: ignore typ = self.type metadata = self._stringify_metadata(leading_comma=True) buf.append("{0} {1} {2}{3}\n" .format(opname, typ, operands, metadata)) - def replace_usage(self, old, new): + def replace_usage(self, old: Constant, new: Constant) -> None: if old in self.operands: - ops = [] + ops: list[Value] = [] for op in self.operands: ops.append(new if op is old else op) - self.operands = tuple(ops) + self.operands = ops self._clear_string_cache() - def __repr__(self): + def __repr__(self) -> str: return "" % ( self.__class__.__name__, self.name, self.type, self.opname, self.operands) class CallInstrAttributes(AttributeSet): - _known = frozenset(['noreturn', 'nounwind', 'readonly', 'readnone', - 'noinline', 'alwaysinline']) + # FIXME: _known is tuple in AttributeSet + _known = frozenset( # type: ignore + ["noreturn", "nounwind", "readonly", "readnone", "noinline", "alwaysinline"] + ) class FastMathFlags(AttributeSet): - _known = frozenset(['fast', 'nnan', 'ninf', 'nsz', 'arcp', 'contract', - 'afn', 'reassoc']) + # FIXME: _known is tuple in AttributeSet + _known = frozenset( # type: ignore + ["fast", "nnan", "ninf", "nsz", "arcp", "contract", "afn", "reassoc"] + ) class CallInstr(Instruction): - def __init__(self, parent, func, args, name='', cconv=None, tail=False, - fastmath=(), attrs=(), arg_attrs=None): + def __init__( + self, + parent: Block, + func: Function | InlineAsm, + args: Sequence[Constant | MetaDataArgument | Function], + name: str = "", + cconv: str | None = None, + tail: bool = False, + fastmath: tuple[str, ...] = (), + attrs: tuple[Any, ...] = (), + arg_attrs: dict[int, tuple[Any, ...]] | None = None, + ): self.cconv = (func.calling_convention if cconv is None and isinstance(func, Function) else cconv) self.tail = tail self.fastmath = FastMathFlags(fastmath) self.attributes = CallInstrAttributes(attrs) - self.arg_attributes = {} + self.arg_attributes: dict[int, ArgumentAttributes] = {} if arg_attrs: for idx, attrs in arg_attrs.items(): if not (0 <= idx < len(args)): @@ -79,12 +109,15 @@ def __init__(self, parent, func, args, name='', cconv=None, tail=False, self.arg_attributes[idx] = ArgumentAttributes(attrs) # Fix and validate arguments + # FIXME: InlineASM has no function_type args = list(args) - for i in range(len(func.function_type.args)): + for i in range(len(func.function_type.args)): # type: ignore arg = args[i] - expected_type = func.function_type.args[i] - if (isinstance(expected_type, types.MetaDataType) and - arg.type != expected_type): + expected_type = func.function_type.args[i] # type: ignore + if ( + isinstance(expected_type, types.MetaDataType) + and arg.type != expected_type + ): arg = MetaDataArgument(arg) if arg.type != expected_type: msg = ("Type of #{0} arg mismatch: {1} != {2}" @@ -93,32 +126,33 @@ def __init__(self, parent, func, args, name='', cconv=None, tail=False, args[i] = arg super(CallInstr, self).__init__(parent, func.function_type.return_type, - "call", [func] + list(args), name=name) + "call", [func] + list(args), name=name) # type: ignore @property - def callee(self): - return self.operands[0] + def callee(self) -> Function: + # FIXME: this can only ever by Function, not Constant | MetaDataArgument + return self.operands[0] # type: ignore @callee.setter - def callee(self, newcallee): + def callee(self, newcallee: Function) -> None: self.operands[0] = newcallee @property - def args(self): + def args(self) -> list[Value]: return self.operands[1:] - def replace_callee(self, newfunc): + def replace_callee(self, newfunc: Function) -> None: if newfunc.function_type != self.callee.function_type: raise TypeError("New function has incompatible type") self.callee = newfunc @property - def called_function(self): + def called_function(self) -> Constant | Function | MetaDataArgument: """Alias for llvmpy""" return self.callee - def _descr(self, buf, add_metadata): - def descr_arg(i, a): + def _descr(self, buf: list[str], add_metadata: bool) -> None: + def descr_arg(i: int, a: Any) -> str: if i in self.arg_attributes: attrs = ' '.join(self.arg_attributes[i]._to_list()) + ' ' else: @@ -129,7 +163,7 @@ def descr_arg(i, a): fnty = self.callee.function_type # Only print function type if variable-argument if fnty.var_arg: - ty = fnty + ty: types.Type = fnty # Otherwise, just print the return type. else: # Fastmath flag work only in this case @@ -138,23 +172,34 @@ def descr_arg(i, a): if self.cconv: callee_ref = "{0} {1}".format(self.cconv, callee_ref) buf.append("{tail}{op}{fastmath} {callee}({args}){attr}{meta}\n".format( - tail='tail ' if self.tail else '', - op=self.opname, - callee=callee_ref, - fastmath=''.join([" " + attr for attr in self.fastmath]), - args=args, - attr=''.join([" " + attr for attr in self.attributes]), - meta=(self._stringify_metadata(leading_comma=True) - if add_metadata else ""), - )) - - def descr(self, buf): + tail='tail ' if self.tail else '', + op=self.opname, + callee=callee_ref, + fastmath=''.join([" " + attr for attr in self.fastmath]), + args=args, + attr=''.join([" " + attr for attr in self.attributes]), + meta=(self._stringify_metadata(leading_comma=True) + if add_metadata else ""), + )) + + def descr(self, buf: list[str]) -> None: self._descr(buf, add_metadata=True) class InvokeInstr(CallInstr): - def __init__(self, parent, func, args, normal_to, unwind_to, name='', - cconv=None, fastmath=(), attrs=(), arg_attrs=None): + def __init__( + self, + parent: Block, + func: Function, + args: Sequence[Constant | MetaDataArgument | Function], + normal_to: Block, + unwind_to: Block, + name: str = "", + cconv: str | None = None, + fastmath: tuple[str, ...] = (), + attrs: tuple[Any, ...] = (), + arg_attrs: dict[int, tuple[Any, ...]] | None = None, + ) -> None: assert isinstance(normal_to, Block) assert isinstance(unwind_to, Block) super(InvokeInstr, self).__init__(parent, func, args, name, cconv, @@ -164,32 +209,36 @@ def __init__(self, parent, func, args, normal_to, unwind_to, name='', self.normal_to = normal_to self.unwind_to = unwind_to - def descr(self, buf): + def descr(self, buf: list[str]) -> None: super(InvokeInstr, self)._descr(buf, add_metadata=False) buf.append(" to label {0} unwind label {1}{metadata}\n".format( - self.normal_to.get_reference(), - self.unwind_to.get_reference(), - metadata=self._stringify_metadata(leading_comma=True), - )) + self.normal_to.get_reference(), + self.unwind_to.get_reference(), + metadata=self._stringify_metadata(leading_comma=True), + )) class Terminator(Instruction): - def __init__(self, parent, opname, operands): - super(Terminator, self).__init__(parent, types.VoidType(), opname, - operands) - - def descr(self, buf): + def __init__( + self, + parent: Block, + opname: str, + operands: list[Value], + ) -> None: + super().__init__(parent, types.VoidType(), opname, operands) + + def descr(self, buf: list[str]) -> None: opname = self.opname - operands = ', '.join(["{0} {1}".format(op.type, op.get_reference()) + # FIXME: type and get_reference unknown + operands = ', '.join(["{0} {1}".format(op.type, op.get_reference()) # type: ignore for op in self.operands]) metadata = self._stringify_metadata(leading_comma=True) buf.append("{0} {1}{2}".format(opname, operands, metadata)) class PredictableInstr(Instruction): - - def set_weights(self, weights): - operands = [MetaDataString(self.module, "branch_weights")] + def set_weights(self, weights: Iterable[int]) -> None: + operands: list[Constant] = [MetaDataString(self.module, "branch_weights")] # type: ignore for w in weights: if w < 0: raise ValueError("branch weight must be a positive integer") @@ -199,25 +248,33 @@ def set_weights(self, weights): class Ret(Terminator): - def __init__(self, parent, opname, return_value=None): + def __init__( + self, + parent: Block, + opname: str, + return_value: Value | None = None, + ) -> None: operands = [return_value] if return_value is not None else [] super(Ret, self).__init__(parent, opname, operands) @property - def return_value(self): + def return_value(self) -> Value | None: if self.operands: return self.operands[0] else: return None - def descr(self, buf): + def descr(self, buf: list[str]) -> None: return_value = self.return_value metadata = self._stringify_metadata(leading_comma=True) if return_value is not None: - buf.append("{0} {1} {2}{3}\n" - .format(self.opname, return_value.type, - return_value.get_reference(), - metadata)) + buf.append("{0} {1} {2}{3}\n".format( + self.opname, + return_value.type, # type: ignore + return_value.get_reference(), # type: ignore + metadata, + ) + ) else: buf.append("{0}{1}\n".format(self.opname, metadata)) @@ -231,54 +288,55 @@ class ConditionalBranch(PredictableInstr, Terminator): class IndirectBranch(PredictableInstr, Terminator): - def __init__(self, parent, opname, addr): + def __init__(self, parent: Block, opname: str, addr: Block) -> None: super(IndirectBranch, self).__init__(parent, opname, [addr]) - self.destinations = [] + self.destinations: list[Block] = [] @property - def address(self): + def address(self) -> Value: return self.operands[0] - def add_destination(self, block): + def add_destination(self, block: Block) -> None: assert isinstance(block, Block) self.destinations.append(block) - def descr(self, buf): + def descr(self, buf: list[str]) -> None: destinations = ["label {0}".format(blk.get_reference()) for blk in self.destinations] buf.append("indirectbr {0} {1}, [{2}] {3}\n".format( - self.address.type, - self.address.get_reference(), + self.address.type, # type: ignore + self.address.get_reference(), # type: ignore ', '.join(destinations), - self._stringify_metadata(leading_comma=True), - )) + self._stringify_metadata(leading_comma=True), + )) class SwitchInstr(PredictableInstr, Terminator): - def __init__(self, parent, opname, val, default): + def __init__(self, parent: Block, opname: str, val: Value, default: Value) -> None: super(SwitchInstr, self).__init__(parent, opname, [val]) self.default = default - self.cases = [] + self.cases: list[tuple[Value, Block]] = [] @property - def value(self): + def value(self) -> Value: return self.operands[0] - def add_case(self, val, block): + def add_case(self, val: Value, block: Block) -> None: assert isinstance(block, Block) + # FIXME: this check should be unnecessary if not isinstance(val, Value): - val = Constant(self.value.type, val) + val = Constant(self.value.type, val) # type: ignore self.cases.append((val, block)) - def descr(self, buf): - cases = ["{0} {1}, label {2}".format(val.type, val.get_reference(), + def descr(self, buf: list[str]) -> None: + cases = ["{0} {1}, label {2}".format(val.type, val.get_reference(), # type: ignore blk.get_reference()) for val, blk in self.cases] buf.append("switch {0} {1}, label {2} [{3}] {4}\n".format( - self.value.type, - self.value.get_reference(), - self.default.get_reference(), + self.value.type, # type: ignore + self.value.get_reference(), # type: ignore + self.default.get_reference(), # type: ignore ' '.join(cases), self._stringify_metadata(leading_comma=True), )) @@ -289,40 +347,63 @@ class Resume(Terminator): class SelectInstr(Instruction): - def __init__(self, parent, cond, lhs, rhs, name='', flags=()): + def __init__( + self, + parent: Block, + cond: Constant, + lhs: Constant, + rhs: Constant, + name: str = "", + flags: tuple[Any, ...] = (), + ) -> None: assert lhs.type == rhs.type super(SelectInstr, self).__init__(parent, lhs.type, "select", [cond, lhs, rhs], name=name, flags=flags) @property - def cond(self): + def cond(self) -> Value: return self.operands[0] @property - def lhs(self): + def lhs(self) -> Value: return self.operands[1] @property - def rhs(self): + def rhs(self) -> Value: return self.operands[2] - def descr(self, buf): + def descr(self, buf: list[str]) -> None: buf.append("select {0} {1} {2}, {3} {4}, {5} {6} {7}\n".format( - ' '.join(self.flags), - self.cond.type, self.cond.get_reference(), - self.lhs.type, self.lhs.get_reference(), - self.rhs.type, self.rhs.get_reference(), - self._stringify_metadata(leading_comma=True), - )) + " ".join(self.flags), + self.cond.type, # type: ignore + self.cond.get_reference(), # type: ignore + self.lhs.type, # type: ignore + self.lhs.get_reference(), # type: ignore + self.rhs.type, # type: ignore + self.rhs.get_reference(), # type: ignore + self._stringify_metadata(leading_comma=True), + ) + ) class CompareInstr(Instruction): # Define the following in subclasses OPNAME = 'invalid-compare' - VALID_OP = {} - - def __init__(self, parent, op, lhs, rhs, name='', flags=[]): + VALID_OP: dict[str, str] = {} + VALID_FLAG: set[str] + + def __init__( + self, + parent: Block, + op: str, + lhs: Constant, + rhs: Constant, + name: str = "", + flags: list[Any] = [], + ) -> None: + # FIXME: mutable container as default argument + # FIXME: why is flags a list here? if op not in self.VALID_OP: raise ValueError("invalid comparison %r for %s" % (op, self.OPNAME)) for flag in flags: @@ -330,6 +411,7 @@ def __init__(self, parent, op, lhs, rhs, name='', flags=[]): raise ValueError("invalid flag %r for %s" % (flag, self.OPNAME)) opname = self.OPNAME if isinstance(lhs.type, types.VectorType): + typ: types.Type typ = types.VectorType(types.IntType(1), lhs.type.count) else: typ = types.IntType(1) @@ -338,16 +420,16 @@ def __init__(self, parent, op, lhs, rhs, name='', flags=[]): name=name) self.op = op - def descr(self, buf): + def descr(self, buf: list[str]) -> None: buf.append("{opname}{flags} {op} {ty} {lhs}, {rhs} {meta}\n".format( - opname=self.opname, - flags=''.join(' ' + it for it in self.flags), - op=self.op, - ty=self.operands[0].type, - lhs=self.operands[0].get_reference(), - rhs=self.operands[1].get_reference(), - meta=self._stringify_metadata(leading_comma=True), - )) + opname=self.opname, + flags=''.join(' ' + it for it in self.flags), + op=self.op, + ty=self.operands[0].type, # type: ignore + lhs=self.operands[0].get_reference(), # type: ignore + rhs=self.operands[1].get_reference(), # type: ignore + meta=self._stringify_metadata(leading_comma=True), + )) class ICMPInstr(CompareInstr): @@ -364,7 +446,7 @@ class ICMPInstr(CompareInstr): 'slt': 'signed less than', 'sle': 'signed less or equal', } - VALID_FLAG = set() + VALID_FLAG: set[str] = set() class FCMPInstr(CompareInstr): @@ -392,184 +474,233 @@ class FCMPInstr(CompareInstr): class CastInstr(Instruction): - def __init__(self, parent, op, val, typ, name=''): - super(CastInstr, self).__init__(parent, typ, op, [val], name=name) - - def descr(self, buf): + def __init__( + self, + parent: Block, + op: str, + val: Value, + typ: types.Type, + name: str = "", + ) -> None: + super().__init__(parent, typ, op, [val], name=name) + + def descr(self, buf: list[str]) -> None: buf.append("{0} {1} {2} to {3} {4}\n".format( - self.opname, - self.operands[0].type, - self.operands[0].get_reference(), - self.type, - self._stringify_metadata(leading_comma=True), - )) + self.opname, + self.operands[0].type, # type: ignore + self.operands[0].get_reference(), # type: ignore + self.type, + self._stringify_metadata(leading_comma=True), + ) + ) class LoadInstr(Instruction): - def __init__(self, parent, ptr, name=''): - super(LoadInstr, self).__init__(parent, ptr.type.pointee, "load", - [ptr], name=name) - self.align = None - - def descr(self, buf): + def __init__(self, parent: Block, ptr: Value, name: str = "") -> None: + super(LoadInstr, self).__init__( + parent, + ptr.type.pointee, # type: ignore + "load", + [ptr], + name=name, + ) + self.align: int | None = None + + def descr(self, buf: list[str]) -> None: [val] = self.operands if self.align is not None: align = ', align %d' % (self.align) else: align = '' buf.append("load {0}, {1} {2}{3}{4}\n".format( - val.type.pointee, - val.type, - val.get_reference(), - align, - self._stringify_metadata(leading_comma=True), - )) + val.type.pointee, # type: ignore + val.type, # type: ignore + val.get_reference(), # type: ignore + align, + self._stringify_metadata(leading_comma=True), + )) class StoreInstr(Instruction): - def __init__(self, parent, val, ptr): + def __init__(self, parent: Block, val: Value, ptr: Value) -> None: super(StoreInstr, self).__init__(parent, types.VoidType(), "store", [val, ptr]) + self.align: int | None = None - def descr(self, buf): + def descr(self, buf: list[str]) -> None: val, ptr = self.operands if self.align is not None: align = ', align %d' % (self.align) else: align = '' buf.append("store {0} {1}, {2} {3}{4}{5}\n".format( - val.type, - val.get_reference(), - ptr.type, - ptr.get_reference(), - align, - self._stringify_metadata(leading_comma=True), - )) + val.type, # type: ignore + val.get_reference(), # type: ignore + ptr.type, # type: ignore + ptr.get_reference(), # type: ignore + align, + self._stringify_metadata(leading_comma=True), + )) class LoadAtomicInstr(Instruction): - def __init__(self, parent, ptr, ordering, align, name=''): - super(LoadAtomicInstr, self).__init__(parent, ptr.type.pointee, - "load atomic", [ptr], name=name) + def __init__( + self, + parent: Block, + ptr: Value, + ordering: str, + align: int, + name: str = "", + ) -> None: + super().__init__( + parent, + ptr.type.pointee, # type: ignore + "load atomic", + [ptr], + name=name, + ) self.ordering = ordering self.align = align - def descr(self, buf): + def descr(self, buf: list[str]) -> None: [val] = self.operands buf.append("load atomic {0}, {1} {2} {3}, align {4}{5}\n".format( - val.type.pointee, - val.type, - val.get_reference(), - self.ordering, - self.align, - self._stringify_metadata(leading_comma=True), - )) + val.type.pointee, # type: ignore + val.type, # type: ignore + val.get_reference(), # type: ignore + self.ordering, + self.align, + self._stringify_metadata(leading_comma=True), + )) class StoreAtomicInstr(Instruction): - def __init__(self, parent, val, ptr, ordering, align): + def __init__( + self, parent: Block, val: Value, ptr: Value, ordering: str, align: int + ) -> None: super(StoreAtomicInstr, self).__init__(parent, types.VoidType(), "store atomic", [val, ptr]) self.ordering = ordering self.align = align - def descr(self, buf): + def descr(self, buf: list[str]) -> None: val, ptr = self.operands buf.append("store atomic {0} {1}, {2} {3} {4}, align {5}{6}\n".format( - val.type, - val.get_reference(), - ptr.type, - ptr.get_reference(), - self.ordering, - self.align, - self._stringify_metadata(leading_comma=True), - )) + val.type, # type: ignore + val.get_reference(), # type: ignore + ptr.type, # type: ignore + ptr.get_reference(), # type: ignore + self.ordering, + self.align, + self._stringify_metadata(leading_comma=True), + )) class AllocaInstr(Instruction): - def __init__(self, parent, typ, count, name): - operands = [count] if count else () + def __init__( + self, parent: Block, typ: types.Type, count: Value | None, name: str + ) -> None: + operands = [count] if count else [] super(AllocaInstr, self).__init__(parent, typ.as_pointer(), "alloca", operands, name) self.align = None - def descr(self, buf): - buf.append("{0} {1}".format(self.opname, self.type.pointee)) + def descr(self, buf: list[str]) -> None: + buf.append("{0} {1}".format(self.opname, self.type.pointee)) # type: ignore if self.operands: - op, = self.operands - buf.append(", {0} {1}".format(op.type, op.get_reference())) + (op,) = self.operands + buf.append(", {0} {1}".format(op.type, op.get_reference())) # type: ignore if self.align is not None: buf.append(", align {0}".format(self.align)) - if self.metadata: + if self.metadata: # type: ignore buf.append(self._stringify_metadata(leading_comma=True)) class GEPInstr(Instruction): - def __init__(self, parent, ptr, indices, inbounds, name): + def __init__( + self, + parent: Block, + ptr: Constant, + indices: list[Constant], + inbounds: bool, + name: str, + ) -> None: typ = ptr.type lasttyp = None lastaddrspace = 0 for i in indices: - lasttyp, typ = typ, typ.gep(i) + lasttyp, typ = typ, typ.gep(i) # type: ignore # inherit the addrspace from the last seen pointer if isinstance(lasttyp, types.PointerType): - lastaddrspace = lasttyp.addrspace + lastaddrspace = lasttyp.addrspace # type: ignore if (not isinstance(typ, types.PointerType) and isinstance(lasttyp, types.PointerType)): - typ = lasttyp + typ = lasttyp # type: ignore else: - typ = typ.as_pointer(lastaddrspace) + typ = typ.as_pointer(lastaddrspace) # type: ignore super(GEPInstr, self).__init__(parent, typ, "getelementptr", - [ptr] + list(indices), name=name) + [ptr] + list(indices), name=name) # type: ignore self.pointer = ptr self.indices = indices self.inbounds = inbounds - def descr(self, buf): + def descr(self, buf: list[str]) -> None: indices = ['{0} {1}'.format(i.type, i.get_reference()) for i in self.indices] op = "getelementptr inbounds" if self.inbounds else "getelementptr" buf.append("{0} {1}, {2} {3}, {4} {5}\n".format( - op, - self.pointer.type.pointee, - self.pointer.type, - self.pointer.get_reference(), - ', '.join(indices), - self._stringify_metadata(leading_comma=True), - )) + op, + self.pointer.type.pointee, # type: ignore + self.pointer.type, + self.pointer.get_reference(), + ', '.join(indices), + self._stringify_metadata(leading_comma=True), + )) class PhiInstr(Instruction): - def __init__(self, parent, typ, name, flags=()): - super(PhiInstr, self).__init__(parent, typ, "phi", (), name=name, + def __init__( + self, + parent: Block, + typ: types.Type, + name: str, + flags: tuple[Any, ...] = (), + ) -> None: + super(PhiInstr, self).__init__(parent, typ, "phi", [], name=name, flags=flags) - self.incomings = [] + self.incomings: list[tuple[Constant, Block]] = [] - def descr(self, buf): + def descr(self, buf: list[str]) -> None: incs = ', '.join('[{0}, {1}]'.format(v.get_reference(), b.get_reference()) for v, b in self.incomings) buf.append("phi {0} {1} {2} {3}\n".format( - ' '.join(self.flags), - self.type, - incs, - self._stringify_metadata(leading_comma=True), - )) + " ".join(self.flags), + self.type, + incs, + self._stringify_metadata(leading_comma=True), + )) - def add_incoming(self, value, block): + def add_incoming(self, value: Constant, block: Block) -> None: assert isinstance(block, Block) self.incomings.append((value, block)) - def replace_usage(self, old, new): + def replace_usage(self, old: Constant, new: Constant) -> None: self.incomings = [((new if val is old else val), blk) for (val, blk) in self.incomings] class ExtractElement(Instruction): - def __init__(self, parent, vector, index, name=''): + def __init__( + self, + parent: Block, + vector: Constant, + index: Constant, + name: str = "", + ) -> None: if not isinstance(vector.type, types.VectorType): raise TypeError("vector needs to be of VectorType.") if not isinstance(index.type, types.IntType): @@ -578,15 +709,23 @@ def __init__(self, parent, vector, index, name=''): super(ExtractElement, self).__init__(parent, typ, "extractelement", [vector, index], name=name) - def descr(self, buf): - operands = ", ".join("{0} {1}".format( - op.type, op.get_reference()) for op in self.operands) + def descr(self, buf: list[str]) -> None: + operands = ", ".join( + "{0} {1}".format(op.type, op.get_reference()) for op in self.operands # type: ignore + ) buf.append("{opname} {operands}\n".format( opname=self.opname, operands=operands)) class InsertElement(Instruction): - def __init__(self, parent, vector, value, index, name=''): + def __init__( + self, + parent: Block, + vector: Constant, + value: Constant, + index: Constant, + name: str = "", + ) -> None: if not isinstance(vector.type, types.VectorType): raise TypeError("vector needs to be of VectorType.") if not value.type == vector.type.element: @@ -599,15 +738,22 @@ def __init__(self, parent, vector, value, index, name=''): super(InsertElement, self).__init__(parent, typ, "insertelement", [vector, value, index], name=name) - def descr(self, buf): + def descr(self, buf: list[str]) -> None: operands = ", ".join("{0} {1}".format( - op.type, op.get_reference()) for op in self.operands) + op.type, op.get_reference()) for op in self.operands) # type: ignore buf.append("{opname} {operands}\n".format( opname=self.opname, operands=operands)) class ShuffleVector(Instruction): - def __init__(self, parent, vector1, vector2, mask, name=''): + def __init__( + self, + parent: Block, + vector1: Constant, + vector2: Constant, + mask: Constant, + name: str = "", + ) -> None: if not isinstance(vector1.type, types.VectorType): raise TypeError("vector1 needs to be of VectorType.") if vector2 != Undefined: @@ -623,27 +769,34 @@ def __init__(self, parent, vector1, vector2, mask, name=''): index_range = range(vector1.type.count if vector2 == Undefined else 2 * vector1.type.count) - if not all(ii.constant in index_range for ii in mask.constant): + if not all(ii.constant in index_range for ii in mask.constant): # type: ignore raise IndexError( "mask values need to be in {0}".format(index_range), ) super(ShuffleVector, self).__init__(parent, typ, "shufflevector", [vector1, vector2, mask], name=name) - def descr(self, buf): + def descr(self, buf: list[str]) -> None: buf.append("shufflevector {0} {1}\n".format( - ", ".join("{0} {1}".format(op.type, op.get_reference()) + ", ".join("{0} {1}".format(op.type, op.get_reference()) # type: ignore for op in self.operands), self._stringify_metadata(leading_comma=True), )) class ExtractValue(Instruction): - def __init__(self, parent, agg, indices, name=''): + def __init__( + self, + parent: Block, + agg: Constant, + indices: list[Constant], + name: str = "", + ) -> None: typ = agg.type try: for i in indices: - typ = typ.elements[i] + # FIXME: Type has no .elements + typ = typ.elements[i] # type: ignore except (AttributeError, IndexError): raise TypeError("Can't index at %r in %s" % (list(indices), agg.type)) @@ -654,7 +807,7 @@ def __init__(self, parent, agg, indices, name=''): self.aggregate = agg self.indices = indices - def descr(self, buf): + def descr(self, buf: list[str]) -> None: indices = [str(i) for i in self.indices] buf.append("extractvalue {0} {1}, {2} {3}\n".format( @@ -666,11 +819,18 @@ def descr(self, buf): class InsertValue(Instruction): - def __init__(self, parent, agg, elem, indices, name=''): + def __init__( + self, + parent: Block, + agg: Constant, + elem: Constant, + indices: list[Constant], + name: str = "", + ) -> None: typ = agg.type try: for i in indices: - typ = typ.elements[i] + typ = typ.elements[i] # type: ignore except (AttributeError, IndexError): raise TypeError("Can't index at %r in %s" % (list(indices), agg.type)) @@ -684,7 +844,7 @@ def __init__(self, parent, agg, elem, indices, name=''): self.value = elem self.indices = indices - def descr(self, buf): + def descr(self, buf: list[str]) -> None: indices = [str(i) for i in self.indices] buf.append("insertvalue {0} {1}, {2} {3}, {4} {5}\n".format( @@ -696,53 +856,67 @@ def descr(self, buf): class Unreachable(Instruction): - def __init__(self, parent): + def __init__(self, parent: Block) -> None: super(Unreachable, self).__init__(parent, types.VoidType(), - "unreachable", (), name='') + "unreachable", [], name='') - def descr(self, buf): + def descr(self, buf: list[str]) -> None: buf += (self.opname, "\n") -class InlineAsm(object): - def __init__(self, ftype, asm, constraint, side_effect=False): +class InlineAsm: + def __init__( + self, + ftype: types.FunctionType, + asm: str, + constraint: Value, + side_effect: bool = False, + ) -> None: self.type = ftype.return_type self.function_type = ftype self.asm = asm self.constraint = constraint self.side_effect = side_effect - def descr(self, buf): + def descr(self, buf: list[str]) -> None: sideeffect = 'sideeffect' if self.side_effect else '' fmt = 'asm {sideeffect} "{asm}", "{constraint}"\n' buf.append(fmt.format(sideeffect=sideeffect, asm=self.asm, constraint=self.constraint)) - def get_reference(self): - buf = [] + def get_reference(self) -> str: + buf: list[str] = [] self.descr(buf) return "".join(buf) - def __str__(self): + def __str__(self) -> str: return "{0} {1}".format(self.type, self.get_reference()) class AtomicRMW(Instruction): - def __init__(self, parent, op, ptr, val, ordering, name): + def __init__( + self, + parent: Block, + op: str, + ptr: Constant, + val: Constant, + ordering: str, + name: str, + ) -> None: super(AtomicRMW, self).__init__(parent, val.type, "atomicrmw", - (ptr, val), name=name) + (ptr, val), name=name) # type: ignore self.operation = op self.ordering = ordering - def descr(self, buf): + def descr(self, buf: list[str]) -> None: ptr, val = self.operands fmt = ("atomicrmw {op} {ptrty} {ptr}, {valty} {val} {ordering} " "{metadata}\n") buf.append(fmt.format(op=self.operation, - ptrty=ptr.type, - ptr=ptr.get_reference(), - valty=val.type, - val=val.get_reference(), + ptrty=ptr.type, # type: ignore + ptr=ptr.get_reference(), # type: ignore + valty=val.type, # type: ignore + val=val.get_reference(), # type: ignore ordering=self.ordering, metadata=self._stringify_metadata( leading_comma=True), @@ -754,36 +928,48 @@ class CmpXchg(Instruction): older llvm versions. """ - def __init__(self, parent, ptr, cmp, val, ordering, failordering, name): + def __init__( + self, + parent: Block, + ptr: Constant, + cmp: Constant, + val: Constant, + ordering: str, + failordering: str, + name: str, + ) -> None: outtype = types.LiteralStructType([val.type, types.IntType(1)]) super(CmpXchg, self).__init__(parent, outtype, "cmpxchg", - (ptr, cmp, val), name=name) + (ptr, cmp, val), name=name) # type: ignore self.ordering = ordering self.failordering = failordering - def descr(self, buf): + def descr(self, buf: list[str]) -> None: ptr, cmpval, val = self.operands fmt = "cmpxchg {ptrty} {ptr}, {ty} {cmp}, {ty} {val} {ordering} " \ "{failordering} {metadata}\n" - buf.append(fmt.format(ptrty=ptr.type, - ptr=ptr.get_reference(), - ty=cmpval.type, - cmp=cmpval.get_reference(), - val=val.get_reference(), - ordering=self.ordering, - failordering=self.failordering, - metadata=self._stringify_metadata( - leading_comma=True), - )) - - -class _LandingPadClause(object): - def __init__(self, value): + buf.append( + fmt.format( + ptrty=ptr.type, # type: ignore + ptr=ptr.get_reference(), # type: ignore + ty=cmpval.type, # type: ignore + cmp=cmpval.get_reference(), # type: ignore + val=val.get_reference(), # type: ignore + ordering=self.ordering, + failordering=self.failordering, + metadata=self._stringify_metadata( + leading_comma=True), + )) + + +class _LandingPadClause: + def __init__(self, value: Constant) -> None: self.value = value + # FIXME: does not contain self.kind - def __str__(self): + def __str__(self) -> str: return "{kind} {type} {value}".format( - kind=self.kind, + kind=self.kind, # type: ignore type=self.value.type, value=self.value.get_reference()) @@ -795,24 +981,26 @@ class CatchClause(_LandingPadClause): class FilterClause(_LandingPadClause): kind = 'filter' - def __init__(self, value): + def __init__(self, value: Constant) -> None: assert isinstance(value, Constant) assert isinstance(value.type, types.ArrayType) super(FilterClause, self).__init__(value) class LandingPadInstr(Instruction): - def __init__(self, parent, typ, name='', cleanup=False): + def __init__( + self, parent: Block, typ: types.Type, name: str = "", cleanup: bool = False + ) -> None: super(LandingPadInstr, self).__init__(parent, typ, "landingpad", [], name=name) self.cleanup = cleanup - self.clauses = [] + self.clauses: list[_LandingPadClause] = [] - def add_clause(self, clause): + def add_clause(self, clause: _LandingPadClause) -> None: assert isinstance(clause, _LandingPadClause) self.clauses.append(clause) - def descr(self, buf): + def descr(self, buf: list[str]) -> None: fmt = "landingpad {type}{cleanup}{clauses}\n" buf.append(fmt.format(type=self.type, cleanup=' cleanup' if self.cleanup else '', @@ -832,8 +1020,14 @@ class Fence(Instruction): VALID_FENCE_ORDERINGS = {"acquire", "release", "acq_rel", "seq_cst"} - def __init__(self, parent, ordering, targetscope=None, name=''): - super(Fence, self).__init__(parent, types.VoidType(), "fence", (), + def __init__( + self, + parent: Block, + ordering: str, # Literal["acquire", "release", "acq_rel", "seq_cst"], + targetscope: None = None, + name: str = "", + ) -> None: + super(Fence, self).__init__(parent, types.VoidType(), "fence", [], name=name) if ordering not in self.VALID_FENCE_ORDERINGS: msg = "Invalid fence ordering \"{0}\"! Should be one of {1}." @@ -842,7 +1036,7 @@ def __init__(self, parent, ordering, targetscope=None, name=''): self.ordering = ordering self.targetscope = targetscope - def descr(self, buf): + def descr(self, buf: list[str]) -> None: if self.targetscope is None: syncscope = "" else: diff --git a/llvmlite/ir/module.py b/llvmlite/ir/module.py index 464f91ec3..4e7b985f9 100644 --- a/llvmlite/ir/module.py +++ b/llvmlite/ir/module.py @@ -1,24 +1,37 @@ +from __future__ import annotations + import collections +from typing import Any, Iterable, NoReturn, OrderedDict, Sequence from llvmlite.ir import context, values, types, _utils -class Module(object): - def __init__(self, name='', context=context.global_context): +class Module: + def __init__( + self, + name: str = "", + # this one is fine with pylance, but not with mypy for some reason + context: context.Context = context.global_context, # type: ignore + ) -> None: self.context = context self.name = name # name is for debugging/informational self.data_layout = "" self.scope = _utils.NameScope() self.triple = 'unknown-unknown-unknown' - self.globals = collections.OrderedDict() + self.globals: OrderedDict[ + str, values.GlobalVariable + ] = collections.OrderedDict() # Innamed metadata nodes. - self.metadata = [] + self.metadata: list[values.MDValue | values.DIValue] = [] # Named metadata nodes - self.namedmetadata = {} + self.namedmetadata: dict[str, values.NamedMetaData] = {} # Cache for metadata node deduplication - self._metadatacache = {} + self._metadatacache: dict[Any, Any] = {} - def _fix_metadata_operands(self, operands): + def _fix_metadata_operands( + self, operands: list[Iterable[values.Value] | str | None] + ) -> list[types.MetaDataType | values.MetaDataString | values.MDValue]: + op: Any fixed_ops = [] for op in operands: if op is None: @@ -29,20 +42,25 @@ def _fix_metadata_operands(self, operands): op = values.MetaDataString(self, op) elif isinstance(op, (list, tuple)): # A sequence creates a metadata node reference - op = self.add_metadata(op) + op = self.add_metadata(op) # type: ignore fixed_ops.append(op) return fixed_ops - def _fix_di_operands(self, operands): + def _fix_di_operands( + self, operands: Iterable[tuple[str, values.Value]] + ) -> list[tuple[str, values.MDValue]]: fixed_ops = [] for name, op in operands: if isinstance(op, (list, tuple)): # A sequence creates a metadata node reference - op = self.add_metadata(op) - fixed_ops.append((name, op)) - return fixed_ops + op = self.add_metadata(op) # type: ignore + fixed_ops.append((name, op)) # type: ignore + return fixed_ops # type: ignore - def add_metadata(self, operands): + def add_metadata( + self, + operands: list[values.Constant], + ) -> values.MDValue: """ Add an unnamed metadata to the module with the given *operands* (a sequence of values) or return a previous equivalent metadata. @@ -52,7 +70,7 @@ def add_metadata(self, operands): if not isinstance(operands, (list, tuple)): raise TypeError("expected a list or tuple of metadata values, " "got %r" % (operands,)) - operands = self._fix_metadata_operands(operands) + operands = self._fix_metadata_operands(operands) # type: ignore key = tuple(operands) if key not in self._metadatacache: n = len(self.metadata) @@ -62,7 +80,9 @@ def add_metadata(self, operands): md = self._metadatacache[key] return md - def add_debug_info(self, kind, operands, is_distinct=False): + def add_debug_info( + self, kind: str, operands: dict[str, values.Value], is_distinct: bool = False + ) -> values.DIValue: """ Add debug information metadata to the module with the given *operands* (a dict of values with string keys) or return @@ -72,17 +92,19 @@ def add_debug_info(self, kind, operands, is_distinct=False): A DIValue instance is returned, it can then be associated to e.g. an instruction. """ - operands = tuple(sorted(self._fix_di_operands(operands.items()))) - key = (kind, operands, is_distinct) + op_tuple = tuple(sorted(self._fix_di_operands(operands.items()))) + key = (kind, op_tuple, is_distinct) if key not in self._metadatacache: n = len(self.metadata) - di = values.DIValue(self, is_distinct, kind, operands, name=str(n)) + di = values.DIValue(self, is_distinct, kind, op_tuple, name=str(n)) self._metadatacache[key] = di else: di = self._metadatacache[key] return di - def add_named_metadata(self, name, element=None): + def add_named_metadata( + self, name: str, element: None = None + ) -> values.NamedMetaData: """ Add a named metadata node to the module, if it doesn't exist, or return the existing node. @@ -97,17 +119,17 @@ def add_named_metadata(self, name, element=None): if name in self.namedmetadata: nmd = self.namedmetadata[name] else: - nmd = self.namedmetadata[name] = values.NamedMetaData(self) + nmd = self.namedmetadata[name] = values.NamedMetaData(self) # type: ignore if element is not None: if not isinstance(element, values.Value): - element = self.add_metadata(element) - if not isinstance(element.type, types.MetaDataType): + element = self.add_metadata(element) # type: ignore + if not isinstance(element.type, types.MetaDataType): # type: ignore raise TypeError("wrong type for metadata element: got %r" % (element,)) - nmd.add(element) - return nmd + nmd.add(element) # type: ignore + return nmd # type: ignore - def get_named_metadata(self, name): + def get_named_metadata(self, name: str) -> values.NamedMetaData: """ Return the metadata node with the given *name*. KeyError is raised if no such node exists (contrast with add_named_metadata()). @@ -115,7 +137,7 @@ def get_named_metadata(self, name): return self.namedmetadata[name] @property - def functions(self): + def functions(self) -> list[values.Function]: """ A list of functions declared or defined in this module. """ @@ -123,40 +145,46 @@ def functions(self): if isinstance(v, values.Function)] @property - def global_values(self): + def global_values(self) -> Iterable[values.GlobalVariable]: """ An iterable of global values in this module. """ return self.globals.values() - def get_global(self, name): + def get_global(self, name: str) -> values.GlobalVariable: """ Get a global value by name. """ return self.globals[name] - def add_global(self, globalvalue): + def add_global(self, globalvalue: values.GlobalVariable) -> None: """ Add a new global value. """ assert globalvalue.name not in self.globals self.globals[globalvalue.name] = globalvalue - def get_unique_name(self, name=''): + def get_unique_name(self, name: str = "") -> str: """ Get a unique global name with the following *name* hint. """ return self.scope.deduplicate(name) - def declare_intrinsic(self, intrinsic, tys=(), fnty=None): - def _error(): + def declare_intrinsic( + self, + intrinsic: str, + tys: Sequence[types.Type] = (), + fnty: types.FunctionType | None = None, + ) -> values.GlobalVariable | values.Function: + def _error() -> NoReturn: raise NotImplementedError("unknown intrinsic %r with %d types" % (intrinsic, len(tys))) - + suffixes: list[str] if intrinsic in {'llvm.cttz', 'llvm.ctlz', 'llvm.fma'}: - suffixes = [tys[0].intrinsic_name] + # FIXME: .intrinsic_name is unknown + suffixes = [tys[0].intrinsic_name] # type: ignore else: - suffixes = [t.intrinsic_name for t in tys] + suffixes = [t.intrinsic_name for t in tys] # type: ignore name = '.'.join([intrinsic] + suffixes) if name in self.globals: return self.globals[name] @@ -171,7 +199,7 @@ def _error(): if intrinsic == 'llvm.powi': fnty = types.FunctionType(tys[0], [tys[0], types.IntType(32)]) elif intrinsic == 'llvm.pow': - fnty = types.FunctionType(tys[0], tys * 2) + fnty = types.FunctionType(tys[0], tys * 2) # type: ignore elif intrinsic == 'llvm.convert.from.fp16': fnty = types.FunctionType(tys[0], [types.IntType(16)]) elif intrinsic == 'llvm.convert.to.fp16': @@ -190,7 +218,7 @@ def _error(): _error() elif len(tys) == 3: if intrinsic in ('llvm.memcpy', 'llvm.memmove'): - tys = tys + [types.IntType(1)] + tys = tys + [types.IntType(1)] # type: ignore fnty = types.FunctionType(types.VoidType(), tys) elif intrinsic == 'llvm.fma': tys = [tys[0]] * 3 @@ -201,10 +229,10 @@ def _error(): _error() return values.Function(self, fnty, name=name) - def get_identified_types(self): + def get_identified_types(self) -> dict[str, types.IdentifiedStructType]: return self.context.identified_types - def _get_body_lines(self): + def _get_body_lines(self) -> list[str]: # Type declarations lines = [it.get_declaration() for it in self.get_identified_types().values()] @@ -212,8 +240,8 @@ def _get_body_lines(self): lines += [str(v) for v in self.globals.values()] return lines - def _get_metadata_lines(self): - mdbuf = [] + def _get_metadata_lines(self) -> list[str]: + mdbuf: list[str] = [] for k, v in self.namedmetadata.items(): mdbuf.append("!{name} = !{{ {operands} }}".format( name=k, operands=', '.join(i.get_reference() @@ -222,16 +250,16 @@ def _get_metadata_lines(self): mdbuf.append(str(md)) return mdbuf - def _stringify_body(self): + def _stringify_body(self) -> str: # For testing return "\n".join(self._get_body_lines()) - def _stringify_metadata(self): + def _stringify_metadata(self) -> str: # For testing return "\n".join(self._get_metadata_lines()) - def __repr__(self): - lines = [] + def __repr__(self) -> str: + lines: list[str] = [] # Header lines += [ '; ModuleID = "%s"' % (self.name,), diff --git a/llvmlite/ir/transforms.py b/llvmlite/ir/transforms.py index a69113d36..d7f439925 100644 --- a/llvmlite/ir/transforms.py +++ b/llvmlite/ir/transforms.py @@ -1,61 +1,66 @@ +from __future__ import annotations + from llvmlite.ir import CallInstr +from llvmlite.ir.instructions import Instruction +from llvmlite.ir.module import Module +from llvmlite.ir.values import Block, Function -class Visitor(object): - def visit(self, module): +class Visitor: + def visit(self, module: Module) -> None: self._module = module for func in module.functions: self.visit_Function(func) - def visit_Function(self, func): + def visit_Function(self, func: Function) -> None: self._function = func for bb in func.blocks: self.visit_BasicBlock(bb) - def visit_BasicBlock(self, bb): + def visit_BasicBlock(self, bb: Block) -> None: self._basic_block = bb for instr in bb.instructions: self.visit_Instruction(instr) - def visit_Instruction(self, instr): + def visit_Instruction(self, instr: Instruction) -> None: raise NotImplementedError @property - def module(self): + def module(self) -> Module: return self._module @property - def function(self): + def function(self) -> Function: return self._function @property - def basic_block(self): + def basic_block(self) -> Block: return self._basic_block class CallVisitor(Visitor): - def visit_Instruction(self, instr): + def visit_Instruction(self, instr: Instruction) -> None: if isinstance(instr, CallInstr): self.visit_Call(instr) - def visit_Call(self, instr): + def visit_Call(self, instr: CallInstr) -> None: raise NotImplementedError class ReplaceCalls(CallVisitor): - def __init__(self, orig, repl): - super(ReplaceCalls, self).__init__() + def __init__(self, orig: Function, repl: Function) -> None: + super().__init__() self.orig = orig self.repl = repl - self.calls = [] + self.calls: list[CallInstr] = [] - def visit_Call(self, instr): + def visit_Call(self, instr: CallInstr) -> None: if instr.callee == self.orig: instr.replace_callee(self.repl) self.calls.append(instr) -def replace_all_calls(mod, orig, repl): +def replace_all_calls(mod: Module, orig: Function, repl: Function) -> list[CallInstr]: """Replace all calls to `orig` to `repl` in module `mod`. Returns the references to the returned calls """ diff --git a/llvmlite/ir/types.py b/llvmlite/ir/types.py index 00740c488..c9b11836f 100644 --- a/llvmlite/ir/types.py +++ b/llvmlite/ir/types.py @@ -2,12 +2,19 @@ Classes that are LLVM types """ +from __future__ import annotations + import struct +from typing import Any, Iterable, Iterator +from llvmlite.binding.targets import TargetData +from llvmlite.binding.value import TypeRef from llvmlite.ir._utils import _StrCaching +from llvmlite.ir.context import Context +from llvmlite.ir.values import Constant, Value -def _wrapname(x): +def _wrapname(x: str) -> str: return '"{0}"'.format(x.replace('\\', '\\5c').replace('"', '\\22')) @@ -18,19 +25,24 @@ class Type(_StrCaching): is_pointer = False null = 'zeroinitializer' - def __repr__(self): + def __repr__(self) -> str: return "<%s %s>" % (type(self), str(self)) - def _to_string(self): + def _to_string(self) -> str: raise NotImplementedError - def as_pointer(self, addrspace=0): - return PointerType(self, addrspace) + def as_pointer(self, addrspace: int = 0) -> PointerType: + # FIXME: either not all types can be pointers or type + # needs to implement intrinsic_name + return PointerType(self, addrspace) # type: ignore - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not (self == other) - def _get_ll_pointer_type(self, target_data, context=None): + # FIXME: target_data is unused? + def _get_ll_pointer_type( + self, target_data: TargetData, context: Context | None = None + ) -> TypeRef: """ Convert this type object to an LLVM type. """ @@ -43,16 +55,20 @@ def _get_ll_pointer_type(self, target_data, context=None): m = Module(context=context) foo = GlobalVariable(m, self, name="foo") with parse_assembly(str(m)) as llmod: - return llmod.get_global_variable(foo.name).type + return llmod.get_global_variable(foo.name).type # type: ignore - def get_abi_size(self, target_data, context=None): + def get_abi_size( + self, target_data: TargetData, context: Context | None = None + ) -> int: """ Get the ABI size of this type according to data layout *target_data*. """ llty = self._get_ll_pointer_type(target_data, context) return target_data.get_pointee_abi_size(llty) - def get_abi_alignment(self, target_data, context=None): + def get_abi_alignment( + self, target_data: TargetData, context: Context | None = None + ) -> Any: """ Get the minimum ABI alignment of this type according to data layout *target_data*. @@ -60,40 +76,39 @@ def get_abi_alignment(self, target_data, context=None): llty = self._get_ll_pointer_type(target_data, context) return target_data.get_pointee_abi_alignment(llty) - def format_constant(self, value): + def format_constant(self, value: float) -> str: """ Format constant *value* of this type. This method may be overriden by subclasses. """ return str(value) - def wrap_constant_value(self, value): + def wrap_constant_value(self, value: Type | str) -> Type | str: """ Wrap constant *value* if necessary. This method may be overriden by subclasses (especially aggregate types). """ return value - def __call__(self, value): + def __call__(self, value: Type | None) -> Constant: """ Create a LLVM constant of this type with the given Python value. """ - from llvmlite.ir import Constant - return Constant(self, value) + return Constant(typ=self, constant=value) class MetaDataType(Type): - - def _to_string(self): + def _to_string(self) -> str: return "metadata" - def as_pointer(self): + # FIXME: invalid method override with Type.as_pointer(self, addrspace) + def as_pointer(self) -> PointerType: # type: ignore raise TypeError - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, MetaDataType) - def __hash__(self): + def __hash__(self) -> int: return hash(MetaDataType) @@ -102,7 +117,7 @@ class LabelType(Type): The label type is the type of e.g. basic blocks. """ - def _to_string(self): + def _to_string(self) -> str: return "label" @@ -113,38 +128,40 @@ class PointerType(Type): is_pointer = True null = 'null' - def __init__(self, pointee, addrspace=0): + def __init__( + self, pointee: IntType | PointerType | FunctionType, addrspace: int = 0 + ) -> None: assert not isinstance(pointee, VoidType) self.pointee = pointee self.addrspace = addrspace - def _to_string(self): + def _to_string(self) -> str: if self.addrspace != 0: return "{0} addrspace({1})*".format(self.pointee, self.addrspace) else: return "{0}*".format(self.pointee) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, PointerType): return (self.pointee, self.addrspace) == (other.pointee, other.addrspace) else: return False - def __hash__(self): + def __hash__(self) -> int: return hash(PointerType) - def gep(self, i): + def gep(self, i: Type) -> Type: """ Resolve the type of the i-th element (for getelementptr lookups). """ - if not isinstance(i.type, IntType): - raise TypeError(i.type) + if not isinstance(i.type, IntType): # type: ignore + raise TypeError(i.type) # type: ignore return self.pointee @property - def intrinsic_name(self): - return 'p%d%s' % (self.addrspace, self.pointee.intrinsic_name) + def intrinsic_name(self) -> str: + return 'p%d%s' % (self.addrspace, self.pointee.intrinsic_name) # type: ignore class VoidType(Type): @@ -152,13 +169,13 @@ class VoidType(Type): The type for empty values (e.g. a function returning no value). """ - def _to_string(self): + def _to_string(self) -> str: return 'void' - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, VoidType) - def __hash__(self): + def __hash__(self) -> int: return hash(VoidType) @@ -167,12 +184,14 @@ class FunctionType(Type): The type for functions. """ - def __init__(self, return_type, args, var_arg=False): + def __init__( + self, return_type: Type, args: Iterable[Type], var_arg: bool = False + ) -> None: self.return_type = return_type self.args = tuple(args) self.var_arg = var_arg - def _to_string(self): + def _to_string(self) -> str: if self.args: strargs = ', '.join([str(a) for a in self.args]) if self.var_arg: @@ -184,14 +203,14 @@ def _to_string(self): else: return '{0} ()'.format(self.return_type) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, FunctionType): return (self.return_type == other.return_type and self.args == other.args and self.var_arg == other.var_arg) else: return False - def __hash__(self): + def __hash__(self) -> int: return hash(FunctionType) @@ -200,9 +219,10 @@ class IntType(Type): The type for integers. """ null = '0' - _instance_cache = {} + _instance_cache: dict[int, IntType] = {} + width: int - def __new__(cls, bits): + def __new__(cls, bits: int) -> IntType: # Cache all common integer types if 0 <= bits <= 128: try: @@ -213,72 +233,76 @@ def __new__(cls, bits): return cls.__new(bits) @classmethod - def __new(cls, bits): + def __new(cls, bits: int) -> IntType: assert isinstance(bits, int) and bits >= 0 self = super(IntType, cls).__new__(cls) self.width = bits return self - def __getnewargs__(self): + def __getnewargs__(self) -> tuple[int]: return self.width, - def __copy__(self): + def __copy__(self) -> IntType: return self - def _to_string(self): + def _to_string(self) -> str: return 'i%u' % (self.width,) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, IntType): return self.width == other.width else: return False - def __hash__(self): + def __hash__(self) -> int: return hash(IntType) - def format_constant(self, val): + # FIXME: IncompatibleMethodOverride + def format_constant(self, val: int | bool) -> str: # type: ignore if isinstance(val, bool): return str(val).lower() else: return str(val) - def wrap_constant_value(self, val): + # FIXME: IncompatibleMethodOverride: val instead of value + def wrap_constant_value(self, val: int | None) -> int: # type: ignore if val is None: return 0 return val @property - def intrinsic_name(self): + def intrinsic_name(self) -> str: return str(self) -def _as_float(value): +def _as_float(value: float) -> float: """ Truncate to single-precision float. """ - return struct.unpack('f', struct.pack('f', value))[0] + return struct.unpack('f', struct.pack('f', value))[0] # type: ignore -def _as_half(value): +def _as_half(value: float) -> float: """ Truncate to half-precision float. """ try: - return struct.unpack('e', struct.pack('e', value))[0] + return struct.unpack('e', struct.pack('e', value))[0] # type: ignore except struct.error: # 'e' only added in Python 3.6+ return _as_float(value) -def _format_float_as_hex(value, packfmt, unpackfmt, numdigits): +def _format_float_as_hex( + value: float, packfmt: str, unpackfmt: str, numdigits: int +) -> str: raw = struct.pack(packfmt, float(value)) intrep = struct.unpack(unpackfmt, raw)[0] out = '{{0:#{0}x}}'.format(numdigits).format(intrep) return out -def _format_double(value): +def _format_double(value: float) -> str: """ Format *value* as a hexadecimal string of its IEEE double precision representation. @@ -287,19 +311,20 @@ def _format_double(value): class _BaseFloatType(Type): + # FIXME: _BaseFloatType doesn't have _instance_cache - def __new__(cls): - return cls._instance_cache + def __new__(cls) -> _BaseFloatType: + return cls._instance_cache # type: ignore - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) - def __hash__(self): + def __hash__(self) -> int: return hash(type(self)) @classmethod - def _create_instance(cls): - cls._instance_cache = super(_BaseFloatType, cls).__new__(cls) + def _create_instance(cls) -> None: + cls._instance_cache = super(_BaseFloatType, cls).__new__(cls) # type: ignore class HalfType(_BaseFloatType): @@ -309,10 +334,10 @@ class HalfType(_BaseFloatType): null = '0.0' intrinsic_name = 'f16' - def __str__(self): + def __str__(self) -> str: return 'half' - def format_constant(self, value): + def format_constant(self, value: float) -> str: return _format_double(_as_half(value)) @@ -323,10 +348,10 @@ class FloatType(_BaseFloatType): null = '0.0' intrinsic_name = 'f32' - def __str__(self): + def __str__(self) -> str: return 'float' - def format_constant(self, value): + def format_constant(self, value: float) -> str: return _format_double(_as_float(value)) @@ -337,10 +362,10 @@ class DoubleType(_BaseFloatType): null = '0.0' intrinsic_name = 'f64' - def __str__(self): + def __str__(self) -> str: return 'double' - def format_constant(self, value): + def format_constant(self, value: float) -> str: return _format_double(value) @@ -348,15 +373,15 @@ def format_constant(self, value): _cls._create_instance() -class _Repeat(object): - def __init__(self, value, size): +class _Repeat: + def __init__(self, value: Type, size: int) -> None: self.value = value self.size = size - def __len__(self): + def __len__(self) -> int: return self.size - def __getitem__(self, item): + def __getitem__(self, item: int) -> Type: if 0 <= item < self.size: return self.value else: @@ -368,50 +393,53 @@ class VectorType(Type): The type for vectors of primitive data items (e.g. ""). """ - def __init__(self, element, count): + def __init__(self, element: Type, count: int) -> None: self.element = element self.count = count @property - def elements(self): + def elements(self) -> _Repeat: return _Repeat(self.element, self.count) - def __len__(self): + def __len__(self) -> int: return self.count - def _to_string(self): + def _to_string(self) -> str: return "<%d x %s>" % (self.count, self.element) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, VectorType): return self.element == other.element and self.count == other.count + return False - def __hash__(self): + def __hash__(self) -> int: # TODO: why does this not take self.element/self.count into account? return hash(VectorType) - def __copy__(self): + def __copy__(self) -> VectorType: return self - def format_constant(self, value): - itemstring = ", " .join(["{0} {1}".format(x.type, x.get_reference()) - for x in value]) + # FIXME: IncompatibleMethodOverride + def format_constant(self, value: Iterable[Value]) -> str: # type: ignore + itemstring = ", ".join( + ["{0} {1}".format(x.type, x.get_reference()) for x in value] # type: ignore + ) return "<{0}>".format(itemstring) - def wrap_constant_value(self, values): - from . import Value, Constant + # FIXME: IncompatibleMethodOverride + def wrap_constant_value(self, values: Iterable[Value]) -> Iterable[Value]: # type: ignore if not isinstance(values, (list, tuple)): if isinstance(values, Constant): if values.type != self.element: raise TypeError("expected {} for {}".format( self.element, values.type)) return (values, ) * self.count - return (Constant(self.element, values), ) * self.count + return (Constant(self.element, values), ) * self.count # type: ignore if len(values) != len(self): raise ValueError("wrong constant size for %s: got %d elements" % (self, len(values))) return [Constant(ty, val) if not isinstance(val, Value) else val - for ty, val in zip(self.elements, values)] + for ty, val in zip(self.elements, values)] # type: ignore class Aggregate(Type): @@ -420,16 +448,19 @@ class Aggregate(Type): See http://llvm.org/docs/LangRef.html#t-aggregate """ - def wrap_constant_value(self, values): - from . import Value, Constant + # FIXME: does not contain .elements + # FIXME: does not implement __len__ + # FIXME: IncompatibleMethodOverride + def wrap_constant_value(self, values: Iterable[Value]) -> Iterable[Value]: # type: ignore if not isinstance(values, (list, tuple)): return values - if len(values) != len(self): + if len(values) != len(self): # type: ignore raise ValueError("wrong constant size for %s: got %d elements" % (self, len(values))) return [Constant(ty, val) if not isinstance(val, Value) else val - for ty, val in zip(self.elements, values)] + for ty, val in zip(self.elements, values)] # type: ignore + class ArrayType(Aggregate): @@ -437,37 +468,39 @@ class ArrayType(Aggregate): The type for fixed-size homogenous arrays (e.g. "[f32 x 3]"). """ - def __init__(self, element, count): + def __init__(self, element: Type, count: int): self.element = element self.count = count @property - def elements(self): + def elements(self) -> _Repeat: return _Repeat(self.element, self.count) - def __len__(self): + def __len__(self) -> int: return self.count - def _to_string(self): + def _to_string(self) -> str: return "[%d x %s]" % (self.count, self.element) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, ArrayType): return self.element == other.element and self.count == other.count + return False - def __hash__(self): + def __hash__(self) -> int: return hash(ArrayType) - def gep(self, i): + def gep(self, i: IntType) -> Type: """ Resolve the type of the i-th element (for getelementptr lookups). """ - if not isinstance(i.type, IntType): - raise TypeError(i.type) + if not isinstance(i.type, IntType): # type: ignore + raise TypeError(i.type) # type: ignore return self.element - def format_constant(self, value): - itemstring = ", " .join(["{0} {1}".format(x.type, x.get_reference()) + # FIXME: incompatible method overrride + def format_constant(self, value: Iterable[Value]) -> str: # type: ignore + itemstring = ", " .join(["{0} {1}".format(x.type, x.get_reference()) # type: ignore for x in value]) return "[{0}]".format(itemstring) @@ -477,9 +510,10 @@ class BaseStructType(Aggregate): The base type for heterogenous struct types. """ _packed = False + # FIXME: BaseStructType has no self.elements @property - def packed(self): + def packed(self) -> bool: """ A boolean attribute that indicates whether the structure uses packed layout. @@ -487,46 +521,48 @@ def packed(self): return self._packed @packed.setter - def packed(self, val): + def packed(self, val: bool) -> None: self._packed = bool(val) - def __len__(self): - assert self.elements is not None - return len(self.elements) + def __len__(self) -> int: + assert self.elements is not None # type: ignore + return len(self.elements) # type: ignore - def __iter__(self): - assert self.elements is not None - return iter(self.elements) + def __iter__(self) -> Iterator[Type]: + assert self.elements is not None # type: ignore + return iter(self.elements) # type: ignore @property - def is_opaque(self): - return self.elements is None + def is_opaque(self) -> bool: + return self.elements is None # type: ignore - def structure_repr(self): + def structure_repr(self) -> str: """ Return the LLVM IR for the structure representation """ - ret = '{%s}' % ', '.join([str(x) for x in self.elements]) + # BaseStructType has not "elements" + ret = '{%s}' % ', '.join([str(x) for x in self.elements]) # type: ignore return self._wrap_packed(ret) - def format_constant(self, value): - itemstring = ", " .join(["{0} {1}".format(x.type, x.get_reference()) + # FIXME: incompatible method override + def format_constant(self, value: Iterable[Value]) -> str: # type: ignore + itemstring = ", " .join(["{0} {1}".format(x.type, x.get_reference()) # type: ignore for x in value]) ret = "{{{0}}}".format(itemstring) return self._wrap_packed(ret) - def gep(self, i): + def gep(self, i: IntType) -> Type: """ Resolve the type of the i-th element (for getelementptr lookups). *i* needs to be a LLVM constant, so that the type can be determined at compile-time. """ - if not isinstance(i.type, IntType): - raise TypeError(i.type) - return self.elements[i.constant] + if not isinstance(i.type, IntType): # type: ignore + raise TypeError(i.type) # type: ignore + return self.elements[i.constant] # type: ignore - def _wrap_packed(self, textrepr): + def _wrap_packed(self, textrepr: str) -> str: """ Internal helper to wrap textual repr of struct type into packed struct """ @@ -544,7 +580,7 @@ class LiteralStructType(BaseStructType): null = 'zeroinitializer' - def __init__(self, elems, packed=False): + def __init__(self, elems: Iterable[Type], packed: bool = False) -> None: """ *elems* is a sequence of types to be used as members. *packed* controls the use of packed layout. @@ -552,14 +588,15 @@ def __init__(self, elems, packed=False): self.elements = tuple(elems) self.packed = packed - def _to_string(self): + def _to_string(self) -> str: return self.structure_repr() - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, LiteralStructType): return self.elements == other.elements + return False - def __hash__(self): + def __hash__(self) -> int: return hash(LiteralStructType) @@ -573,7 +610,7 @@ class IdentifiedStructType(BaseStructType): """ null = 'zeroinitializer' - def __init__(self, context, name, packed=False): + def __init__(self, context: Context, name: str, packed: bool = False) -> None: """ *context* is a llvmlite.ir.Context. *name* is the identifier for the new struct type. @@ -582,13 +619,13 @@ def __init__(self, context, name, packed=False): assert name self.context = context self.name = name - self.elements = None + self.elements: None | tuple[Type, ...] = None self.packed = packed - def _to_string(self): + def _to_string(self) -> str: return "%{name}".format(name=_wrapname(self.name)) - def get_declaration(self): + def get_declaration(self) -> str: """ Returns the string for the declaration of the type """ @@ -599,15 +636,16 @@ def get_declaration(self): strrep=str(self), struct=self.structure_repr()) return out - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, IdentifiedStructType): return self.name == other.name + return False - def __hash__(self): + def __hash__(self) -> int: return hash(IdentifiedStructType) - def set_body(self, *elems): + def set_body(self, *elems: Iterable[Type]) -> None: if not self.is_opaque: raise RuntimeError("{name} is already defined".format( name=self.name)) - self.elements = tuple(elems) + self.elements = tuple(elems) # type: ignore diff --git a/llvmlite/ir/values.py b/llvmlite/ir/values.py index 853927724..26a94e89d 100644 --- a/llvmlite/ir/values.py +++ b/llvmlite/ir/values.py @@ -3,13 +3,20 @@ Instructions are in the instructions module. """ -import functools +from __future__ import annotations + import string import re +from typing import TYPE_CHECKING, Any, Iterable, Iterator +from typing import Type as PyType from llvmlite.ir import values, types, _utils from llvmlite.ir._utils import (_StrCaching, _StringReferenceCaching, _HasMetadata) +from llvmlite.ir.module import Module + +if TYPE_CHECKING: + from llvmlite.ir.instructions import Instruction _VALID_CHARS = (frozenset(map(ord, string.ascii_letters)) | @@ -28,7 +35,10 @@ } -def _escape_string(text, _map={}): +def _escape_string( + text: str | bytes | bytearray, + _map: dict[int, str] = {}, +) -> str: """ Escape the given bytestring for safe use as a LLVM array constant. Any unicode string input is first encoded with utf8 into bytes. @@ -48,176 +58,180 @@ def _escape_string(text, _map={}): return ''.join(buf) -def _binop(opname): - def wrap(fn): - @functools.wraps(fn) - def wrapped(lhs, rhs): - if lhs.type != rhs.type: - raise ValueError("Operands must be the same type, got (%s, %s)" - % (lhs.type, rhs.type)) - - fmt = "{0} ({1} {2}, {3} {4})".format(opname, - lhs.type, lhs.get_reference(), - rhs.type, rhs.get_reference()) - return FormattedConstant(lhs.type, fmt) - - return wrapped - return wrap - - -def _castop(opname): - def wrap(fn): - @functools.wraps(fn) - def wrapped(self, typ): - fn(self, typ) - if typ == self.type: - return self - - op = "{0} ({1} {2} to {3})".format(opname, self.type, - self.get_reference(), typ) - return FormattedConstant(typ, op) - - return wrapped - return wrap - - class _ConstOpMixin(object): """ A mixin defining constant operations, for use in constant-like classes. """ + # FIXME: self.type missing. How to communicate that Mixedin needs this? + # # Arithmetic APIs # - @_binop('shl') - def shl(self, other): + def _binop(self, opname: str, rhs: Constant) -> FormattedConstant: + if self.type != rhs.type: # type: ignore + raise ValueError( + "Operands must be the same type, got ({0}, {1})".format( + self.type, # type: ignore + rhs.type, + ) + ) + fmt = "{0} ({1} {2}, {3} {4})".format( + opname, + self.type, # type: ignore + self.get_reference(), # type: ignore + rhs.type, + rhs.get_reference(), + ) + return FormattedConstant( + self.type, # type: ignore + fmt, + ) + + def _castop(self, opname: str, typ: types.Type) -> Constant: + # Returns Constant | FormattedConstant, which is a child of Constant + if typ == self.type: # type: ignore + return self # type: ignore + + op = "{0} ({1} {2} to {3})".format( + opname, + self.type, # type: ignore + self.get_reference(), # type: ignore + typ, + ) + return FormattedConstant(typ, op) + + def shl(self, other: Constant) -> FormattedConstant: """ Left integer shift: lhs << rhs """ + return self._binop("shl", other) - @_binop('lshr') - def lshr(self, other): + def lshr(self, other: Constant) -> FormattedConstant: """ Logical (unsigned) right integer shift: lhs >> rhs """ + return self._binop("lshr", other) - @_binop('ashr') - def ashr(self, other): + def ashr(self, other: Constant) -> FormattedConstant: """ Arithmetic (signed) right integer shift: lhs >> rhs """ + return self._binop("ashr", other) - @_binop('add') - def add(self, other): + def add(self, other: Constant) -> FormattedConstant: """ Integer addition: lhs + rhs """ + return self._binop("add", other) - @_binop('fadd') - def fadd(self, other): + def fadd(self, other: Constant) -> FormattedConstant: """ Floating-point addition: lhs + rhs """ + return self._binop("fadd", other) - @_binop('sub') - def sub(self, other): + def sub(self, other: Constant) -> FormattedConstant: """ Integer subtraction: lhs - rhs """ + return self._binop("sub", other) - @_binop('fsub') - def fsub(self, other): + def fsub(self, other: Constant) -> FormattedConstant: """ Floating-point subtraction: lhs - rhs """ + return self._binop("fsub", other) - @_binop('mul') - def mul(self, other): + def mul(self, other: Constant) -> FormattedConstant: """ Integer multiplication: lhs * rhs """ + return self._binop("mul", other) - @_binop('fmul') - def fmul(self, other): + def fmul(self, other: Constant) -> FormattedConstant: """ Floating-point multiplication: lhs * rhs """ + return self._binop("fmul", other) - @_binop('udiv') - def udiv(self, other): + def udiv(self, other: Constant) -> FormattedConstant: """ Unsigned integer division: lhs / rhs """ + return self._binop("udiv", other) - @_binop('sdiv') - def sdiv(self, other): + def sdiv(self, other: Constant) -> FormattedConstant: """ Signed integer division: lhs / rhs """ + return self._binop("sdiv", other) - @_binop('fdiv') - def fdiv(self, other): + def fdiv(self, other: Constant) -> FormattedConstant: """ Floating-point division: lhs / rhs """ + return self._binop("fdiv", other) - @_binop('urem') - def urem(self, other): + def urem(self, other: Constant) -> FormattedConstant: """ Unsigned integer remainder: lhs % rhs """ + return self._binop("urem", other) - @_binop('srem') - def srem(self, other): + def srem(self, other: Constant) -> FormattedConstant: """ Signed integer remainder: lhs % rhs """ + return self._binop("srem", other) - @_binop('frem') - def frem(self, other): + def frem(self, other: Constant) -> FormattedConstant: """ Floating-point remainder: lhs % rhs """ + return self._binop("frem", other) - @_binop('or') - def or_(self, other): + def or_(self, other: Constant) -> FormattedConstant: """ Bitwise integer OR: lhs | rhs """ + return self._binop("or", other) - @_binop('and') - def and_(self, other): + def and_(self, other: Constant) -> FormattedConstant: """ Bitwise integer AND: lhs & rhs """ + return self._binop("and", other) - @_binop('xor') - def xor(self, other): + def xor(self, other: Constant) -> FormattedConstant: """ Bitwise integer XOR: lhs ^ rhs """ + return self._binop("xor", other) - def _cmp(self, prefix, sign, cmpop, other): - ins = prefix + 'cmp' + def _cmp( + self, prefix: str, sign: str, cmpop: str, other: Constant + ) -> FormattedConstant: + ins = prefix + "cmp" try: op = _CMP_MAP[cmpop] except KeyError: @@ -226,18 +240,26 @@ def _cmp(self, prefix, sign, cmpop, other): if not (prefix == 'i' and cmpop in ('==', '!=')): op = sign + op - if self.type != other.type: - raise ValueError("Operands must be the same type, got (%s, %s)" - % (self.type, other.type)) + if self.type != other.type: # type: ignore + raise ValueError( + "Operands must be the same type, got ({0}, {1})".format( + self.type, # type: ignore + other.type, + ) + ) fmt = "{0} {1} ({2} {3}, {4} {5})".format( - ins, op, - self.type, self.get_reference(), - other.type, other.get_reference()) + ins, + op, + self.type, # type: ignore + self.get_reference(), # type: ignore + other.type, + other.get_reference(), + ) return FormattedConstant(types.IntType(1), fmt) - def icmp_signed(self, cmpop, other): + def icmp_signed(self, cmpop: str, other: Constant) -> FormattedConstant: """ Signed integer comparison: lhs rhs @@ -246,7 +268,7 @@ def icmp_signed(self, cmpop, other): """ return self._cmp('i', 's', cmpop, other) - def icmp_unsigned(self, cmpop, other): + def icmp_unsigned(self, cmpop: str, other: Constant) -> FormattedConstant: """ Unsigned integer (or pointer) comparison: lhs rhs @@ -255,7 +277,7 @@ def icmp_unsigned(self, cmpop, other): """ return self._cmp('i', 'u', cmpop, other) - def fcmp_ordered(self, cmpop, other): + def fcmp_ordered(self, cmpop: str, other: Constant) -> FormattedConstant: """ Floating-point ordered comparison: lhs rhs @@ -264,7 +286,7 @@ def fcmp_ordered(self, cmpop, other): """ return self._cmp('f', 'o', cmpop, other) - def fcmp_unordered(self, cmpop, other): + def fcmp_unordered(self, cmpop: str, other: Constant) -> FormattedConstant: """ Floating-point unordered comparison: lhs rhs @@ -277,141 +299,144 @@ def fcmp_unordered(self, cmpop, other): # Unary APIs # - def not_(self): + def not_(self) -> FormattedConstant: """ Bitwise integer complement: ~value """ - if isinstance(self.type, types.VectorType): - rhs = values.Constant(self.type, (-1,) * self.type.count) + if isinstance(self.type, types.VectorType): # type: ignore + rhs = values.Constant(self.type, (-1,) * self.type.count) # type: ignore else: - rhs = values.Constant(self.type, -1) + rhs = values.Constant(self.type, -1) # type: ignore return self.xor(rhs) - def neg(self): + def neg(self) -> FormattedConstant: """ Integer negative: -value """ - zero = values.Constant(self.type, 0) - return zero.sub(self) + zero = values.Constant(self.type, 0) # type: ignore + return zero.sub(self) # type: ignore - def fneg(self): + def fneg(self) -> FormattedConstant: """ Floating-point negative: -value """ - fmt = "fneg ({0} {1})".format(self.type, self.get_reference()) - return FormattedConstant(self.type, fmt) + fmt = "fneg ({0} {1})".format(self.type, self.get_reference()) # type: ignore + return FormattedConstant(self.type, fmt) # type: ignore # # Cast APIs # - @_castop('trunc') - def trunc(self, typ): + def trunc(self, typ: types.Type) -> Constant: """ Truncating integer downcast to a smaller type. """ + return self._castop("trunc", typ) - @_castop('zext') - def zext(self, typ): + def zext(self, typ: types.Type) -> Constant: """ Zero-extending integer upcast to a larger type """ + return self._castop("zext", typ) - @_castop('sext') - def sext(self, typ): + def sext(self, typ: types.Type) -> Constant: """ Sign-extending integer upcast to a larger type. """ + return self._castop("sext", typ) - @_castop('fptrunc') - def fptrunc(self, typ): + def fptrunc(self, typ: types.Type) -> Constant: """ Floating-point downcast to a less precise type. """ + return self._castop("fptrunc", typ) - @_castop('fpext') - def fpext(self, typ): + def fpext(self, typ: types.Type) -> Constant: """ Floating-point upcast to a more precise type. """ + return self._castop("fpext", typ) - @_castop('bitcast') - def bitcast(self, typ): + def bitcast(self, typ: types.Type) -> Constant: """ Pointer cast to a different pointer type. """ + return self._castop("bitcast", typ) - @_castop('fptoui') - def fptoui(self, typ): + def fptoui(self, typ: types.Type) -> Constant: """ Convert floating-point to unsigned integer. """ + return self._castop("fptoui", typ) - @_castop('uitofp') - def uitofp(self, typ): + def uitofp(self, typ: types.Type) -> Constant: """ Convert unsigned integer to floating-point. """ + return self._castop("uitofp", typ) - @_castop('fptosi') - def fptosi(self, typ): + def fptosi(self, typ: types.Type) -> Constant: """ Convert floating-point to signed integer. """ + return self._castop("fptosi", typ) - @_castop('sitofp') - def sitofp(self, typ): + def sitofp(self, typ: types.Type) -> Constant: """ Convert signed integer to floating-point. """ + return self._castop("sitofp", typ) - @_castop('ptrtoint') - def ptrtoint(self, typ): + def ptrtoint(self, typ: types.Type) -> Constant: """ Cast pointer to integer. """ - if not isinstance(self.type, types.PointerType): + if not isinstance(self.type, types.PointerType): # type: ignore msg = "can only call ptrtoint() on pointer type, not '%s'" - raise TypeError(msg % (self.type,)) + raise TypeError(msg % (self.type,)) # type: ignore if not isinstance(typ, types.IntType): raise TypeError("can only ptrtoint() to integer type, not '%s'" % (typ,)) + return self._castop("ptrtoint", typ) - @_castop('inttoptr') - def inttoptr(self, typ): + def inttoptr(self, typ: types.Type) -> Constant: """ Cast integer to pointer. """ - if not isinstance(self.type, types.IntType): + if not isinstance(self.type, types.IntType): # type: ignore msg = "can only call inttoptr() on integer constants, not '%s'" - raise TypeError(msg % (self.type,)) + raise TypeError(msg % (self.type,)) # type: ignore if not isinstance(typ, types.PointerType): raise TypeError("can only inttoptr() to pointer type, not '%s'" % (typ,)) + return self._castop("inttoptr", typ) - def gep(self, indices): + def gep(self, indices: list[Constant]) -> FormattedConstant: """ Call getelementptr on this pointer constant. """ - if not isinstance(self.type, types.PointerType): + if not isinstance(self.type, types.PointerType): # type: ignore raise TypeError("can only call gep() on pointer constants, not '%s'" - % (self.type,)) + % (self.type,)) # type: ignore - outtype = self.type + outtype = self.type # type: ignore for i in indices: - outtype = outtype.gep(i) + outtype = outtype.gep(i) # type: ignore strindices = ["{0} {1}".format(idx.type, idx.get_reference()) for idx in indices] op = "getelementptr ({0}, {1} {2}, {3})".format( - self.type.pointee, self.type, - self.get_reference(), ', '.join(strindices)) - return FormattedConstant(outtype.as_pointer(self.addrspace), op) + self.type.pointee, # type: ignore + self.type, # type: ignore + self.get_reference(), # type: ignore + ", ".join(strindices), + ) + return FormattedConstant(outtype.as_pointer(self.addrspace), op) # type: ignore class Value(object): @@ -419,15 +444,21 @@ class Value(object): The base class for all values. """ - def __repr__(self): - return "" % (self.__class__.__name__, self.type,) + # FIXME: neither .type nor .get_reference() are defined here + + def __repr__(self) -> str: + return "".format( + self.__class__.__name__, + self.type, # type: ignore + ) class _Undefined(object): """ 'undef': a value for undefined values. """ - def __new__(cls): + + def __new__(cls) -> _Undefined: try: return Undefined except NameError: @@ -442,17 +473,28 @@ class Constant(_StrCaching, _StringReferenceCaching, _ConstOpMixin, Value): A constant LLVM value. """ - def __init__(self, typ, constant): + def __init__( + self, + typ: types.Type, + constant: types.Type + | int + | str + | tuple[int, ...] + | list[Constant] + | _Undefined + | bytearray + | None, + ) -> None: assert isinstance(typ, types.Type) assert not isinstance(typ, types.VoidType) self.type = typ - constant = typ.wrap_constant_value(constant) + constant = typ.wrap_constant_value(constant) # type: ignore self.constant = constant - def _to_string(self): + def _to_string(self) -> str: return '{0} {1}'.format(self.type, self.get_reference()) - def _get_reference(self): + def _get_reference(self) -> str: if self.constant is None: val = self.type.null @@ -463,12 +505,12 @@ def _get_reference(self): val = 'c"{0}"'.format(_escape_string(self.constant)) else: - val = self.type.format_constant(self.constant) + val = self.type.format_constant(self.constant) # type: ignore return val @classmethod - def literal_array(cls, elems): + def literal_array(cls: PyType[Constant], elems: list[Constant]) -> Constant: """ Construct a literal array constant made of the given members. """ @@ -482,7 +524,7 @@ def literal_array(cls, elems): return cls(types.ArrayType(ty, len(elems)), elems) @classmethod - def literal_struct(cls, elems): + def literal_struct(cls, elems: list[Constant]) -> Constant: """ Construct a literal structure constant made of the given members. """ @@ -490,24 +532,24 @@ def literal_struct(cls, elems): return cls(types.LiteralStructType(tys), elems) @property - def addrspace(self): + def addrspace(self) -> int: if not isinstance(self.type, types.PointerType): raise TypeError("Only pointer constant have address spaces") return self.type.addrspace - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Constant): return str(self) == str(other) else: return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash(str(self)) - def __repr__(self): + def __repr__(self) -> str: return "" % (self.type, self.constant) @@ -516,69 +558,75 @@ class FormattedConstant(Constant): A constant with an already formatted IR representation. """ - def __init__(self, typ, constant): + def __init__(self, typ: types.Type, constant: str) -> None: assert isinstance(constant, str) Constant.__init__(self, typ, constant) - def _to_string(self): - return self.constant + def _to_string(self) -> str: + # FIXME: self.constant can be types.Type! + return self.constant # type: ignore - def _get_reference(self): - return self.constant + def _get_reference(self) -> str: + # FIXME: self.constant can be types.Type! + return self.constant # type: ignore class NamedValue(_StrCaching, _StringReferenceCaching, Value): """ The base class for named values. """ + _name: str name_prefix = '%' deduplicate_name = True - def __init__(self, parent, type, name): + def __init__( + self, parent: Module | Function | Block, type: types.Type, name: str + ) -> None: assert parent is not None assert isinstance(type, types.Type) self.parent = parent self.type = type - self._set_name(name) + self.name = name - def _to_string(self): - buf = [] + def _to_string(self) -> str: + buf: list[str] = [] if not isinstance(self.type, types.VoidType): buf.append("{0} = ".format(self.get_reference())) self.descr(buf) return "".join(buf).rstrip() - def descr(self, buf): + def descr(self, buf: list[str]) -> None: raise NotImplementedError - def _get_name(self): + @property + def name(self) -> str: return self._name - def _set_name(self, name): + @name.setter + def name(self, name: str) -> None: name = self.parent.scope.register(name, deduplicate=self.deduplicate_name) self._name = name - name = property(_get_name, _set_name) - def _get_reference(self): + def _get_reference(self) -> str: name = self.name # Quote and escape value name if '\\' in name or '"' in name: name = name.replace('\\', '\\5c').replace('"', '\\22') return '{0}"{1}"'.format(self.name_prefix, name) - def __repr__(self): + def __repr__(self) -> str: return "" % ( self.__class__.__name__, self.name, self.type) @property - def function_type(self): + def function_type(self) -> types.FunctionType: ty = self.type if isinstance(ty, types.PointerType): - ty = self.type.pointee + ty = self.type.pointee # type: ignore if isinstance(ty, types.FunctionType): - return ty + return ty # type: ignore else: raise TypeError("Not a function: {0}".format(self.type)) @@ -589,30 +637,28 @@ class MetaDataString(NamedValue): node. """ - def __init__(self, parent, string): - super(MetaDataString, self).__init__(parent, - types.MetaDataType(), - name="") + def __init__(self, parent: Module | Block, string: str) -> None: + super(MetaDataString, self).__init__(parent, types.MetaDataType(), name="") self.string = string - def descr(self, buf): + def descr(self, buf: list[str]) -> None: buf += (self.get_reference(), "\n") - def _get_reference(self): + def _get_reference(self) -> str: return '!"{0}"'.format(_escape_string(self.string)) _to_string = _get_reference - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, MetaDataString): return self.string == other.string else: return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash(self.string) @@ -625,16 +671,18 @@ class MetaDataArgument(_StrCaching, _StringReferenceCaching, Value): automatically. """ - def __init__(self, value): + def __init__(self, value: Value) -> None: assert isinstance(value, Value) - assert not isinstance(value.type, types.MetaDataType) + assert not isinstance(value.type, types.MetaDataType) # type: ignore self.type = types.MetaDataType() self.wrapped_value = value - def _get_reference(self): + def _get_reference(self) -> str: # e.g. "i32* %2" - return "{0} {1}".format(self.wrapped_value.type, - self.wrapped_value.get_reference()) + return "{0} {1}".format( + self.wrapped_value.type, # type: ignore + self.wrapped_value.get_reference(), # type: ignore + ) _to_string = _get_reference @@ -646,11 +694,11 @@ class NamedMetaData(object): Do not instantiate directly, use Module.add_named_metadata() instead. """ - def __init__(self, parent): + def __init__(self, parent: Block) -> None: self.parent = parent - self.operands = [] + self.operands: list[MDValue] = [] - def add(self, md): + def add(self, md: MDValue) -> None: self.operands.append(md) @@ -662,15 +710,15 @@ class MDValue(NamedValue): """ name_prefix = '!' - def __init__(self, parent, values, name): + def __init__(self, parent: Module, values: list[Constant], name: str) -> None: super(MDValue, self).__init__(parent, types.MetaDataType(), name=name) self.operands = tuple(values) parent.metadata.append(self) - def descr(self, buf): - operands = [] + def descr(self, buf: list[str]) -> None: + operands: list[str] = [] for op in self.operands: if isinstance(op.type, types.MetaDataType): if isinstance(op, Constant) and op.constant is None: @@ -679,22 +727,21 @@ def descr(self, buf): operands.append(op.get_reference()) else: operands.append("{0} {1}".format(op.type, op.get_reference())) - operands = ', '.join(operands) - buf += ("!{{ {0} }}".format(operands), "\n") + buf += ("!{{ {0} }}".format(', '.join(operands)), "\n") - def _get_reference(self): + def _get_reference(self) -> str: return self.name_prefix + str(self.name) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, MDValue): return self.operands == other.operands else: return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash(self.operands) @@ -706,7 +753,7 @@ class DIToken: Use this to wrap known constants, e.g. the DW_* enumerations. """ - def __init__(self, value): + def __init__(self, value: Constant) -> None: self.value = value @@ -718,16 +765,23 @@ class DIValue(NamedValue): """ name_prefix = '!' - def __init__(self, parent, is_distinct, kind, operands, name): + def __init__( + self, + parent: Module, + is_distinct: bool, + kind: str, + operands: Iterable[tuple[str, Value]], + name: str, + ) -> None: super(DIValue, self).__init__(parent, types.MetaDataType(), name=name) self.is_distinct = is_distinct self.kind = kind - self.operands = tuple(operands) + self.operands: tuple[tuple[str, Value], ...] = tuple(operands) parent.metadata.append(self) - def descr(self, buf): + def descr(self, buf: list[str]) -> None: if self.is_distinct: buf += ("distinct ",) operands = [] @@ -741,22 +795,20 @@ def descr(self, buf): elif isinstance(value, DIToken): strvalue = value.value elif isinstance(value, str): - strvalue = '"{}"'.format(_escape_string(value)) + strvalue = f'"{_escape_string(value)}"' # type: ignore elif isinstance(value, int): - strvalue = str(value) + strvalue = str(value) # type: ignore elif isinstance(value, NamedValue): - strvalue = value.get_reference() + strvalue = value.get_reference() # type: ignore else: - raise TypeError("invalid operand type for debug info: %r" - % (value,)) - operands.append("{0}: {1}".format(key, strvalue)) - operands = ', '.join(operands) - buf += ("!", self.kind, "(", operands, ")\n") + raise TypeError(f"invalid operand type for debug info: {value!r}") + operands.append(f"{key}: {strvalue}") # type: ignore + buf += ("!", self.kind, "(", ", ".join(operands), ")\n") - def _get_reference(self): + def _get_reference(self) -> str: return self.name_prefix + str(self.name) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, DIValue): return self.is_distinct == other.is_distinct and \ self.kind == other.kind and \ @@ -764,10 +816,10 @@ def __eq__(self, other): else: return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: return not self.__eq__(other) - def __hash__(self): + def __hash__(self) -> int: return hash((self.is_distinct, self.kind, self.operands)) @@ -778,12 +830,17 @@ class GlobalValue(NamedValue, _ConstOpMixin, _HasMetadata): name_prefix = '@' deduplicate_name = False - def __init__(self, *args, **kwargs): - super(GlobalValue, self).__init__(*args, **kwargs) + def __init__( + self, + parent: Module | Block | Function, + type: types.Type, + name: str, + ) -> None: + super(GlobalValue, self).__init__(parent=parent, type=type, name=name) self.linkage = '' self.storage_class = '' self.section = '' - self.metadata = {} + self.metadata = {} # type: ignore class GlobalVariable(GlobalValue): @@ -791,7 +848,9 @@ class GlobalVariable(GlobalValue): A global variable. """ - def __init__(self, module, typ, name, addrspace=0): + def __init__( + self, module: Module, typ: types.Type, name: str, addrspace: int = 0 + ) -> None: assert isinstance(typ, types.Type) super(GlobalVariable, self).__init__(module, typ.as_pointer(addrspace), name=name) @@ -801,9 +860,9 @@ def __init__(self, module, typ, name, addrspace=0): self.global_constant = False self.addrspace = addrspace self.align = None - self.parent.add_global(self) + self.parent.add_global(self) # type: ignore - def descr(self, buf): + def descr(self, buf: list[str]) -> None: if self.global_constant: kind = 'constant' else: @@ -834,7 +893,7 @@ def descr(self, buf): buf.append(" " + self.initializer.get_reference()) elif linkage not in ('external', 'extern_weak'): # emit 'undef' for non-external linkage GV - buf.append(" " + self.value_type(Undefined).get_reference()) + buf.append(" " + self.value_type(Undefined).get_reference()) # type: ignore if self.section: buf.append(", section \"%s\"" % (self.section,)) @@ -848,7 +907,7 @@ def descr(self, buf): buf.append("\n") -class AttributeSet(set): +class AttributeSet(set): # type: ignore """A set of string attribute. Only accept items listed in *_known*. @@ -857,24 +916,25 @@ class AttributeSet(set): """ _known = () - def __init__(self, args=()): + def __init__(self, args: tuple[str, ...] | str = ()) -> None: if isinstance(args, str): - args = [args] + args = [args] # type: ignore for name in args: self.add(name) - def add(self, name): + def add(self, name: str) -> None: if name not in self._known: raise ValueError('unknown attr {!r} for {}'.format(name, self)) return super(AttributeSet, self).add(name) - def __iter__(self): + def __iter__(self) -> Iterator[str]: # In sorted order return iter(sorted(super(AttributeSet, self).__iter__())) class FunctionAttributes(AttributeSet): - _known = frozenset([ + # FIXME: tuple in parent class + _known = frozenset([ # type: ignore 'argmemonly', 'alwaysinline', 'builtin', 'cold', 'inaccessiblememonly', 'inaccessiblemem_or_argmemonly', 'inlinehint', 'jumptable', 'minsize', 'naked', 'nobuiltin', 'noduplicate', @@ -884,12 +944,12 @@ class FunctionAttributes(AttributeSet): 'sanitize_memory', 'sanitize_thread', 'ssp', 'sspreg', 'sspstrong', 'uwtable']) - def __init__(self, args=()): - self._alignstack = 0 - self._personality = None + def __init__(self, args: tuple[str, ...] = ()) -> None: + self._alignstack: int = 0 + self._personality: GlobalValue | None = None super(FunctionAttributes, self).__init__(args) - def add(self, name): + def add(self, name: str) -> None: if ((name == 'alwaysinline' and 'noinline' in self) or (name == 'noinline' and 'alwaysinline' in self)): raise ValueError("Can't have alwaysinline and noinline") @@ -897,25 +957,25 @@ def add(self, name): super().add(name) @property - def alignstack(self): - return self._alignstack + def alignstack(self) -> int: + return self._alignstack # type: ignore @alignstack.setter - def alignstack(self, val): + def alignstack(self, val: int) -> None: assert val >= 0 self._alignstack = val @property - def personality(self): + def personality(self) -> GlobalValue | None: return self._personality @personality.setter - def personality(self, val): + def personality(self, val: GlobalValue | None) -> None: assert val is None or isinstance(val, GlobalValue) self._personality = val - def __repr__(self): - attrs = list(self) + def __repr__(self) -> str: + attrs: list[str] = list(self) if self.alignstack: attrs.append('alignstack({0:d})'.format(self.alignstack)) if self.personality: @@ -930,44 +990,45 @@ class Function(GlobalValue): Global Values are stored as a set of dependencies (attribute `depends`). """ - def __init__(self, module, ftype, name): + parent: Module + + def __init__(self, module: Module, ftype: types.FunctionType, name: str) -> None: assert isinstance(ftype, types.Type) super(Function, self).__init__(module, ftype.as_pointer(), name=name) self.ftype = ftype self.scope = _utils.NameScope() - self.blocks = [] + self.blocks: list[Block] = [] self.attributes = FunctionAttributes() - self.args = tuple([Argument(self, t) - for t in ftype.args]) - self.return_value = ReturnValue(self, ftype.return_type) - self.parent.add_global(self) + self.args = tuple([Argument(self, t) for t in ftype.args]) # type: ignore + self.return_value = ReturnValue(self, ftype.return_type) # type: ignore + self.parent.add_global(self) # type: ignore self.calling_convention = '' @property - def module(self): + def module(self) -> Module: return self.parent @property - def entry_basic_block(self): + def entry_basic_block(self) -> Block: return self.blocks[0] @property - def basic_blocks(self): + def basic_blocks(self) -> list[Block]: return self.blocks - def append_basic_block(self, name=''): + def append_basic_block(self, name: str = "") -> Block: blk = Block(parent=self, name=name) self.blocks.append(blk) return blk - def insert_basic_block(self, before, name=''): + def insert_basic_block(self, before: int, name: str = "") -> Block: """Insert block before """ blk = Block(parent=self, name=name) self.blocks.insert(before, blk) return blk - def descr_prototype(self, buf): + def descr_prototype(self, buf: list[str]) -> None: """ Describe the prototype ("head") of the function. """ @@ -975,8 +1036,8 @@ def descr_prototype(self, buf): ret = self.return_value args = ", ".join(str(a) for a in self.args) name = self.get_reference() - attrs = self.attributes - attrs = ' {}'.format(attrs) if attrs else '' + _attrs = self.attributes + attrs = ' {}'.format(_attrs) if _attrs else '' if any(self.args): vararg = ', ...' if self.ftype.var_arg else '' else: @@ -993,70 +1054,70 @@ def descr_prototype(self, buf): metadata=metadata) buf.append(prototype) - def descr_body(self, buf): + def descr_body(self, buf: list[str]) -> None: """ Describe of the body of the function. """ for blk in self.blocks: blk.descr(buf) - def descr(self, buf): + def descr(self, buf: list[str]) -> None: self.descr_prototype(buf) if self.blocks: buf.append("{\n") self.descr_body(buf) buf.append("}\n") - def __str__(self): - buf = [] + def __str__(self) -> str: + buf: list[str] = [] self.descr(buf) return "".join(buf) @property - def is_declaration(self): + def is_declaration(self) -> bool: return len(self.blocks) == 0 class ArgumentAttributes(AttributeSet): - _known = frozenset(['byval', 'inalloca', 'inreg', 'nest', 'noalias', + _known = frozenset(['byval', 'inalloca', 'inreg', 'nest', 'noalias', # type: ignore 'nocapture', 'nonnull', 'returned', 'signext', 'sret', 'zeroext']) - def __init__(self, args=()): + def __init__(self, args: tuple[Any, ...] = ()) -> None: self._align = 0 self._dereferenceable = 0 self._dereferenceable_or_null = 0 super(ArgumentAttributes, self).__init__(args) @property - def align(self): + def align(self) -> int: return self._align @align.setter - def align(self, val): + def align(self, val: int) -> None: assert isinstance(val, int) and val >= 0 self._align = val @property - def dereferenceable(self): + def dereferenceable(self) -> int: return self._dereferenceable @dereferenceable.setter - def dereferenceable(self, val): + def dereferenceable(self, val: int) -> None: assert isinstance(val, int) and val >= 0 self._dereferenceable = val @property - def dereferenceable_or_null(self): + def dereferenceable_or_null(self) -> int: return self._dereferenceable_or_null @dereferenceable_or_null.setter - def dereferenceable_or_null(self, val): + def dereferenceable_or_null(self, val: int) -> None: assert isinstance(val, int) and val >= 0 self._dereferenceable_or_null = val - def _to_list(self): - attrs = sorted(self) + def _to_list(self) -> list[str]: + attrs: list[str] = sorted(self) if self.align: attrs.append('align {0:d}'.format(self.align)) if self.dereferenceable: @@ -1068,17 +1129,19 @@ def _to_list(self): class _BaseArgument(NamedValue): - def __init__(self, parent, typ, name=''): - assert isinstance(typ, types.Type) - super(_BaseArgument, self).__init__(parent, typ, name=name) - self.parent = parent + def __init__(self, parent: Block, type: types.Type, name: str = "") -> None: + assert isinstance(type, types.Type) + super(_BaseArgument, self).__init__(parent, typ, name=name) # type: ignore self.attributes = ArgumentAttributes() - def __repr__(self): - return "" % (self.__class__.__name__, self.name, - self.type) + def __repr__(self) -> str: + return "".format( + self.__class__.__name__, + self.name, + self.type, # type: ignore + ) - def add_attribute(self, attr): + def add_attribute(self, attr: str) -> None: self.attributes.add(attr) @@ -1087,7 +1150,7 @@ class Argument(_BaseArgument): The specification of a function argument. """ - def __str__(self): + def __str__(self) -> str: attrs = self.attributes._to_list() if attrs: return "{0} {1} {2}".format(self.type, ' '.join(attrs), @@ -1101,7 +1164,7 @@ class ReturnValue(_BaseArgument): The specification of a function's return value. """ - def __str__(self): + def __str__(self) -> str: attrs = self.attributes._to_list() if attrs: return "{0} {1}".format(' '.join(attrs), self.type) @@ -1118,41 +1181,45 @@ class Block(NamedValue): instruction. """ - def __init__(self, parent, name=''): - super(Block, self).__init__(parent, types.LabelType(), name=name) + def __init__(self, parent: Function, name: str = "") -> None: + super(Block, self).__init__(parent=parent, type=types.LabelType(), name=name) self.scope = parent.scope - self.instructions = [] - self.terminator = None + self.instructions: list[Instruction] = [] + self.terminator: Instruction | None = None @property - def is_terminated(self): + def is_terminated(self) -> bool: return self.terminator is not None @property - def function(self): - return self.parent + def function(self) -> Function: + parent = self.parent + # FIXME: sure about this? + if not isinstance(parent, Function): + raise AttributeError("Parent must be function") + return parent @property - def module(self): - return self.parent.module + def module(self) -> Module: + return self.parent.module # type: ignore - def descr(self, buf): + def descr(self, buf: list[str]) -> None: buf.append("{0}:\n".format(self._format_name())) buf += [" {0}\n".format(instr) for instr in self.instructions] - def replace(self, old, new): + def replace(self, old: Constant, new: Constant) -> None: """Replace an instruction""" if old.type != new.type: raise TypeError("new instruction has a different type") - pos = self.instructions.index(old) - self.instructions.remove(old) - self.instructions.insert(pos, new) + pos = self.instructions.index(old) # type: ignore + self.instructions.remove(old) # type: ignore + self.instructions.insert(pos, new) # type: ignore - for bb in self.parent.basic_blocks: - for instr in bb.instructions: - instr.replace_usage(old, new) + for bb in self.parent.basic_blocks: # type: ignore + for instr in bb.instructions: # type: ignore + instr.replace_usage(old, new) # type: ignore - def _format_name(self): + def _format_name(self) -> str: # Per the LLVM Language Ref on identifiers, names matching the following # regex do not need to be quoted: [%@][-a-zA-Z$._][-a-zA-Z$._0-9]* # Otherwise, the identifier must be quoted and escaped. @@ -1168,17 +1235,17 @@ class BlockAddress(Value): The address of a basic block. """ - def __init__(self, function, basic_block): + def __init__(self, function: Function, basic_block: Block) -> None: assert isinstance(function, Function) assert isinstance(basic_block, Block) self.type = types.IntType(8).as_pointer() self.function = function self.basic_block = basic_block - def __str__(self): + def __str__(self) -> str: return '{0} {1}'.format(self.type, self.get_reference()) - def get_reference(self): + def get_reference(self) -> str: return "blockaddress({0}, {1})".format( self.function.get_reference(), self.basic_block.get_reference()) diff --git a/llvmlite/llvmpy/core.py b/llvmlite/llvmpy/core.py index 93e22ab1e..94ed5f34e 100644 --- a/llvmlite/llvmpy/core.py +++ b/llvmlite/llvmpy/core.py @@ -1,6 +1,12 @@ +from __future__ import annotations + +import builtins import itertools +from typing import Iterable from llvmlite import ir +from llvmlite.ir.instructions import FCMPInstr, ICMPInstr +from llvmlite.ir.values import GlobalVariable, MDValue, NamedMetaData from llvmlite import binding as llvm import warnings @@ -19,7 +25,7 @@ class LLVMException(Exception): _icmp_ct = itertools.count() -def _icmp_get(): +def _icmp_get() -> int: return next(_icmp_ct) @@ -69,124 +75,130 @@ def _icmp_get(): class Type(object): @staticmethod - def int(width=32): + def int(width: builtins.int = 32) -> ir.IntType: return ir.IntType(width) @staticmethod - def float(): - return ir.FloatType() + def float() -> ir.FloatType: + return ir.FloatType() # type: ignore @staticmethod - def half(): - return ir.HalfType() + def half() -> ir.HalfType: + return ir.HalfType() # type: ignore @staticmethod - def double(): - return ir.DoubleType() + def double() -> ir.DoubleType: + return ir.DoubleType() # type: ignore @staticmethod - def pointer(ty, addrspace=0): + def pointer( + ty: ir.IntType | ir.PointerType | ir.FunctionType, addrspace: builtins.int = 0 + ) -> ir.PointerType: return ir.PointerType(ty, addrspace) @staticmethod - def function(res, args, var_arg=False): + def function( + res: ir.Type, args: list[ir.Type], var_arg: bool = False + ) -> ir.FunctionType: return ir.FunctionType(res, args, var_arg=var_arg) @staticmethod - def struct(members): + def struct(members: Iterable[ir.Type]) -> ir.LiteralStructType: return ir.LiteralStructType(members) @staticmethod - def array(element, count): + def array(element: ir.Type, count: builtins.int) -> ir.ArrayType: return ir.ArrayType(element, count) @staticmethod - def void(): + def void() -> ir.VoidType: return ir.VoidType() class Constant(object): @staticmethod - def all_ones(ty): + def all_ones(ty: ir.IntType) -> ir.Constant: if isinstance(ty, ir.IntType): return Constant.int(ty, int('1' * ty.width, 2)) else: raise NotImplementedError(ty) @staticmethod - def int(ty, n): + def int(ty: ir.Type, n: builtins.int) -> ir.Constant: return ir.Constant(ty, n) @staticmethod - def int_signextend(ty, n): + def int_signextend(ty: ir.Type, n: builtins.int) -> ir.Constant: return ir.Constant(ty, n) @staticmethod - def real(ty, n): + def real(ty: ir.Type, n: builtins.int) -> ir.Constant: return ir.Constant(ty, n) @staticmethod - def struct(elems): + def struct(elems: list[ir.Constant]) -> ir.Constant: return ir.Constant.literal_struct(elems) @staticmethod - def null(ty): + def null(ty: ir.Type) -> ir.Constant: return ir.Constant(ty, None) @staticmethod - def undef(ty): + def undef(ty: ir.Type) -> ir.Constant: return ir.Constant(ty, ir.Undefined) @staticmethod - def stringz(string): - n = (len(string) + 1) - buf = bytearray((' ' * n).encode('ascii')) + def stringz(string: str) -> ir.Constant: + n = len(string) + 1 + buf = bytearray((" " * n).encode("ascii")) buf[-1] = 0 buf[:-1] = string.encode('utf-8') return ir.Constant(ir.ArrayType(ir.IntType(8), n), buf) @staticmethod - def array(typ, val): + def array(typ: ir.Type, val: list[ir.Constant]) -> ir.Constant: return ir.Constant(ir.ArrayType(typ, len(val)), val) @staticmethod - def bitcast(const, typ): - return const.bitcast(typ) + def bitcast(const: ir.Constant, typ: ir.Type) -> ir.FormattedConstant: + return const.bitcast(typ) # type: ignore @staticmethod - def inttoptr(const, typ): - return const.inttoptr(typ) + def inttoptr(const: ir.Constant, typ: ir.Type) -> ir.FormattedConstant: + return const.inttoptr(typ) # type: ignore @staticmethod - def gep(const, indices): + def gep(const: ir.Constant, indices: list[ir.Constant]) -> ir.FormattedConstant: return const.gep(indices) class Module(ir.Module): - def get_or_insert_function(self, fnty, name): + def get_or_insert_function(self, fnty: ir.FunctionType, name: str) -> ir.Function: if name in self.globals: - return self.globals[name] + return self.globals[name] # type: ignore else: return ir.Function(self, fnty, name) - def verify(self): - llvm.parse_assembly(str(self)) + def verify(self) -> None: + llvm.parse_assembly(str(self)) # type: ignore - def add_function(self, fnty, name): + def add_function(self, fnty: ir.FunctionType, name: str) -> ir.Function: return ir.Function(self, fnty, name) - def add_global_variable(self, ty, name, addrspace=0): + def add_global_variable( + self, ty: ir.Type, name: str, addrspace: builtins.int = 0 + ) -> ir.GlobalVariable: return ir.GlobalVariable(self, ty, self.get_unique_name(name), addrspace) - def get_global_variable_named(self, name): + def get_global_variable_named(self, name: str) -> ir.GlobalVariable: try: return self.globals[name] except KeyError: raise LLVMException(name) - def get_or_insert_named_metadata(self, name): + def get_or_insert_named_metadata(self, name: str) -> NamedMetaData: try: return self.get_named_metadata(name) except KeyError: @@ -196,11 +208,15 @@ def get_or_insert_named_metadata(self, name): class Function(ir.Function): @classmethod - def new(cls, module_obj, functy, name=''): + def new( + cls, module_obj: Module, functy: ir.FunctionType, name: str = "" + ) -> Function: return cls(module_obj, functy, name) @staticmethod - def intrinsic(module, intrinsic, tys): + def intrinsic( + module: Module, intrinsic: str, tys: list[ir.Type] + ) -> GlobalVariable | ir.Function: return module.declare_intrinsic(intrinsic, tys) @@ -242,33 +258,41 @@ def intrinsic(module, intrinsic, tys): class Builder(ir.IRBuilder): - - def icmp(self, pred, lhs, rhs, name=''): + def icmp( + self, pred: str, lhs: ir.Constant, rhs: ir.Constant, name: str = "" + ) -> ICMPInstr: if pred in _icmp_umap: - return self.icmp_unsigned(_icmp_umap[pred], lhs, rhs, name=name) + return self.icmp_unsigned(_icmp_umap[pred], lhs, rhs, name=name) # type: ignore else: - return self.icmp_signed(_icmp_smap[pred], lhs, rhs, name=name) + return self.icmp_signed(_icmp_smap[pred], lhs, rhs, name=name) # type: ignore - def fcmp(self, pred, lhs, rhs, name=''): + def fcmp( + self, pred: str, lhs: ir.Constant, rhs: ir.Constant, name: str = "" + ) -> FCMPInstr: if pred in _fcmp_umap: - return self.fcmp_unordered(_fcmp_umap[pred], lhs, rhs, name=name) + return self.fcmp_unordered(_fcmp_umap[pred], lhs, rhs, name=name) # type: ignore else: - return self.fcmp_ordered(_fcmp_omap[pred], lhs, rhs, name=name) + return self.fcmp_ordered(_fcmp_omap[pred], lhs, rhs, name=name) # type: ignore class MetaDataString(ir.MetaDataString): @staticmethod - def get(module, text): + def get(module: Module, text: str) -> MetaDataString: return MetaDataString(module, text) class MetaData(object): @staticmethod - def get(module, values): + def get(module: Module, values: list[ir.Constant]) -> MDValue: return module.add_metadata(values) class InlineAsm(ir.InlineAsm): @staticmethod - def get(*args, **kwargs): - return InlineAsm(*args, **kwargs) + def get( + ftype: ir.FunctionType, + asm: str, + constraint: ir.Value, + side_effect: bool = False, + ) -> InlineAsm: + return InlineAsm(ftype, asm, constraint, side_effect) diff --git a/llvmlite/llvmpy/passes.py b/llvmlite/llvmpy/passes.py index 873080c2a..6775ef5cc 100644 --- a/llvmlite/llvmpy/passes.py +++ b/llvmlite/llvmpy/passes.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Useful options to debug LLVM passes @@ -10,34 +12,43 @@ """ -from llvmlite import binding as llvm -from collections import namedtuple import warnings warnings.warn( "The module `llvmlite.llvmpy.passes` is deprecated and will be removed in " "the future. If you are using this code, it should be inlined into your " "own project.") +from typing import Any, NamedTuple, Union + +# from typing_extensions import Literal + +from llvmlite import binding as llvm +from llvmlite.binding.passmanagers import FunctionPassManager, PassManager +from llvmlite.binding.transforms import PassManagerBuilder +pms = NamedTuple( + "pms", [("pm", PassManager), ("fpm", Union[FunctionPassManager, None])] +) -def _inlining_threshold(optlevel, sizelevel=0): + +def _inlining_threshold(optlevel: int, sizelevel: int = 0) -> int: # Refer http://llvm.org/docs/doxygen/html/InlineSimple_8cpp_source.html if optlevel > 2: return 275 - # -Os if sizelevel == 1: return 75 - # -Oz if sizelevel == 2: return 25 - return 225 -def create_pass_manager_builder(opt=2, loop_vectorize=False, - slp_vectorize=False): +def create_pass_manager_builder( + opt: int = 2, # Literal[0, 1, 2, 3] + loop_vectorize: bool = False, + slp_vectorize: bool = False, +) -> PassManagerBuilder: pmb = llvm.create_pass_manager_builder() pmb.opt_level = opt pmb.loop_vectorize = loop_vectorize @@ -46,7 +57,7 @@ def create_pass_manager_builder(opt=2, loop_vectorize=False, return pmb -def build_pass_managers(**kws): +def build_pass_managers(**kws: dict[str, Any]) -> pms: mod = kws.get('mod') if not mod: raise NameError("module must be provided") @@ -60,34 +71,34 @@ def build_pass_managers(**kws): fpm = None with llvm.create_pass_manager_builder() as pmb: - pmb.opt_level = opt = kws.get('opt', 2) - pmb.loop_vectorize = kws.get('loop_vectorize', False) - pmb.slp_vectorize = kws.get('slp_vectorize', False) - pmb.inlining_threshold = _inlining_threshold(optlevel=opt) + pmb.opt_level = opt = kws.get('opt', 2) # type: ignore + pmb.loop_vectorize = kws.get('loop_vectorize', False) # type: ignore + pmb.slp_vectorize = kws.get('slp_vectorize', False) # type: ignore + pmb.inlining_threshold = _inlining_threshold(optlevel=opt) # type: ignore if mod: - tli = llvm.create_target_library_info(mod.triple) + tli = llvm.create_target_library_info(mod.triple) # type: ignore if kws.get('nobuiltins', False): # Disable all builtins (-fno-builtins) - tli.disable_all() + tli.disable_all() # type: ignore else: # Disable a list of builtins given for k in kws.get('disable_builtins', ()): - libf = tli.get_libfunc(k) - tli.set_unavailable(libf) + libf = tli.get_libfunc(k) # type: ignore + tli.set_unavailable(libf) # type: ignore - tli.add_pass(pm) + tli.add_pass(pm) # type: ignore if fpm is not None: - tli.add_pass(fpm) + tli.add_pass(fpm) # type: ignore tm = kws.get('tm') if tm: - tm.add_analysis_passes(pm) + tm.add_analysis_passes(pm) # type: ignore if fpm is not None: - tm.add_analysis_passes(fpm) + tm.add_analysis_passes(fpm) # type: ignore - pmb.populate(pm) + pmb.populate(pm) # type: ignore if fpm is not None: - pmb.populate(fpm) + pmb.populate(fpm) # type: ignore - return namedtuple("pms", ['pm', 'fpm'])(pm=pm, fpm=fpm) + return pms(pm=pm, fpm=fpm) diff --git a/llvmlite/py.typed b/llvmlite/py.typed new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/llvmlite/py.typed @@ -0,0 +1 @@ + diff --git a/llvmlite/utils.py b/llvmlite/utils.py index e07ecd370..6cb3b2d43 100644 --- a/llvmlite/utils.py +++ b/llvmlite/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys @@ -5,7 +7,8 @@ # This module must be importable without loading the binding, to avoid # bootstrapping issues in setup.py. -def get_library_name(): + +def get_library_name() -> str: """ Return the name of the llvmlite shared library file. """ @@ -19,7 +22,7 @@ def get_library_name(): return 'llvmlite.dll' -def get_library_files(): +def get_library_files() -> list[str]: """ Return the names of shared library files needed for this platform. """ From 54a872fd872ae2e87fc5cbf861f693e0ba50a062 Mon Sep 17 00:00:00 2001 From: Marvin van Aalst Date: Tue, 31 May 2022 16:45:20 +0200 Subject: [PATCH 2/2] renamed typ to type --- llvmlite/ir/values.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/llvmlite/ir/values.py b/llvmlite/ir/values.py index 26a94e89d..c457d8eb4 100644 --- a/llvmlite/ir/values.py +++ b/llvmlite/ir/values.py @@ -1131,15 +1131,12 @@ def _to_list(self) -> list[str]: class _BaseArgument(NamedValue): def __init__(self, parent: Block, type: types.Type, name: str = "") -> None: assert isinstance(type, types.Type) - super(_BaseArgument, self).__init__(parent, typ, name=name) # type: ignore + super(_BaseArgument, self).__init__(parent, type, name=name) self.attributes = ArgumentAttributes() def __repr__(self) -> str: - return "".format( - self.__class__.__name__, - self.name, - self.type, # type: ignore - ) + return "" % (self.__class__.__name__, self.name, + self.type) def add_attribute(self, attr: str) -> None: self.attributes.add(attr)