Skip to content

Commit

Permalink
update to match tinygrad and use built in sparse_categorical_crossent…
Browse files Browse the repository at this point in the history
…ropy
  • Loading branch information
geohot committed Aug 21, 2023
1 parent 7bbf6cc commit 918f632
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 45 deletions.
19 changes: 4 additions & 15 deletions mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,19 @@
import numpy as np
from teenygrad.tensor import Tensor
from tqdm import trange
from teenygrad.nn import optim
import gzip, os

# sorted in order of increasing complexity
from teenygrad.nn import optim
from teenygrad.helpers import getenv

def sparse_categorical_crossentropy(out, Y):
num_classes = out.shape[-1]
YY = Y.flatten().astype(np.int32)
y = np.zeros((YY.shape[0], num_classes), np.float32)
# correct loss for NLL, torch NLL loss returns one per row
y[range(y.shape[0]),YY] = -1.0*num_classes
y = y.reshape(list(Y.shape)+[num_classes])
y = Tensor(y)
return out.mul(y).mean()

def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categorical_crossentropy,
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
Tensor.training = True
losses, accuracies = [], []
for i in (t := trange(steps, disable=getenv('CI', False))):
samp = np.random.randint(0, X_train.shape[0], size=(BS))
x = Tensor(transform(X_train[samp]), requires_grad=False)
y = target_transform(Y_train[samp])
y = Tensor(target_transform(Y_train[samp]))

# network
out = model.forward(x) if hasattr(model, 'forward') else model(x)
Expand All @@ -39,7 +28,7 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
# printing
if not noloss:
cat = np.argmax(out.numpy(), axis=-1)
accuracy = (cat == y).mean()
accuracy = (cat == y.numpy()).mean()

loss = loss.detach().numpy()
losses.append(loss)
Expand Down
9 changes: 8 additions & 1 deletion teenygrad/lazy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations
from typing import Tuple
from teenygrad.helpers import dtypes
from teenygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps
from teenygrad.ops import UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps
import numpy as np

class Device:
DEFAULT = "CPU"
_buffers = ["CPU"]
def canonicalize(x): return "CPU"

def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
Expand Down Expand Up @@ -61,6 +62,12 @@ def binary_op(self, op, y:LazyBuffer):
else:
raise NotImplementedError(op)

def ternary_op(self, op, y:LazyBuffer, z:LazyBuffer):
if op == TernaryOps.WHERE:
return LazyBuffer(np.where(self._np, y._np, z._np))
else:
raise NotImplementedError(op)

def reduce_op(self, op, new_shape):
if op == ReduceOps.SUM:
return LazyBuffer(self._np.sum(shape_to_axis(self.shape, new_shape), keepdims=True))
Expand Down
72 changes: 43 additions & 29 deletions teenygrad/tensor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
from __future__ import annotations
import time, operator
import time
from functools import partialmethod, reduce
from itertools import accumulate, filterfalse
from itertools import accumulate
import numpy as np
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast
from math import ceil, pi, prod, sqrt, log, cos, copysign, isinf
Expand Down Expand Up @@ -277,27 +277,23 @@ def __getitem__(self, val):
def normalize_int(e, i, dim_sz):
if -dim_sz <= e < dim_sz: return e if e != -1 else dim_sz-1
raise IndexError(f"index {e} is out of bounds for dimension {i} with size {self.shape[i]}")
val = list(val) if isinstance(val, tuple) else [val]
if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in val)) > len(self.shape):
orig_slices = list(val) if isinstance(val, tuple) else [val]
if (num_slices := sum(isinstance(v, (slice, int, Tensor)) for v in orig_slices)) > len(self.shape):
raise IndexError(f"too many indices for tensor of dimension {len(self.shape)}")
orig_slices = list(val)
ellipses_found = [i for i, v in enumerate(val) if v is Ellipsis]
if len(ellipses_found) > 0:
if len(ellipses_found) != 1:
raise IndexError("an index can only have a single ellipsis ('...')")
ellipsis_idx = ellipses_found[0]
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)
else:
orig_slices += [slice(None)] * (len(self.shape) - num_slices)
ellipses_found = [i for i, v in enumerate(orig_slices) if v is Ellipsis]
if len(ellipses_found) > 1: raise IndexError("an index can only have a single ellipsis ('...')")
ellipsis_idx = len(orig_slices) if len(ellipses_found) == 0 else ellipses_found[0]
orig_slices[ellipsis_idx:ellipsis_idx+1] = [slice(None)] * (len(self.shape) - num_slices)

