diff --git a/import_from_tinygrad.py b/import_from_tinygrad.py index 4f4fd0e..745402d 100755 --- a/import_from_tinygrad.py +++ b/import_from_tinygrad.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import pathlib -FILES = ["tensor.py", "mlops.py", "nn/optim.py", "../test/test_ops.py", "../test/test_dtype.py", "../test/test_optim.py"] +FILES = ["tensor.py", "mlops.py", "dtype.py", "nn/optim.py", "../test/test_ops.py", "../test/test_dtype.py", "../test/test_optim.py"] src = pathlib.Path("../tinygrad/tinygrad") dest = pathlib.Path("teenygrad") diff --git a/teenygrad/device.py b/teenygrad/device.py new file mode 100644 index 0000000..3a920e3 --- /dev/null +++ b/teenygrad/device.py @@ -0,0 +1,15 @@ +from typing import Optional, Any +from teenygrad.dtype import DType +import numpy as np + +class Device: + DEFAULT = "CPU" + _devices = ["CPU"] + @staticmethod + def canonicalize(device:Optional[str]) -> str: return "CPU" + +class Buffer: + def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options=None): + self.device, self.size, self.dtype, self._buf = device, size, dtype, opaque[1] if isinstance(opaque, tuple) else opaque + def copyin(self, buf): self._buf = np.frombuffer(buf, dtype=self.dtype.np) + def as_buffer(self): return self._buf.data diff --git a/teenygrad/dtype.py b/teenygrad/dtype.py new file mode 100644 index 0000000..d108e49 --- /dev/null +++ b/teenygrad/dtype.py @@ -0,0 +1,103 @@ +from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union +from dataclasses import dataclass +import numpy as np # TODO: remove numpy +import functools + +Scalar = Union[float, int, bool] + +@dataclass(frozen=True, order=True) +class DType: + priority: int # this determines when things get upcasted + itemsize: int + name: str + fmt: Optional[str] + count: int + def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}" + def vec(self, sz:int): + assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}" + return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz) + def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self + # TODO: someday this will be removed with the "remove numpy" project + @property + def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None + +# dependent typing? +@dataclass(frozen=True, repr=False) +class ImageDType(DType): + shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape + base: DType + def scalar(self): return self.base + def vec(self, sz:int): return self.base.vec(sz) + def __repr__(self): return f"dtypes.{self.name}({self.shape})" + +# @dataclass(frozen=True, init=False, repr=False, eq=False) +class PtrDType(DType): + def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count) + def __repr__(self): return f"ptr.{super().__repr__()}" + def __hash__(self): return super().__hash__() + def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count + def __ne__(self, dt): return not (self == dt) + +def cast_scalar(scalar: Scalar, dtype:DType): + return int(scalar) if dtypes.is_int(dtype) else float(scalar) if dtypes.is_float(dtype) else bool(scalar) + +class dtypes: + @staticmethod + def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64) + @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool + def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x) + @staticmethod + def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) + @staticmethod + def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name] + @staticmethod # NOTE: isinstance(True, int) is True in python + def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int + @staticmethod + def fields() -> Dict[str, DType]: return DTYPES_DICT + bool: Final[DType] = DType(0, 1, "bool", '?', 1) + int8: Final[DType] = DType(1, 1, "char", 'b', 1) + uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1) + int16: Final[DType] = DType(3, 2, "short", 'h', 1) + uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1) + int32: Final[DType] = DType(5, 4, "int", 'i', 1) + uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1) + int64: Final[DType] = DType(7, 8, "long", 'l', 1) + uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1) + float16: Final[DType] = DType(9, 2, "half", 'e', 1) + # bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16 + bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1) + float32: Final[DType] = DType(11, 4, "float", 'f', 1) + float64: Final[DType] = DType(12, 8, "double", 'd', 1) + + # dtype aliases + half = float16; float = float32; double = float64 # noqa: E702 + uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702 + char = int8; short = int16; int = int32; long = int64 # noqa: E702 + + # NOTE: these are image dtypes + @staticmethod + def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32) + @staticmethod + def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32) + + default_float: ClassVar[DType] = float32 + default_int: ClassVar[DType] = int32 + +# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html +# we don't support weak type and complex type +promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], + dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], + dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16], + dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], } + +@functools.lru_cache(None) +def _get_recursive_parents(dtype:DType) -> Set[DType]: + return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} +@functools.lru_cache(None) +def least_upper_dtype(*ds:DType) -> DType: + return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] +def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32) + +# HACK: staticmethods are not callable in 3.8 so we have to compare the class +DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)} +INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()} diff --git a/teenygrad/helpers.py b/teenygrad/helpers.py index 7ad89af..89068dc 100644 --- a/teenygrad/helpers.py +++ b/teenygrad/helpers.py @@ -1,61 +1,25 @@ -from typing import Union, Tuple, Iterator, Optional, Final, Any -import os, functools, platform -import numpy as np -from math import prod # noqa: F401 # pylint:disable=unused-import -from dataclasses import dataclass +from typing import Union, Tuple, Sequence, Any, Iterable, Dict, TypeVar +import os, functools, platform, operator +T = TypeVar("T") +U = TypeVar("U") OSX = platform.system() == "Darwin" +def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1) def dedup(x): return list(dict.fromkeys(x)) # retains list orderi def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x -def flatten(l:Iterator): return [item for sublist in l for item in sublist] +def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist] +def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])] def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python -def all_int(t: Tuple[Any, ...]) -> bool: return all(isinstance(s, int) for s in t) +def all_int(t: Sequence[Any]) -> bool: return all(isinstance(s, int) for s in t) def round_up(num, amt:int): return (num+amt-1)//amt * amt +def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]: + assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501 + return {k:v for d in ds for k,v in d.items()} +def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,)) @functools.lru_cache(maxsize=None) def getenv(key, default=0): return type(default)(os.getenv(key, default)) -DEBUG = getenv("DEBUG") +DEBUG, WINO, IMAGE = getenv("DEBUG"), getenv("WINO"), 0 CI = os.getenv("CI", "") != "" - -@dataclass(frozen=True, order=True) -class DType: - priority: int # this determines when things get upcasted - itemsize: int - name: str - np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project - sz: int = 1 - def __repr__(self): return f"dtypes.{self.name}" - -class dtypes: - @staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool - def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) - @staticmethod - def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64) - @staticmethod - def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) - @staticmethod - def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name] - bool: Final[DType] = DType(0, 1, "bool", np.bool_) - float16: Final[DType] = DType(9, 2, "half", np.float16) - half = float16 - float32: Final[DType] = DType(10, 4, "float", np.float32) - float = float32 - float64: Final[DType] = DType(11, 8, "double", np.float64) - double = float64 - int8: Final[DType] = DType(1, 1, "char", np.int8) - int16: Final[DType] = DType(3, 2, "short", np.int16) - int32: Final[DType] = DType(5, 4, "int", np.int32) - int64: Final[DType] = DType(7, 8, "long", np.int64) - uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8) - uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16) - uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32) - uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64) - - # NOTE: bfloat16 isn't supported in numpy - bfloat16: Final[DType] = DType(9, 2, "__bf16", None) - -DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod} - -PtrDType, ImageDType, IMAGE = None, None, 0 # junk to remove diff --git a/teenygrad/lazy.py b/teenygrad/lazy.py index 41ee96e..43fcdf8 100644 --- a/teenygrad/lazy.py +++ b/teenygrad/lazy.py @@ -1,5 +1,7 @@ from __future__ import annotations -from teenygrad.helpers import DType, dtypes, DEBUG +from teenygrad.helpers import DEBUG, prod +from teenygrad.dtype import DType, dtypes +from teenygrad.device import Buffer from teenygrad.ops import UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps import numpy as np @@ -9,17 +11,16 @@ def toCPU(self): return self.x class LazyBuffer: device = "CPU" - - def __init__(self, buf: np.ndarray): self._np = buf + def __init__(self, buf: np.ndarray): self.realized = Buffer("CPU", buf.size, dtypes.from_np(buf.dtype), buf) @property def base(self): return self @property - def dtype(self): return dtypes.from_np(self._np.dtype) + def dtype(self): return self.realized.dtype @property - def realized(self): return RawCPUBuffer(self._np) + def _np(self): return self.realized._buf @property - def shape(self): return self._np.shape + def shape(self): return self.realized._buf.shape def __repr__(self): return f"" def schedule(self, seen=None): return [] @@ -31,7 +32,9 @@ def fromCPU(x): return LazyBuffer(x) @staticmethod def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer: - if op == LoadOps.RAND: return LazyBuffer(np.random.default_rng(arg).random(size=shape, dtype=dtype.np)) + if op == LoadOps.CUSTOM: + arg(ret := Buffer(device, prod(shape), dtype)) + return ret._buf.reshape(shape) elif op == LoadOps.CONST: return LazyBuffer(np.full(shape, arg, dtype=dtype.np)) elif op == LoadOps.EMPTY: return LazyBuffer(np.empty(shape, dtype=dtype.np)) else: raise NotImplementedError(op) @@ -52,16 +55,16 @@ def e(self, op, *srcs:LazyBuffer): elif op == BinaryOps.SUB: ret = self._np - srcs[0]._np elif op == BinaryOps.MUL: ret = self._np * srcs[0]._np elif op == BinaryOps.DIV: ret = self._np / srcs[0]._np + elif op == BinaryOps.XOR: ret = self._np ^ srcs[0]._np elif op == BinaryOps.MAX: ret = np.maximum(self._np, srcs[0]._np) elif op == BinaryOps.CMPLT: ret = self._np < srcs[0]._np + elif op == BinaryOps.CMPEQ: ret = self._np == srcs[0]._np elif op == TernaryOps.WHERE: ret = np.where(self._np, srcs[0]._np, srcs[1]._np) else: raise NotImplementedError(op) return LazyBuffer(ret.astype(self.dtype.np if len(srcs) == 0 else max(self.dtype, *[x.dtype for x in srcs]).np, copy=False)) - def r(self, op, new_shape): - if DEBUG >= 1: print(op, self, new_shape) - assert len(self.shape) == len(new_shape), "reduce shapes must have same dimensions" - axis = tuple(i for i,(a,b) in enumerate(zip(self.shape, new_shape)) if a != b) + def r(self, op, axis): + if DEBUG >= 1: print(op, self, axis) if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(axis, dtype=self._np.dtype, keepdims=True)) elif op == ReduceOps.MAX: return LazyBuffer(self._np.max(axis, keepdims=True)) else: raise NotImplementedError(op) diff --git a/teenygrad/mlops.py b/teenygrad/mlops.py index 3a2f7b0..dc91c88 100644 --- a/teenygrad/mlops.py +++ b/teenygrad/mlops.py @@ -1,6 +1,7 @@ import math -from typing import Tuple, Optional, cast -from teenygrad.helpers import argsort, DType +from typing import Tuple, Optional +from teenygrad.helpers import argsort +from teenygrad.dtype import DType from teenygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps from teenygrad.tensor import Function from teenygrad.lazy import LazyBuffer @@ -19,26 +20,25 @@ def forward(self, x:LazyBuffer, dtype:DType, bitcast:bool=False) -> LazyBuffer: self.input_dtype, self.bitcast = x.dtype, bitcast return x.cast(dtype, bitcast) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.cast(self.input_dtype, self.bitcast) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.cast(self.input_dtype, self.bitcast) # ************* unary ops ************* class Zero(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: return x.const(0) - def backward(self, grad:LazyBuffer) -> LazyBuffer: return grad.const(0) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.const(0) class Neg(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: return x.e(UnaryOps.NEG) - def backward(self, grad:LazyBuffer) -> LazyBuffer: return grad.e(UnaryOps.NEG) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(UnaryOps.NEG) class Sin(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x return x.e(UnaryOps.SIN) - def backward(self, grad:LazyBuffer) -> LazyBuffer: - return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: + return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output) # NOTE: maximum(x, 0) behaves differently where x=0 class Relu(Function): @@ -47,23 +47,21 @@ def forward(self, x:LazyBuffer) -> LazyBuffer: return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output) + return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output) class Log(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.x = x return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2))) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.e(BinaryOps.DIV, self.x) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.DIV, self.x) class Exp(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2) return self.ret - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return self.ret.e(BinaryOps.MUL, grad_output) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output) class Sqrt(Function): def forward(self, x:LazyBuffer) -> LazyBuffer: @@ -87,20 +85,25 @@ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # ************* binary ops ************* class Less(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.CMPLT, y) + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPLT, y) + def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None + +class Eq(Function): + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.CMPEQ, y) + def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return None, None + +class Xor(Function): + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.XOR, y) class Add(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.ADD, y) + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.ADD, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output if self.needs_input_grad[0] else None, \ grad_output if self.needs_input_grad[1] else None class Sub(Function): - def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: - return x.e(BinaryOps.SUB, y) + def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: return x.e(BinaryOps.SUB, y) def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output if self.needs_input_grad[0] else None, \ @@ -122,39 +125,38 @@ def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \ - grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None + grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501 # ************* ternary ops ************* class Where(Function): def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer: self.x = x - return x.e(TernaryOps.WHERE, y, z) + return self.x.e(TernaryOps.WHERE, y, z) def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]: return None, \ - self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \ - self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None + self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \ + self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None # ************* reduce ops ************* class Sum(Function): - def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: + def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: self.input_shape = x.shape - return x.r(ReduceOps.SUM, new_shape) + return x.r(ReduceOps.SUM, axis) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.expand(self.input_shape) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.expand(self.input_shape) class Max(Function): - def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer: - self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape) + def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: + self.x, self.ret, self.axis = x, x.r(ReduceOps.MAX, axis), axis return self.ret def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # 1s in locations where the max was chosen (can be two locations) - max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape))) - div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape) + max_is_1s = self.x.e(BinaryOps.CMPEQ, self.ret.expand(self.x.shape)).cast(self.x.dtype) + div = max_is_1s.r(ReduceOps.SUM, self.axis).expand(self.x.shape) return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape)) # ************* movement ops ************* @@ -162,50 +164,42 @@ def backward(self, grad_output:LazyBuffer) -> LazyBuffer: # NOTE: this is sum in reverse class Expand(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: - self.input_shape = x.shape + self.expanded_axis = tuple(i for i, (si, so) in enumerate(zip(x.shape, shape)) if si != so) return x.expand(shape) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.r(ReduceOps.SUM, self.input_shape) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.r(ReduceOps.SUM, self.expanded_axis) class Reshape(Function): def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer: self.input_shape = x.shape return x.reshape(shape) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.reshape(self.input_shape) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.reshape(self.input_shape) class Permute(Function): def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer: self.input_order = order return x.permute(order) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.permute(argsort(self.input_order)) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.permute(argsort(self.input_order)) class Pad(Function): def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)]) return x.pad(arg) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.shrink(self.narg) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.shrink(self.narg) class Shrink(Function): def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)]) return x.shrink(arg) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward" - # need this cast because mypy cannot narrow the type even with assert - return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg)) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.pad(self.narg) class Flip(Function): def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer: self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))]) return x.stride(self.arg) - def backward(self, grad_output:LazyBuffer) -> LazyBuffer: - return grad_output.stride(self.arg) + def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.stride(self.arg) diff --git a/teenygrad/nn/optim.py b/teenygrad/nn/optim.py index b18448d..cd18b61 100644 --- a/teenygrad/nn/optim.py +++ b/teenygrad/nn/optim.py @@ -1,6 +1,6 @@ # sorted in order of increasing complexity from typing import List -from teenygrad.helpers import dedup +from teenygrad.helpers import dedup, getenv from teenygrad.tensor import Tensor class Optimizer: @@ -13,7 +13,7 @@ def __init__(self, params: List[Tensor], lr: float): assert len(self.params) != 0, "optimizer must have at least one param" self.device = self.params[0].device self.buffers: List[Tensor] = dedup([x for x in params if not x.requires_grad]) # buffers are still realized - self.lr = Tensor([lr], requires_grad=False, device=self.device).contiguous() + self.lr = lr if getenv("CONST_LR") else Tensor([lr], requires_grad=False, device=self.device).contiguous() def zero_grad(self): for param in self.params: param.grad = None @@ -22,6 +22,8 @@ def realize(self, extra=None): # NOTE: in extra is too late for most of the params due to issues with assign Tensor.corealize(extra + self.params + self.buffers if extra is not None else self.params + self.buffers) + def step(self) -> None: raise NotImplementedError + class SGD(Optimizer): def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, nesterov=False): super().__init__(params, lr) @@ -32,9 +34,12 @@ def __init__(self, params: List[Tensor], lr=0.001, momentum=0, weight_decay=0.0, def step(self) -> None: for i, t in enumerate(self.params): assert t.grad is not None - g = t.grad.realize() + self.wd * t.detach() + # this is needed since the grads can form a "diamond" + # TODO: fix this in lazy.py + t.grad.realize() + g = t.grad + self.wd * t.detach() if self.momentum: - self.b[i].assign(self.momentum * self.b[i] + g).realize() # NOTE: self.b[i] is zero on the first run, no if required + self.b[i].assign(self.momentum * self.b[i] + g) # NOTE: self.b[i] is zero on the first run, no if required g = (g + self.momentum * self.b[i]) if self.nesterov else self.b[i] t.assign(t.detach() - g * self.lr) self.realize(self.b) @@ -46,17 +51,16 @@ def Adam(params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-8): return LAM class LAMB(Optimizer): def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, wd=0.0, adam=False): super().__init__(params, lr) - self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], requires_grad=False).realize() + self.b1, self.b2, self.eps, self.wd, self.adam, self.t = b1, b2, eps, wd, adam, Tensor([0], device=self.device, requires_grad=False).realize() self.m = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] self.v = [Tensor.zeros(*t.shape, device=t.device, requires_grad=False) for t in self.params] def step(self) -> None: - self.t.assign(self.t + 1).realize() + self.t.assign(self.t + 1) for i, t in enumerate(self.params): assert t.grad is not None - g = t.grad.realize() - self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g).realize() - self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).realize() + self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad) + self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)) m_hat = self.m[i] / (1.0 - self.b1**self.t) v_hat = self.v[i] / (1.0 - self.b2**self.t) up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach() diff --git a/teenygrad/ops.py b/teenygrad/ops.py index e72c98a..149a8a7 100644 --- a/teenygrad/ops.py +++ b/teenygrad/ops.py @@ -1,15 +1,8 @@ from enum import Enum, auto -from typing import Optional -class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 -class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 +class UnaryOps(Enum): EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); NEG = auto() # noqa: E702 +class BinaryOps(Enum): + ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPEQ = auto(); XOR = auto() # noqa: E702 +class TernaryOps(Enum): WHERE = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 -class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 -class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702 -class LoadOps(Enum): EMPTY = auto(); RAND = auto(); CONST = auto(); FROM = auto(); CONTIGUOUS = auto(); CUSTOM = auto() # noqa: E702 - -class Device: - DEFAULT = "CPU" - _buffers = ["CPU"] - @staticmethod - def canonicalize(device:Optional[str]) -> str: return "CPU" +class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); WAIT = auto() # noqa: E702 \ No newline at end of file diff --git a/teenygrad/realize.py b/teenygrad/realize.py index d495deb..4d797c6 100644 --- a/teenygrad/realize.py +++ b/teenygrad/realize.py @@ -1 +1,2 @@ -def run_schedule(schedule, disable_logging=False): pass \ No newline at end of file +def run_schedule(schedule, disable_logging=False): pass +def create_schedule(outs, seen=None): return [] \ No newline at end of file diff --git a/teenygrad/tensor.py b/teenygrad/tensor.py index 190e862..0a216bd 100644 --- a/teenygrad/tensor.py +++ b/teenygrad/tensor.py @@ -1,20 +1,25 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import time, math -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Any, Iterable, Set +import time, math, itertools +from contextlib import ContextDecorator +from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args from collections import defaultdict from functools import partialmethod, reduce -from itertools import accumulate import numpy as np -from teenygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int, round_up +from teenygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype, cast_scalar +from teenygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten, flat_mv from teenygrad.lazy import LazyBuffer -from teenygrad.ops import Device, LoadOps +from teenygrad.features.multi import MultiLazyBuffer +from teenygrad.ops import LoadOps +from teenygrad.device import Device, Buffer from teenygrad.shape.symbolic import sint -from teenygrad.realize import run_schedule +from teenygrad.realize import run_schedule, create_schedule + +# **** start with two base classes, Tensor and Function **** class Function: - def __init__(self, device:str, *tensors:Tensor): + def __init__(self, device:Union[str, Tuple[str, ...]], *tensors:Tensor): self.device = device self.needs_input_grad = [t.requires_grad for t in tensors] self.requires_grad = True if any(self.needs_input_grad) else None if None in self.needs_input_grad else False @@ -26,28 +31,57 @@ def backward(self, *args, **kwargs): raise RuntimeError(f"backward not implement @classmethod def apply(fxn:Type[Function], *x:Tensor, **kwargs) -> Tensor: ctx = fxn(x[0].device, *x) - ret = Tensor(ctx.forward(*[t.lazydata for t in x], **kwargs), device=ctx.device, requires_grad=ctx.requires_grad) - if ctx.requires_grad and not Tensor.no_grad: ret._ctx = ctx # used by autograd engine + ret = Tensor.__new__(Tensor) + ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None + ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine return ret import teenygrad.mlops as mlops -# **** start with two base classes, Tensor and Function **** +def _loadop(op, shape:Tuple[sint,...], dtype:DType, device:Union[str, Tuple[str, ...]], arg=None, src:Optional[LazyBuffer]=None): + if isinstance(device, str): return LazyBuffer.loadop(op, shape, dtype, device, arg, src) + return MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype, d, arg, src) for d in device], None) + +def _fromcpu(x: np.ndarray) -> LazyBuffer: + ret = LazyBuffer.loadop(LoadOps.EMPTY, x.shape, dtypes.from_np(x.dtype), "EXT") + if x.size == 0: + ret.realized = Buffer("EXT", 0, dtypes.from_np(x.dtype), (memoryview(bytearray()), None)) + else: + ret.realized = Buffer("EXT", prod(x.shape), dtypes.from_np(x.dtype), (flat_mv(np.require(x, requirements='C').data), x)) + return ret + +def _get_winograd_matcols(mat, dims:int, shp:Tuple[sint, ...], device:Union[str, Tuple[str, ...]]) -> List[List[Tensor]]: + return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device) for m in mat], dim=dim) + for k in range(len(mat[0]))] for dim in range(dims)] + +# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308 +def _apply_winograd_matrix(mat, t:Tensor, dims:int): + # multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t + # due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic + t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims + # precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...) + matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device) + # multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t + return sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims)) class Tensor: __slots__ = "lazydata", "requires_grad", "grad", "_ctx" __deletable__ = ('_ctx',) training: ClassVar[bool] = False - class train: - def __init__(self, val=True): self.val = val - def __enter__(self): self.prev, Tensor.training = Tensor.training, self.val - def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): Tensor.training = self.prev + class train(ContextDecorator): + def __init__(self, mode:bool = True): self.mode = mode + def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode + def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev no_grad: ClassVar[bool] = False - default_type: ClassVar[DType] = dtypes.float32 - def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, bytes], device:Optional[str]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): + class inference_mode(ContextDecorator): + def __init__(self, mode:bool = True): self.mode = mode + def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode + def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev + def __init__(self, data:Union[None, Scalar, List, Tuple, LazyBuffer, np.ndarray, bytes, MultiLazyBuffer], + device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DType]=None, requires_grad:Optional[bool]=None): assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" - device = Device.canonicalize(device) + device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) # tensors have gradients, buffers do not self.grad: Optional[Tensor] = None @@ -58,23 +92,26 @@ def __init__(self, data:Union[None, int, float, list, LazyBuffer, np.ndarray, by # internal variables used for autograd graph construction self._ctx: Optional[Function] = None if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" - elif isinstance(data, (int, float)): - data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or Tensor.default_type, device, data) - elif data is None or data.__class__ is list: - assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" - data = LazyBuffer.fromCPU(np.array([] if data is None else data, dtype=(dtype or Tensor.default_type).np)) - elif isinstance(data, bytes): - data = LazyBuffer.fromCPU(np.frombuffer(data, np.uint8)) + elif isinstance(data, get_args(Scalar)): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) + elif isinstance(data, bytes): data = _fromcpu(np.frombuffer(data, np.uint8)) + elif data is None: data = _loadop(LoadOps.EMPTY, (0,), dtype or dtypes.default_float, device) + elif isinstance(data, list): + if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtype or dtypes.bool + elif d and all_int(d): dtype = dtype or dtypes.default_int + else: dtype = dtype or dtypes.default_float + # NOTE: cast at the end for the dtypes that do not have a numpy dtype + data = _fromcpu(np.array(data, dtype.np)).cast(dtype) elif isinstance(data, np.ndarray): - assert dtype is None or dtype.np is not None, f"{dtype} doesn't have a numpy dtype" - if data.shape == (): - data = LazyBuffer.loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) - else: - data = LazyBuffer.fromCPU(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) + if data.shape == (): data = _loadop(LoadOps.CONST, tuple(), dtype or dtypes.from_np(data.dtype), device, data.item()) + else: data = _fromcpu(data.astype(dtype.np) if dtype is not None and dtype.np is not None else data) # data is a LazyBuffer, but it might be on the wrong device - if not isinstance(data, LazyBuffer): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") - self.lazydata = data if data.device == device else data.copy_to_device(device) + if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") + if isinstance(device, tuple): + # TODO: what if it's a MultiLazyBuffer on other devices? + self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = MultiLazyBuffer.from_sharded(data, device, None) if isinstance(data, LazyBuffer) else data + else: + self.lazydata = data if data.device == device else data.copy_to_device(device) def __repr__(self): return f"" @@ -82,8 +119,10 @@ def __repr__(self): # Python has a non moving GC, so this should be okay def __hash__(self): return id(self) + def __bool__(self): raise TypeError("__bool__ on Tensor is not defined") + @property - def device(self) -> str: return self.lazydata.device + def device(self) -> Union[str, Tuple[str, ...]]: return self.lazydata.device @property def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape @@ -95,88 +134,120 @@ def dtype(self) -> DType: return self.lazydata.dtype @staticmethod def corealize(lst:Iterable[Tensor]): - seen:Set[LazyBuffer] = set() - sched = [] - for t in lst: sched += t.lazydata.schedule(seen) - run_schedule(sched) + run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst]))) def realize(self) -> Tensor: - run_schedule(self.lazydata.schedule()) + Tensor.corealize([self]) return self def assign(self, x) -> Tensor: - # TODO: this is a hack for writing to DISK - if self.device.startswith("DISK"): - if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype) - self.contiguous().realize().lazydata.realized._copyin(x.numpy()) + # TODO: this is a hack for writing to DISK. remove with working assign + if isinstance(self.device, str) and self.device.startswith("DISK"): + if x.__class__ is not Tensor: x = Tensor(x, device="EXT", dtype=self.dtype) + self.contiguous().realize().lazydata.base.realized.copyin(x.numpy().data) return self if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype) - assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}" + # NOTE: we allow cross device assign + assert self.shape == x.shape, f"assign shape mismatch {self.shape} != {x.shape}" + if isinstance(self.lazydata, MultiLazyBuffer): + assert self.lazydata.axis == x.lazydata.axis assert not x.requires_grad # self requires_grad is okay? if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") - if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized + if self.dtype == x.dtype and not getenv("DISALLOW_ASSIGN"): + if isinstance(self.lazydata, MultiLazyBuffer): + for d,s in zip(x.lazydata.lbs, self.lazydata.lbs): d.output_buffer = s.base.realized + else: + if self.lazydata.base.realized is not None: x.lazydata.output_buffer = self.lazydata.base.realized self.lazydata = x.lazydata return self - def detach(self) -> Tensor: return Tensor(self.lazydata, device=self.device, requires_grad=False) + + def _data(self) -> memoryview: + if 0 in self.shape: return memoryview(bytearray(0)) + t = self if isinstance(self.device, str) else self.to(self.device[0]) # deal with multitensor + return cast(Buffer, t.cast(t.dtype.scalar()).contiguous().realize().lazydata.base.realized).as_buffer() + + def data(self) -> memoryview: + assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}" + assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" + return self._data().cast(self.dtype.fmt, self.shape if len(self.shape) else (1,)) + def item(self) -> Scalar: + assert self.dtype.fmt is not None, f"no fmt dtype for {self.dtype}" + assert self.numel() == 1, "must have one element for item" + return self._data().cast(self.dtype.fmt)[0] def numpy(self) -> np.ndarray: - assert all_int(self.shape), f"no numpy if shape is symbolic, {self.shape=}" - assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}" - return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape) - def item(self) -> Union[float, int]: return self.numpy().item() + assert self.dtype.np is not None, f"no np dtype for {self.dtype}" + assert all_int(self.shape), f"no data if shape is symbolic, {self.shape=}" + return np.frombuffer(self._data(), dtype=self.dtype.np).reshape(self.shape) - def to(self, device:Optional[str]) -> Tensor: + def to(self, device:Optional[Union[str, Tuple[str, ...]]]) -> Tensor: + device = tuple(Device.canonicalize(x) for x in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device) if device is None or device == self.device: return self - ret = Tensor(self.lazydata, device) - if self.grad: ret.grad = self.grad.to(device) + if not isinstance(device, str): return self.shard(device) + ret = Tensor(self.lazydata, device, requires_grad=self.requires_grad) + if self.grad is not None: ret.grad = self.grad.to(device) + if hasattr(self, '_ctx'): ret._ctx = self._ctx return ret - def to_(self, device:Optional[str]): - if device is None or device == self.device: return - if self.grad: self.grad = self.grad.to_(device) - _ret = Tensor(self.lazydata, device) - self.lazydata = _ret.lazydata + def to_(self, device:Optional[Union[str, Tuple[str, ...]]]): + real = self.to(device) + # TODO: is this assign? + if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata + self.lazydata = real.lazydata + + def shard(self, devices:Tuple[str, ...], axis:Optional[int]=None) -> Tensor: + assert isinstance(self.lazydata, LazyBuffer), "can't shard a MultiLazyBuffer" + canonical_devices = tuple(Device.canonicalize(x) for x in devices) + if axis is not None and axis < 0: axis += len(self.shape) + return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, canonical_devices, axis), device=canonical_devices, requires_grad=self.requires_grad) + + def shard_(self, devices:Tuple[str, ...], axis:Optional[int]=None): + self.lazydata = self.shard(devices, axis).lazydata + return self # ***** creation llop entrypoint ***** @staticmethod - def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=None, **kwargs): - assert isinstance(sz, int), f"cannot create with symbolic size {sz}" - return Tensor(LazyBuffer.loadop(op, (sz,), Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs) + def _loadop(op, shape, device:Optional[Union[Tuple[str], str]]=None, dtype:Optional[DType]=None, arg=None, **kwargs): + if isinstance(device, tuple): + return Tensor(MultiLazyBuffer([LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(d), arg) \ + for d in device], None), device, dtype, **kwargs) + return Tensor(LazyBuffer.loadop(op, shape, dtype or dtypes.default_float, Device.canonicalize(device), arg), device, dtype, **kwargs) @staticmethod - def empty(*shape, **kwargs): - return Tensor._loadop(LoadOps.EMPTY, prod((shape:=argfix(*shape))), **kwargs).reshape(shape) + def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, argfix(*shape), **kwargs) _seed: int = int(time.time()) @staticmethod def manual_seed(seed=0): Tensor._seed = seed @staticmethod - def rand(*shape, **kwargs): - Tensor._seed += 1 - return Tensor._loadop(LoadOps.RAND, prod((shape:=argfix(*shape))), arg=Tensor._seed, **kwargs).reshape(shape) + def rand(*shape, **kwargs): return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, **kwargs) # ***** creation helper functions ***** @staticmethod - def full(shape:Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape) + def full(shape:Tuple[sint, ...], fill_value:Scalar, **kwargs): + return Tensor(fill_value, **kwargs).reshape((1, )*len(new_shape := argfix(shape))).expand(new_shape) @staticmethod - def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs) + def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0.0, **kwargs) @staticmethod - def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1, **kwargs) + def ones(*shape, **kwargs): return Tensor.full(argfix(*shape), 1.0, **kwargs) @staticmethod def arange(start, stop=None, step=1, **kwargs): if stop is None: stop, start = start, 0 - return Tensor.full((math.ceil((stop-start)/step),), step, **kwargs).cumsum() + (start - step) + dtype = kwargs.pop("dtype", dtypes.default_float if any(isinstance(x, float) for x in (start, stop, step)) else dtypes.default_int) + return (Tensor.full((math.ceil((stop-start)/step),), step, dtype=dtype, **kwargs).cumsum() + (start - step)).cast(dtype) @staticmethod - def eye(dim:int, **kwargs): return Tensor.full((dim,1),1,**kwargs).pad(((0,0),(0,dim))).reshape(dim*(dim+1)).shrink(((0,dim*dim),)).reshape(dim, dim) + def eye(dim:int, **kwargs): + return Tensor.ones((dim,1),**kwargs).pad((None,(0,dim))).flatten().shrink(((0,dim*dim),)).reshape(dim, dim) - def full_like(self, fill_value, **kwargs): return Tensor.full(self.shape, fill_value=fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) + def full_like(self, fill_value:Scalar, **kwargs): + return Tensor.full(self.shape, fill_value, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) def zeros_like(self, **kwargs): return self.full_like(0, **kwargs) def ones_like(self, **kwargs): return self.full_like(1, **kwargs) @@ -185,70 +256,69 @@ def ones_like(self, **kwargs): return self.full_like(1, **kwargs) @staticmethod def randn(*shape, dtype:Optional[DType]=None, **kwargs) -> Tensor: # https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - src = Tensor.rand(2, *shape, **kwargs) - return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(Tensor.default_type if dtype is None else dtype) + src = Tensor.rand((2, *argfix(*shape)), **kwargs) + return src[0].mul(2*math.pi).cos().mul((1 - src[1]).log().mul(-2).sqrt()).cast(dtype or dtypes.default_float) @staticmethod - def randint(*shape, low=0, high=10, **kwargs) -> Tensor: - return (Tensor.rand(*shape, **kwargs)*(high-low)+low).cast(dtypes.int32) + def randint(*shape, low=0, high=10, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=low, high=high, dtype=dtypes.int32, **kwargs) @staticmethod def normal(*shape, mean=0.0, std=1.0, **kwargs) -> Tensor: return (std * Tensor.randn(*shape, **kwargs)) + mean @staticmethod def uniform(*shape, low=0.0, high=1.0, **kwargs) -> Tensor: - dtype = kwargs.pop("dtype", Tensor.default_type) + dtype = kwargs.pop("dtype", dtypes.default_float) return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low @staticmethod - def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(shape)**-0.5) + def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul(prod(argfix(*shape))**-0.5) # https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform @staticmethod - def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5) + def glorot_uniform(*shape, **kwargs) -> Tensor: + return Tensor.uniform(*shape, low=-1.0, high=1.0, **kwargs).mul((6/(argfix(*shape)[0]+prod(argfix(*shape)[1:])))**0.5) # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_ @staticmethod def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor: - bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:])) + bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs) # https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_ @staticmethod def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor: - std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:])) + std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(argfix(*shape)[1:])) return Tensor.normal(*shape, mean=0.0, std=std, **kwargs) def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> Tensor: assert 1 <= self.ndim <= 2 and num_samples > 0, f"{self.ndim=} must be 1 or 2 dim, {num_samples=} must be positive" assert replacement or num_samples == 1, "no replacement only supports num_samples = 1" weight = self.unsqueeze(0) if self.ndim == 1 else self - cdf = (cw := weight.cumsum(1)) / cw[:, -1].unsqueeze(1) - unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1) + cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1) + unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device) indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0)) - return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.int32) + return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int) # ***** toposort and backward pass ***** def deepwalk(self): - def _deepwalk(node, visited, nodes): + def _deepwalk(node, visited): visited.add(node) if getattr(node, "_ctx", None): for i in node._ctx.parents: - if i not in visited: _deepwalk(i, visited, nodes) - nodes.append(node) - return nodes - return _deepwalk(self, set(), []) + if i not in visited: yield from _deepwalk(i, visited) + yield node + return list(_deepwalk(self, set())) def backward(self) -> Tensor: assert self.shape == tuple(), f"backward can only be called for scalar tensors, but it has shape {self.shape})" # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous # this is "implicit gradient creation" - self.grad = Tensor(1, device=self.device, requires_grad=False) + self.grad = Tensor(1.0, device=self.device, requires_grad=False) for t0 in reversed(self.deepwalk()): - assert (t0.grad is not None) + if t0.grad is None: raise RuntimeError("tensor has no grad") grads = t0._ctx.backward(t0.grad.lazydata) grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None for g in ([grads] if len(t0._ctx.parents) == 1 else grads)] @@ -263,204 +333,265 @@ def backward(self) -> Tensor: def reshape(self, shape, *args) -> Tensor: new_shape = argfix(shape, *args) - return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)])) - def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))])) + new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else (s if s is not None else self.shape[i]) for i,s in enumerate(new_shape)]) + return mlops.Reshape.apply(self, shape=new_shape) if new_shape != self.shape else self + def expand(self, shape, *args) -> Tensor: + new_shape = tuple([x if x != -1 and x is not None else s for s,x in zip(self.shape, argfix(shape, *args))]) + return mlops.Expand.apply(self, shape=new_shape) if new_shape != self.shape else self def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args)) def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)]) - def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) if any(x is not None and x != (0,s) for x,s in zip(arg, self.shape)) else self - def pad(self, arg:Tuple[Optional[Tuple[int, int]], ...], value:float=0.0) -> Tensor: + def shrink(self, arg:Tuple[Optional[Tuple[sint, sint]], ...]) -> Tensor: + if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self + return mlops.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) + def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor: if all(x is None or x == (0,0) for x in arg): return self ret = mlops.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg))) return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value) # ***** movement hlops ***** - # - Negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element - # - A slice i:j returns the elements with indices in [i, j) - # - If omitted, i and j will default to 0 and N, respectively, where N is the length of the sequence - # - Negative values for i and j are taken relative to the end of the sequence - # - Both i and j will be clamped to the range (-N, N], where N in the length of the sequence - # - Indexing with None on a given axis will add a new dimension of size one before that axis - # - Empty slices are not allowed (tensors with 0s in shape have to be supported first, for all backends). - # - For a slice [i:j:k] finding the correct indices is delegated to slice.indices(len). - # - Strides > 1 and < 0 are now allowed!: - # - This works by applying Shrink -> [[Flip -> ] Pad -> Reshape -> Shrink] -> Reshape (ops in brackets are optional) - # - Idea of stride < 0 support: - # - Do the slice first, flip the axes were slice.step is negative, do slice.step -> -slice.step. Go to steps below. - # - Idea of stride `s` > 1 support (Pad -> Reshape -> Shrink): - # - Instead of doing [::s] on axis [dim_sz], do [:, 0] on axes [dim_sz_padded // s, s]. - # - So pad dim_sz with as many zeros as needed (dim_sz -> dim_sz_padded) so that reshape to [dim_sz_padded // s, s] - # is possible. - # - Apply Shrink to do the slice [:, 0] on axes of shapes [dim_sz_padded // s, s]. - # - Fancy indexing and combined indexing is supported - # - Combined indexing works by letting regular slicing finish first -> computing the resulting dims w.r.t to Tensors passed in -> fancy indexing - # - Any Tensors passed in __getitem__ will perform (CMPEQ with arange -> MUL with self -> SUM_REDUCE) iteratively - # - The first iteration will expand the dim of self while consecutive iterations will reduce the dim - # - There's a special case where a permute is needed at the end: - # - if first Tensor passed in (expand dims) is not at dim 0 - # - and following Tensors does not follow consecutively to the end of fancy indexing's dims - def __getitem__(self, val) -> Tensor: # val: Union[int, slice, Tensor, None, Ellipsis, Tuple[Union[int, slice, Tensor, None, Ellipsis], ...]] - def normalize_int(e, i, dim_sz): - if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1 - raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}") - - orig_slices = list(val) if isinstance(val, tuple) else [val] - count = defaultdict(list) - for i,v in enumerate(orig_slices): count[type(v)].append(i) - - if (num_slices := len(count[int]) + len(count[slice]) + len(count[Tensor])) > len(self.shape): raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}") - if len(ellipsis_found := count[type(Ellipsis)]) > 1: raise IndexError("an index can only have a single ellipsis ('...')") - - ellipsis_idx = ellipsis_found[0] if ellipsis_found else len(orig_slices) - orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices) - - valid_slices = [v for v in orig_slices if v is not None] - valid_slices = [v if isinstance(v, slice) else slice(y_ := normalize_int(v, i, dim_sz), y_+1) if isinstance(v, int) else slice(None) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))] - - start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ()) - new_slice = tuple(((0, 0) if e < s else (s, e)) if st > 0 else ((0, 0) if e > s else (e+1, s+1)) for s, e, st in zip(start, stop, strides)) - sliced_tensor = self.shrink(new_slice).flip(axis=[i for i, s in enumerate(strides) if s < 0]) - new_shape = sliced_tensor.shape + # Supported Indexing Implementations: + # 1. Int indexing (no copy) + # - for all dims where there's int, shrink -> reshape + # - negative indices are taken relative to the end of the sequence, so X[-2] returns the 2nd-to-last element + # - X = Tensor.rand(4,5,9); X[2,-2] shrinks the Tensor to X.shrink(((2, 3), (3, 4), (0, 9))) -> X.shape=(1,1,9) + # - Then we reshape (collapse) the int dim away such that for X: (1,1,9) -> (9,) + # 2. Slice indexing (no copy) + # - for all dims where slice is start:end:stride, shrink -> Optional[flip] -> pad -> reshape -> shrink + # - first shrink the Tensor to X.shrink(((start, end),)) + # - then we apply stride through Optional[flip] -> pad -> reshape -> shrink + # - flip where dim value is negative + # - pad 0's on dims such that reshaping [dim_size_padded] -> [dim_size_padded // stride, stride] is possible + # - shrink [dim_size_padded // stride, stride] -> [dim_size_padded // stride, 1] + # - reshape [dim_size_padded // stride, 1] -> [dim_size_padded // stride] and now you have your stride + # 3. None indexing (no copy) + # - reshape (inject) a dim at the dim where there's None + # 4. Tensor indexing (copy) + # - use Tensor.arange == tensor_index to create a mask + # - apply mask to self by mask * self for dims where index is a tensor + # - (mask * self).sum(dim) to reduce to correct shape + # Tiny Things: + # 1. Supported indices: Union[int, slice, Tensor, None, List, Tuple, Ellipsis] + # - for any list, List[Union[List, Tuple, int]], must have homogeneous shape + # - for any tuple, Tuple[Union[List, Tuple, int]], must have homogeneous shape + # 2. Bool indexing is not supported + # 3. Out of bounds Tensor indexing results in 0 + # - e.g: Tensor([1, 2, 3])[Tensor([4, 3, 2])] -> [0, 0, 3] index 4 and 3 are OOB + def __getitem__(self, indices) -> Tensor: + # 1. indices normalization and validation + # treat internal tuples and lists as Tensors and standardize indices to list type + if isinstance(indices, list) and all_int(indices): indices = [Tensor(indices, self.device, requires_grad=False)] + elif isinstance(indices, (tuple, list)): + indices = [Tensor(list(i), self.device, requires_grad=False) if isinstance(i, (tuple, list)) else i for i in indices] + else: indices = [indices] + + # turn scalar Tensors into const val for int indexing if possible + indices = [self._to_const_val(i) if isinstance(i, Tensor) else i for i in indices] + # move Tensor indices to the same device as self + indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices] + + # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None) + ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis] + fill_idx = ellipsis_idx[0] if ellipsis_idx else len(indices) + num_indices = len(indices) - len(ellipsis_idx) - sum(1 for i in indices if i is None) + indices[fill_idx:fill_idx+1] = [slice(None)] * (len(self.shape) - num_indices) + + # use Dict[type, List[dimension]] to track elements in indices + type_dim: DefaultDict[Union[type, None], List[int]] = defaultdict(list) + + # record None for dimension injection later and filter None and record rest of indices + type_dim[None] = [dim for dim, i in enumerate(indices) if i is None] + indices_filtered = [v for v in indices if v is not None] + for dim,i in enumerate(indices_filtered): type_dim[type(i)].append(dim) + + for index_type in type_dim: + if index_type not in [None, int, slice, Tensor]: raise IndexError(f"{index_type=} not supported") + if len(ellipsis_idx) > 1: raise IndexError("indices can only have a single ellipsis ('...')") + if num_indices > self.ndim: raise IndexError(f"too many {num_indices=} for {self.ndim=}") + + # 2. basic indexing, uses only movement ops (no copy) + # currently indices_filtered: Tuple[Union[slice, int, Tensor], ...] + # turn indices in indices_filtered to Tuple[shrink_arg, strides] + for dim in type_dim[int]: + if (index := indices_filtered[dim]) >= (size := self.shape[dim]) or index < -size: + raise IndexError(f"{index=} is out of bounds on {dim=} with {size=}") + indices_filtered[dim] = ((index, index+1), 1) if index >= 0 else ((size+index, size+index+1), 1) + for dim in type_dim[slice]: + if (index := indices_filtered[dim]).step == 0: raise ValueError(f"{index=} on {dim=} cannot have 0 as step") + s, e, st = index.indices(self.shape[dim]) + indices_filtered[dim] = ((0, 0) if (st * (e - s)) < 0 else (s, e) if st > 0 else (e+1, s+1), st) + # record tensors and skip all Tensor dims for basic indexing + tensor_index: List[Tensor] = [] + for dim in type_dim[Tensor]: + tensor_index.append(index := indices_filtered[dim]) + if not dtypes.is_int(index.dtype): raise IndexError(f"{index.dtype=} on {dim=} is not supported, only int tensor indexing is supported") + indices_filtered[dim] = ((0, self.shape[dim]), 1) + + new_slice, strides = ((),()) if not indices_filtered else zip(*indices_filtered) + ret = self.shrink(new_slice).flip(tuple(i for i, s in enumerate(strides) if s < 0)) if any(abs(s) != 1 for s in strides): strides = tuple(abs(s) for s in strides) - # Pad: add pad at the end: [dim_sz] -> [dim_sz_padded] - padded_tensor = sliced_tensor.pad(tuple((0, s-(dim_sz % s) if dim_sz % s != 0 else 0) for s, dim_sz in zip(strides, sliced_tensor.shape))) - # Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s] - reshaped_tensor = padded_tensor.reshape(flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides))) - new_shape = reshaped_tensor.shape[::2] - # Shrink: do [:, 0] - sliced_tensor = reshaped_tensor.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in new_shape))) - - final_shape, it_shape, dim, tensors, dim_collapsed = [], iter(new_shape), [], [], 0 - for i,s in enumerate(orig_slices): - if s is None: final_shape.append(1) - else: # s is int or slice or Tensor - dim_shape = next(it_shape) - if isinstance(s, int): - dim_collapsed += 1 - else: - assert isinstance(dim_shape, int), f"does not support symbolic shape {dim_shape}" - final_shape.append(dim_shape) - if isinstance(s, Tensor): - tensors.append(s) - dim.append(i-dim_collapsed) - ret = sliced_tensor.reshape(tuple(final_shape)) - - if tensors: # Fancy/tensor indexing - # normalize idx - # TODO: first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm - idx = [t.sign().contiguous().__neg__().contiguous().relu() * ret.shape[d] + t for d,t in zip(dim, tensors)] - max_dim = max(i.ndim for i in idx) + ret = ret.pad(tuple((0, round_up(sh, s) - sh) for s, sh in zip(strides, ret.shape))) + ret = ret.reshape(tuple(flatten((sh // s, s) for s, sh in zip(strides, ret.shape)))) + ret = ret.shrink(tuple(flatten(((0, sh), (0, 1)) for sh in ret.shape[::2]))).reshape(ret.shape[::2]) + + # inject 1 for dim where it's None and collapse dim for int + new_shape = list(ret.shape) + for dim in type_dim[None]: new_shape.insert(dim, 1) + for dim in (dims_collapsed := tuple(dim + sum(1 for d in type_dim[None] if dim >= d) for dim in reversed(type_dim[int]))): new_shape.pop(dim) + + ret = ret.reshape(new_shape) + assert all_int(ret.shape), f"does not support symbolic shape {ret.shape}" + + # 3. advanced indexing (copy) + if type_dim[Tensor]: + # calculate dim of current ret by subtracting dims collapsed and adding dims injected up until tensor_dim + def calc_dim(tensor_dim:int) -> int: + return tensor_dim - sum(1 for d in dims_collapsed if tensor_dim >= d) + sum(1 for d in type_dim[None] if tensor_dim >= d) + + # track tensor_dim and tensor_index using a dict + # calc_dim to get dim and use that to normalize the negative tensor indices + idx: Dict[int,Tensor] = {(dim := calc_dim(td)):(tensor<0).where(ret.shape[dim],0) + tensor for td,tensor in zip(type_dim[Tensor],tensor_index)} + # compute sum_dim, arange, and idx - sum_dim = [d if n==0 else d+max_dim-n for n,d in enumerate(dim)] - arange = [Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(*[1]*sd, ret.shape[d], *[1]*(ret.ndim + max_dim - n - sd - 1)) for n,(sd,d) in enumerate(zip(sum_dim, dim))] - first_idx = [idx[0].reshape(*[1]*dim[0], *[1]*(1 + max_dim - idx[0].ndim), *idx[0].shape, *[1]*(ret.ndim - dim[0] - 1))] - rest_idx = [i.reshape(*[1]*dim[0], *[1]*(max_dim - i.ndim), *i.shape, *[1]*(ret.ndim - dim[0] - n)) for n,i in enumerate(idx[1:], 1)] - idx = first_idx + rest_idx - ret = ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*max_dim, *ret.shape[sum_dim[0]+1:]) - # iteratively fancy index - for a,i,sd in zip(arange, idx, sum_dim): ret = (a==i).mul(ret).sum(sd) + max_idx_dim, first_dim, last_dim = max(i.ndim for i in idx.values()), min(idx.keys()), max(idx.keys()) + sum_dim = tuple(d if n==0 else d+max_idx_dim-n for n,d in enumerate(idx.keys())) + arange = [Tensor.arange(ret.shape[d], requires_grad=False, device=self.device).reshape(ret.shape[d], *[1]*(ret.ndim+max_idx_dim-n-sd-1)) \ + for n,(sd,d) in enumerate(zip(sum_dim, idx.keys()))] + reshaped_idx = [i.reshape(i.shape + (1,)*(ret.ndim - first_dim - (n or 1))) for n,i in enumerate(idx.values())] + ret = ret.reshape(ret.shape[:first_dim+1] + (1,)*max_idx_dim + ret.shape[first_dim+1:]) + + # iteratively eq -> mul -> sum fancy index + try: + for a,i,sd in zip(arange, reshaped_idx, sum_dim): ret = (a==i).mul(ret).sum(sd) + except AssertionError as exc: raise IndexError("cannot broadcast indices") from exc + # special permute case - if dim[0] != 0 and len(dim) != 1 and dim != list(range(dim[0], dim[-1]+1)): + if first_dim != 0 and len(idx) != 1 and tuple(idx.keys()) != tuple(range(first_dim, last_dim+1)): ret_dims = list(range(ret.ndim)) - ret = ret.permute(ret_dims[dim[0]:dim[0]+max_dim] + ret_dims[:dim[0]] + ret_dims[dim[0]+max_dim:]) + ret = ret.permute(ret_dims[first_dim:first_dim+max_idx_dim] + ret_dims[:first_dim] + ret_dims[first_dim+max_idx_dim:]) return ret - def __setitem__(self,s,v): return self.__getitem__(s).assign(v) + def __setitem__(self,indices,v): return self.__getitem__(indices).assign(v) # NOTE: using slice is discouraged and things should migrate to pad and shrink def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor: - arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)]) - padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)]) - return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)])) + arg_ = tuple(a if a is not None else (0, s) for s,a in zip(self.shape, arg)) + padding = tuple((max(0, -l), max(0, r-s)) for s,(l,r) in zip(self.shape, arg_)) + return self.pad(padding, value=value).shrink(tuple((l + pl, r + pl) for (l,r),(pl,_) in zip(arg_, padding))) - def gather(self: Tensor, idx: Tensor, dim: int) -> Tensor: + def gather(self:Tensor, idx:Tensor, dim:int) -> Tensor: assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim" assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape" if dim < 0: dim += self.ndim - idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1) + idx = idx.to(self.device).transpose(ax1=dim, ax2=0).unsqueeze(-1) permarg = list(range(self.ndim)) permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]] - return ((idx == Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) + return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink( + tuple([*[(0,sh) for sh in idx.shape[1:-1]], (0,self.shape[dim])])).unsqueeze(0)).sum(-1).transpose(ax1=0, ax2=dim) - def cat(self, *args, dim=0) -> Tensor: - dim = (dim + len(self.shape)) if dim < 0 else dim + def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: + if dim < 0: dim += self.ndim assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args) catargs = [self, *args] - assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated" - shapes = [s.shape[dim] for s in catargs] - shape_cumsum = [0, *accumulate(shapes)] - slc = [[(0, 0) for _ in self.shape] for _ in catargs] - for shp,k,s in zip(shapes, shape_cumsum[:-1], slc): s[dim] = (k, shape_cumsum[-1] - k - shp) + cat_dims = [s.shape[dim] for s in catargs] + cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)] + slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs] + for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d) return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)]) @staticmethod - def stack(tensors, dim=0) -> Tensor: - first = tensors[0].unsqueeze(dim) - unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors[1:]] + def stack(tensors:Sequence[Tensor], dim:int=0) -> Tensor: + unsqueezed_tensors = [tensor.unsqueeze(dim) for tensor in tensors] # checks for shapes and number of dimensions delegated to cat - return first.cat(*unsqueezed_tensors, dim=dim) + return unsqueezed_tensors[0].cat(*unsqueezed_tensors[1:], dim=dim) - def repeat(self, repeats) -> Tensor: + def repeat(self, repeats:Sequence[int]) -> Tensor: base_shape = (1,) * (len(repeats) - self.ndim) + self.shape new_shape = [x for b in base_shape for x in [1, b]] expand_shape = [x for rs in zip(repeats, base_shape) for x in rs] final_shape = [r*s for r,s in zip(repeats, base_shape)] return self.reshape(new_shape).expand(expand_shape).reshape(final_shape) + def _resolve_dim(self, dim:int, *, outer:bool=False) -> int: + if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer): + raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}") + return dim + self.ndim+outer if dim < 0 else dim + + def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]: + assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + dim = self._resolve_dim(dim) + if isinstance(sizes, int): sizes = [min(sizes, self.shape[dim]-i) for i in range(0, max(1, self.shape[dim]), max(1, sizes))] + assert sum(sizes) == self.shape[dim], f"expect sizes to sum exactly to {self.shape[dim]}, but got {sum(sizes)}" + return tuple(self[sl] for sl in [tuple([slice(None)]*dim + [slice(sum(sizes[:i]), sum(sizes[:i + 1]))]) for i in range(len(sizes))]) + def chunk(self, num:int, dim:int=0) -> List[Tensor]: assert all_int(self.shape), f"does not support symbolic shape {self.shape}" - dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num) - slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)] - return [self[tuple(sl)] for sl in slice_params] - - def squeeze(self, dim=None) -> Tensor: - if dim is None: return self if 1 not in self.shape else self.reshape(*[size for size in self.shape if size != 1]) - if dim <= 0 and self.ndim == 0: return self # This is to match PyTorch behavior - if not -self.ndim <= dim < self.ndim: raise IndexError(f"Dimension out of range (expected to be in range of [{-self.ndim if self.ndim > 0 else self.ndim-1}, {self.ndim-1 if self.ndim > 0 else self.ndim}], but got {dim})") - if dim < 0: dim += self.ndim - return self if self.shape[dim] != 1 else self.reshape(*[size for idx, size in enumerate(self.shape) if idx != dim]) + dim = self._resolve_dim(dim) + assert num > 0, f"expect num to be greater than 0, got: {num}" + return list(self.split(math.ceil(self.shape[dim]/num) if self.shape[dim] else [0]*num, dim=dim)) - def unsqueeze(self, dim) -> Tensor: - if dim < 0: dim = len(self.shape) + dim + 1 + def squeeze(self, dim:Optional[int]=None) -> Tensor: + if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1)) + dim = self._resolve_dim(dim) + return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:]) + + def unsqueeze(self, dim:int) -> Tensor: + dim = self._resolve_dim(dim, outer=True) return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) # (padding_left, padding_right, padding_top, padding_bottom) - def pad2d(self, padding:Union[List[int], Tuple[int, ...]], value:float=0) -> Tensor: + def pad2d(self, padding:Sequence[int], value:float=0) -> Tensor: slc = [(-p0, s+p1) for p0,p1,s in zip(padding[::2], padding[1::2], self.shape[::-1])][::-1] return self.slice([(0,s) for s in self.shape[:-(len(padding)//2)]] + slc, value=value) @property def T(self) -> Tensor: return self.transpose() def transpose(self, ax1=1, ax2=0) -> Tensor: - order = list(range(len(self.shape))) + order = list(range(self.ndim)) order[ax1], order[ax2] = order[ax2], order[ax1] return self.permute(order) - def flatten(self, start_dim=0): return self.reshape(shape=self.shape[:start_dim] + (-1,)) + def flatten(self, start_dim=0, end_dim=-1): + start_dim, end_dim = start_dim + self.ndim if start_dim < 0 else start_dim, end_dim + self.ndim if end_dim < 0 else end_dim + return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:]) + def unflatten(self, dim:int, sizes:Tuple[int,...]): + if dim < 0: dim += self.ndim + return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:]) # ***** reduce ops ***** def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False) -> Tensor: - axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if isinstance(axis, int) else list(axis)) - axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_] + axis_: Tuple[int, ...] = tuple(range(len(self.shape))) if axis is None else ((axis,) if isinstance(axis, int) else tuple(axis)) + axis_ = tuple(x if x >= 0 else x+len(self.shape) for x in axis_) shape = tuple(s for i,s in enumerate(self.shape) if i not in axis_) - if 0 in self.shape and 0 not in shape: return Tensor.full(tuple(1 if s == 0 else s for s in self.shape) if keepdim else shape, {mlops.Sum: 0, mlops.Max: -float("inf")}[fxn]) - ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)])) + ret = fxn.apply(self, axis=axis_) return ret if keepdim else ret.reshape(shape=shape) - def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim) + def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None): + if acc_dtype is None: acc_dtype = least_upper_dtype(self.dtype, dtypes.uint) if dtypes.is_unsigned(self.dtype) else \ + least_upper_dtype(self.dtype, dtypes.int) if (dtypes.is_int(self.dtype) or self.dtype==dtypes.bool) else \ + least_upper_dtype(self.dtype, dtypes.float) + # cast back to float16 or bfloat16 to match torch / jax behavior, but we use float for acc + output_dtype = self.dtype if self.dtype in (dtypes.float16, dtypes.bfloat16) else acc_dtype + return self.cast(acc_dtype)._reduce(mlops.Sum, axis, keepdim).cast(output_dtype) + def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, keepdim) def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim)) def mean(self, axis=None, keepdim=False): assert all_int(self.shape), "does not support symbolic shape" out = self.sum(axis=axis, keepdim=keepdim) - return out.mul(prod(out.shape)/prod(self.shape)) if 0 not in self.shape else out - def std(self, axis=None, keepdim=False, correction=1): + return out.div(prod(self.shape) / prod(out.shape)) if 0 not in out.shape else out + def var(self, axis=None, keepdim=False, correction=1): assert all_int(self.shape), "does not support symbolic shape" square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim) - return square_sum.div(prod(self.shape)/prod(square_sum.shape)-correction).sqrt() + return square_sum.div(max(0, prod(self.shape)/prod(square_sum.shape)-correction)) + def std(self, axis=None, keepdim=False, correction=1): return self.var(axis, keepdim, correction).sqrt() + def _softmax(self, axis): + if len(self.shape) == 0: + assert axis in [-1, 0], f"{axis=} out of range of [-1, 0]" + axis = None m = self - self.max(axis=axis, keepdim=True) e = m.exp() return m, e, e.sum(axis=axis, keepdim=True) @@ -475,159 +606,192 @@ def log_softmax(self, axis=-1): def argmax(self, axis=None, keepdim=False): if axis is None: - idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape) - return prod(self.shape) - idx.max() - 1 + idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape) + return (prod(self.shape) - idx.max() - 1).cast(dtypes.default_int) axis = axis + len(self.shape) if axis < 0 else axis m = self == self.max(axis=axis, keepdim=True) - idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) - return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1 + idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) + return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.default_int) def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim) + @staticmethod + def einsum(formula:str, *raw_xs) -> Tensor: + xs:Tuple[Tensor] = argfix(*raw_xs) + formula = formula.replace(" ", "") + inputs_str, output = formula.split("->") if "->" in formula else (formula, sorted(formula)) + inputs = [x for x in cast(str,inputs_str).split(',')] + assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}" + + # map the value of each letter in the formula + letter_val = sorted(merge_dicts([{letter:dim for letter, dim in zip(letters, tensor.shape)} for letters, tensor in zip(inputs, xs)]).items()) + + xs_:List[Tensor] = [] + lhs = [sorted(enumerate(s), key=lambda e:e[1]) for s in inputs] + for x,(order,letters) in zip(xs, [list(zip(*l)) for l in lhs]): + # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters + xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val])) + + rhs_order, rhs_letters = tuple(zip(*sorted(enumerate(output), key=lambda e:e[1]))) or ([], []) + # sum over all axes that's not in the output, then permute to the output order + return reduce(lambda a,b:a*b, xs_).sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in rhs_letters]).permute(rhs_order) + # ***** processing ops ***** def _pool(self, k_:Tuple[sint, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor: assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" assert all_int(self.shape) and all_int(k_), f"does not support symbolic {self.shape=}, {k_=}" s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_)) - assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" - slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):] + assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" + noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):] if any(k > s for k,s in zip(k_, s_)) or any(d != 1 for d in d_): o_ = [(i - d * (k-1) - 1)//s + 1 for i,d,k,s in zip(i_, d_, k_, s_)] - e_ = [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)] # expands such that we don't need padding - xup = self.reshape(*prefix, *flatten((1,i) for i in i_)).expand(*prefix, *flatten((e,i) for e,i in zip(e_, i_))).reshape(*prefix, *[e*i for e,i in zip(e_, i_)]) - # slide by dilation - xup = xup.slice(slc_prefix + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]) - xup = xup.reshape(*prefix, *flatten((k,i+d) for k,i,d in zip(k_, i_, d_))) - xup = xup.slice(slc_prefix + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))) - # handle stride, and permute to move reduce to the end - xup = xup.reshape(*prefix, *flatten((k,o,s) for k,o,s in zip(k_, o_, s_))) - xup = xup.slice(slc_prefix + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))) - xup = xup.reshape(*prefix, *flatten((k,o) for k,o in zip(k_, o_))) - return xup.permute(*range(len(prefix)), *[len(prefix)+i*2+1 for i in range(len(k_))], *[len(prefix)+i*2 for i in range(len(k_))]) + # repeats such that we don't need padding + xup = self.repeat([1]*len(noop_) + [math.ceil(k*(i+d) / i) for k,i,d in zip(k_, i_, d_)]) + # slice by dilation + xup = xup.slice(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)]).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_))) + # handle stride + xup = xup.slice(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_))) + xup = xup.slice(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_))) + # permute to move reduce to the end + return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))]) # TODO: once the shapetracker can optimize well, remove this alternative implementation. or not if the CPU implementation doesn't use ShapeTracker o_ = [(i+(s-k))//s for i,s,k in zip(i_, s_, k_)] - xup = self.slice(slc_prefix + [(0,o*s) for o,s in zip(o_, s_)]) - xup = xup.reshape(*prefix, *flatten(((o, s) for o,s in zip(o_, s_)))) - xup = xup.slice(slc_prefix + flatten(((0,o), (0,k)) for o,k in zip(o_, k_))) - return xup.permute(*range(len(prefix)), *[len(prefix)+i*2 for i in range(len(k_))], *[len(prefix)+i*2+1 for i in range(len(k_))]) + xup = self.slice(noop_ + [(0,o*s) for o,s in zip(o_, s_)]) + xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_)))) + xup = xup.slice(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_))) + return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))]) # NOTE: these work for more than 2D - def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) - def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool(make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) + def avg_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool( + make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).mean(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) + def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1): return self._pool( + make_pair(kernel_size), stride if stride is not None else kernel_size, dilation).max(axis=tuple(range(0-len(make_pair(kernel_size)), 0))) def conv_transpose2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, output_padding=0) -> Tensor: HW, trailing = weight.shape[2:], list(range(3, len(weight.shape)+1)) - x, w = self, weight.reshape(groups, weight.shape[0]//groups, weight.shape[1], *weight.shape[2:]).permute(0,2,1,*trailing).flip(trailing) + x, w = self, weight.unflatten(0, (groups, -1)).permute(0,2,1,*trailing).flip(trailing) stride = make_pair(stride, len(HW)) if any(s>1 for s in stride): - x = x.reshape(*x.shape[:2], *flatten((k,1) for k in x.shape[2:])) - x = x.pad(((0,0), (0,0), *flatten(((0,0),(0,s-1)) for s in stride))) - x = x.reshape(*x.shape[:2], *[k*s for k,s in zip(x.shape[2::2], stride)]) - x = x.shrink(((0,x.shape[0]), (0,x.shape[1]), *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) - padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list(zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) - return x.conv2d(w.reshape(w.shape[0]*w.shape[1],*w.shape[2:]), groups=groups, bias=bias, dilation=dilation, padding=padding) - - wino = int(getenv("WINO", "0")) - def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0) -> Tensor: + x = x.reshape(None, None, *flatten((k,1) for k in x.shape[2:])) + x = x.pad((None, None, *flatten((None,(0,s-1)) for s in stride))) + x = x.reshape(None, None, *[k*s for k,s in zip(x.shape[2::2], stride)]) + x = x.shrink((None, None, *[(0,k-(s-1)) for k,s in zip(x.shape[2:], stride)])) + padding = flatten((((k-1)*d-p,(k-1)*d-p+op) for k,d,p,op in reversed(list( + zip(HW, make_pair(dilation, len(HW)), make_pair(padding, len(HW)), make_pair(output_padding, len(HW))))))) + return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) + + def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding=0, acc_dtype:Optional[DType]=None) -> Tensor: (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] - assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" - if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" - padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) + assert groups*cin == cin_ and len(self.shape) == len(weight.shape), f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" # noqa: E501 + if isinstance(padding, (tuple,list)): assert len(padding) == 2*len(HW) or len(padding) == len(HW), f"Expected padding of length {2*len(HW)} or {len(HW)}, but got {len(padding)} for tensor of shape {self.shape}" # noqa: E501 + padding_ = [padding]*2*len(HW) if isinstance(padding, int) else (padding if len(padding) == 2*len(HW) else [p for p in padding for _ in range(2)][::-1]) # noqa: E501 # conv2d is a pooling op (with padding) x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W) rcout, oyx = cout//groups, x.shape[2:-len(HW)] - if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not Tensor.wino: + if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO: # normal conv - x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) + x = x.reshape(bs, groups, cin, 1, *oyx, *HW).expand(bs, groups, cin, rcout, *oyx, *HW).permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) # noqa: E501 # conv! broadcasted to (bs, groups, rcout, *oyx, cin, *HW) - ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx) + ret = (x * weight.reshape(1, groups, rcout, *[1] * len(oyx), cin, *HW)).sum([-1-i for i in range(1+len(oyx))], keepdim=True, acc_dtype=acc_dtype).reshape(bs, cout, *oyx) # noqa: E501 return ret if bias is None else ret.add(bias.reshape(1, -1, *[1] * len(HW))) - # winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308 - def apply_matrix(mat, t, dim=0): return t if dim == len(HW) else Tensor.stack([apply_matrix(mat, sum(mm*t[j] for j,mm in enumerate(m) if mm), dim=dim+1) for m in mat]) HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles - winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]] winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]] - winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order almost doubles compilation time + winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]] + winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time # todo: stride == dilation # use padding to round up to 4x4 output tiles - d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # (bs, cin_, tyx, HWI) - d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))).contiguous_backward() # move HW to the front: # (HWI, bs, cin_, tyx) + # (bs, cin_, tyx, HWI) + d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501 + # move HW to the front: # (HWI, bs, cin_, tyx) + d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))) tyx = d.shape[-len(HWI):] # dim of tiling g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front # compute 6x6 winograd tiles: GgGt, BtdB - gfactors = apply_matrix(winograd_G, g).contiguous().reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1)) - dfactors = apply_matrix(winograd_Bt, d).contiguous().reshape(*HWI, bs, groups, 1, cin, *tyx) # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx) + # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1)) + gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) + # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx) + dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx) - ret = apply_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW))) # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) + # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) + ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), acc_dtype=acc_dtype), len(HW)) - ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) - ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final + # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) + ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) + # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final + ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink(tuple((0, s) for s in [bs, cout, *oyx])) return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() - def dot(self, w:Tensor) -> Tensor: + def dot(self, w:Tensor, acc_dtype:Optional[DType]=None) -> Tensor: n1, n2 = len(self.shape), len(w.shape) assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" - assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" + assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})" x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1]) w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2)) - return (x*w).sum(-1) + return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype)) + + def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DType]=None) -> Tensor: + return x.dot(self, acc_dtype=acc_dtype) if reverse else self.dot(x, acc_dtype=acc_dtype) - def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: return self.transpose(axis,-1).pad2d((self.shape[axis]-int(not _first_zero),0))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) + def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: + pl_sz = self.shape[axis] - int(not _first_zero and self.shape[axis] != 0) + return self.transpose(axis,-1).pad2d((pl_sz,0))._pool((self.shape[axis] or 1,)).sum(-1).transpose(axis,-1) def cumsum(self, axis:int=0) -> Tensor: # TODO: someday the optimizer will find this on it's own # for now this is a two stage cumsum SPLIT = 256 if self.shape[axis] <= SPLIT*2: return self._cumsum(axis) ret = self.transpose(axis,-1).pad2d((round_up(self.shape[axis], SPLIT)-self.shape[axis], 0)) - ret = ret.reshape(*ret.shape[0:-1], ret.shape[-1]//SPLIT, SPLIT)._cumsum(-1) + ret = ret.unflatten(-1, (-1, SPLIT))._cumsum(-1) base_add = ret[..., -1]._cumsum(-1, _first_zero=True)[..., :-1] base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1]) - def fix(x:Tensor): return x.reshape(*ret.shape[0:-2], ret.shape[-2] * ret.shape[-1])[..., -self.shape[axis]:].transpose(axis,-1) + def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -self.shape[axis]:].transpose(axis,-1) return fix(ret) + fix(base_add) @staticmethod - def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) - def triu(self, k:int=0) -> Tensor: - assert all_int(self.shape), f"does not support symbolic shape {self.shape}" - return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self)) - def tril(self, k:int=0) -> Tensor: - assert all_int(self.shape), f"does not support symbolic shape {self.shape}" - return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self) + def _tri(r:sint, c:sint, k:int=0, **kwargs) -> Tensor: + assert all_int((r,c)), "does not support symbolic" + if r == 0: return Tensor.zeros((r, c), **kwargs) + return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c) + def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, device=self.device).where(self, 0) + def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, device=self.device).where(0, self) # ***** mlops (unary) ***** - def neg(self): return mlops.Neg.apply(self) + def logical_not(self): return mlops.Eq.apply(*self._broadcasted(False)) + def neg(self): return mlops.Neg.apply(self) if self.dtype != dtypes.bool else self.logical_not() def contiguous(self): return mlops.Contiguous.apply(self) def contiguous_backward(self): return mlops.ContiguousBackward.apply(self) - def log(self): return mlops.Log.apply(self) - def log2(self): return mlops.Log.apply(self)/math.log(2) - def exp(self): return mlops.Exp.apply(self) + def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype))) + def log2(self): return self.log()/math.log(2) + def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype))) def exp2(self): return mlops.Exp.apply(self*math.log(2)) def relu(self): return mlops.Relu.apply(self) - def sigmoid(self): return mlops.Sigmoid.apply(self) - def sin(self): return mlops.Sin.apply(self) - def sqrt(self): return mlops.Sqrt.apply(self) - def rsqrt(self): return (1/self).sqrt() + def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype))) + def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype))) + def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype))) + def rsqrt(self): return self.reciprocal().sqrt() def cos(self): return ((math.pi/2)-self).sin() def tan(self): return self.sin() / self.cos() # ***** math functions (unary) ***** - def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype) + def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).cast(self.dtype) def ceil(self: Tensor) -> Tensor: return (self > (b := self.trunc())).where(b+1, b) def floor(self: Tensor) -> Tensor: return (self < (b := self.trunc())).where(b-1, b) + def round(self: Tensor) -> Tensor: + return ((self > 0) == ((b := self.cast(dtypes.int32) / 2.0).cast(dtypes.int32) == b)).where((self - 0.5).ceil(), (self + 0.5).floor()) def square(self): return self*self def clip(self, min_, max_): return self.maximum(min_).minimum(max_) def abs(self): return self.relu() + (-self).relu() - def sign(self): return self / (self.abs() + 1e-10) + def sign(self): return ((self.float()) / (self.float().abs() + 1e-12)).cast(self.dtype) def reciprocal(self): return 1.0/self # ***** activation functions (unary) ***** @@ -652,52 +816,59 @@ def mish(self): return self * self.softplus().tanh() def softplus(self, beta=1): return (1/beta) * (1 + (self*beta).exp()).log() def softsign(self): return self / (1 + self.abs()) - # ***** broadcasted binary mlops ***** + # ***** broadcasted elementwise mlops ***** - def _broadcasted(self, y:Union[Tensor, float], reverse:bool=False) -> Tuple[Tensor, Tensor]: + def _broadcasted(self, y:Union[Tensor, Scalar], reverse:bool=False, match_dtype:bool=True) -> Tuple[Tensor, Tensor]: x: Tensor = self if not isinstance(y, Tensor): - if 0 in x.shape: return x, x.full_like(y) - y = Tensor(y, device=self.device, requires_grad=False, dtype=self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32) + # make y a Tensor + assert isinstance(y, (float, int, bool)), f"{type(y)=}, {y=}" + if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype + else: y_dtype = dtypes.from_py(y) + y = Tensor(cast_scalar(y, y_dtype), self.device, y_dtype, requires_grad=False) + + if match_dtype: + output_dtype = least_upper_dtype(x.dtype, y.dtype) + x, y = x.cast(output_dtype), y.cast(output_dtype) + if reverse: x, y = y, x - if (xshape:=x.shape) == (yshape:=y.shape): return (x, y) - shape_delta = len(xshape) - len(yshape) - if shape_delta > 0: y = y.reshape((1,) * shape_delta + yshape) - elif shape_delta < 0: x = x.reshape((1,) * -shape_delta + xshape) - if (xshape:=x.shape) == (yshape:=y.shape): return (x, y) + # left pad shape with 1s + if len(y.shape) < len(x.shape): y = y.reshape((1,) * (len(x.shape) - len(y.shape)) + y.shape) + elif len(x.shape) < len(y.shape): x = x.reshape((1,) * (len(y.shape) - len(x.shape)) + x.shape) - shape_ret = tuple([max(x, y) for x, y in zip(xshape, yshape)]) - if xshape != shape_ret: x = x.expand(shape_ret) - if yshape != shape_ret: y = y.expand(shape_ret) - return (x, y) + broadcasted_shape = tuple(0 if xi==0 or yi==0 else max(xi, yi) for xi, yi in zip(x.shape, y.shape)) + return x.expand(broadcasted_shape), y.expand(broadcasted_shape) - def _to_float(self, x:Union[Tensor, float]): - return x.lazydata.base.op.arg if isinstance(x, Tensor) and x.lazydata.is_unrealized_contiguous_const() \ + def _to_const_val(self, x:Union[Tensor, Scalar]) -> Union[Tensor, Scalar]: + # TODO: update with multi + return x.lazydata.base.arg if isinstance(x, Tensor) and isinstance(x.lazydata, LazyBuffer) and x.lazydata.is_unrealized_contiguous_const() \ and not x.requires_grad and self._broadcasted(x)[0].shape == self.shape else x - def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: - x = self._to_float(x) - return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self - def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: - x = self._to_float(x) - return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else (-self if reverse else self) - def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: - x = self._to_float(x) - if x.__class__ is not Tensor and x == 0.0: return mlops.Zero.apply(self) - if x.__class__ is not Tensor and x == -1.0: return -self - return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self - def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: - x = self._to_float(x) - return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) - def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: - x = self._to_float(x) - if x.__class__ is not Tensor and not reverse: + def add(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: + x = self._to_const_val(x) + return mlops.Add.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else self + def sub(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: + x = self._to_const_val(x) + return mlops.Sub.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x else (-self if reverse else self) + def mul(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: + x = self._to_const_val(x) + if not isinstance(x, Tensor) and x == 0.0: return mlops.Zero.apply(self) + if not isinstance(x, Tensor) and x == -1.0: return -self + return mlops.Mul.apply(*self._broadcasted(x, reverse)) if isinstance(x, Tensor) or x != 1.0 else self + def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: + x = self._to_const_val(x) + if not isinstance(x, Tensor) and not reverse and x != 0: return self.mul(1/x) + if isinstance(x, Tensor) and dtypes.is_float(x.dtype): return mlops.Div.apply(*self._broadcasted(x, reverse)) + return mlops.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse)) + def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse)) + + def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: + x = self._to_const_val(x) + if not isinstance(x, Tensor) and not reverse: # simple pow identities if x < 0: return self.reciprocal().pow(-x) - if x == 3.0: return self*self*self - if x == 2.0: return self*self - if x == 1.0: return self + if x in [3,2,1,0]: return reduce(lambda acc,_: acc * self, range(int(x)), mlops.Zero.apply(self)+1) if x == 0.5: return self.sqrt() if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(math.log(x)).exp() ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(math.log(abs(x))).exp() @@ -708,18 +879,21 @@ def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: # we need 0 to be positive so we need to correct base_sign when the base is 0 base_sign = base_sign - (1.5 * (1 - (self.sign().abs() if not reverse else x.sign().abs() if isinstance(x, Tensor) else abs(int(bool(x)))))) # inject nan if the base is negative and the power is not an integer - to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign + to_nan = (((x - x.trunc()) * 1e10).abs().clip(0, 1) if isinstance(x, Tensor) else \ + int(bool(x - int(x))) if not reverse else ((self - self.trunc()) * 1e10).abs().clip(0, 1)) * base_sign inject_nan = ((((-to_nan) * 2) + 1)).log().add(1) if isinstance(to_nan, Tensor) else 1 if not to_nan else float("nan") return ar.mul(sign * base_sign + (1 - base_sign)).mul(inject_nan) - def matmul(self, x:Tensor, reverse=False) -> Tensor: return x.dot(self) if reverse else self.dot(x) - def maximum(self, x:Union[Tensor, float]) -> Tensor: return (selfx).detach().where(self, (self+x)/2)) - def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x)) + def maximum(self, x:Union[Tensor, Scalar]) -> Tensor: + return (self Tensor: return -((-self).maximum(-x)) - def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]): - x_,y = self._broadcasted(input_) - x,z = x_._broadcasted(other) - return mlops.Where.apply(x, *y._broadcasted(z)) + def where(self:Tensor, input_:Union[Tensor, Scalar], other:Union[Tensor, Scalar]): + if isinstance(input_, Tensor): input_, other = input_._broadcasted(other) + elif isinstance(other, Tensor): other, input_ = other._broadcasted(input_) + x_,y = self._broadcasted(input_, match_dtype=False) + x,z = x_._broadcasted(other, match_dtype=False) + return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z)) # ***** op wrappers (wasted lines to make the typechecker happy) ***** @@ -731,6 +905,7 @@ def __mul__(self, x) -> Tensor: return self.mul(x) def __pow__(self, x) -> Tensor: return self.pow(x) def __truediv__(self, x) -> Tensor: return self.div(x) def __matmul__(self, x) -> Tensor: return self.matmul(x) + def __xor__(self, x) -> Tensor: return self.xor(x) def __radd__(self, x) -> Tensor: return self.add(x, True) def __rsub__(self, x) -> Tensor: return self.sub(x, True) @@ -738,6 +913,7 @@ def __rmul__(self, x) -> Tensor: return self.mul(x, True) def __rpow__(self, x) -> Tensor: return self.pow(x, True) def __rtruediv__(self, x) -> Tensor: return self.div(x, True) def __rmatmul__(self, x) -> Tensor: return self.matmul(x, True) + def __rxor__(self, x) -> Tensor: return self.xor(x, True) def __iadd__(self, x) -> Tensor: return self.assign(self.add(x)) def __isub__(self, x) -> Tensor: return self.assign(self.sub(x)) @@ -745,13 +921,14 @@ def __imul__(self, x) -> Tensor: return self.assign(self.mul(x)) def __ipow__(self, x) -> Tensor: return self.assign(self.pow(x)) def __itruediv__(self, x) -> Tensor: return self.assign(self.div(x)) def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x)) + def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x)) def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)) def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)) - def __ge__(self, x) -> Tensor: return 1.0-(self Tensor: return 1.0-(self>x) - def __ne__(self, x) -> Tensor: return (selfx) # type: ignore - def __eq__(self, x) -> Tensor: return 1.0-(self != x) # type: ignore + def __ge__(self, x) -> Tensor: return (self Tensor: return (self>x).logical_not() + def __eq__(self, x) -> Tensor: return mlops.Eq.apply(*self._broadcasted(x, True)) # type: ignore[override] + def __ne__(self, x) -> Tensor: return (self==x).logical_not() # type: ignore[override] # ***** functional nn ops ***** @@ -765,40 +942,52 @@ def layernorm(self, axis=-1, eps:float=1e-5) -> Tensor: y = (self - self.mean(axis, keepdim=True)) return y.mul((y*y).mean(axis, keepdim=True).add(eps).rsqrt()) - def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor) -> Tensor: - x = (self - mean.reshape(shape=[1, -1, 1, 1])) - if weight: x = x * weight.reshape(shape=[1, -1, 1, 1]) - ret = x.mul(invstd.reshape(shape=[1, -1, 1, 1]) if len(invstd.shape) == 1 else invstd) - return (ret + bias.reshape(shape=[1, -1, 1, 1])) if bias else ret + def batchnorm(self, weight:Optional[Tensor], bias:Optional[Tensor], mean:Tensor, invstd:Tensor, axis:Union[int,Tuple[int,...]]=1) -> Tensor: + axis_ = argfix(axis) + shape = tuple(s if ax in axis_ else 1 for ax, s in enumerate(self.shape)) + x = self - mean.reshape(shape) + if weight is not None: x = x * weight.reshape(shape) + ret = x.mul(invstd.reshape(shape) if len(invstd.shape) == len(axis_) else invstd) + return (ret + bias.reshape(shape)) if bias is not None else ret def dropout(self, p=0.5) -> Tensor: if not Tensor.training or p == 0: return self - mask = (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p).cast(dtypes.bool) - return self * mask * (1/(1.0 - p)) + return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p)) - def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: + def one_hot(self, num_classes:int) -> Tensor: + return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0) + + def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, + dropout_p:float=0.0, is_causal:bool=False) -> Tensor: # NOTE: it works if key, value have symbolic shape assert all_int(self.shape), f"does not support symbolic shape {self.shape}" if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool) - if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask) - return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value + if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), 0) + qk = self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + return ((qk+attn_mask) if attn_mask is not None else qk).softmax(-1).dropout(dropout_p) @ value def binary_crossentropy(self, y:Tensor) -> Tensor: return (-y*self.log() - (1-y)*(1-self).log()).mean() def binary_crossentropy_logits(self, y:Tensor) -> Tensor: - return (self.maximum(0) - y * self + (1 + self.abs().__neg__().exp()).log()).mean() + return (self.maximum(0) - y * self + (1 + self.abs().neg().exp()).log()).mean() - def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor: + def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1, label_smoothing=0.0) -> Tensor: + assert 0.0 <= label_smoothing <= 1.0, "label_smoothing must be in [0.0, 1.0]" # NOTE: self is a logits input - loss_mask = Y != ignore_index - y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) - y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) - return self.log_softmax().mul(y).sum() / loss_mask.sum() + log_probs, loss_mask = self.log_softmax(), (Y != ignore_index) + y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) + y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) + smoothing = -1 * label_smoothing * (log_probs.mean(-1) * loss_mask).sum() / loss_mask.sum() + return (1 - label_smoothing) * (log_probs * y).sum() / loss_mask.sum() + smoothing # ***** cast ops ***** - def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self + def cast(self, dtype:DType) -> Tensor: + if self.dtype == dtype: return self + # hack for devices that don't support bfloat16 + if self.dtype == dtypes.bfloat16: return self.bitcast(dtypes.uint16).cast(dtypes.uint32).mul(1<<16).bitcast(dtypes.float32).cast(dtype) + return mlops.Cast.apply(self, dtype=dtype) def bitcast(self, dtype:DType) -> Tensor: assert self.dtype.itemsize == dtype.itemsize, "can't bitcast mismatched dtype itemsizes" return mlops.Cast.apply(self, dtype=dtype, bitcast=True) if self.dtype != dtype else self @@ -815,10 +1004,18 @@ def nbytes(self) -> int: return self.numel() * self.element_size() def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype) # register functions to move between devices -for device in Device._buffers: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device)) +for device in Device._devices: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device)) if IMAGE: # if IMAGE>0 we install these replacement functions in Tensor (hack!) from teenygrad.features.image import image_conv2d, image_dot setattr(Tensor, "conv2d", image_conv2d) setattr(Tensor, "dot", image_dot) + +# TODO: remove the custom op and replace with threefry +def custom_random(out:Buffer): + Tensor._seed += 1 + if DEBUG >= 2: print(f"*** {out.device} rand seed {Tensor._seed} size {out.size:<15d} dtype {out.dtype}") + rng = np.random.default_rng(Tensor._seed) + rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=out.dtype.np, copy=False) + out.copyin(rng_np_buffer.data) diff --git a/test/test_dtype.py b/test/test_dtype.py index a79d31a..70148a2 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,31 +1,38 @@ -import unittest +import unittest, operator, sys import numpy as np -from teenygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX -from teenygrad.ops import Device -from teenygrad.tensor import Tensor, dtypes +import torch from typing import Any, List +from teenygrad.helpers import CI, getenv, DEBUG, OSX, temp +from teenygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype +from tinygrad import Device, Tensor, dtypes +from hypothesis import given, settings, strategies as strat -def is_dtype_supported(dtype: DType): - # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) - # for LLVM, it segfaults because it can't link to the casting function - if dtype == dtypes.half: return not (CI and Device.DEFAULT in ["GPU", "LLVM"]) and Device.DEFAULT != "WEBGPU" and getenv("CUDACPU") != 1 +settings.register_profile("my_profile", max_examples=200, deadline=None) +settings.load_profile("my_profile") + +core_dtypes = list(DTYPES_DICT.values()) +floats = [dt for dt in core_dtypes if dtypes.is_float(dt)] +def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): if dtype == dtypes.bfloat16: return False # numpy doesn't support bf16, tested separately in TestBFloat16DType - if dtype == dtypes.float64: return Device.DEFAULT not in ["WEBGPU", "METAL"] and (not OSX and Device.DEFAULT == "GPU") - if dtype in [dtypes.int8, dtypes.uint8]: return Device.DEFAULT not in ["WEBGPU"] - if dtype in [dtypes.int16, dtypes.uint16]: return Device.DEFAULT not in ["WEBGPU", "TORCH"] - if dtype == dtypes.uint32: return Device.DEFAULT not in ["TORCH"] - if dtype in [dtypes.int64, dtypes.uint64]: return Device.DEFAULT not in ["WEBGPU", "TORCH"] - if dtype == dtypes.bool: - # host-shareablity is a requirement for storage buffers, but 'bool' type is not host-shareable - if Device.DEFAULT == "WEBGPU": return False + if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32] + # for CI GPU, cl_khr_fp16 isn't supported + # for CI LLVM, it segfaults because it can't link to the casting function + # CUDA in CI uses CUDACPU that does not support half + # PYTHON supports half memoryview in 3.12+ https://github.com/python/cpython/issues/90751 + if dtype == dtypes.half: + if device in ["GPU", "LLVM", "CUDA"]: return not CI + if device == "PYTHON": return sys.version_info >= (3, 12) + if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU") return True -def get_available_cast_dtypes(dtype: DType) -> List[DType]: return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes +def get_available_cast_dtypes(dtype: DType) -> List[DType]: + if not is_dtype_supported(dtype): return [] + return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] # dont cast internal dtypes def _test_to_np(a:Tensor, np_dtype, target): if DEBUG >= 2: print(a) na = a.numpy() - if DEBUG >= 2: print(na, na.dtype, a.lazydata.realized) + if DEBUG >= 2: print(na, na.dtype, a.lazydata.base.realized) try: assert na.dtype == np_dtype np.testing.assert_allclose(na, target) @@ -36,21 +43,26 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target): if DEBUG >= 2: print(tensor.numpy()) try: assert tensor.dtype == target_dtype - np.testing.assert_allclose(tensor.numpy(), target) + np.testing.assert_allclose(tensor.numpy(), target, rtol=1e-3 if target_dtype == dtypes.float16 else 1e-7) except AssertionError as e: raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e -def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target) -def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist()) -def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target) +def _test_op(fxn, target_dtype:DType, target): + _assert_eq(fxn(), target_dtype, target) +def _test_cast(a:Tensor, target_dtype:DType): + _test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np))) +def _test_bitcast(a:Tensor, target_dtype:DType, target=None): + _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist()) class TestDType(unittest.TestCase): DTYPE: Any = None DATA: Any = None @classmethod def setUpClass(cls): - if not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported") - cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() if dtypes.is_int(cls.DTYPE) else np.random.choice([True, False], size=10).tolist() if cls.DTYPE == dtypes.bool else np.random.uniform(0, 1, size=10).tolist() + if not cls.DTYPE or not is_dtype_supported(cls.DTYPE): raise unittest.SkipTest("dtype not supported") + if dtypes.is_int(cls.DTYPE): cls.DATA = np.random.randint(0, 100, size=10, dtype=cls.DTYPE.np).tolist() + elif cls.DTYPE == dtypes.bool: cls.DATA = np.random.choice([True, False], size=10).tolist() + else: cls.DATA = np.random.uniform(0, 1, size=10).tolist() def setUp(self): if self.DTYPE is None: raise unittest.SkipTest("base class") @@ -66,42 +78,63 @@ def test_casts_from(self): list(map( )) def test_same_size_ops(self): - def get_target_dtype(dtype): - if any([dtypes.is_float(dtype), dtypes.is_float(self.DTYPE)]): return max([dtype, self.DTYPE], key=lambda x: x.priority) - return dtype if dtypes.is_unsigned(dtype) else self.DTYPE list(map( - lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype, target_dtype=get_target_dtype(dtype)) if dtype.itemsize == self.DTYPE.itemsize else None, + lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None, get_available_cast_dtypes(self.DTYPE) )) - def test_upcast_ops(self): list(map( - lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None, - get_available_cast_dtypes(self.DTYPE) + def test_upcast_ops(self): + list(map( + lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None, + get_available_cast_dtypes(self.DTYPE) )) def test_upcast_to_ops(self): list(map( - lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None, - get_available_cast_dtypes(self.DTYPE) + lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None, + get_available_cast_dtypes(self.DTYPE) )) + def test_bitcast(self): + if Device.DEFAULT == "WEBGL": raise unittest.SkipTest("no bitcast in WebGL GLSL") + if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast") + list(map( + lambda dtype: + _test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None, + get_available_cast_dtypes(self.DTYPE) + )) + + def test_dtypes_fields(self): + fields = dtypes.fields() + self.assertTrue(all(isinstance(value, DType) for value in fields.values())) + self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None)) + + def test_resulting_and_init_dtypes_match(self): + dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"])) + data = [1., 2., 0., 0.5, -1.5, 5.25] + for dt in dtypes: + arr = np.asarray(data, dtype=dt) + tin = Tensor(arr).numpy() + tor = torch.as_tensor(arr).detach().numpy() + assert dt is tin.dtype is tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}" + np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3) def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): - if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): return + target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype) + if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return if a_dtype == dtypes.bool or b_dtype == dtypes.bool: return - target_dtype = target_dtype or (max([a_dtype, b_dtype], key=lambda x: x.priority) if a_dtype.priority != b_dtype.priority else max([a_dtype, b_dtype], key=lambda x: x.itemsize)) _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8]) + _assert_eq((Tensor([1], dtype=a_dtype).cast(b_dtype)+Tensor([1], dtype=a_dtype).cast(b_dtype)).cast(a_dtype), a_dtype, [2]) _assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16]) _assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]]) _assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy()) +@unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP"], "bfloat16 not supported") class TestBFloat16DType(unittest.TestCase): - def setUp(self): - if not is_dtype_supported(dtypes.bfloat16): raise unittest.SkipTest("bfloat16 not supported") def test_bf16_to_float(self): with self.assertRaises(AssertionError): - _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000]) + _test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32) def test_float_to_bf16(self): with self.assertRaises(AssertionError): - _test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000]) + _test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16) # torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16) @@ -111,13 +144,13 @@ def test_bf16(self): back = t.cast(dtypes.float32) assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) + @unittest.skipIf(getenv("HIPCPU"), "no real HIP device exists in CI") def test_bf16_disk_write_read(self): - from extra.utils import temp t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32) t.to(f"disk:{temp('f32')}").realize() # hack to "cast" f32 -> bf16 - dat = open(temp('f32'), "rb").read() + with open(temp('f32'), "rb") as f: dat = f.read() adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)]) with open(temp('bf16'), "wb") as f: f.write(adat) @@ -125,43 +158,107 @@ def test_bf16_disk_write_read(self): back = t.cast(dtypes.float32) assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) + +@unittest.skipUnless(Device.DEFAULT in ["HIP"], "bfloat16 not supported") +class TestBFloat16DTypeCast(unittest.TestCase): + def test_f16_to_bf16_conversion(self): + original_tensor = Tensor([1.0, 2.0, 3.0], dtype=dtypes.float16) + converted_tensor = original_tensor.cast(dtypes.bfloat16) + self.assertEqual(converted_tensor.dtype, dtypes.bfloat16) + back_to_float32 = converted_tensor.cast(dtypes.float32) + original_to_float32 = original_tensor.cast(dtypes.float32) + np.testing.assert_allclose(back_to_float32.numpy(), original_to_float32.numpy(), rtol=1e-2, atol=1e-3) + + def test_f16_to_bf16_edge_cases(self): + edge_cases = Tensor([0.0, -0.0, float('inf'), float('-inf'), float('nan')], dtype=dtypes.float16) + converted = edge_cases.cast(dtypes.bfloat16).cast(dtypes.float32) + np.testing.assert_equal(converted.numpy(), edge_cases.cast(dtypes.float32).numpy()) + + def test_f16_to_bf16_range_precision(self): + large_value = Tensor([65504.0], dtype=dtypes.float16) # Max representable in float16 + small_value = Tensor([6.1035e-5], dtype=dtypes.float16) # Smallest positive normal float16 + large_converted = large_value.cast(dtypes.bfloat16).cast(dtypes.float32) + small_converted = small_value.cast(dtypes.bfloat16).cast(dtypes.float32) + np.testing.assert_allclose(large_converted.numpy(), large_value.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3) + np.testing.assert_equal(small_converted.numpy(), small_value.cast(dtypes.float32).numpy()) + + def test_f16_to_bf16_randomized(self): + np.random.seed(42) # For reproducibility + random_values = Tensor(np.random.uniform(-65504, 65504, 1000), dtype=dtypes.float16) + converted = random_values.cast(dtypes.bfloat16).cast(dtypes.float32) + np.testing.assert_allclose(converted.numpy(), random_values.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3) + class TestHalfDtype(TestDType): DTYPE = dtypes.half -class TestFloatDType(TestDType): DTYPE = dtypes.float +class TestFloatDType(TestDType): + DTYPE = dtypes.float + + def test_float_to_uint(self): + _test_op(lambda: Tensor([-0.9, -0.3, 1.2], dtype=dtypes.float32).cast(dtypes.uint32), dtypes.uint32, + [0, 0, 1]) + +class TestDoubleDtype(TestDType): + DTYPE = dtypes.double + @unittest.skipIf(getenv("CUDACPU"), "conversion not supported on CUDACPU") + def test_float64_increased_precision(self): + for func in [ + lambda t: t.exp(), + lambda t: t.exp2(), + lambda t: t.log(), + lambda t: t.log2(), + lambda t: t.sqrt(), + lambda t: t.rsqrt(), + lambda t: t.sin(), + lambda t: t.cos(), + lambda t: t.tan(), + lambda t: t.sigmoid(), + ]: + a = [2, 3, 4] + np.testing.assert_allclose(func(Tensor(a, dtype=self.DTYPE)).numpy(), func(torch.tensor(a, dtype=torch.float64)), rtol=1e-12, atol=1e-12) + + def test_float64_to_float32_cast_inf(self): + _test_op(lambda: Tensor([3.4e40, 3.4e38, 1, 0], dtype=dtypes.float64).cast(dtypes.float32), + dtypes.float32, [float('inf'), 3.4e38, 1, 0]) -class TestDoubleDtype(TestDType): DTYPE = dtypes.double class TestInt8Dtype(TestDType): DTYPE = dtypes.int8 @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") - def test_int8_to_uint8_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252]) + def test_int8_to_uint8_negative(self): + _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint8), dtypes.uint8, [255, 254, 253, 252]) + + def test_int8_to_uint16_negative(self): + _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint16), dtypes.uint16, [2**16-1, 2**16-2, 2**16-3, 2**16-4]) class TestUint8Dtype(TestDType): DTYPE = dtypes.uint8 @unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently") - def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) + def test_uint8_to_int8_overflow(self): + _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4]) -@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH") +@unittest.skipIf(Device.DEFAULT == "WEBGL", "No bitcast on WebGL") class TestBitCast(unittest.TestCase): - def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432]) - @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch") - def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432]) - def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0]) - - # NOTE: these are the same as normal casts - def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252]) - def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4]) - @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") - def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612]) - @unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch") - def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4]) - def test_shape_change_bitcast(self): with self.assertRaises(AssertionError): _test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000]) + def test_bitcast_float_to_int32(self): + a = Tensor([1.,2,3]) + b = a.bitcast(dtypes.int32) + assert b.numpy()[0] == 0x3f800000 + + def test_bitcast_upcasted(self): + a = Tensor.zeros(100, 4, dtype=dtypes.int32).contiguous() + 0x3f800000 + b = a.bitcast(dtypes.float32) + assert b.numpy()[0,0] == 1. + class TestInt16Dtype(TestDType): DTYPE = dtypes.int16 -class TestUint16Dtype(TestDType): DTYPE = dtypes.uint16 + +class TestUint16Dtype(TestDType): + DTYPE = dtypes.uint16 + + def test_uint16_to_int8_overflow(self): + _test_op(lambda: Tensor([2**16-1, 2**16-2, 1, 0], dtype=dtypes.uint16).cast(dtypes.int8), dtypes.int8, [-1, -2, 1, 0]) class TestInt32Dtype(TestDType): DTYPE = dtypes.int32 class TestUint32Dtype(TestDType): DTYPE = dtypes.uint32 @@ -171,6 +268,14 @@ class TestUint64Dtype(TestDType): DTYPE = dtypes.uint64 class TestBoolDtype(TestDType): DTYPE = dtypes.bool +class TestImageDType(unittest.TestCase): + def test_image_scalar(self): + assert dtypes.imagef((10,10)).scalar() == dtypes.float32 + assert dtypes.imageh((10,10)).scalar() == dtypes.float32 + def test_image_vec(self): + assert dtypes.imagef((10,10)).vec(4) == dtypes.float32.vec(4) + assert dtypes.imageh((10,10)).vec(4) == dtypes.float32.vec(4) + class TestEqStrDType(unittest.TestCase): def test_image_ne(self): if ImageDType is None: raise unittest.SkipTest("no ImageDType support") @@ -183,12 +288,299 @@ def test_ptr_ne(self): if PtrDType is None: raise unittest.SkipTest("no PtrDType support") # TODO: is this the wrong behavior? assert PtrDType(dtypes.float32) == dtypes.float32 - #assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32) + assert not (PtrDType(dtypes.float32) != dtypes.float32) + assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32) + assert not (PtrDType(dtypes.float32) != PtrDType(dtypes.float32)) #assert PtrDType(dtypes.float32) != dtypes.float32 def test_strs(self): if PtrDType is None: raise unittest.SkipTest("no PtrDType support") self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))") self.assertEqual(str(PtrDType(dtypes.float32)), "ptr.dtypes.float") +class TestHelpers(unittest.TestCase): + signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) + uints = (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) + floats = (dtypes.float16, dtypes.float32, dtypes.float64) + + @given(strat.sampled_from(signed_ints+uints), strat.integers(min_value=1, max_value=8)) + def test_is_int(self, dtype, amt): + assert dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype) + assert not dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype) + + @given(strat.sampled_from(uints), strat.integers(min_value=1, max_value=8)) + def test_is_unsigned_uints(self, dtype, amt): + assert dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype) + + @given(strat.sampled_from(signed_ints), strat.integers(min_value=1, max_value=8)) + def test_is_unsigned_signed_ints(self, dtype, amt): + assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype) + + @given(strat.sampled_from(floats), strat.integers(min_value=1, max_value=8)) + def test_is_float(self, dtype, amt): + assert dtypes.is_float(dtype.vec(amt) if amt > 1 else dtype) + assert not dtypes.is_int(dtype.vec(amt) if amt > 1 else dtype) + assert not dtypes.is_unsigned(dtype.vec(amt) if amt > 1 else dtype) + + def test_bf16_is_float(self): + assert dtypes.is_float(dtypes.bfloat16) + + @given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), strat.integers(min_value=2, max_value=8)) + def test_scalar(self, dtype, amt): + assert dtype.vec(amt).scalar() == dtype + +class TestTypeSpec(unittest.TestCase): + def setUp(self): + self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float + def tearDown(self): + dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float + + def test_set_dtype_default(self): + dtypes.default_int = dtypes.int16 + assert dtypes.default_int == dtypes.int16 + dtypes.default_int = dtypes.int64 + assert dtypes.default_int == dtypes.int64 + dtypes.default_int = dtypes.int32 + assert dtypes.default_int == dtypes.int32 + dtypes.default_float = dtypes.float16 + assert dtypes.default_float == dtypes.float16 + dtypes.default_float = dtypes.float64 + assert dtypes.default_float == dtypes.float64 + + @given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + def test_creation(self, default_int, default_float): + dtypes.default_int, dtypes.default_float = default_int, default_float + assert Tensor(True).dtype == dtypes.bool + assert Tensor(None).dtype == dtypes.default_float + assert Tensor(2).dtype == dtypes.default_int + assert Tensor(2.34).dtype == dtypes.default_float + assert Tensor([]).dtype == dtypes.default_float + assert Tensor([1]).dtype == dtypes.default_int + assert Tensor([1.1]).dtype == dtypes.default_float + assert Tensor([0,1], dtype=dtypes.bfloat16).dtype == dtypes.bfloat16 + + assert Tensor.eye(0).dtype == dtypes.default_float + assert Tensor.eye(3).dtype == dtypes.default_float + assert Tensor.eye(3, dtype=dtypes.float16).dtype == dtypes.float16 + assert Tensor.eye(3, dtype=dtypes.int64).dtype == dtypes.int64 + + + @given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + def test_full(self, default_int, default_float): + dtypes.default_int, dtypes.default_float = default_int, default_float + + assert Tensor.ones([2,3]).dtype == dtypes.default_float + assert Tensor.zeros([2,3]).dtype == dtypes.default_float + assert Tensor.full([2,3], 3.3).dtype == dtypes.default_float + assert Tensor.full([2,3], 3).dtype == dtypes.default_int + assert Tensor.full([2,3], True).dtype == dtypes.bool + + assert Tensor.zeros(3, 3).dtype == dtypes.default_float + assert Tensor.zeros(3, 3, dtype=dtypes.float16).dtype == dtypes.float16 + assert Tensor.zeros(3, 3, dtype=dtypes.int64).dtype == dtypes.int64 + + assert Tensor.ones(3, 3).dtype == dtypes.default_float + assert Tensor.ones(3, 3, dtype=dtypes.float16).dtype == dtypes.float16 + assert Tensor.ones(3, 3, dtype=dtypes.int64).dtype == dtypes.int64 + + assert Tensor.full((3, 3), 3).dtype == dtypes.default_int + assert Tensor.full((3, 3), 3.0).dtype == dtypes.default_float + assert Tensor.full((3, 3), 3, dtype=dtypes.float16).dtype == dtypes.float16 + assert Tensor.full((3, 3), 3, dtype=dtypes.int64).dtype == dtypes.int64 + + def test_reduce_0d_default(self): + assert Tensor.ones([2,3,0]).sum(2).dtype == dtypes.default_float + assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int + + @given(strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + def test_arange(self, default_int, default_float): + dtypes.default_int, dtypes.default_float = default_int, default_float + + assert Tensor.arange(5).dtype == dtypes.default_int + assert Tensor.arange(5.0).dtype == dtypes.default_float + assert Tensor.arange(5, dtype=dtypes.int16).dtype == dtypes.int16 + assert Tensor.arange(5, dtype=dtypes.int64).dtype == dtypes.int64 + assert Tensor.arange(5, dtype=dtypes.float16).dtype == dtypes.float16 + assert Tensor.arange(3, 9, 0.7).dtype == dtypes.default_float + assert Tensor.arange(3, 8.5, 3).dtype == dtypes.default_float + + @unittest.skipIf(Device.DEFAULT == "WEBGPU", "WEBGPU doesn't follow the bool ops spec") + @given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne])) + def test_bool_ops(self, dtype, op): + assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool + + @given(strat.sampled_from(core_dtypes), + strat.sampled_from([dtypes.int8,dtypes.int16,dtypes.int32,dtypes.int64]), strat.sampled_from([dtypes.float16,dtypes.float32,dtypes.float64])) + def test_functions_return_index(self, dtype, default_int, default_float): + dtypes.default_int, dtypes.default_float = default_int, default_float + assert Tensor([0, 1], dtype=dtype).argmax().dtype == dtypes.default_int + assert Tensor([0, 1], dtype=dtype).argmin().dtype == dtypes.default_int + assert Tensor([0, 1], dtype=dtype).multinomial().dtype == dtypes.default_int + +class TestTypePromotion(unittest.TestCase): + @given(strat.sampled_from(core_dtypes)) + def test_self_promo_to_self(self, dtype): + assert least_upper_dtype(dtype) == dtype + assert least_upper_dtype(dtype, dtype) == dtype + assert least_upper_dtype(dtype, dtype, dtype) == dtype + + @given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes)) + def test_promo_resulted_higher_than_inputs(self, dtype1, dtype2): + result = least_upper_dtype(dtype1, dtype2) + assert result >= dtype1 and result >= dtype2 + + def test_dtype_promo(self): + assert least_upper_dtype(dtypes.bool, dtypes.int8) == dtypes.int8 + assert least_upper_dtype(dtypes.int8, dtypes.uint8) == dtypes.int16 + assert least_upper_dtype(dtypes.uint8, dtypes.int16) == dtypes.int16 + assert least_upper_dtype(dtypes.int16, dtypes.uint16) == dtypes.int32 + assert least_upper_dtype(dtypes.uint16, dtypes.int32) == dtypes.int32 + assert least_upper_dtype(dtypes.int32, dtypes.uint32) == dtypes.int64 + assert least_upper_dtype(dtypes.uint32, dtypes.int64) == dtypes.int64 + # similar to jax but we don't use weak type + assert least_upper_dtype(dtypes.int64, dtypes.uint64) == dtypes.float16 + assert least_upper_dtype(dtypes.float16, dtypes.float32) == dtypes.float32 + assert least_upper_dtype(dtypes.float32, dtypes.float64) == dtypes.float64 + + assert least_upper_dtype(dtypes.bool, dtypes.float32) == dtypes.float32 + assert least_upper_dtype(dtypes.bool, dtypes.float64) == dtypes.float64 + assert least_upper_dtype(dtypes.float16, dtypes.int64) == dtypes.float16 + assert least_upper_dtype(dtypes.float16, dtypes.uint64) == dtypes.float16 + + @given(strat.sampled_from(floats)) + def test_float_to_float(self, dt): + assert least_upper_float(dt) == dt + +class TestAutoCastType(unittest.TestCase): + def setUp(self): + self.old_default_int, self.old_default_float = dtypes.default_int, dtypes.default_float + def tearDown(self): + dtypes.default_int, dtypes.default_float = self.old_default_int, self.old_default_float + + @given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_int(d) and is_dtype_supported(d)])) + def test_int_to_float_unary_func(self, dtype): + for func in [ + lambda t: t.exp(), + lambda t: t.exp2(), + lambda t: t.log(), + lambda t: t.log2(), + lambda t: t.sqrt(), + lambda t: t.rsqrt(), + lambda t: t.sin(), + lambda t: t.cos(), + lambda t: t.tan(), + lambda t: t.sigmoid(), + ]: + a = [2, 3, 4] + # float16 can have larger precision errors + np.testing.assert_allclose(func(Tensor(a, dtype=dtype)).numpy(), func(torch.tensor(a)), rtol=1e-3, atol=1e-3) + + @given(strat.sampled_from(core_dtypes)) + def test_broadcast_scalar(self, dt): + assert (Tensor.rand(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float) + assert (Tensor.rand(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int) + if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool: + assert (Tensor.rand(4, 4, dtype=dt) + True).dtype == dt + + def test_sum(self): + assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int8)).sum().dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int16)).sum().dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int32)).sum().dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int64)).sum().dtype == dtypes.int64 + assert (Tensor([0, 1], dtype=dtypes.uint8)).sum().dtype == dtypes.uint32 + assert (Tensor([0, 1], dtype=dtypes.uint16)).sum().dtype == dtypes.uint32 + assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32 + assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64 + assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16 + assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64 + + def test_cumsum(self): + assert (Tensor([0, 1], dtype=dtypes.bool)).cumsum(0).dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int8)).cumsum(0).dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int16)).cumsum(0).dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int32)).cumsum(0).dtype == dtypes.int32 + assert (Tensor([0, 1], dtype=dtypes.int64)).cumsum(0).dtype == dtypes.int64 + assert (Tensor([0, 1], dtype=dtypes.uint8)).cumsum(0).dtype == dtypes.uint32 + assert (Tensor([0, 1], dtype=dtypes.uint16)).cumsum(0).dtype == dtypes.uint32 + assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32 + assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64 + assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16 + assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16 + assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32 + assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64 + + @given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes)) + def test_matmul(self, dt1, dt2): + assert (Tensor([0, 1], dtype=dt1) @ Tensor([0, 1], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2) + + @staticmethod + def check_where_alternate_input_other(input_, other, data_type): + assert (Tensor([True, False]).where(input_, other)).dtype == data_type + assert (Tensor([True, False]).where(other, input_)).dtype == data_type + + @given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes)) + def test_where_no_scalar(self, dt1, dt2): + self.check_where_alternate_input_other(Tensor(2, dtype=dt1), Tensor(3, dtype=dt2), least_upper_dtype(dt1, dt2)) + + @given(strat.sampled_from(core_dtypes)) + def test_where_one_scalar(self, dt): + t = Tensor(2, dtype=dt) + self.check_where_alternate_input_other(t, 3.2, (dt if dtypes.is_float(dt) else dtypes.default_float)) + self.check_where_alternate_input_other(t, 3, (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)) + self.check_where_alternate_input_other(t, True, dt) + + def test_where_two_scalars(self): + self.check_where_alternate_input_other(3.1, 3.2, dtypes.default_float) + self.check_where_alternate_input_other(3.1, 3, dtypes.default_float) + self.check_where_alternate_input_other(3.1, True, dtypes.default_float) + self.check_where_alternate_input_other(3, 2, dtypes.default_int) + self.check_where_alternate_input_other(3, True, dtypes.default_int) + self.check_where_alternate_input_other(False, True, dtypes.bool) + + @given(strat.sampled_from(core_dtypes), strat.sampled_from(core_dtypes)) + def test_maximum(self, dt1, dt2): + assert Tensor([0, 1, 2], dtype=dt1).maximum(Tensor([2, 0, 5], dtype=dt2)).dtype == least_upper_dtype(dt1, dt2) + + @given(strat.sampled_from(core_dtypes)) + def test_maximum_const(self, dt): + assert Tensor([1, 2], dtype=dt).maximum(3.1).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float) + assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int) + assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt + + def test_div(self): + assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float + assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float + assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32 + assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16 + + def test_div_const(self): + assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float + assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float + assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16 + assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16 + +class TestImplicitFunctionTypeChange(unittest.TestCase): + def test_functions(self): + result = [] + for func in [ + lambda t: t.exp(), + lambda t: t.exp2(), + lambda t: t.log(), + lambda t: t.log2(), + lambda t: t.sqrt(), + lambda t: t.sin(), + ]: + t = func(Tensor([4.0, 3.0])).max() == func(Tensor([4.0, 3.0])) + result.append(t.numpy().sum()) + + if Device.DEFAULT not in ["PYTHON", "CLANG"]: + assert all(result) + else: + # CLANG and PYTHON function default returns in double, and comparison to float can fail + # TODO: fix this + assert not all(result) + if __name__ == '__main__': unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index c57c817..23b7d9e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,11 +1,8 @@ -import torch -import time -import math +import time, math, unittest import numpy as np -import unittest -from teenygrad.tensor import Tensor -from teenygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes -from teenygrad.ops import Device +import torch +from teenygrad.helpers import getenv, IMAGE, DEBUG, CI +from tinygrad import Tensor, Device, dtypes if CI: import warnings @@ -13,25 +10,31 @@ FORWARD_ONLY = getenv("FORWARD_ONLY", 0) PRINT_TENSORS = getenv("PRINT_TENSORS", 0) -def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, forward_only=False, vals=None, a=-0.5, b=3): + +def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, grad_atol=1e-4, grad_rtol=1e-3, + forward_only=False, vals=None, low=-1.5, high=1.5): if tinygrad_fxn is None: tinygrad_fxn = torch_fxn - ts, tst = prepare_test_op(a, b, shps, vals, forward_only) + ts, tst = prepare_test_op(low, high, shps, vals, forward_only) st = time.monotonic() out = torch_fxn(*ts) torch_fp = time.monotonic() - st + # move inputs to a different device, test the device of intermediate tensors are correct + if mt:=getenv("MOVE_TENSOR", ""): + for t in tst: t.to_(mt) + st = time.monotonic() ret = tinygrad_fxn(*tst).realize() tinygrad_fp = time.monotonic() - st - def compare(s, x,y,atol,rtol): - if PRINT_TENSORS: print(s, x, y) - assert x.shape == y.shape, f"shape mismatch: tinygrad={x.shape} | torch={y.shape}" + def compare(s, tinygrad_output, torch_output, atol, rtol): + if PRINT_TENSORS: print(s, tinygrad_output, torch_output) + assert tinygrad_output.shape == torch_output.shape, f"shape mismatch: tinygrad={tinygrad_output.shape} | torch={torch_output.shape}" try: - np.testing.assert_allclose(x,y, atol=atol, rtol=rtol) - except Exception: - raise Exception(f"{s} failed shape {x.shape}") + np.testing.assert_allclose(tinygrad_output, torch_output, atol=atol, rtol=rtol) + except Exception as e: + raise Exception(f"{s} failed shape {tinygrad_output.shape}: {e}") if DEBUG >= 6: np.set_printoptions(linewidth=200, suppress=True) @@ -53,21 +56,23 @@ def compare(s, x,y,atol,rtol): for i, (t, tt) in enumerate(zip(ts, tst)): compare(f"backward pass tensor {i}", tt.grad.numpy(), t.grad.detach().numpy(), atol=grad_atol, rtol=grad_rtol) - if not CI: print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="") + if not CI: + print("\ntesting %40r torch/tinygrad fp: %.2f / %.2f ms bp: %.2f / %.2f ms " % \ + (shps, torch_fp*1000, tinygrad_fp*1000, torch_fbp*1000, tinygrad_fbp*1000), end="") -def prepare_test_op(a, b, shps, vals, forward_only=False): +def prepare_test_op(low, high, shps, vals, forward_only=False): torch.manual_seed(0) np.random.seed(0) if shps is None: ts = [torch.tensor(x, requires_grad=(not forward_only)) for x in vals] - else: ts = [torch.tensor((np.random.random(size=x) + a) * b, requires_grad=(not forward_only), dtype=torch.float32) for x in shps] + else: ts = [torch.tensor(np.random.uniform(low=low, high=high, size=x), requires_grad=(not forward_only), dtype=torch.float32) for x in shps] tst = [Tensor(x.detach().numpy(), requires_grad=(not forward_only and not FORWARD_ONLY)) for x in ts] return ts, tst class TestOps(unittest.TestCase): - def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, exact=False, vals=None, a=-0.5, b=3): + def helper_test_exception(self, shps, torch_fxn, tinygrad_fxn, expected, exact=False, vals=None, low=-1.5, high=1.5): if getenv("CUDACPU"): self.skipTest('helper_test_exception fails in CUDACPU') - ts, tst = prepare_test_op(a, b, shps, vals) + ts, tst = prepare_test_op(low, high, shps, vals) with self.assertRaises(expected) as torch_cm: torch_fxn(*ts) with self.assertRaises(expected) as tinygrad_cm: @@ -81,6 +86,7 @@ def test_full_like(self): helper_test_op([], lambda: torch.full_like(b, 4), lambda: Tensor.full_like(a, 4), forward_only=True) def test_full(self): helper_test_op([], lambda: torch.full((45,65), 4), lambda: Tensor.full((45,65), 4), forward_only=True) + def test_zeros(self): helper_test_op([], lambda: torch.zeros(45,65), lambda: Tensor.zeros(45,65), forward_only=True) helper_test_op([], lambda: torch.zeros([45,65]), lambda: Tensor.zeros([45,65]), forward_only=True) @@ -89,8 +95,10 @@ def test_zeros_like(self): a = Tensor([[1,2,3],[4,5,6]]) b = torch.tensor([[1,2,3],[4,5,6]]) helper_test_op([], lambda: torch.zeros_like(b), lambda: Tensor.zeros_like(a), forward_only=True) + def test_empty_0(self): helper_test_op([], lambda: torch.empty(45,65)*0/0, lambda: Tensor.empty(45,65)*0/0, forward_only=True) + def test_ones(self): helper_test_op([], lambda: torch.ones(45,65), lambda: Tensor.ones(45,65), forward_only=True) helper_test_op([], lambda: torch.ones([45,65]), lambda: Tensor.ones([45,65]), forward_only=True) @@ -99,9 +107,30 @@ def test_ones_like(self): a = Tensor([[1,2,3],[4,5,6]]) b = torch.tensor([[1,2,3],[4,5,6]]) helper_test_op([], lambda: torch.ones_like(b), lambda: Tensor.ones_like(a), forward_only=True) + def test_eye(self): helper_test_op([], lambda: torch.eye(10), lambda: Tensor.eye(10), forward_only=True) helper_test_op([], lambda: torch.eye(1), lambda: Tensor.eye(1), forward_only=True) + helper_test_op([], lambda: torch.eye(0), lambda: Tensor.eye(0), forward_only=True) + + def test_split(self): + def tensor(s): return torch.arange(math.prod(s)).reshape(s), Tensor.arange(math.prod(s)).reshape(s) + test_cases = [ + (tensor((10,)), 5, {}), + (tensor((10,)), [1,4,5], {}), + (tensor((10,)), 3, {}), + (tensor((3,4,)), 1, {}), + (tensor((3,4,)), 1, {'dim':1}), + (tensor((4,4,)), [2,2], {}), + (tensor((4,4,)), [2,2], {'dim':1}), + (tensor((10000,)), 2500, {}), + ] + + for (tor, ten), sizes, args in test_cases: + tor_splits, ten_splits = tor.split(sizes, **args), ten.split(sizes, **args) + assert len(tor_splits) == len(ten_splits) + for tor_chunk, ten_chunk in zip(tor_splits, ten_splits): + helper_test_op([], lambda: tor_chunk, lambda: ten_chunk, forward_only=True) def test_chunk(self): tor = torch.arange(13).repeat(8, 1).chunk(6, 1) @@ -133,8 +162,6 @@ def test_arange(self): helper_test_op([], lambda: torch.arange(5, 10, 3), lambda: Tensor.arange(5, 10, 3), forward_only=True) helper_test_op([], lambda: torch.arange(10, 5, -3), lambda: Tensor.arange(10, 5, -3), forward_only=True) helper_test_op([], lambda: torch.arange(11, 5, -3), lambda: Tensor.arange(11, 5, -3), forward_only=True) - def test_arange_simple(self): - helper_test_op([], lambda: torch.arange(10), lambda: Tensor.arange(10), forward_only=True) def test_arange_big(self): helper_test_op([], lambda: torch.arange(256), lambda: Tensor.arange(256), forward_only=True) @@ -148,11 +175,13 @@ def test_sum_collapse_neg(self): helper_test_op([], lambda: (-torch.ones(3,3)).sum(axis=1), lambda: (-Tensor.ones(3,3)).sum(axis=1), forward_only=True) def test_sum_pad_collapse(self): - helper_test_op([], lambda: torch.nn.functional.pad(torch.ones(256,256), pad=(0,64,0,0)).sum(axis=1), lambda: Tensor.ones(256,256).pad(((0,0), (0,64))).sum(axis=1), forward_only=True) + helper_test_op([], lambda: torch.nn.functional.pad(torch.ones(256,256), pad=(0,64,0,0)).sum(axis=1), + lambda: Tensor.ones(256,256).pad(((0,0), (0,64))).sum(axis=1), forward_only=True) # this is more complex and won't fold for a while def test_sum_cat_collapse(self): - helper_test_op([], lambda: torch.cat([torch.ones(256,256), torch.zeros(256,64)], dim=1).sum(axis=1), lambda: Tensor.cat(Tensor.ones(256,256), Tensor.zeros(256,64), dim=1).sum(axis=1), forward_only=True) + helper_test_op([], lambda: torch.cat([torch.ones(256,256), torch.zeros(256,64)], dim=1).sum(axis=1), + lambda: Tensor.cat(Tensor.ones(256,256), Tensor.zeros(256,64), dim=1).sum(axis=1), forward_only=True) def test_max_dont_collapse(self): helper_test_op([], lambda: torch.ones(256,256).max(1)[0], lambda: Tensor.ones(256,256).max(1), forward_only=True) @@ -180,7 +209,8 @@ def _test_cmp(self, fxn, reverse=True): helper_test_op(shps, fxn, fxn, forward_only=True) helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]]) helper_test_op(None, lambda x,y: fxn(x,2), lambda x,y: fxn(x,2), forward_only=True, vals=[[0.,1,2], [2.,1,0]]) - helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]]) + if Device.DEFAULT != "WEBGPU": # bool is not HOST_SHARABLE, so it cannot be used as a storage buffer type + helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]]) if reverse: helper_test_op(None, lambda x,y: fxn(2,y), lambda x,y: fxn(2,y), forward_only=True, vals=[[0.,1,2], [2.,1,0]]) def test_cmp_eq(self): self._test_cmp(lambda x,y: x==y, reverse=False) @@ -196,6 +226,11 @@ def test_cmp_eq_backwards(self): tt1 = Tensor.ones(4, requires_grad=True) tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 == tt2).sum().backward) + tt = Tensor.randn(4, requires_grad=True) + (tt*(tt == 0)).sum().backward() + t = torch.tensor(tt.numpy(), requires_grad=True) + (t*(t == 0)).sum().backward() + np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), rtol=1e-5) def test_cmp_lt_backwards(self): t1 = torch.ones(4, requires_grad=True) @@ -204,142 +239,193 @@ def test_cmp_lt_backwards(self): tt1 = Tensor.ones(4, requires_grad=True) tt2 = Tensor.ones(4, requires_grad=True) self.assertRaises(RuntimeError, (tt1 < tt2).sum().backward) + tt = Tensor.randn(4, requires_grad=True) + (tt*(tt < 0)).sum().backward() + t = torch.tensor(tt.numpy(), requires_grad=True) + (t*(t < 0)).sum().backward() + np.testing.assert_allclose(t.grad.numpy(), tt.grad.numpy(), rtol=1e-5) - #@unittest.skip("this is broken with contiguous") def test_trunc(self): - helper_test_op([(45,65)], lambda x: torch.trunc(x), lambda x: x.trunc(), forward_only=True) - a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5]) - helper_test_op([], lambda: torch.trunc(b), lambda: Tensor.trunc(a), forward_only=True) - #@unittest.skip("this is broken with contiguous") + helper_test_op([(45,35)], lambda x: x.trunc(), forward_only=True) + helper_test_op(None, lambda x: x.trunc(), vals=[[1.499, 1.5, 1.501, 1.0, 2.1, 0.0, -5.0, -2.499, -2.5, -2.501]], forward_only=True) def test_floor(self): - helper_test_op([(45,65)], lambda x: torch.floor(x), lambda x: x.floor(), forward_only=True) - a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5]) - helper_test_op([], lambda: torch.floor(b), lambda: Tensor.floor(a), forward_only=True) - #@unittest.skip("this is broken with contiguous") + helper_test_op([(45,35)], lambda x: x.floor(), forward_only=True) + helper_test_op(None, lambda x: x.floor(), vals=[[1.499, 1.5, 1.501, 1.0, 2.1, 0.0, -5.0, -2.499, -2.5, -2.501]], forward_only=True) def test_ceil(self): - helper_test_op([(45,65)], lambda x: torch.ceil(x), lambda x: x.ceil(), forward_only=True) - a, b = Tensor([1.0, 2.1, 0.0, -5.0, -2.5]), torch.tensor([1.0, 2.1, 0.0, -5.0, -2.5]) - helper_test_op([], lambda: torch.ceil(b), lambda: Tensor.ceil(a), forward_only=True) + helper_test_op([(45,35)], lambda x: x.ceil(), forward_only=True) + helper_test_op(None, lambda x: x.ceil(), vals=[[1.499, 1.5, 1.501, 1.0, 2.1, 0.0, -5.0, -2.499, -2.5, -2.501]], forward_only=True) + def test_round(self): + helper_test_op([(45,35)], lambda x: x.round(), forward_only=True) + helper_test_op(None, lambda x: x.round(), vals=[[1.499, 1.5, 1.501, 1.0, 2.1, 0.0, -5.0, -2.499, -2.5, -2.501]], forward_only=True) + helper_test_op(None, lambda x: x.round(), vals=[[2.5, -1.5]], forward_only=True) + def test_tril(self): - helper_test_op([(3,3)], lambda x: x.tril(), lambda x: x.tril()) - helper_test_op([(3,3)], lambda x: x.tril(1), lambda x: x.tril(1)) - helper_test_op([(3,3)], lambda x: x.tril(-1), lambda x: x.tril(-1)) - helper_test_op([(5,3,3)], lambda x: x.tril(), lambda x: x.tril()) - helper_test_op([(5,3,3)], lambda x: x.tril(1), lambda x: x.tril(1)) + helper_test_op([(3,3)], lambda x: x.tril()) + helper_test_op([(3,3)], lambda x: x.tril(1)) + helper_test_op([(3,3)], lambda x: x.tril(-1)) + helper_test_op([(5,3,3)], lambda x: x.tril()) + helper_test_op([(5,0,3)], lambda x: x.tril()) + helper_test_op([(5,3,3)], lambda x: x.tril(1)) def test_triu(self): - helper_test_op([(3,3)], lambda x: x.triu(), lambda x: x.triu()) - helper_test_op([(3,3)], lambda x: x.triu(1), lambda x: x.triu(1)) - helper_test_op([(3,3)], lambda x: x.triu(-1), lambda x: x.triu(-1)) - helper_test_op([(5,3,3)], lambda x: x.triu(), lambda x: x.triu()) - helper_test_op([(5,3,3)], lambda x: x.triu(1), lambda x: x.triu(1)) + helper_test_op([(3,3)], lambda x: x.triu()) + helper_test_op([(3,3)], lambda x: x.triu(1)) + helper_test_op([(3,3)], lambda x: x.triu(-1)) + helper_test_op([(5,3,3)], lambda x: x.triu()) + helper_test_op([(5,0,3)], lambda x: x.triu()) + helper_test_op([(5,3,3)], lambda x: x.triu(1)) + def test_maximum(self): helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum) helper_test_op([(), ()], torch.maximum, Tensor.maximum) - helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., 4.], [1., 2., 3., 0.]]) - helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1, 0, 3, 4], [1, 2, 3, 0]], forward_only=True) + helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], 3.]) + helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]]) + helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], True], forward_only=True) + helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[True, False, False], [True, True, False]], forward_only=True) def test_minimum(self): helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum) helper_test_op([(), ()], torch.minimum, Tensor.minimum) + helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], 3.]) + helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[1., 0., 3., -4.], [-1., -2., 3., 0.]]) + helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], True], forward_only=True) + helper_test_op(None, torch.minimum, Tensor.minimum, vals=[[True, False, False], [True, True, False]], forward_only=True) + + def test_tiny_add(self): + helper_test_op([(3), (3)], lambda x,y: x+y, Tensor.add, forward_only=True) + def test_add(self): helper_test_op([(45,68), (45,68)], lambda x,y: x+y, Tensor.add) - def test_add_number(self): - helper_test_op([(), ()], lambda x,y: x+y, Tensor.add) + helper_test_op([(45,68), (45,68)], lambda x,y: x+y) + helper_test_op([(), ()], lambda x,y: x+y) def test_add3(self): helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: x+y+z) - def test_add_simple(self): - helper_test_op([(256), (256)], lambda x,y: x+y, Tensor.add, forward_only=True) def test_broadcasted_add(self): - helper_test_op([(45,65), (45,1)], lambda x,y: x+y, lambda x,y: x+y) - helper_test_op([(45,65), ()], lambda x,y: x+y, lambda x,y: x+y) + helper_test_op([(45,65), (45,1)], lambda x,y: x+y) + helper_test_op([(45,65), ()], lambda x,y: x+y) def test_broadcasted_add_2(self): - helper_test_op([(45,65), (65,)], lambda x,y: x+y, lambda x,y: x+y) + helper_test_op([(45,65), (65,)], lambda x,y: x+y) + def test_sub(self): helper_test_op([(45,65), (45,65)], lambda x,y: x-y, Tensor.sub) - helper_test_op([(), ()], lambda x,y: x-y, Tensor.sub) + helper_test_op([(45,65), (45,65)], lambda x,y: x-y) + helper_test_op([(), ()], lambda x,y: x-y) + def test_scalar_sub(self): + helper_test_op([(45,65)], lambda x: x-2) + helper_test_op([()], lambda x: x-2) + def test_scalar_rsub(self): + helper_test_op([(45,65)], lambda x: 2-x) + helper_test_op([()], lambda x: 2-x) + def test_neg(self): helper_test_op([(45,65)], lambda x: -x) - helper_test_op([()], lambda x: -x) + helper_test_op([(45,65)], lambda x: x.neg()) + helper_test_op([()], lambda x: x.neg()) + def test_logical_not(self): + helper_test_op(None, torch.logical_not, Tensor.logical_not, vals=[[True, False, True]], forward_only=True) + helper_test_op(None, torch.logical_not, Tensor.logical_not, vals=[[1.,2.,0.,0.5]], forward_only=True) + def test_mul(self): helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul) - def test_mul_number(self): - helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul) - def test_mul_const(self): - helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2) - helper_test_op([(45,65)], lambda x: x*-1, lambda x: x*-1) - helper_test_op([(45,65)], lambda x: 255*x, lambda x: 255*x) + helper_test_op([(64,64), (64,64)], lambda x,y: x*y) + helper_test_op([(), ()], lambda x,y: x*y) + def test_scalar_mul(self): + helper_test_op([(45,65)], lambda x: x*2) + helper_test_op([(45,65)], lambda x: x*-1) + helper_test_op([(45,65)], lambda x: 255*x) + helper_test_op([(45,65)], lambda x: 2*x) + helper_test_op([()], lambda x: x*2) + helper_test_op([()], lambda x: 2*x) + def test_div(self): helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div) - helper_test_op([(), ()], lambda x,y: x/y, Tensor.div) - helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5],[1]]) + helper_test_op([(45,65), (45,65)], lambda x,y: x/y) + helper_test_op([(), ()], lambda x,y: x/y) def test_div_int(self): - helper_test_op(None, lambda x: (x/2).to(torch.int), lambda x: x/2, forward_only=True, vals=[[3]]) - def test_div_const(self): - helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255) - helper_test_op([(45,65)], lambda x: x/1, lambda x: x/1) - helper_test_op([(45,65)], lambda x: 1/x, lambda x: 1/x) - helper_test_op([(45,65)], lambda x: x/2, lambda x: x/2) - helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x) - helper_test_op([()], lambda x: x/2, lambda x: x/2) - helper_test_op([()], lambda x: 2/x, lambda x: 2/x) - @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "WEBGPU does not have support for inf/nan, METAL has issues with -inf") - def test_mul_const_naninf(self): - helper_test_op([(45,65)], lambda x: x*float("inf"), lambda x: x*float("inf")) - helper_test_op([(45,65)], lambda x: x*-float("inf"), lambda x: x*-float("inf")) - helper_test_op([(45,65)], lambda x: x*float("nan"), lambda x: x*float("nan")) - @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "WEBGPU does not have support for inf/nan, METAL has issues with -inf") - def test_div_const_naninf(self): - helper_test_op([(45,65)], lambda x: x/float("inf"), lambda x: x/float("inf")) - helper_test_op([(45,65)], lambda x: x/-float("inf"), lambda x: x/-float("inf")) - helper_test_op([(45,65)], lambda x: x/float("nan"), lambda x: x/float("nan")) - helper_test_op([(45,65)], lambda x: float("inf")/x, lambda x: float("inf")/x) - helper_test_op([(45,65)], lambda x: (-float("inf"))/x, lambda x: (-float("inf"))/x) - helper_test_op([(45,65)], lambda x: float("nan")/x, lambda x: float("nan")/x) + helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32)) + helper_test_op(None, lambda x: x/2, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32)) + def test_scalar_div(self): + helper_test_op([(45,65)], lambda x: x/255) + helper_test_op([(45,65)], lambda x: x/1) + helper_test_op([(45,65)], lambda x: 1/x) + helper_test_op([(45,65)], lambda x: x/2) + helper_test_op([(45,65)], lambda x: 2/x) + helper_test_op([()], lambda x: x/2) + helper_test_op([()], lambda x: 2/x) + + def test_mul_naninf(self): + helper_test_op([(45,65)], lambda x: x*math.inf) + helper_test_op([(45,65)], lambda x: x*-math.inf) + helper_test_op([(45,65)], lambda x: x*math.nan) + def test_div_naninf(self): + helper_test_op([(45,65)], lambda x: x/math.inf) + helper_test_op([(45,65)], lambda x: x/-math.inf) + helper_test_op([(45,65)], lambda x: x/math.nan) + helper_test_op([(45,65)], lambda x: math.inf/x) + helper_test_op([(45,65)], lambda x: (-math.inf)/x) + helper_test_op([(45,65)], lambda x: math.nan/x) + def test_pow_full(self): - helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0) + helper_test_op([(45,65), (45,65)], lambda x,y: x**y) + helper_test_op([(45,65), (45,65)], lambda x,y: x.pow(y)) def test_pow(self): - # TODO: why is a=0 for these tests? - helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0) - helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0) - helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0) - helper_test_op([()], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0) - helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0) + helper_test_op([(45,65)], lambda x: x**0) + helper_test_op([(45,65)], lambda x: x**1) + helper_test_op([(45,65)], lambda x: x**2) + helper_test_op([(45,65)], lambda x: x**3) + helper_test_op([(45,65)], lambda x: x**-2) + helper_test_op([()], lambda x: x**2) + helper_test_op([()], lambda x: x**-2) # Regression tests for https://github.com/tinygrad/tinygrad/issues/1151 - helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=-10) - helper_test_op([()], lambda x: x**3, lambda x: Tensor.pow(x,3), a=-10) + helper_test_op([(45,65)], lambda x: x**3, low=-30, high=-27) + helper_test_op([()], lambda x: x**3, low=-30, high=-27) # Regression tests for https://github.com/tinygrad/tinygrad/issues/1251 - helper_test_op([(45,65)], lambda x: x**0.2, lambda x: Tensor.pow(x,0.2), a=-10) - helper_test_op([(45,65)], lambda x: x**1.2, lambda x: Tensor.pow(x,1.2), a=-10) - helper_test_op([()], lambda x: x**0.2, lambda x: Tensor.pow(x,0.2), a=-10) - helper_test_op([()], lambda x: x**1.2, lambda x: Tensor.pow(x,1.2), a=-10) + helper_test_op([(45,65)], lambda x: x**0.2, low=-30, high=-27) + helper_test_op([(45,65)], lambda x: x**1.2, low=-30, high=-27) + helper_test_op([()], lambda x: x**0.2, low=-30, high=-27) + helper_test_op([()], lambda x: x**1.2, low=-30, high=-27) a, b = Tensor([0.0], requires_grad=True), torch.tensor([0.0], requires_grad=True) - helper_test_op([], lambda: b**1.1, lambda: a**1.1, ) + helper_test_op([], lambda: b**1.1, lambda: a**1.1) def test_pow_const(self): - helper_test_op([(45,65)], lambda x: x**1.0, lambda x: x**1.0) - helper_test_op([(45,65)], lambda x: x**-1.0, lambda x: x**-1.0) - helper_test_op([(45,65)], lambda x: 1.0**x, lambda x: 1.0**x) - helper_test_op([(45,65)], lambda x: x**2.0, lambda x: x**2.0) - helper_test_op([(45,65)], lambda x: 2.0**x, lambda x: 2.0**x) - helper_test_op([()], lambda x: x**2.0, lambda x: x**2.0) - helper_test_op([()], lambda x: 2.0**x, lambda x: 2.0**x) + helper_test_op([(45,65)], lambda x: x**1.0) + helper_test_op([(45,65)], lambda x: x**-1.0) + helper_test_op([(45,65)], lambda x: 1.0**x) + helper_test_op([(45,65)], lambda x: x**2.0) + helper_test_op([(45,65)], lambda x: 2.0**x) + helper_test_op([()], lambda x: x**2.0) + helper_test_op([()], lambda x: 2.0**x) + # TODO: fix 0**x and 0**0 == 1 + # helper_test_op(None, lambda x: 0**x, vals=[[-2.,-1,0,1,2,3]]) + # TODO: fix backward, should be nan + helper_test_op(None, lambda x: (-2)**x, vals=[[-2.,-1,0,1,2,3]], forward_only=True) + def test_sqrt(self): - helper_test_op([(45,65)], lambda x: x.sqrt(), Tensor.sqrt, a=0) - helper_test_op([()], lambda x: x.sqrt(), Tensor.sqrt, a=0) + helper_test_op([(45,65)], lambda x: x.sqrt()) + helper_test_op([()], lambda x: x.sqrt()) def test_rsqrt(self): - helper_test_op([(45,65)], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0) - helper_test_op([()], lambda x: torch.rsqrt(x), Tensor.rsqrt, a=0) + helper_test_op([(45,65)], lambda x: x.rsqrt()) + helper_test_op([()], lambda x: x.rsqrt()) + + def test_xor(self): + tor = torch.tensor([[1,-8,1],[32,1,6]], dtype=torch.int) + ten = Tensor([[1,-8,1],[32,1,6]], dtype=dtypes.int32) + helper_test_op([], lambda: tor^tor, lambda: ten^ten, forward_only=True) + helper_test_op([], lambda: tor^0x1337, lambda: ten^0x1337, forward_only=True) + helper_test_op([], lambda: 0x1337^tor, lambda: 0x1337^ten, forward_only=True) def test_sin(self): - helper_test_op([(45,65)], lambda x: x.sin(), Tensor.sin, a=0) + helper_test_op([(45,65)], lambda x: x.sin()) + helper_test_op([()], lambda x: x.sin()) def test_cos(self): - helper_test_op([(45,65)], lambda x: x.cos(), Tensor.cos, a=0) + helper_test_op([(45,65)], lambda x: x.cos()) + helper_test_op([()], lambda x: x.cos()) def test_tan(self): - helper_test_op([(45,65)], lambda x: x.tan(), Tensor.tan, a=0) + helper_test_op([(45,65)], lambda x: x.tan()) + helper_test_op([()], lambda x: x.tan()) def test_relu(self): - helper_test_op([(64,64)], lambda x: x.relu(), Tensor.relu) - helper_test_op([()], lambda x: x.relu(), Tensor.relu) + helper_test_op([(64,64)], lambda x: x.relu()) + helper_test_op([()], lambda x: x.relu()) def test_relu_exact(self): - helper_test_op(None, lambda x: x.relu(), Tensor.relu, vals=[[-1.,0,1]]) + helper_test_op(None, lambda x: x.relu(), vals=[[-1.,0,1]]) def test_relu_maximum_exact(self): helper_test_op(None, lambda x: torch.maximum(x, torch.zeros_like(x, requires_grad=False)), lambda x: Tensor.maximum(x, 0), vals=[[-1.,0,1]]) def test_leakyrelu(self): @@ -349,59 +435,156 @@ def test_celu(self): for val in range(1, 5): helper_test_op([(45,65)], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) helper_test_op([()], lambda x: torch.nn.functional.celu(x,val), lambda x: x.celu(val)) + def test_abs(self): - helper_test_op([(45,65)], lambda x: torch.abs(x), Tensor.abs) - helper_test_op([()], lambda x: torch.abs(x), Tensor.abs) + helper_test_op([(45,65)], torch.abs, Tensor.abs) + helper_test_op([()], torch.abs, Tensor.abs) + def test_log(self): - helper_test_op([(45,65)], lambda x: torch.log(x), Tensor.log) - helper_test_op([()], lambda x: torch.log(x), Tensor.log) + helper_test_op([(45,65)], torch.log, Tensor.log) + helper_test_op([()], torch.log, Tensor.log) def test_log2(self): - helper_test_op([(45,65)], lambda x: torch.log2(x), Tensor.log2) - helper_test_op([()], lambda x: torch.log2(x), Tensor.log2) + helper_test_op([(45,65)], torch.log2, Tensor.log2) + helper_test_op([()], torch.log2, Tensor.log2) + def test_exp(self): - helper_test_op([(45,65)], lambda x: torch.exp(x), Tensor.exp) - helper_test_op([()], lambda x: torch.exp(x), Tensor.exp) + helper_test_op([(45,65)], torch.exp, Tensor.exp) + helper_test_op([()], torch.exp, Tensor.exp) def test_exp2(self): - helper_test_op([(45,65)], lambda x: torch.exp2(x), Tensor.exp2) - helper_test_op([()], lambda x: torch.exp2(x), Tensor.exp2) + helper_test_op([(45,65)], torch.exp2, Tensor.exp2) + helper_test_op([()], torch.exp2, Tensor.exp2) + def test_sign(self): - helper_test_op([(45,65)], lambda x: torch.sign(x), Tensor.sign) - helper_test_op([()], lambda x: torch.sign(x), Tensor.sign) + helper_test_op([(45,65)], torch.sign, Tensor.sign) + helper_test_op([()], torch.sign, Tensor.sign) def test_softsign(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.softsign(x), Tensor.softsign) - helper_test_op([()], lambda x: torch.nn.functional.softsign(x), Tensor.softsign) + helper_test_op([(45,65)], torch.nn.functional.softsign, Tensor.softsign) + helper_test_op([()], torch.nn.functional.softsign, Tensor.softsign) + def test_sigmoid(self): - helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid) - helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=100) - helper_test_op([(45,65)], lambda x: x.sigmoid(), Tensor.sigmoid, a=-100) - helper_test_op([()], lambda x: x.sigmoid(), Tensor.sigmoid, forward_only=True) + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid) + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=300, high=303) + helper_test_op([(45,65)], torch.sigmoid, Tensor.sigmoid, low=-300, high=-297) + helper_test_op([()], torch.sigmoid, Tensor.sigmoid) def test_softplus(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6) - helper_test_op([()], lambda x: torch.nn.functional.softplus(x), Tensor.softplus, atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6) + helper_test_op([()], torch.nn.functional.softplus, Tensor.softplus, grad_atol=1e-6) + def test_gelu(self): helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu) - #helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=100) - helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, a=-100) + helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=300, high=303) + helper_test_op([(45,65)], lambda x: torch.nn.functional.gelu(x, approximate="tanh"), Tensor.gelu, low=-300, high=-297) def test_quick_gelu(self): helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) - helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=100) - helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, a=-100) + helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=300, high=303) + helper_test_op([(45,65)], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu, low=-300, high=-297) helper_test_op([()], lambda x: x * torch.sigmoid(1.702 * x), Tensor.quick_gelu) + def test_elu(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x), Tensor.elu) + helper_test_op([(45,65)], torch.nn.functional.elu, Tensor.elu) helper_test_op([(45,65)], lambda x: torch.nn.functional.elu(x, alpha=0.1), lambda x: Tensor.elu(x, alpha=0.1)) - helper_test_op([()], lambda x: torch.nn.functional.elu(x), Tensor.elu) + helper_test_op([()], torch.nn.functional.elu, Tensor.elu) def test_relu6(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.relu6(x), Tensor.relu6) - helper_test_op([()], lambda x: torch.nn.functional.relu6(x), Tensor.relu6) + helper_test_op([(45,65)], torch.nn.functional.relu6, Tensor.relu6) + helper_test_op([()], torch.nn.functional.relu6, Tensor.relu6) def test_hardswish(self): - helper_test_op([(45,65)], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6) - helper_test_op([()], lambda x: torch.nn.functional.hardswish(x), Tensor.hardswish, atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], torch.nn.functional.hardswish, Tensor.hardswish, grad_atol=1e-6) + helper_test_op([()], torch.nn.functional.hardswish, Tensor.hardswish, grad_atol=1e-6) def test_mish(self): - def _mish_pytorch(x): - return x*torch.tanh(torch.nn.functional.softplus(x)) - helper_test_op([(45,65)], _mish_pytorch, Tensor.mish, atol=1e-4) - helper_test_op([()], _mish_pytorch, Tensor.mish, atol=1e-4) + helper_test_op([(45,65)], torch.nn.functional.mish, Tensor.mish) + helper_test_op([()], torch.nn.functional.mish, Tensor.mish) + + def test_multinomial(self): + # NOTE: this is random, so it has a very large atol + helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1), + lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.) + + def test_small_cumsum(self): + helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) + def test_simple_cumsum(self): + helper_test_op([(512)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) + helper_test_op([(1022)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) + def test_cumsum(self): + helper_test_op([(20)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) + helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) + helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1)) + helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2)) + helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=-1), lambda x: Tensor.cumsum(x, axis=-1)) + def test_cumsum_zero_axis(self): + helper_test_op([(2,0,4)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1)) + helper_test_op([(0,3)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0)) + helper_test_op([(2,3,0)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2)) + + def test_argmax(self): + self.assertEqual(torch.tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max + helper_test_op([(10,20)], lambda x: x.argmax(), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmax(0, False), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmax(1, False), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmax(1, True), forward_only=True) + + def test_argmin(self): + self.assertEqual(torch.tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy()) + helper_test_op([(10,20)], lambda x: x.argmin(), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmin(0, False), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmin(1, False), forward_only=True) + helper_test_op([(10,20)], lambda x: x.argmin(1, True), forward_only=True) + + def test_einsum(self): + # matrix transpose + helper_test_op([(150,150)], lambda a: torch.einsum('ij->ji', a), lambda a: Tensor.einsum('ij->ji', a)) + helper_test_op([(150,150)], lambda a: torch.einsum('ij -> ji', a), lambda a: Tensor.einsum('ij -> ji', a)) + helper_test_op([(150,150)], lambda a: torch.einsum('ji', a), lambda a: Tensor.einsum('ji', a)) + helper_test_op([(20,30,40)], lambda a: torch.einsum('jki', a), lambda a: Tensor.einsum('jki', a)) + helper_test_op([(20,30,40)], lambda a: torch.einsum('dog', a), lambda a: Tensor.einsum('dog', a)) + # sum all elements + helper_test_op([(20,30,40)], lambda a: torch.einsum('ijk->', a), lambda a: Tensor.einsum('ijk->', a)) + # column sum + helper_test_op([(50,50)], lambda a: torch.einsum('ij->j', a), lambda a: Tensor.einsum('ij->j', a)) + # row sum + helper_test_op([(15,15)], lambda a: torch.einsum('ij->i', a), lambda a: Tensor.einsum('ij->i', a)) + # matrix-vector multiplication + helper_test_op([(15,20), (20,)], lambda a,b: torch.einsum('ik,k->i', a,b), lambda a,b: Tensor.einsum('ik,k->i', a, b)) + # matrix-matrix multiplication + helper_test_op([(15,20), (20,30)], lambda a,b: torch.einsum('ik,kj->ij', a,b), lambda a,b: Tensor.einsum('ik,kj->ij', a, b)) + # dot product + helper_test_op([(30),(30)], lambda a,b: torch.einsum('i,i->i', [a,b]), lambda a,b: Tensor.einsum('i,i->i', [a,b])) + # hadamard product + helper_test_op([(30,40),(30,40)], lambda a,b: torch.einsum('ij,ij->ij', a,b), lambda a,b: Tensor.einsum('ij,ij->ij', a,b)) + # outer product + helper_test_op([(15,), (15,)], lambda a,b: torch.einsum('i,j->ij', a,b), lambda a,b: Tensor.einsum('i,j->ij',a,b)) + # batch matrix multiplication + helper_test_op([(10,20,30),(10,30,40)], lambda a,b: torch.einsum('ijk,ikl->ijl', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->ijl', [a, b])) + # batch matrix multiplication, result permuted + helper_test_op([(10,20,25),(10,25,32)], lambda a,b: torch.einsum('ijk,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('ijk,ikl->jil', [a, b])) + # batch matrix multiplication, result & input permuted + helper_test_op([(20,10,25),(10,25,32)], lambda a,b: torch.einsum('jik,ikl->jil', [a, b]), lambda a,b: Tensor.einsum('jik,ikl->jil', [a, b])) + # tensor contraction + helper_test_op([(3,5,8,10),(11,13,5,16,8)], lambda a,b: torch.einsum('pqrs,tuqvr->pstuv', a,b), + lambda a,b: Tensor.einsum('pqrs,tuqvr->pstuv', a,b), atol=1e-5) + # tensor contraction, input permuted + helper_test_op([(3,8,10,5),(11,5,13,16,8)], lambda a,b: torch.einsum('prsq,tquvr->pstuv', a,b), + lambda a,b: Tensor.einsum('prsq,tquvr->pstuv', a,b), atol=1e-5) + # bilinear transformation + helper_test_op([(2,3),(5,3,7),(2,7)], lambda a,b,c: torch.einsum('ik,jkl,il->ij', [a,b,c]), lambda a,b,c: Tensor.einsum('ik,jkl,il->ij', [a,b,c])) + + def test_einsum_shape_check(self): + a = Tensor.zeros(3,8,10,5) + b = Tensor.zeros(11,5,13,16,8) + with self.assertRaises(AssertionError): + Tensor.einsum('pqrs,tuqvr->pstuv',a,b) + + def test_einsum_arity_check1(self): + a = Tensor.zeros(10,15) + b = Tensor.zeros(15,20) + c = Tensor.zeros(20,10) + with self.assertRaises(AssertionError): + Tensor.einsum('ij,jk->ij', a,b,c) + + def test_einsum_arity_check2(self): + a = Tensor.zeros(10,10) + with self.assertRaises(AssertionError): + Tensor.einsum('ij,jk->ij', a) + @unittest.skipIf(IMAGE>0, "no 1d dot for images") def test_dot_1d(self): helper_test_op([(65), (65)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) @@ -420,34 +603,26 @@ def test_dot(self): with self.assertRaises(AssertionError): a = Tensor(3.14) a.matmul(a) - - def test_multinomial(self): - # NOTE: this is random, so it has a very large atol - helper_test_op([(1000,)], lambda x: torch.multinomial(x.clip(0,1), num_samples=1), lambda x: Tensor.multinomial(x.clip(0,1)), forward_only=True, atol=1000.) - - def test_small_cumsum(self): - helper_test_op([(10)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) - def test_simple_cumsum(self): - helper_test_op([(1022)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) - def test_cumsum(self): - helper_test_op([(20)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) - helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=0), lambda x: Tensor.cumsum(x, axis=0), atol=1e-6) - helper_test_op([(20,30)], lambda x: torch.cumsum(x, dim=1), lambda x: Tensor.cumsum(x, axis=1), atol=1e-6) - helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=2), lambda x: Tensor.cumsum(x, axis=2), atol=1e-6) - helper_test_op([(20,30,40)], lambda x: torch.cumsum(x, dim=-1), lambda x: Tensor.cumsum(x, axis=-1), atol=1e-6) - - def test_argmax(self): - self.assertEqual(torch.Tensor([2,2]).argmax().numpy(), Tensor([2,2]).argmax().numpy()) # check if returns first index for same max - helper_test_op([(10,20)], lambda x: x.argmax(), lambda x: x.argmax(), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmax(0, False), lambda x: x.argmax(0, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmax(1, False), lambda x: x.argmax(1, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmax(1, True), lambda x: x.argmax(1, True), forward_only=True) - def test_argmin(self): - self.assertEqual(torch.Tensor([2, 2]).argmin().numpy(), Tensor([2, 2]).argmin().numpy()) - helper_test_op([(10,20)], lambda x: x.argmin(), lambda x: x.argmin(), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmin(0, False), lambda x: x.argmin(0, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmin(1, False), lambda x: x.argmin(1, False), forward_only=True) - helper_test_op([(10,20)], lambda x: x.argmin(1, True), lambda x: x.argmin(1, True), forward_only=True) + def test_mulacc_with_zero_strides(self): + helper_test_op( + [], + lambda: torch.tensor(1.0).reshape((1,1,1)).expand(2,4,3).mul(torch.tensor(1.0).reshape((1,1,1)).expand(2,4,3)).sum(-1), + lambda: Tensor(1.0).reshape((1,1,1)).expand(2,4,3).mul(Tensor(1.0).reshape((1,1,1)).expand(2,4,3)).sum(-1), + forward_only=True + ) + a = [[1.,1.,1.,1.], [1.,1.,1.,1.]] + b = [1.,1.,1.,1.] + helper_test_op( + [], + lambda: torch.tensor(a).reshape((2,4,1)).expand(2,4,3).mul(torch.tensor(b).reshape((1,4,1)).expand(2,4,3)).sum([0,2]), + lambda: Tensor(a).reshape((2,4,1)).expand(2,4,3).mul(Tensor(b).reshape((1,4,1)).expand(2,4,3)).sum([0,2]), + forward_only=True + ) + helper_test_op( + [], + lambda: torch.ones((1,2)).matmul(torch.ones((2,3))), lambda: Tensor.ones((1,2)).dot(Tensor.ones((2,3))), + forward_only=True + ) def test_matmul_simple(self): helper_test_op([(4), (4,4)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) @@ -463,12 +638,24 @@ def test_matmul_batched_vector(self): helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) def test_small_gemm(self): helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3) + def test_small_gemm_range(self): + helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.arange(0,64,dtype=np.float32).reshape(8,8), + np.arange(64,128,dtype=np.float32).reshape(8,8)]) def test_small_gemm_eye(self): helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)]) def test_gemm(self): helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3) def test_big_gemm(self): helper_test_op([(256,256), (256,256)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3) + @unittest.skipIf(IMAGE>0, "no 0 in shape matmul on images") + def test_gemm_with_zeros_shape(self): + helper_test_op([(8,8), (8,0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) + helper_test_op([(0,8), (8,8)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) + helper_test_op([(0,8), (8,0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) + helper_test_op([(8,0), (0,8)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) + helper_test_op([(0,0), (0,0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) + helper_test_op([(0), (0,8)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) + helper_test_op([(0), (0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) def test_broadcastdot(self): helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) with self.assertRaises(AssertionError): @@ -478,155 +665,221 @@ def test_broadcastdot(self): def test_multidot(self): helper_test_op([(10,45,65), (10,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) helper_test_op([(3,3,45,65), (3,3,65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) + def test_sum_simple(self): - helper_test_op(None, lambda x: x.sum(), Tensor.sum, vals=[[1.,1.]]) + helper_test_op(None, lambda x: x.sum(), vals=[[1.,1.]]) def test_sum_full(self): - helper_test_op([(16384)], lambda x: x.sum(), lambda x: x.sum()) - def test_sum_small_full(self): - helper_test_op([(45,5)], lambda x: x.sum(), Tensor.sum) + helper_test_op([(16384)], lambda x: x.sum()) def test_sum_relu(self): - helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu(), lambda x: x.relu().sum().relu()) + helper_test_op([(3,4,5)], lambda x: x.relu().sum().relu()) def test_sum(self): - helper_test_op([(45,3)], lambda x: x.sum(), Tensor.sum) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3), lambda x: Tensor.sum(x, axis=3)) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3)), lambda x: Tensor.sum(x, axis=(1,3))) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2)), lambda x: Tensor.sum(x, axis=(0,2))) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2)), lambda x: Tensor.sum(x, axis=(1,2))) - helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1), lambda x: Tensor.sum(x, axis=1)) + helper_test_op([(45,3)], lambda x: x.sum()) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=3)) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,3))) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(0,2))) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=(1,2))) + helper_test_op([(3,4,5,6)], lambda x: x.sum(axis=1)) helper_test_op([()], lambda x: x.sum(), Tensor.sum) + def test_sum_with_zeros_shape(self): + helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,))) + helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,))) + helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,1))) + def test_min(self): - helper_test_op([(3,3)], lambda x: x.min(), Tensor.min) - helper_test_op([(45,3)], lambda x: x.min(), Tensor.min) - helper_test_op([(45,3)], lambda x: x.min().mul(0.5), lambda x: Tensor.min(x).mul(0.5)) - helper_test_op([()], lambda x: x.min(), Tensor.min) + helper_test_op([(3,3)], lambda x: x.min()) + helper_test_op([(45,3)], lambda x: x.min()) + helper_test_op([(45,3)], lambda x: x.min().mul(0.5)) + helper_test_op([()], lambda x: x.min()) def test_max(self): - helper_test_op([(45,3)], lambda x: x.max(), Tensor.max) - helper_test_op([(45,3)], lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5)) - helper_test_op(None, lambda x: x.max().mul(0.5), lambda x: Tensor.max(x).mul(0.5), - vals=[ - [[1.0,1.0,0.0,1.0]], - ]) - helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: Tensor.max(x, axis=1)) - helper_test_op([()], lambda x: x.max(), Tensor.max) + helper_test_op([(45,3)], lambda x: x.max()) + helper_test_op([(45,3)], lambda x: x.max().mul(0.5)) + helper_test_op(None, lambda x: x.max().mul(0.5), vals=[[[1.0,1.0,0.0,1.0]],]) + helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1)) + helper_test_op([()], lambda x: x.max()) + def test_mean(self): helper_test_op([(3,4,5,6)], lambda x: x.mean()) helper_test_op([()], lambda x: x.mean()) def test_mean_axis(self): - helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2)), lambda x: Tensor.mean(x, axis=(1,2))) + helper_test_op([(3,4,5,6)], lambda x: x.mean(axis=(1,2))) + def test_mean_zero_axis(self): + helper_test_op([(1,0,3,0,5)], lambda x: x.mean(axis=(1,3))) + + def test_var(self): + helper_test_op([(15, 25, 35)], lambda x: x.var()) + helper_test_op([(15, 25, 35)], lambda x: x.var(correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.var(correction=5)) + # TODO: fix this + # helper_test_op([(10, 2)], lambda x: x.var(correction=50)) + def test_var_axis(self): + helper_test_op([(15, 25, 35)], lambda x: x.var(0)) + helper_test_op([(15, 25, 35)], lambda x: x.var(2)) + helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2])) + helper_test_op([(15, 25, 35)], lambda x: x.var(0, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.var(2, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.var([1, 2], correction=0)) + def test_var_zero_axis(self): + helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3))) + helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=0)) + helper_test_op([(1,0,3,0,5)], lambda x: x.var(axis=(1,3), correction=5)) + def test_var_keepdim(self): + helper_test_op([(15, 25, 35)], lambda x: x.var(keepdim=True)) + helper_test_op([(15, 25, 35)], lambda x: x.var(0, keepdim=True, correction=0)) + def test_std(self): - helper_test_op([(45, 65, 85)], lambda x: torch.std(x), lambda x: Tensor.std(x)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, correction=0), lambda x: Tensor.std(x, correction=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, correction=5), lambda x: Tensor.std(x, correction=5)) + helper_test_op([(15, 25, 35)], lambda x: x.std()) + helper_test_op([(15, 25, 35)], lambda x: x.std(correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.std(correction=5)) def test_std_axis(self): - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0), lambda x: Tensor.std(x, axis=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=2), lambda x: Tensor.std(x, axis=2)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2])) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None), lambda x: Tensor.std(x, axis=None)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=0), lambda x: Tensor.std(x, axis=0, correction=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=2), lambda x: Tensor.std(x, axis=2, correction=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=[1, 2]), lambda x: Tensor.std(x, axis=[1, 2], correction=0)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, correction=0, dim=None), lambda x: Tensor.std(x, axis=None, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.std(0)) + helper_test_op([(15, 25, 35)], lambda x: x.std(2)) + helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2])) + helper_test_op([(15, 25, 35)], lambda x: x.std(0, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.std(2, correction=0)) + helper_test_op([(15, 25, 35)], lambda x: x.std([1, 2], correction=0)) + def test_std_zero_axis(self): + helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3))) + helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=0)) + helper_test_op([(1,0,3,0,5)], lambda x: x.std(axis=(1,3), correction=5)) def test_std_keepdim(self): - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=None, keepdim=True), lambda x: Tensor.std(x, keepdim=True)) - helper_test_op([(45, 65, 85)], lambda x: torch.std(x, dim=0, keepdim=True, correction=0), lambda x: Tensor.std(x, keepdim=True, correction=0, axis=0)) + helper_test_op([(15, 25, 35)], lambda x: x.std(keepdim=True)) + helper_test_op([(15, 25, 35)], lambda x: x.std(0, keepdim=True, correction=0)) + + def test_softmax(self): + # exceed per kernel buffer limit with backward + forward_only = (Device.DEFAULT == "WEBGPU") + helper_test_op([(45,65)], torch.nn.Softmax(dim=1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) + helper_test_op([(45)], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) + helper_test_op([()], torch.nn.Softmax(dim=0), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) + helper_test_op([()], torch.nn.Softmax(dim=-1), Tensor.softmax, atol=1e-7, grad_atol=1e-7, forward_only=forward_only) + def test_softmax_other_axis(self): + helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7) + helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7) + helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=1e-7) + @unittest.skipIf(CI and Device.DEFAULT in ["CLANG", "PYTHON"], "Broken ISSUE #3552") + def test_softmax_argmax(self): + helper_test_op([(45,65)], lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7) + helper_test_op([(45,65)], lambda x: x.softmax(1).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7) def test_log_softmax(self): - helper_test_op([(45,65)], lambda x: torch.nn.LogSoftmax(dim=1)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) - helper_test_op([()], lambda x: torch.nn.LogSoftmax(dim=0)(x), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([(45,65)], torch.nn.LogSoftmax(dim=1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([(45)], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([()], torch.nn.LogSoftmax(dim=0), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) + helper_test_op([()], torch.nn.LogSoftmax(dim=-1), Tensor.log_softmax, atol=1e-7, grad_atol=1e-7) def test_log_softmax_other_axis(self): - helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7) - helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7) - helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7) + helper_test_op([(10,10,10)], lambda x: x.log_softmax(0), atol=1e-7, grad_atol=1e-7) + helper_test_op([(10,10,10)], lambda x: x.log_softmax(1), atol=1e-7, grad_atol=1e-7) + helper_test_op([(10,10,10)], lambda x: x.log_softmax(2), atol=1e-7, grad_atol=1e-7) + + def test_sinh(self): + helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6) + # TODO: backward nan instead of inf + helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6, low=-300, high=-297, forward_only=True) + helper_test_op([(45,65)], lambda x: x.sinh(), grad_atol=1e-6, low=300, high=303, forward_only=True) + def test_cosh(self): + helper_test_op([(45,65)], lambda x: x.cosh(), grad_atol=1e-6) + # TODO: backward nan instead of inf + helper_test_op([(45,65)], lambda x: x.cosh(), grad_atol=1e-6, low=-300, high=-297, forward_only=True) + helper_test_op([(45,65)], lambda x: x.cosh(), grad_atol=1e-6, low=300, high=303, forward_only=True) def test_tanh(self): - helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) - helper_test_op([(45,65)], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6, a=-100) - helper_test_op([()], lambda x: x.tanh(), Tensor.tanh, atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: x.tanh(), grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: x.tanh(), grad_atol=1e-6, low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.tanh(), grad_atol=1e-6, low=300, high=303) def test_hardtanh(self): for val in range(10, 30, 5): - helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) - helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x,-val, val), lambda x: x.hardtanh(-val, val), atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), grad_atol=1e-6) + helper_test_op([()], lambda x: torch.nn.functional.hardtanh(x, -val, val), lambda x: x.hardtanh(-val, val), grad_atol=1e-6) + def test_asinh(self): + helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6) + # NOTE: this one has larger atol + helper_test_op([(45,65)], lambda x: x.asinh(), atol=1e-2, grad_atol=1e-6, low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.asinh(), grad_atol=1e-6, low=300, high=303) + def test_acosh(self): + helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6, low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.acosh(), grad_atol=1e-6, low=300, high=303) + def test_atanh(self): + helper_test_op([(45,65)], lambda x: x.atanh(), grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: x.atanh(), grad_atol=1e-6, low=-300, high=-297) + helper_test_op([(45,65)], lambda x: x.atanh(), grad_atol=1e-6, low=300, high=303) + def test_topo_sort(self): - helper_test_op([(45,65)], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6) - helper_test_op([()], lambda x: (x+x)*x, lambda x: x.add(x).mul(x), atol=1e-6, grad_atol=1e-6) + helper_test_op([(45,65)], lambda x: (x+x)*x, grad_atol=1e-6) + helper_test_op([()], lambda x: (x+x)*x, grad_atol=1e-6) - def test_scalar_mul(self): - helper_test_op([(45,65)], lambda x: x*2, lambda x: x*2) - helper_test_op([()], lambda x: x*2, lambda x: x*2) - def test_scalar_rmul(self): - helper_test_op([(45,65)], lambda x: 2*x, lambda x: 2*x) - helper_test_op([()], lambda x: 2*x, lambda x: 2*x) - def test_scalar_sub(self): - helper_test_op([(45,65)], lambda x: x-2, lambda x: x-2) - helper_test_op([()], lambda x: x-2, lambda x: x-2) - def test_scalar_rsub(self): - helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x) - helper_test_op([()], lambda x: 2-x, lambda x: 2-x) def test_flip_eye_crash(self): helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)), lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True) - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "this test uses more than 8 bufs passing the WEBGPU limit") #TODO: remove after #1461 def test_broadcast_full(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), - (torch.div, Tensor.div)]: #, (torch.pow, Tensor.pow)]: + (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: for shapes in [((5,13,24,16), (5,1,24,1)), ((1,3,1,7,1), (2,1,5,1,8))]: with self.subTest(op=torch_op.__name__, shapes=shapes): - helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0) + if tinygrad_op != Tensor.pow: + helper_test_op(shapes, torch_op, tinygrad_op) + else: + helper_test_op(shapes, torch_op, tinygrad_op, low=0, high=3) def test_broadcast_simple(self): - helper_test_op([(45,65), (45,1)], lambda x,y: x/y, lambda x,y: x/y) - helper_test_op([(45,65), ()], lambda x,y: x/y, lambda x,y: x/y) + helper_test_op([(45,65), (45,1)], lambda x,y: x/y) + helper_test_op([(45,65), ()], lambda x,y: x/y) - @unittest.skipIf(Device.DEFAULT == "WEBGPU", "this test uses more than 8 bufs passing the WEBGPU limit") #TODO: remove after #1461 def test_broadcast_partial(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), - (torch.div, Tensor.div)]: #, (torch.pow, Tensor.pow)]: + (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: for shapes in [((1,32,32,32), (1,32,1,1)), ((5,13,24,16,2), (1,13,24,1,1)), ((4,1), (4,5)), ((1,4), (5,4))]: with self.subTest(op=torch_op.__name__, shapes=shapes): # NOTE: ANE backwards? - helper_test_op(shapes, torch_op, tinygrad_op, a=-0.5 if tinygrad_op != Tensor.pow else 0.0) + if tinygrad_op != Tensor.pow: + helper_test_op(shapes, torch_op, tinygrad_op) + else: + helper_test_op(shapes, torch_op, tinygrad_op, low=0, high=3) def test_slice_in_bounds_1dim(self): - helper_test_op([(3)], lambda x: x[1:3], lambda x: x[1:3]) - helper_test_op([(3)], lambda x: x[0:2], lambda x: x[0:2]) - helper_test_op([(3)], lambda x: x[-2:2], lambda x: x[-2:2]) + helper_test_op([(3)], lambda x: x[1:3]) + helper_test_op([(3)], lambda x: x[0:2]) + helper_test_op([(3)], lambda x: x[-2:2]) def test_slice_on_0dim_tensor(self): - helper_test_op([()], lambda x: x[None], lambda x: x[None]) + helper_test_op([()], lambda x: x[None]) with self.assertRaises(IndexError): a = Tensor(3.14) a[0] def test_slice_int_indexing(self): - helper_test_op([(3)], lambda x: x[1], lambda x: x[1]) - helper_test_op([(3)], lambda x: x[-2], lambda x: x[-2]) - helper_test_op([(10,10)], lambda x: x[1], lambda x: x[1]) - helper_test_op([(3,3,3)], lambda x: x[1,1,1], lambda x: x[1,1,1]) + helper_test_op([(3)], lambda x: x[0]) + helper_test_op([(3)], lambda x: x[2]) + helper_test_op([(3)], lambda x: x[-1]) + helper_test_op([(3)], lambda x: x[-3]) + helper_test_op([(10,10)], lambda x: x[1]) + helper_test_op([(3,3,3)], lambda x: x[1,1,1]) def test_slice_in_bounds_multidim(self): - helper_test_op([(3,3,3)], lambda x: x[1:2], lambda x: x[1:2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 2], lambda x: x[1:2, 2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2], lambda x: x[1:2, 1:2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1], lambda x: x[1:2, 1:2, 0:-1]) + helper_test_op([(3,3,3)], lambda x: x[1:2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, 0:-1]) def test_slice_with_none(self): - helper_test_op([(3,3,3)], lambda x: x[None], lambda x: x[None]) - helper_test_op([(3,3,3)], lambda x: x[1:2, None], lambda x: x[1:2, None]) - helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2], lambda x: x[1:2, None, 1:2]) - helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1], lambda x: x[1:2, 1:2, None, -1]) + helper_test_op([(3,3,3)], lambda x: x[None]) + helper_test_op([(3,3,3)], lambda x: x[1:2, None]) + helper_test_op([(3,3,3)], lambda x: x[1:2, None, 1:2]) + helper_test_op([(3,3,3)], lambda x: x[1:2, 1:2, None, -1]) + helper_test_op([(3,3,3)], lambda x: x[None, None, 1, None, 2, 0:2]) def test_slice_one_endpoint_out_of_bounds(self): - helper_test_op([(3,3,3)], lambda x: x[0:4], lambda x: x[0:4]) - helper_test_op([(3,3,3)], lambda x: x[-6:4], lambda x: x[-6:4]) - helper_test_op([(3,3,3)], lambda x: x[1:50], lambda x: x[1:50]) - helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1], lambda x: x[1:50, 1:2, -1]) + helper_test_op([(3,3,3)], lambda x: x[0:4]) + helper_test_op([(3,3,3)], lambda x: x[-6:4]) + helper_test_op([(3,3,3)], lambda x: x[1:50]) + helper_test_op([(3,3,3)], lambda x: x[1:50, 1:2, -1]) def test_slice_stride_gt_one(self): - helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4], lambda x: x[::2, ::3, ::4]) - helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4], lambda x: x[1:5:2, ::3, ::4]) - helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4], lambda x: x[1:5:2, 3, ::4]) - helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4], lambda x: x[1:5:2, None, None, 3, None, ::4]) + helper_test_op([(7,5,10)], lambda x: x[::2, ::3, ::4]) + helper_test_op([(7,5,10)], lambda x: x[1:5:2, ::3, ::4]) + helper_test_op([(7,5,10)], lambda x: x[1:5:2, 3, ::4]) + helper_test_op([(7,5,10)], lambda x: x[1:5:2, None, None, 3, None, ::4]) def test_slice_negative_strides(self): # Torch doesn't support slicing with negative steps @@ -637,41 +890,41 @@ def test_slice_negative_strides(self): np.testing.assert_allclose(a[:, 2:0:-1], t[:, 2:0:-1].numpy()) np.testing.assert_allclose(a[:, 2:0:-1, 3:1:-2], t[:, 2:0:-1, 3:1:-2].numpy()) np.testing.assert_allclose(a[4:0:-3, 2:0:-1, -1:-5:-2], t[4:0:-3, 2:0:-1, -1:-5:-2].numpy()) - if Device.DEFAULT != "CPU": - # broken - np.testing.assert_allclose(a[2:5:-1, :, :], t[2:5:-1, :, :].numpy()) # shape = (0, 10, 10) - np.testing.assert_allclose(a[:, 2:5:-1, :], t[:, 2:5:-1, :].numpy()) # shape = (0, 10, 10) - np.testing.assert_allclose(a[:, :, 2:5:-1], t[:, :, 2:5:-1].numpy()) # shape = (0, 10, 10) + np.testing.assert_allclose(a[2:5:-1, :, :], t[2:5:-1, :, :].numpy()) # shape = (0, 10, 10) + np.testing.assert_allclose(a[:, 2:5:-1, :], t[:, 2:5:-1, :].numpy()) # shape = (0, 10, 10) + np.testing.assert_allclose(a[:, :, 2:5:-1], t[:, :, 2:5:-1].numpy()) # shape = (0, 10, 10) def test_slice_both_endpoints_out_of_bounds(self): - helper_test_op([(3,3,3)], lambda x: x[5:10], lambda x: x[5:10], forward_only=True) - helper_test_op([(3,3,3)], lambda x: x[-15:-7], lambda x: x[-15:-7], forward_only=True) + helper_test_op([(3,3,3)], lambda x: x[5:10]) + helper_test_op([(3,3,3)], lambda x: x[-15:-7]) def test_slice_start_gt_end(self): - helper_test_op([(3,3,3)], lambda x: x[-2:2], lambda x: x[-2:2], forward_only=True) - helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) + helper_test_op([(3,3,3)], lambda x: x[-2:2]) + helper_test_op([(3,3,3)], lambda x: x[-2:-5]) def test_slice_empty(self): - helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True) + helper_test_op([(10,10)], lambda x: x[1:1]) def test_slice_zero_in_shape(self): - helper_test_op([(10,10)], lambda x: x[1:1], lambda x: x[1:1], forward_only=True) # x.shape = (0, 10) - helper_test_op([(3,3,3)], lambda x: x[-2:-5], lambda x: x[-2:-5], forward_only=True) # x.shape = (0, 3, 3) + helper_test_op([(10,10)], lambda x: x[1:1]) # x.shape = (0, 10) + helper_test_op([(3,3,3)], lambda x: x[-2:-5]) # x.shape = (0, 3, 3) def test_slice_errors(self): a = Tensor.ones(4, 3) - with self.assertRaises(IndexError): - a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds) - a[1, 77] # IndexError: (out of bounds). - a[0, -77] - a[..., ...] # IndexError: only single ellipsis + b = Tensor(2) + with self.assertRaises(IndexError): a[1, 77, 77, 77] # IndexError: (finds too many indices before the out of bounds) + with self.assertRaises(IndexError): a[1, 3] # IndexError: (out of bounds). + with self.assertRaises(IndexError): a[1, -4] + with self.assertRaises(IndexError): a[..., ...] # IndexError: only single ellipsis + with self.assertRaises(ValueError): a[::0, 1] # no 0 strides + with self.assertRaises(IndexError): b[:] # slice cannot be applied to a 0-dim tensor def test_slice_ellipsis(self): - helper_test_op([(3,3,3,3)], lambda x: x[..., 0], lambda x: x[..., 0]) - helper_test_op([(3,3,3,3)], lambda x: x[0, ...], lambda x: x[0, ...]) - helper_test_op([(3,3,3,3)], lambda x: x[0, ..., 0], lambda x: x[0, ..., 0]) - helper_test_op([(3,3,3,3)], lambda x: x[0:3, ..., 2:3], lambda x: x[0:3, ..., 2:3]) - helper_test_op([(3,3,3,3)], lambda x: x[None, 0:3, ..., 0, None], lambda x: x[None, 0:3, ..., 0, None]) + helper_test_op([(3,3,3,3)], lambda x: x[..., 0]) + helper_test_op([(3,3,3,3)], lambda x: x[0, ...]) + helper_test_op([(3,3,3,3)], lambda x: x[0, ..., 0]) + helper_test_op([(3,3,3,3)], lambda x: x[0:3, ..., 2:3]) + helper_test_op([(3,3,3,3)], lambda x: x[None, 0:3, ..., 0, None]) def test_pad2d(self): helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4))) @@ -682,76 +935,118 @@ def test_pad2d(self): def test_pad(self): helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2)))) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5)) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("inf")), lambda x: x.pad(((3,4), (1,2)), value=float("inf"))) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=float("-inf")), lambda x: x.pad(((3,4), (1,2)), value=float("-inf"))) + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=math.inf), lambda x: x.pad(((3,4), (1,2)), value=math.inf)) + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=-math.inf), lambda x: x.pad(((3,4), (1,2)), value=-math.inf)) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,3,4), value=1), lambda x: x.pad(((3,4), None), value=1)) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0), value=1), lambda x: x.pad((None, None), value=1)) + @unittest.skipIf(Device.DEFAULT == "WEBGL", "incorrect result") + def test_pad_slice(self): + for value in 0., 3.456: + helper_test_op([(1)], lambda x: torch.nn.functional.pad(x,(1,0), value=value)[0], lambda x: x.pad(((1,0),), value=value)[0]) + helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(1,0), value=value)[0], lambda x: x.pad(((1,0),), value=value)[0]) + helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(3,0), value=value)[0:1], lambda x: x.pad(((3,0),), value=value)[0:1]) + helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(0,3), value=value)[6], lambda x: x.pad(((0,3),), value=value)[6]) + helper_test_op([(4)], lambda x: torch.nn.functional.pad(x,(0,3), value=value)[4:6], lambda x: x.pad(((0,3),), value=value)[4:6]) + helper_test_op([(5,5)], lambda x: torch.nn.functional.pad(x,(0,0,1,0), value=value)[0], lambda x: x.pad(((1,0),(0,0)), value=value)[0]) + helper_test_op([(2,2)], lambda x: torch.nn.functional.pad(x,(0,1,0,0), value=value)[0,2], lambda x: x.pad(((0,0),(0,1)), value=value)[0,2]) + helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(0,0,1,0), value=value)[0,2], lambda x: x.pad(((1,0),(0,0)), value=value)[0,2]) + helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(0,0,0,2), value=value)[5], lambda x: x.pad(((0,2),(0,0)), value=value)[5]) + helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(0,0,0,2), value=value)[3:5], lambda x: x.pad(((0,2),(0,0)), value=value)[3:5]) + helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,0,0,0), value=value)[1,0], lambda x: x.pad(((0,0),(3,0)), value=value)[1,0]) + helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,0,0,0), value=value)[1,0:4], lambda x: x.pad(((0,0),(3,0)), value=value)[1,0:4]) + helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,4,1,2), value=value)[0], lambda x: x.pad(((1,2),(3,4)), value=value)[0]) + helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,4,1,2), value=value)[:,1], lambda x: x.pad(((1,2),(3,4)), value=value)[:,1]) + helper_test_op([(4,4)], lambda x: torch.nn.functional.pad(x,(3,4,1,2), value=value)[:,4], lambda x: x.pad(((1,2),(3,4)), value=value)[:,4]) + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x,(0,3,0,0), value=value)[:,4:6], lambda x: x.pad(((0,0),(0,3)), value=value)[:,4:6]) + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x,(0,1,3,2), value=value)[0:2,:], lambda x: x.pad(((3,2),(0,1)), value=value)[0:2,:]) + helper_test_op([(3,3,3)], lambda x: torch.nn.functional.pad(x,(1,1,0,1,3,2), value=value)[0:2,:,:], + lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[0:2,:,:]) + helper_test_op([(3,3,3)], lambda x: torch.nn.functional.pad(x,(1,1,0,1,3,2), value=value)[2:4,:,:], + lambda x: x.pad(((3,2),(0,1),(1,1)), value=value)[2:4,:,:]) + + def test_stack_slice(self): + helper_test_op([(4)], lambda x: torch.stack([x for i in range(3)])[0,:], lambda x: Tensor.stack([x for i in range(3)])[0,:]) + helper_test_op([(5)], lambda x: torch.stack([x for i in range(3)])[0,0], lambda x: Tensor.stack([x for i in range(3)])[0,0]) + helper_test_op([(4,4)], lambda x: torch.stack([x for i in range(4)])[3], lambda x: Tensor.stack([x for i in range(4)])[3]) + def test_transpose(self): - helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(1,2)) - helper_test_op([(3,3,3)], lambda x: x.transpose(0,2), lambda x: x.transpose(0,2)) - helper_test_op([(1,2,3,4)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.permute(order=(3,0,2,1))) - helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.permute(order=(3,2,1,0))) - helper_test_op([()], lambda x: x.permute(()), lambda x: x.permute(())) + helper_test_op([(3,3)], lambda x: x.T) + helper_test_op([(3,3,3)], lambda x: x.transpose(1,2)) + helper_test_op([(3,3,3)], lambda x: x.transpose(0,2)) + helper_test_op([(1,2,3,4)], lambda x: x.permute((3,0,2,1))) + helper_test_op([(3,4,5,6)], lambda x: x.permute((3,2,1,0))) + helper_test_op([()], lambda x: x.permute(())) def test_reshape(self): - helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6))) - helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,1,6,6)), lambda x: x.reshape(shape=(-1,1,6,6))) - helper_test_op([()], lambda x: torch.reshape(x, []), lambda x: x.reshape([])) - helper_test_op([(1,)], lambda x: torch.reshape(x, []), lambda x: x.reshape([])) - helper_test_op([()], lambda x: torch.reshape(x, [1]), lambda x: x.reshape([1])) + helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,3,6,6))) + helper_test_op([(4,3,6,6)], lambda x: x.reshape((-1,1,6,6))) + helper_test_op([()], lambda x: x.reshape([])) + helper_test_op([(1,)], lambda x: x.reshape([])) + helper_test_op([()], lambda x: x.reshape([1])) + helper_test_op([()], lambda x: x.reshape([1, 1, 1])) with self.assertRaises(ValueError): x = Tensor.ones((4,3,6,6)) x.reshape([]) def test_flip(self): - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,)), lambda x: x.flip(axis=(0,))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1)), lambda x: x.flip(axis=(0,1))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)), lambda x: x.flip(axis=(0,1,3))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(3,))) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (0,1,3)).flip((0,)), lambda x: x.flip(axis=(0,1,3)).flip(0)) - helper_test_op([(4,3,6,6)], lambda x: torch.flip(x, (3,)), lambda x: x.flip(axis=(-1,))) - helper_test_op([()], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) - helper_test_op([(1,)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) - helper_test_op([(4, 3, 6, 6)], lambda x: torch.flip(x, ()), lambda x: x.flip(axis=())) + helper_test_op([(4,3,6,6)], lambda x: x.flip((0,))) + helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1))) + helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1,3))) + helper_test_op([(4,3,6,6)], lambda x: x.flip((3,))) + helper_test_op([(4,3,6,6)], lambda x: x.flip((0,1,3)).flip(0)) + helper_test_op([(4,3,6,6)], lambda x: x.flip((-1,))) + helper_test_op([()], lambda x: x.flip(())) + helper_test_op([(1,)], lambda x: x.flip(())) + helper_test_op([(4, 3, 6, 6)], lambda x: x.flip(())) def test_squeeze(self): - helper_test_op([(1,3,6,6)], lambda x: torch.squeeze(x, 0), lambda x: x.squeeze(dim=0)) - helper_test_op([(4,3,1,6)], lambda x: torch.squeeze(x, 1), lambda x: x.squeeze(dim=1)) - helper_test_op([(4,3,6,6)], lambda x: torch.squeeze(x, 3), lambda x: x.squeeze(dim=3)) - self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, 50), lambda x: x.squeeze(dim=50), expected=IndexError, exact=True) - self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, -50), lambda x: x.squeeze(dim=-50), expected=IndexError, exact=True) - helper_test_op([(4,3,6,1)], lambda x: torch.squeeze(x, -1), lambda x: x.squeeze(dim=-1)) - helper_test_op([(4,3,6,6)], lambda x: torch.squeeze(x), lambda x: x.squeeze()) - helper_test_op([(1,3,6,6)], lambda x: torch.squeeze(x), lambda x: x.squeeze()) - helper_test_op([(2,3,1)], lambda x: torch.squeeze(x), lambda x: x.squeeze()) - helper_test_op([()], lambda x: torch.squeeze(x, -1), lambda x: x.squeeze(dim=-1)) - helper_test_op([()], lambda x: torch.squeeze(x, 0), lambda x: x.squeeze(dim=0)) - self.helper_test_exception([()], lambda x: torch.squeeze(x, 10), lambda x: x.squeeze(dim=10), expected=IndexError, exact=True) - helper_test_op([()], lambda x: torch.squeeze(x), lambda x: x.squeeze()) + helper_test_op([(1,3,6,6)], lambda x: x.squeeze(0)) + helper_test_op([(4,3,1,6)], lambda x: x.squeeze(1)) + helper_test_op([(4,3,6,6)], lambda x: x.squeeze(3)) + self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, 50), lambda x: x.squeeze(dim=50), expected=IndexError) + self.helper_test_exception([(4,3,6,6)], lambda x: torch.squeeze(x, -50), lambda x: x.squeeze(dim=-50), expected=IndexError) + helper_test_op([(4,3,6,1)], lambda x: x.squeeze(-1)) + helper_test_op([(4,3,6,6)], lambda x: x.squeeze()) + helper_test_op([(1,3,6,6)], lambda x: x.squeeze()) + helper_test_op([(2,3,1)], lambda x: x.squeeze()) + helper_test_op([()], lambda x: x.squeeze(-1)) + helper_test_op([()], lambda x: x.squeeze(0)) + helper_test_op([()], lambda x: x.squeeze()) + self.helper_test_exception([()], lambda x: torch.squeeze(x, 10), lambda x: x.squeeze(dim=10), expected=IndexError) + self.helper_test_exception([()], lambda x: torch.squeeze(x, 1), lambda x: x.squeeze(dim=1), expected=IndexError) + self.helper_test_exception([()], lambda x: torch.squeeze(x, -2), lambda x: x.squeeze(dim=-2), expected=IndexError) def test_unsqueeze(self): - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0)) - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, 4), lambda x: x.unsqueeze(dim=4)) - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -1), lambda x: x.unsqueeze(dim=-1)) - helper_test_op([(4,3,6,6)], lambda x: torch.unsqueeze(x, -3), lambda x: x.unsqueeze(dim=-3)) - helper_test_op([()], lambda x: torch.unsqueeze(x, 0), lambda x: x.unsqueeze(dim=0)) + helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(0)) + helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(4)) + helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(-1)) + helper_test_op([(4,3,6,6)], lambda x: x.unsqueeze(-3)) + helper_test_op([()], lambda x: x.unsqueeze(0)) def test_flatten(self): for axis in range(3): - helper_test_op([(4,3,6,6)], lambda x: torch.flatten(x, start_dim=axis), lambda x: x.flatten(axis)) - helper_test_op([()], lambda x: x.flatten(), lambda x: x.flatten()) - helper_test_op([(1,)], lambda x: x.flatten(), lambda x: x.flatten()) + helper_test_op([(4,3,6,6)], lambda x: x.flatten(start_dim=axis)) + for axis in range(3): + helper_test_op([(4,3,6,6)], lambda x: x.flatten(end_dim=axis)) + helper_test_op([(4,3,6,6)], lambda x: x.flatten(start_dim=1, end_dim=3)) + helper_test_op([()], lambda x: x.flatten()) + helper_test_op([(1,)], lambda x: x.flatten()) + + def test_unflatten(self): + helper_test_op([(4,3,6,6)], lambda x: x.unflatten(0, (2, 2))) + helper_test_op([(4,3,6,6)], lambda x: x.unflatten(3, (3, 2))) + helper_test_op([(4,3,6,6)], lambda x: x.unflatten(-1, (3, 2, 1))) def test_detach(self): - helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True) - helper_test_op([()], lambda x: x.detach(), lambda x: x.detach(), forward_only=True) + helper_test_op([(4,3,6,6)], lambda x: x.detach(), forward_only=True) + helper_test_op([()], lambda x: x.detach(), forward_only=True) def test_expand(self): - arg = (4,3,2,6) - helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg)) - helper_test_op([()], lambda x: x.expand([]), lambda x: x.expand(shape=[])) + helper_test_op([(4,3,1,6)], lambda x: x.expand((4,3,2,6))) + helper_test_op([(1,1,1,1)], lambda x: x.expand((4,3,2,6))) + helper_test_op([()], lambda x: x.expand([])) @unittest.skip("very slow") def test_sd_big_conv(self): @@ -794,7 +1089,7 @@ def test_simple_conv3d(self): @unittest.skipIf(IMAGE>0, "no conv3d on images") def test_padded_conv3d(self): - helper_test_op([(1,4,9,9,9), (4,4,3,3,3)], + helper_test_op([(1,4,5,5,5), (4,4,3,3,3)], lambda x,w: torch.nn.functional.conv3d(x,w,padding=1).relu(), lambda x,w: Tensor.conv2d(x,w,padding=[1,1,1,1,1,1]).relu(), atol=1e-4, grad_rtol=1e-5) @@ -865,7 +1160,6 @@ def test_strided_conv_transpose2d(self): lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride).relu(), lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5) - @unittest.skipIf(Device.DEFAULT == "METAL" and CI, "broken in METAL CI") def test_output_padded_conv_transpose2d(self): for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]: helper_test_op([(2,4,6,5), (4,4,3,3),(4,)], @@ -910,7 +1204,7 @@ def test_strided_conv1d_simple(self): @unittest.skipIf(IMAGE>0, "no conv1d on images") def test_asymmetric_padding_conv1d(self): for p in [(0,1), (2,1), (2,0)]: - with self.subTest(padding := p): + with self.subTest(p): for n in [3,4]: for k in [2]: helper_test_op([(1,1,n), (1,1,k)], @@ -1026,7 +1320,7 @@ def test_simple_padding_conv2d(self): def test_asymmetric_padding_conv2d(self): for p in [(0,1,0,1), (2,1,2,1), (2,0,2,1)]: - with self.subTest(padding := p): + with self.subTest(p): for n in [3,4]: for k in [2]: helper_test_op([(1,1,n,n), (1,1,k,k)], @@ -1036,14 +1330,12 @@ def test_asymmetric_padding_conv2d(self): lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(), lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4) - @unittest.skipIf(Device.DEFAULT == "METAL" and CI, "broken in METAL CI") def test_padded_conv2d_p21(self): bs,cin,H,W,padding = 4, 3, 3, 3, (2,1) helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) - @unittest.skipIf(Device.DEFAULT == "METAL" and CI, "broken in METAL CI") def test_padded_conv2d_p22(self): bs,cin,H,W,padding = 4, 3, 3, 3, (2,2) helper_test_op([(bs,cin,11,28), (4,cin,H,W)], @@ -1099,20 +1391,20 @@ def test_maxpool2d_bigger_stride(self): @unittest.skipIf(Device.DEFAULT == "CUDA", "CUDA fails on this") def test_maxpool2d_unit_stride(self): - helper_test_op([(32,2,110,28)], + helper_test_op([(8, 2, 17, 14)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=1)) def test_maxpool2d_smaller_stride(self): for stride in [(2,3), (3,2), 2, 3]: with self.subTest(stride=stride): - helper_test_op([(32,2,110,28)], + helper_test_op([(8, 2, 17, 14)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride)) def test_maxpool2d_dilation(self): for dilation in [(2, 3), (3, 2), 2, 3]: - helper_test_op([(32,2,110,28)], + helper_test_op([(8, 2, 17, 14)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation)) @@ -1133,7 +1425,14 @@ def test_cat(self): for dim in range(-2, 3): helper_test_op([(45,65,9), (45,65,9), (45,65,9)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) - with self.assertRaises(AssertionError): + # zero in non-cat axis + helper_test_op([(45,0,9), (45,0,9), (45,0,9)], lambda x,y,z: torch.cat((x,y,z), 0), lambda x,y,z: x.cat(y, z, dim=0)) + + # zero in cat axis + helper_test_op([(45,0,9), (45,1,9), (45,2,9)], lambda x,y,z: torch.cat((x,y,z), 1), lambda x,y,z: x.cat(y, z, dim=1)) + helper_test_op([(45,0,9), (45,0,9), (45,0,9)], lambda x,y,z: torch.cat((x,y,z), 1), lambda x,y,z: x.cat(y, z, dim=1)) + + with self.assertRaises(IndexError): a = Tensor(3.14) a.cat(a) @@ -1142,13 +1441,11 @@ def test_multicat(self): helper_test_op([(45,65), (45,65), (45,65)], lambda x,y,z: torch.cat((x,y,z), dim), lambda x,y,z: x.cat(y, z, dim=dim)) def test_stack(self): - x = Tensor.randn(45, 65, 3) - for dim in range(-1, 3): - helper_test_op([(45, 65, 3), (45, 65, 3), (45, 65, 3)], lambda x, y, z: torch.stack((x, y, z), dim=dim), lambda x, y, z: Tensor.stack([x, y, z], dim=dim)) + helper_test_op([(45,65,3), (45,65,3), (45,65,3)], lambda x, y, z: torch.stack((x, y, z), dim), lambda x, y, z: Tensor.stack([x, y, z], dim)) with self.assertRaises(IndexError): - Tensor.stack([x], dim=77) + Tensor.stack([Tensor.randn(45, 65, 3)], dim=77) a = Tensor(3.14) np.testing.assert_allclose(Tensor.stack([a, a]).numpy(), Tensor([3.14, 3.14]).numpy()) @@ -1167,8 +1464,17 @@ def test_repeat(self): np.testing.assert_allclose(x.repeat((2, 0, 4)).numpy(), Tensor.zeros(8, 0, 12).numpy()) + def test_simple_repeat(self): + repeats = [3, 3, 4] + helper_test_op([(3, 3)], lambda x: x.repeat(*repeats), lambda x: x.repeat(repeats)) + def test_clip(self): - helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2), lambda x: x.clip(-2.3, 1.2)) + helper_test_op([(45,65)], lambda x: x.clip(-2.3, 1.2)) + helper_test_op([(45,65)], lambda x: x.clip(0, 0)) + helper_test_op([(45,65)], lambda x: x.clip(10, 100)) + helper_test_op([(45,65)], lambda x: x.clip(0, 0.1)) + helper_test_op([(45,65)], lambda x: x.clip(-0.3, -0.2)) + helper_test_op([(45,65)], lambda x: x.clip(3, 0)) def test_matvecmat(self): helper_test_op([(1,128), (128,128), (128,128)], lambda x,y,z: (x@y).relu()@z, atol=1e-4) @@ -1195,13 +1501,12 @@ def test_inf_where(self): def _get_index_randoms(self): # indices cannot have gradient - # TODO currently does not support IndexError for out of bounds idx values a = torch.randint(low=-1, high=1, size=(2,1,1,1,1,1), dtype=torch.int64, requires_grad=False) b = torch.randint(high=1, size=(1,3,1,1,1,1), dtype=torch.int64, requires_grad=False) c = torch.randint(low=-5, high=5, size=(1,1,4,1,1,1), dtype=torch.int64, requires_grad=False) d = torch.randint(high=4, size=(2,1,1,5,1,1), dtype=torch.int64, requires_grad=False) e = torch.randint(high=1, size=(1,1,1,1,6,1), dtype=torch.int64, requires_grad=False) - i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), dtype=dtypes.int32, requires_grad=False) for tor in [a,b,c,d,e]] + i, j, k, o, p = [Tensor(tor.detach().numpy().astype(np.int32), requires_grad=False) for tor in [a,b,c,d,e]] return a,b,c,d,e,i,j,k,o,p def test_slice_fancy_indexing_no_dim_collapse(self): @@ -1232,16 +1537,58 @@ def test_slice_fancy_indexing_dim_inject_none(self): helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,:,None,d,e], lambda x: x[i,:,None,o,p]) def test_slice_fancy_indexing_dim_inject_and_collapse(self): - a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() + a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() # noqa # dim injection and collapse helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,b,None,d,1], lambda x: x[1,j,None,o,1]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[None,b,2,d,None], lambda x: x[None,j,2,o,None]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[...,1,d,None], lambda x: x[...,1,o,None]) - def test_slice_fancy_indexing_with_idx(self): + def test_slice_fancy_indexing_with_tensors(self): # indexing using idx with different dim - helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor(1)], lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor(1)]) - helper_test_op([(2,3)], lambda x: x[torch.tensor([1]), torch.tensor([[0,0,0],[0,0,0]])], lambda x: x[Tensor([1]), Tensor([[0,0,0],[0,0,0]])]) + helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor(1)], + lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor(1)]) + helper_test_op([(2,3)], lambda x: x[torch.tensor([1]), torch.tensor([[0,0,0],[0,0,0]])], + lambda x: x[Tensor([1]), Tensor([[0,0,0],[0,0,0]])]) + helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,0,0],[0,0,0]]), torch.tensor([2,1,1])], + lambda x: x[Tensor([[0,0,0],[0,0,0]]), Tensor([2,1,1])]) + helper_test_op([(2,3)], lambda x: x[torch.tensor([[0,1,-1],[-1,-2,0]]), torch.tensor([2,1,-1])], + lambda x: x[Tensor([[0,1,-1],[-1,-2,0]]), Tensor([2,1,-1])]) + + def test_slice_fancy_indexing_list_indices(self): + a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() + helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[0]]], lambda x: x[[[0]]]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[[0],b,c,d,:], lambda x: x[[0],j,k,o,:]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[[0]]],b,c,d,[[1]]], lambda x: x[[[[0]]],j,k,o,[[1]]]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1,0],b,c,d,:], lambda x: x[[1,0],j,k,o,:]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[1,2,3],...], lambda x: x[i,j,k,[1,2,3],...]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[[1],[2],[3]],...], lambda x: x[i,j,k,[[1],[2],[3]],...]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[2,1,0],e], lambda x: x[i,[2,1,0],k,[2,1,0],p]) + + def test_slice_fancy_indexing_tuple_indices(self): + a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() + helper_test_op([(2,5,6,5,3,4)], lambda x: x[(((0,),),)], lambda x: x[(((0,),),)]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[(0,),b,c,d,:], lambda x: x[(0,),j,k,o,:]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[(1,0),b,c,d,:], lambda x: x[(1,0),j,k,o,:]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,(1,2,3),...], lambda x: x[i,j,k,(1,2,3),...]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,((2,),(1,),(0,)),c,(2,1,0)], lambda x: x[i,((2,),(1,),(0,)),k,(2,1,0)]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[1,(2,1,0),None,c,(2,1,0),e], lambda x: x[1,(2,1,0),None,k,(2,1,0),p]) + + def test_slice_fancy_indexing_list_with_tensors(self): + a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() + helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a]], lambda x: x[[i]]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,1]], lambda x: x[[i,1]]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,[1,1]]], lambda x: x[[i,[1,1]]]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,(1,1)]], lambda x: x[[i,(1,1)]]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[[a,b,c,d,e]], lambda x: x[[i,j,k,o,p]]) + + def test_slice_fancy_indexing_errors(self): + a = Tensor.ones(10,11,12) + # tensors used as indices must be int tensors + with self.assertRaises(IndexError): a[Tensor(1.1)] + with self.assertRaises(IndexError): a[Tensor([True, True])] + # shape mismatch, cannot broadcast + with self.assertRaises(IndexError): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1), Tensor.randint(2,4,4,1)] + with self.assertRaises(IndexError): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1,1)] def test_gather(self): # indices cannot have gradient @@ -1252,19 +1599,37 @@ def test_gather(self): helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=1), lambda x: x.gather(idx=a, dim=1)) helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=2), lambda x: x.gather(idx=a, dim=2)) helper_test_op([(3,4,5)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0)) - self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0), lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError)) + self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0), + lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError)) + self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0), + lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError)) def test_scaled_product_attention(self): - helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z)) - helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m)) - helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True)) + helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], torch.nn.functional.scaled_dot_product_attention, Tensor.scaled_dot_product_attention) + helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], + lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), + lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m)) + + def test_scaled_product_attention_causal(self): + helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], + lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), + lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True)) def test_binary_crossentropy(self): - helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) - helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) - helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) - helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), + lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), + lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), + lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1))) + helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), + lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) + + def test_one_hot(self): + data = [1, 2, 4] + helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6), lambda: Tensor(data).one_hot(6), forward_only=True) + data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]] + helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8), lambda: Tensor(data).one_hot(8), forward_only=True) if __name__ == '__main__': np.random.seed(1337) diff --git a/test/test_optim.py b/test/test_optim.py index 76f9bcf..3be44f8 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1,17 +1,22 @@ import numpy as np import torch import unittest -from teenygrad.tensor import Tensor +from tinygrad import Tensor, Device from teenygrad.nn.optim import Adam, SGD, AdamW -import pytest - -pytestmark = pytest.mark.exclude_cuda +from teenygrad.helpers import CI np.random.seed(1337) x_init = np.random.randn(1,4).astype(np.float32) W_init = np.random.randn(4,4).astype(np.float32) m_init = np.random.randn(1,4).astype(np.float32) +class TeenyNet: + def __init__(self, tensor): + self.x = tensor(x_init.copy(), requires_grad=True) + self.W = tensor(W_init.copy(), requires_grad=True) + def forward(self): + return (self.x * self.W).sum() + class TinyNet: def __init__(self, tensor): self.x = tensor(x_init.copy(), requires_grad=True) @@ -25,8 +30,8 @@ def forward(self): out = out.mul(self.m).add(self.m).sum() return out -def step(tensor, optim, steps=1, kwargs={}): - net = TinyNet(tensor) +def step(tensor, optim, steps=1, teeny=False, **kwargs): + net = TeenyNet(tensor) if teeny else TinyNet(tensor) optim = optim([net.x, net.W], **kwargs) for _ in range(steps): out = net.forward() @@ -35,17 +40,21 @@ def step(tensor, optim, steps=1, kwargs={}): optim.step() return net.x.detach().numpy(), net.W.detach().numpy() +@unittest.skipIf(CI and Device.DEFAULT == "CUDA", "slow") class TestOptim(unittest.TestCase): def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol): - for x,y in zip(step(Tensor, tinygrad_optim, steps, kwargs=opts), - step(torch.tensor, torch_optim, steps, kwargs=opts)): + for x,y in zip(step(Tensor, tinygrad_optim, steps, **opts), + step(torch.tensor, torch_optim, steps, **opts)): np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) def _test_sgd(self, steps, opts, atol, rtol): self._test_optim(SGD, torch.optim.SGD, steps, opts, atol, rtol) def _test_adam(self, steps, opts, atol, rtol): self._test_optim(Adam, torch.optim.Adam, steps, opts, atol, rtol) def _test_adamw(self, steps, opts, atol, rtol): self._test_optim(AdamW, torch.optim.AdamW, steps, opts, atol, rtol) + def test_multistep_sgd_high_lr_teeny(self): self._test_sgd(2, {'lr': 1.1, 'teeny': True}, 1e-6, 1e-5) + def test_multistep_adam_high_lr_teeny(self): self._test_adam(2, {'lr': 1.1, 'teeny': True}, 2e-4, 5e-4) + def test_sgd(self): self._test_sgd(1, {'lr': 0.001}, 1e-6, 0) def test_sgd_high_lr(self): self._test_sgd(1, {'lr': 10}, 1e-6, 1e-5) def test_sgd_wd(self): self._test_sgd(1, {'lr': 0.001, 'weight_decay': 0.1}, 1e-6, 0) @@ -63,8 +72,10 @@ def test_multistep_sgd_high_lr_momentum_wd(self): self._test_sgd(10, {'lr': 10, def test_multistep_sgd_nesterov_momentum(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True}, 1e-5, 0) def test_multistep_sgd_high_lr_nesterov_momentum(self): self._test_sgd(10, {'lr': 10, 'momentum': 0.9, 'nesterov': True}, 1e-5, 3e-4) - def test_multistep_sgd_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0) - def test_multistep_sgd_high_lr_nesterov_momentum_wd(self): self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4) + def test_multistep_sgd_nesterov_momentum_wd(self): + self._test_sgd(10, {'lr': 0.001, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 0) + def test_multistep_sgd_high_lr_nesterov_momentum_wd(self): + self._test_sgd(10, {'lr': 9, 'momentum': 0.9, 'nesterov': True, 'weight_decay': 0.1}, 1e-5, 3e-4) def test_adam(self): self._test_adam(1, {'lr': 0.001}, 1e-5, 0) def test_adam_high_lr(self): self._test_adam(1, {'lr': 10}, 1e-4, 1e-4)