Skip to content

Commit

Permalink
teenygrad up to date
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Sep 23, 2023
1 parent acbac2e commit 2da66f2
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 40 deletions.
4 changes: 3 additions & 1 deletion teenygrad/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union, Tuple, Iterator, NamedTuple, Optional, Final
from typing import Union, Tuple, Iterator, NamedTuple, Optional, Final, Any
from typing_extensions import TypeGuard
import os, functools
import numpy as np
from math import prod # noqa: F401 # pylint:disable=unused-import
Expand All @@ -8,6 +9,7 @@ def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)

@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))
Expand Down
22 changes: 11 additions & 11 deletions teenygrad/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ def __init__(self, buf): self._np = buf
@property
def shape(self): return self._np.shape

def contiguous(x): return x
def realize(x): return x

def const(self, x) -> LazyBuffer: return LazyBuffer(np.full_like(self._np, x))

@staticmethod
def fromCPU(x): return LazyBuffer(x)
def toCPU(self): return self._np
Expand All @@ -33,13 +30,8 @@ def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
elif op == LoadOps.CONST: return LazyBuffer(np.full(shape, arg))
else: raise NotImplementedError(op)

# MovementOps
def reshape(self, arg): return LazyBuffer(self._np.reshape(arg))
def expand(self, arg): return LazyBuffer(np.broadcast_to(self._np, arg))
def shrink(self, arg): return LazyBuffer(self._np[tuple(slice(p[0], p[1], None) for p in arg)])
def permute(self, arg): return LazyBuffer(self._np.transpose(arg))
def pad(self, arg): return LazyBuffer(np.pad(self._np, arg))
def stride(self, arg): return LazyBuffer(self._np[tuple(slice(None, None, i) for i in arg)])
def contiguous(x): return x
def const(self, x) -> LazyBuffer: return LazyBuffer(np.full_like(self._np, x))

def e(self, op, *srcs):
if op == UnaryOps.NEG: return LazyBuffer(-self._np)
Expand All @@ -56,7 +48,15 @@ def e(self, op, *srcs):
elif op == TernaryOps.WHERE: return LazyBuffer(np.where(self._np, srcs[0]._np, srcs[1]._np))
else: raise NotImplementedError(op)

def reduce_op(self, op, new_shape):
def r(self, op, new_shape):
if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(shape_to_axis(self.shape, new_shape), keepdims=True))
elif op == ReduceOps.MAX: return LazyBuffer(self._np.max(shape_to_axis(self.shape, new_shape), keepdims=True))
else: raise NotImplementedError(op)

# MovementOps
def reshape(self, arg): return LazyBuffer(self._np.reshape(arg))
def expand(self, arg): return LazyBuffer(np.broadcast_to(self._np, arg))
def shrink(self, arg): return LazyBuffer(self._np[tuple(slice(p[0], p[1], None) for p in arg)])
def permute(self, arg): return LazyBuffer(self._np.transpose(arg))
def pad(self, arg): return LazyBuffer(np.pad(self._np, arg))
def stride(self, arg): return LazyBuffer(self._np[tuple(slice(None, None, i) for i in arg)])
17 changes: 10 additions & 7 deletions teenygrad/mlops.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math
from typing import Tuple, Optional
from typing import Tuple, Optional, cast
from teenygrad.helpers import argsort, DType
from teenygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from teenygrad.tensor import Function
from teenygrad.lazy import LazyBuffer
from teenygrad.shape.symbolic import sint

class Contiguous(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
Expand Down Expand Up @@ -88,20 +89,20 @@ def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
class Sum(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.input_shape = x.shape
return x.reduce_op(ReduceOps.SUM, new_shape)
return x.r(ReduceOps.SUM, new_shape)

def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.expand(self.input_shape)

class Max(Function):
def forward(self, x:LazyBuffer, new_shape:Tuple[int, ...]) -> LazyBuffer:
self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape)
self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
return self.ret

def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.const(1.0).e(BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape)))
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).e(BinaryOps.MUL, grad_output.expand(self.x.shape))