tensor_found = [(i,v) for i, v in enumerate(orig_slices) if isinstance(v, Tensor)]
orig_slices = [slice(None, None, None) if isinstance(v, Tensor) else v for v in orig_slices]
valid_slices = list(filterfalse(lambda x: x is None, orig_slices))
orig_slices = [slice(None) if isinstance(v, Tensor) else v for v in orig_slices]
valid_slices = [s for s in orig_slices if s is not None]
valid_slices = [v if isinstance(v, slice) else slice(y := normalize_int(v, i, dim_sz), y+1) for i, (v, dim_sz) in enumerate(zip(valid_slices, self.shape))]
start, stop, strides = zip(*y) if (y := [s.indices(dim_sz) for s, dim_sz in zip(valid_slices, self.shape)]) else ((), (), ())
new_slice = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in zip(start, stop, strides))
new_shape = tuple(e - s for s, e in new_slice)
# Shrink
sliced_tensor = self.shrink(new_slice)
new_shape = sliced_tensor.shape
# Flip
if (flip_axes := tuple(i for i, s in enumerate(strides) if s < 0)):
sliced_tensor = sliced_tensor.flip(axis=flip_axes)
Expand All @@ -309,15 +305,14 @@ def num_zeros(step, dim_sz): return 0 if step == 1 or (y := dim_sz % step) == 0
paddings = tuple((0, num_zeros(s, dim_sz)) for s, dim_sz in zip(strides, sliced_tensor.shape))
padded_tensor = sliced_tensor.pad(paddings)
# Reshape: [dim_sz_padded] -> [dim_sz_padded // s, s]
new_shape = reduce(operator.add, [[sh // s, s] for sh, s in zip(padded_tensor.shape, strides)], []) # type: ignore
new_shape = flatten([sh // s, s] for sh, s in zip(padded_tensor.shape, strides))
reshaped_tensor = padded_tensor.reshape(new_shape)
# Shrink: do [:, 0]
new_shape = new_shape[::2]
final_slice = reduce(operator.add, (((0, sh), (0, 1)) for sh in new_shape), ())
final_slice = tuple(flatten(((0, sh), (0, 1)) for sh in new_shape))
sliced_tensor = reshaped_tensor.shrink(final_slice)
final_shape = []
final_shape, it_shape = [], iter(new_shape)
sub = [0] * len(tensor_found)
it_shape = iter(new_shape)
for i,s in enumerate(orig_slices):
if isinstance(s, (int, slice)):
dim_shape = next(it_shape)
Expand All @@ -332,14 +327,14 @@ def num_zeros(step, dim_sz): return 0 if step == 1 or (y := dim_sz % step) == 0
for i,s in enumerate(sub): tensor_found[i] = (tensor_found[i][0]+s, tensor_found[i][1])
dim = [i[0] for i in tensor_found]
idx = [i[1].sign().contiguous().__neg__().contiguous().relu() * ret.shape[i[0]] + i[1] for i in tensor_found] # TODO first contiguous fixes torch+cpu_only CI, but it causes llvm to fail. Second one fixes llvm
max_dim = max(idx, key=lambda i: i.ndim).ndim
idx = [i if i.ndim == max_dim else i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
sum_dim = [d if n==0 else d+i.ndim-n for n,(d,i) in enumerate(zip(dim,idx))]
new_idx = idx[0].reshape(*[1]*sum_dim[0], 1, *idx[0].shape, *[1]*(ret.ndim-sum_dim[0]-1))
arange = Tensor.arange(ret.shape[sum_dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*sum_dim[0], ret.shape[sum_dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-sum_dim[0]-1))
ret = (ret.reshape(*ret.shape[:sum_dim[0]+1], *[1]*idx[0].ndim, *ret.shape[sum_dim[0]+1:]) * (arange == new_idx)).sum(sum_dim[0])
max_dim = max(i.ndim for i in idx)
idx = [i.reshape(*[1]*(max_dim-i.ndim), *i.shape) for i in idx]
sum_dim = [d+max_dim-n for n,d in enumerate(dim)]
new_idx = idx[0].reshape(*[1]*dim[0], 1,*idx[0].shape, *[1]*(ret.ndim-dim[0]-1))
arange = Tensor.arange(ret.shape[dim[0]], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim[0], ret.shape[dim[0]], *[1]*idx[0].ndim, *[1]*(ret.ndim-dim[0]-1))
ret = (ret.reshape(*ret.shape[:dim[0]+1], *[1]*idx[0].ndim, *ret.shape[dim[0]+1:]) * (arange == new_idx)).sum(dim[0])
for idx_,d in zip(idx[1:],sum_dim[1:]):
new_idx = idx_.reshape(*[1]*sum_dim[0], *idx_.shape, *[1]*(ret.ndim-sum_dim[0]-idx_.ndim))
new_idx = idx_.reshape(*[1]*dim[0], *idx_.shape, *[1]*(ret.ndim-dim[0]-idx_.ndim))
arange = Tensor.arange(ret.shape[d], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*(d), ret.shape[d], *[1]*(ret.ndim-d-1))
ret = ((new_idx == arange) * ret).sum(d)
if dim[0] != 0 and dim != list(range(dim[0], dim[-1]+1)) and len(dim) != 1: # special permute case
Expand Down Expand Up @@ -454,6 +449,14 @@ def log_softmax(self, axis=-1):
m, _, ss = self._softmax(axis)
return m - ss.log()

def argmax(self, axis=None, keepdim=False):
if axis is None: return prod(self.shape) - ((self == self.max(axis)).flatten() * Tensor.arange(prod(self.shape)-1,-1,-1)).max() - 1
axis = axis + self.ndim if axis < 0 else axis
m = self == (self.max(axis=axis, keepdim=keepdim) if keepdim else self.max(axis=axis, keepdim=keepdim).unsqueeze(axis))
idx = m * Tensor.arange(self.shape[axis]-1,-1,-1).reshape(*[1]*axis, self.shape[axis], *[1]*(self.ndim-(axis+1)))
return self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1
def argmin(self, axis=None, keepdim=False): return (-self).argmax(axis=axis, keepdim=keepdim)

# ***** processing ops *****

def _pool(self, k_:Tuple[int, ...], stride:Union[Tuple[int, ...], int]=1, dilation:Union[Tuple[int, ...], int]=1) -> Tensor:
Expand Down Expand Up @@ -604,7 +607,7 @@ def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor:
if x == 2.0: return self*self
if x == 1.0: return self
if x == 0.5: return self.sqrt()
if x.__class__ is not Tensor and reverse and x > 0: return self.mul(log(x)).exp()
if not isinstance(x, Tensor) and reverse and x > 0: return self.mul(log(x)).exp()
ar = self.abs().log().mul(x).exp() if not reverse or isinstance(x, Tensor) else self.mul(log(abs(x))).exp()
# correct sign of negative numbers raised to a power (cos has a period of 2pi so we use it here to get the oddness of the power)
sign = (x * pi).cos() if isinstance(x, Tensor) else cos(x * pi) if not reverse else (self * pi).cos()
Expand All @@ -625,7 +628,7 @@ def minimum(self, x:Union[Tensor, float]) -> Tensor: return -((-self).maximum(-x

def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
# TODO: ensure self is non-differentiable, could mess with ceil/float though
dtype = self.dtype if self.dtype != dtypes.bool and self.dtype.__class__ is not ImageDType else dtypes.float32
dtype = self.dtype if self.dtype != dtypes.bool else dtypes.float32
x: Tensor = self
y: Tensor = Tensor(cast(float, input_), device=self.device, requires_grad=False, dtype=dtype) if input_.__class__ is not Tensor else cast(Tensor, input_)
z: Tensor = Tensor(cast(float, other), device=self.device, requires_grad=False, dtype=dtype) if other.__class__ is not Tensor else cast(Tensor, other)
Expand Down Expand Up @@ -705,6 +708,12 @@ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optio
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
return (self @ key.transpose(-2,-1) / sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value

def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
loss_mask = Y != ignore_index
y_counter = Tensor.arange(self.shape[-1], requires_grad=False).unsqueeze(0).expand(Y.numel(), self.shape[-1])
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
return self.log_softmax().mul(y).sum() / loss_mask.sum()

# ***** cast ops *****

def cast(self, dtype:DType) -> Tensor: return mlops.Cast.apply(self, dtype=dtype) if self.dtype != dtype else self
Expand All @@ -719,3 +728,8 @@ def numel(self) -> int: return prod(self.shape)
def element_size(self) -> int: return self.dtype.itemsize
def nbytes(self) -> int: return self.numel() * self.element_size()
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)

# register functions to move between devices
for device in Device._buffers:
setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
setattr(Tensor, f"{device.lower()}_", partialmethod(Tensor.to_, device))

0 comments on commit 918f632

Please sign in to comment.