diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2b1908c..e25fb45 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,10 +23,12 @@ jobs: - name: Get code size - name: Train MNIST run: PYTHONPATH="." python mnist.py - - name: Install torch for testing - run: pip install torch --extra-index-url https://download.pytorch.org/whl/cpu + - name: Install mypy + torch for testing + run: pip install mypy torch --extra-index-url https://download.pytorch.org/whl/cpu - name: Test ops / dtype / optim run: | PYTHONPATH="." python test/test_ops.py PYTHONPATH="." python test/test_dtype.py PYTHONPATH="." python test/test_optim.py + - name: Check types with mypy + run: mypy \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..10d1df0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: + - repo: local + hooks: + - id: tests + name: tests + entry: env PYTHONPATH="." pytest test/ + language: system + always_run: true + pass_filenames: false + - id: mypy + name: mypy + entry: mypy + language: system + always_run: true + pass_filenames: false \ No newline at end of file diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..4e3c031 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,9 @@ +[mypy] +warn_unused_configs = True +files = teenygrad +ignore_missing_imports = True +check_untyped_defs = True +explicit_package_bases = True +warn_unreachable = True +warn_redundant_casts = True +warn_unused_ignores = True \ No newline at end of file diff --git a/teenygrad/helpers.py b/teenygrad/helpers.py index e11d219..7ad89af 100644 --- a/teenygrad/helpers.py +++ b/teenygrad/helpers.py @@ -10,7 +10,7 @@ def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x def flatten(l:Iterator): return [item for sublist in l for item in sublist] def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python -def all_int(t: Tuple[Any, ...]) -> Tuple[int, ...]: return all(isinstance(s, int) for s in t) +def all_int(t: Tuple[Any, ...]) -> bool: return all(isinstance(s, int) for s in t) def round_up(num, amt:int): return (num+amt-1)//amt * amt @functools.lru_cache(maxsize=None) diff --git a/teenygrad/lazy.py b/teenygrad/lazy.py index 84cc4e6..41ee96e 100644 --- a/teenygrad/lazy.py +++ b/teenygrad/lazy.py @@ -12,6 +12,8 @@ class LazyBuffer: def __init__(self, buf: np.ndarray): self._np = buf + @property + def base(self): return self @property def dtype(self): return dtypes.from_np(self._np.dtype) @property @@ -21,7 +23,8 @@ def shape(self): return self._np.shape def __repr__(self): return f"" def schedule(self, seen=None): return [] - def is_unrealized_const(self): return False + def is_unrealized_contiguous_const(self): return False + def copy_to_device(self, device:str) -> LazyBuffer: return self @staticmethod def fromCPU(x): return LazyBuffer(x) diff --git a/teenygrad/ops.py b/teenygrad/ops.py index 308d966..e72c98a 100644 --- a/teenygrad/ops.py +++ b/teenygrad/ops.py @@ -1,4 +1,5 @@ from enum import Enum, auto +from typing import Optional class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 @@ -10,4 +11,5 @@ class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto( class Device: DEFAULT = "CPU" _buffers = ["CPU"] - def canonicalize(x): return "CPU" + @staticmethod + def canonicalize(device:Optional[str]) -> str: return "CPU" diff --git a/teenygrad/tensor.py b/teenygrad/tensor.py index 1d147ac..190e862 100644 --- a/teenygrad/tensor.py +++ b/teenygrad/tensor.py @@ -71,9 +71,9 @@ def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, by data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) else: data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) - else: raise RuntimeError(f"can't create Tensor from {data} with type {type(data)}") # data is a LazyBuffer, but it might be on the wrong device + if not isinstance(data, LazyBuffer): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") self.lazydata = data if data.device == device else data.copy_to_device(device) def __repr__(self): @@ -673,8 +673,8 @@ def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tens return (x, y) def _to_float(self, x:Union[Tensor, float]): - return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_const() and not x.requires_grad \ - and x.lazydata.st.contiguous and self._broadcasted(x)[0].shape == self.shape else x + return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \ + and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: x = self._to_float(x)