# ************* binary ops *************
Expand Down Expand Up @@ -165,7 +166,7 @@ def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
return x.expand(shape)

def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
return grad_output.r(ReduceOps.SUM, self.input_shape)

class Reshape(Function):
def forward(self, x:LazyBuffer, shape:Tuple[int, ...]) -> LazyBuffer:
Expand All @@ -192,12 +193,14 @@ def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.shrink(self.narg)

class Shrink(Function):
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
return x.shrink(arg)

def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.pad(self.narg)
assert all(isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg), "symbolic shrink does not support backward"
# need this cast because mypy cannot narrow the type even with assert
return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg))

class Flip(Function):
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]) -> LazyBuffer:
Expand Down
1 change: 1 addition & 0 deletions teenygrad/shape/symbolic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sint = int
57 changes: 36 additions & 21 deletions teenygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence

from teenygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes
from teenygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int
from teenygrad.lazy import LazyBuffer
from teenygrad.ops import Device, LoadOps
from teenygrad.shape.symbolic import sint

# An instantiation of the Function is the Context
class Function:
Expand Down Expand Up @@ -75,7 +76,7 @@ def __hash__(self): return id(self)
def device(self) -> str: return self.lazydata.device

@property
def shape(self) -> Tuple[int, ...]: return self.lazydata.shape
def shape(self) -> Tuple[sint, ...]: return self.lazydata.shape

@property
def dtype(self) -> DType: return self.lazydata.dtype
Expand All @@ -90,13 +91,13 @@ def assign(self, x) -> Tensor:
# TODO: this is a hack for writing to DISK
if self.device.startswith("DISK"):
if x.__class__ is not Tensor: x = Tensor(x, device="CPU", dtype=self.dtype)
self.lazydata.realize().realized._copyin(x.numpy()) # type: ignore
self.lazydata.contiguous().realize().realized._copyin(x.numpy()) # type: ignore
return self
if x.__class__ is not Tensor: x = Tensor(x, device=self.device, dtype=self.dtype)
assert self.shape == x.shape and self.device == x.device, f"assign shape mismatch {self.shape} != {x.shape} or device mismatch {self.device} != {x.device}"
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
if self.dtype == x.dtype and self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
self.lazydata = x.lazydata
return self

Expand All @@ -121,21 +122,24 @@ def _loadop(op, sz, device:Optional[str]=None, dtype:Optional[DType]=None, arg=N
return Tensor(LazyBuffer.loadop(op, [sz], Tensor.default_type if dtype is None else dtype, Device.canonicalize(device), arg), dtype=dtype, device=device, **kwargs)

@staticmethod
def empty(*shape, **kwargs): return Tensor._loadop(LoadOps.EMPTY, math.prod(shape), **kwargs).reshape(shape)
def empty(*shape, **kwargs):
assert all_int(shape), f"cannot create with symbolic shape {shape}"
return Tensor._loadop(LoadOps.EMPTY, prod(shape), **kwargs).reshape(shape)

_seed: int = int(time.time())
@staticmethod
def manual_seed(seed=0): Tensor._seed = seed

@staticmethod
def rand(*shape, **kwargs):
assert all_int(shape), f"cannot create with symbolic shape {shape}"
Tensor._seed += 1
return Tensor._loadop(LoadOps.RAND, math.prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)
return Tensor._loadop(LoadOps.RAND, prod(shape), arg=Tensor._seed, **kwargs).reshape(shape)

# ***** creation helper functions *****

@staticmethod
def full(shape:Tuple[int, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)
def full(shape:Tuple[sint, ...], fill_value, **kwargs): return Tensor(fill_value, **kwargs).reshape([1]*len(new_shape := argfix(shape))).expand(new_shape)

@staticmethod
def zeros(*shape, **kwargs): return Tensor.full(argfix(*shape), 0, **kwargs)
Expand Down Expand Up @@ -173,22 +177,22 @@ def uniform(*shape, low=-1.0, high=1.0, **kwargs) -> Tensor:
return ((high-low) * Tensor.rand(*shape, **kwargs)).cast(dtype) + low

@staticmethod
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(math.prod(shape)**-0.5)
def scaled_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul(prod(shape)**-0.5)

# https://www.tensorflow.org/api_docs/python/tf/keras/initializers/GlorotUniform
@staticmethod
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+math.prod(shape[1:])))**0.5)
def glorot_uniform(*shape, **kwargs) -> Tensor: return Tensor.uniform(*shape, **kwargs).mul((6/(shape[0]+prod(shape[1:])))**0.5)

# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_uniform_
@staticmethod
def kaiming_uniform(*shape, a:float = 0.01, **kwargs) -> Tensor:
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
bound = math.sqrt(3.0) * math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.uniform(*shape, low=-bound, high=bound, **kwargs)

# https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
@staticmethod
def kaiming_normal(*shape, a:float = 0.01, **kwargs) -> Tensor:
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(math.prod(shape[1:]))
std = math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(prod(shape[1:]))
return Tensor.normal(*shape, mean=0.0, std=std, **kwargs)

# ***** toposort and backward pass *****
Expand Down Expand Up @@ -224,11 +228,11 @@ def backward(self):
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
assert 0 not in new_shape, f"zeros not allowed in shape {new_shape}"
return mlops.Reshape.apply(self, shape=tuple([-math.prod(self.shape) // math.prod(new_shape) if s == -1 else s for s in new_shape]))
return mlops.Reshape.apply(self, shape=tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple([x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))]))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=[x if x >= 0 else x+len(self.shape) for x in argfix(axis, *args)])
def shrink(self, arg:Tuple[Tuple[int, int], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]) -> Tensor: return mlops.Shrink.apply(self, arg=arg) if any(x != (0,s) for x,s in zip(arg, self.shape)) else self
def pad(self, arg: Tuple[Tuple[int, int], ...], value:float=0) -> Tensor:
ret = mlops.Pad.apply(self, arg=arg) if any(x != (0, 0) for x in arg) else self
return ret if 0 == value else ret + mlops.Pad.apply(Tensor.ones_like(self), arg=arg).where(0, value)
Expand Down Expand Up @@ -299,6 +303,7 @@ def normalize_int(e, i, dim_sz):
if isinstance(s, int):
dim_collapsed += 1
else:
assert isinstance(dim_shape, int), f"does not support symbolic shape {dim_shape}"
final_shape.append(dim_shape)
if isinstance(s, Tensor):
tensors.append(s)
Expand Down Expand Up @@ -326,7 +331,7 @@ def normalize_int(e, i, dim_sz):
return ret

# NOTE: using slice is discouraged and things should migrate to pad and shrink
def slice(self, arg:Sequence[Optional[Tuple[int, int]]], value:float=0) -> Tensor:
def slice(self, arg:Sequence[Optional[Tuple[int, sint]]], value:float=0) -> Tensor:
arg_ = tuple([a if a is not None else (0,s) for s,a in zip(self.shape, arg)])
padding = tuple([(max(0, -p[0]), max(0, p[1]-self.shape[i])) for i,p in enumerate(arg_)])
return self.pad(padding, value=value).shrink(tuple([(p[0] + padding[i][0], p[1] + padding[i][0]) for i,p in enumerate(arg_)]))
Expand Down Expand Up @@ -367,6 +372,7 @@ def repeat(self, repeats):
return self.reshape(new_shape).expand(expand_shape).reshape(final_shape)

def chunk(self, num:int, dim:int) -> List[Tensor]:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
dim, step = dim + self.ndim if dim < 0 else dim, math.ceil(self.shape[dim]/num)
slice_params = [[slice(None)]*dim + [slice(k, k + step)] for k in range(0, self.shape[dim], step)]
return [self[tuple(sl)] for sl in slice_params]
Expand Down Expand Up @@ -409,11 +415,13 @@ def max(self, axis=None, keepdim=False): return self._reduce(mlops.Max, axis, ke
def min(self, axis=None, keepdim=False): return -((-self).max(axis=axis, keepdim=keepdim))

def mean(self, axis=None, keepdim=False):
assert all_int(self.shape), "does not support symbolic shape"
out = self.sum(axis=axis, keepdim=keepdim)
return out * (math.prod(out.shape)/math.prod(self.shape))
return out * (prod(out.shape)/prod(self.shape))
def std(self, axis=None, keepdim=False, correction=1):
assert all_int(self.shape), "does not support symbolic shape"
square_sum = ((self - self.mean(axis=axis, keepdim=True)).square()).sum(axis=axis, keepdim=keepdim)
return (square_sum / (math.prod(self.shape)/math.prod(square_sum.shape)-correction)).sqrt()
return (square_sum / (prod(self.shape)/prod(square_sum.shape)-correction)).sqrt()
def _softmax(self, axis):
m = self - self.max(axis=axis, keepdim=True)
e = m.exp()
Expand All @@ -429,8 +437,8 @@ def log_softmax(self, axis=-1):

def argmax(self, axis=None, keepdim=False):
if axis is None:
idx = (self == self.max(axis)) * Tensor.arange(math.prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
return math.prod(self.shape) - idx.max() - 1
idx = (self == self.max(axis)) * Tensor.arange(prod(self.shape)-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape)
return prod(self.shape) - idx.max() - 1
axis = axis + len(self.shape) if axis < 0 else axis
m = self == self.max(axis=axis, keepdim=True)
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, dtype=dtypes.int32, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1))
Expand All @@ -441,6 +449,7 @@ def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, kee

def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}"
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
s_, d_ = make_pair(stride, len(k_)), make_pair(dilation, len(k_))
assert len(k_) == len(s_) and len(k_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}"
slc_prefix, prefix, i_ = [(0,x) for x in self.shape[0:-len(k_)]], self.shape[0:-len(k_)], self.shape[-len(k_):]
Expand Down Expand Up @@ -552,8 +561,12 @@ def tan(self): return self.sin() / self.cos()

@staticmethod
def _tri(r:int, c:int, k:int=0, **kwargs) -> Tensor: return Tensor.arange(r, **kwargs).unsqueeze(1).expand(r,c) <= Tensor.arange(-k, c-k, **kwargs).unsqueeze(0).expand(r,c)
def triu(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor: return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)
def triu(self, k:int=0) -> Tensor:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
return Tensor._tri(self.shape[-2], self.shape[-1], k=k, dtype=self.dtype, device=self.device).where(self, Tensor.zeros_like(self))
def tril(self, k:int=0) -> Tensor:
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
return Tensor._tri(self.shape[-2], self.shape[-1], k=k+1, dtype=self.dtype, device=self.device).where(Tensor.zeros_like(self), self)

# ***** math functions (unary) *****
def trunc(self: Tensor) -> Tensor: return self.cast(dtypes.int32).contiguous().cast(self.dtype)
Expand Down Expand Up @@ -693,6 +706,8 @@ def dropout(self, p=0.5) -> Tensor:
return self * mask * (1/(1.0 - p))

def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor:
# NOTE: it works if key, value have symbolic shape
assert all_int(self.shape), f"does not support symbolic shape {self.shape}"
if is_causal: attn_mask = Tensor.ones(self.shape[-2], key.shape[-2], requires_grad=False, device=self.device).tril(0).cast(dtypes.bool)
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
Expand All @@ -716,7 +731,7 @@ def half(self) -> Tensor: return self.cast(dtypes.float16)

@property
def ndim(self) -> int: return len(self.shape)
def numel(self) -> int: return math.prod(self.shape)
def numel(self) -> sint: return prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
Expand Down

0 comments on commit 2da66f2

Please sign in to comment.