Skip to content

Commit

Permalink
update to latest tinygrad
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Mar 9, 2024
1 parent b911d3c commit 48b96d5
Show file tree
Hide file tree
Showing 13 changed files with 1,972 additions and 930 deletions.
2 changes: 1 addition & 1 deletion import_from_tinygrad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
import pathlib

FILES = ["tensor.py", "mlops.py", "nn/optim.py", "../test/test_ops.py", "../test/test_dtype.py", "../test/test_optim.py"]
FILES = ["tensor.py", "mlops.py", "dtype.py", "nn/optim.py", "../test/test_ops.py", "../test/test_dtype.py", "../test/test_optim.py"]
src = pathlib.Path("../tinygrad/tinygrad")
dest = pathlib.Path("teenygrad")

Expand Down
15 changes: 15 additions & 0 deletions teenygrad/device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Optional, Any
from teenygrad.dtype import DType
import numpy as np

class Device:
DEFAULT = "CPU"
_devices = ["CPU"]
@staticmethod
def canonicalize(device:Optional[str]) -> str: return "CPU"

class Buffer:
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options=None):
self.device, self.size, self.dtype, self._buf = device, size, dtype, opaque[1] if isinstance(opaque, tuple) else opaque
def copyin(self, buf): self._buf = np.frombuffer(buf, dtype=self.dtype.np)
def as_buffer(self): return self._buf.data
103 changes: 103 additions & 0 deletions teenygrad/dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
from dataclasses import dataclass
import numpy as np # TODO: remove numpy
import functools

Scalar = Union[float, int, bool]

@dataclass(frozen=True, order=True)
class DType:
priority: int # this determines when things get upcasted
itemsize: int
name: str
fmt: Optional[str]
count: int
def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}"
def vec(self, sz:int):
assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
# TODO: someday this will be removed with the "remove numpy" project
@property
def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None

# dependent typing?
@dataclass(frozen=True, repr=False)
class ImageDType(DType):
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape
base: DType
def scalar(self): return self.base
def vec(self, sz:int): return self.base.vec(sz)
def __repr__(self): return f"dtypes.{self.name}({self.shape})"

# @dataclass(frozen=True, init=False, repr=False, eq=False)
class PtrDType(DType):
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count)
def __repr__(self): return f"ptr.{super().__repr__()}"
def __hash__(self): return super().__hash__()
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
def __ne__(self, dt): return not (self == dt)

def cast_scalar(scalar: Scalar, dtype:DType):
return int(scalar) if dtypes.is_int(dtype) else float(scalar) if dtypes.is_float(dtype) else bool(scalar)

class dtypes:
@staticmethod
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64)
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x)
@staticmethod
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod # NOTE: isinstance(True, int) is True in python
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
int16: Final[DType] = DType(3, 2, "short", 'h', 1)
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1)
int32: Final[DType] = DType(5, 4, "int", 'i', 1)
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1)
int64: Final[DType] = DType(7, 8, "long", 'l', 1)
uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1)
float16: Final[DType] = DType(9, 2, "half", 'e', 1)
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16
bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1)
float32: Final[DType] = DType(11, 4, "float", 'f', 1)
float64: Final[DType] = DType(12, 8, "double", 'd', 1)

# dtype aliases
half = float16; float = float32; double = float64 # noqa: E702
uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702
char = int8; short = int16; int = int32; long = int64 # noqa: E702

# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32)

default_float: ClassVar[DType] = float32
default_int: ClassVar[DType] = int32

# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html
# we don't support weak type and complex type
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64],
dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32],
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16],
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], }

@functools.lru_cache(None)
def _get_recursive_parents(dtype:DType) -> Set[DType]:
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64}
@functools.lru_cache(None)
def least_upper_dtype(*ds:DType) -> DType:
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0]
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32)

# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)}
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()}
62 changes: 13 additions & 49 deletions teenygrad/helpers.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,25 @@
from typing import Union, Tuple, Iterator, Optional, Final, Any
import os, functools, platform
import numpy as np
from math import prod # noqa: F401 # pylint:disable=unused-import
from dataclasses import dataclass
from typing import Union, Tuple, Sequence, Any, Iterable, Dict, TypeVar
import os, functools, platform, operator

T = TypeVar("T")
U = TypeVar("U")
OSX = platform.system() == "Darwin"
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1)
def dedup(x): return list(dict.fromkeys(x)) # retains list orderi
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])]
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_int(t: Tuple[Any, ...]) -> bool: return all(isinstance(s, int) for s in t)
def all_int(t: Sequence[Any]) -> bool: return all(isinstance(s, int) for s in t)
def round_up(num, amt:int): return (num+amt-1)//amt * amt
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
return {k:v for d in ds for k,v in d.items()}
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,))

@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))

DEBUG = getenv("DEBUG")
DEBUG, WINO, IMAGE = getenv("DEBUG"), getenv("WINO"), 0
CI = os.getenv("CI", "") != ""

@dataclass(frozen=True, order=True)
class DType:
priority: int # this determines when things get upcasted
itemsize: int
name: str
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
sz: int = 1
def __repr__(self): return f"dtypes.{self.name}"

class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64)
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float16: Final[DType] = DType(9, 2, "half", np.float16)
half = float16
float32: Final[DType] = DType(10, 4, "float", np.float32)
float = float32
float64: Final[DType] = DType(11, 8, "double", np.float64)
double = float64
int8: Final[DType] = DType(1, 1, "char", np.int8)
int16: Final[DType] = DType(3, 2, "short", np.int16)
int32: Final[DType] = DType(5, 4, "int", np.int32)
int64: Final[DType] = DType(7, 8, "long", np.int64)
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)

# NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(9, 2, "__bf16", None)

DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}

PtrDType, ImageDType, IMAGE = None, None, 0 # junk to remove
25 changes: 14 additions & 11 deletions teenygrad/lazy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations
from teenygrad.helpers import DType, dtypes, DEBUG
from teenygrad.helpers import DEBUG, prod
from teenygrad.dtype import DType, dtypes
from teenygrad.device import Buffer
from teenygrad.ops import UnaryOps, BinaryOps, ReduceOps, TernaryOps, LoadOps
import numpy as np

Expand All @@ -9,17 +11,16 @@ def toCPU(self): return self.x

class LazyBuffer:
device = "CPU"

def __init__(self, buf: np.ndarray): self._np = buf
def __init__(self, buf: np.ndarray): self.realized = Buffer("CPU", buf.size, dtypes.from_np(buf.dtype), buf)

@property
def base(self): return self
@property
def dtype(self): return dtypes.from_np(self._np.dtype)
def dtype(self): return self.realized.dtype
@property
def realized(self): return RawCPUBuffer(self._np)
def _np(self): return self.realized._buf
@property
def shape(self): return self._np.shape
def shape(self): return self.realized._buf.shape
def __repr__(self): return f"<LB {self.shape} {self.dtype}>"

def schedule(self, seen=None): return []
Expand All @@ -31,7 +32,9 @@ def fromCPU(x): return LazyBuffer(x)

@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
if op == LoadOps.RAND: return LazyBuffer(np.random.default_rng(arg).random(size=shape, dtype=dtype.np))
if op == LoadOps.CUSTOM:
arg(ret := Buffer(device, prod(shape), dtype))
return ret._buf.reshape(shape)
elif op == LoadOps.CONST: return LazyBuffer(np.full(shape, arg, dtype=dtype.np))
elif op == LoadOps.EMPTY: return LazyBuffer(np.empty(shape, dtype=dtype.np))
else: raise NotImplementedError(op)
Expand All @@ -52,16 +55,16 @@ def e(self, op, *srcs:LazyBuffer):
elif op == BinaryOps.SUB: ret = self._np - srcs[0]._np
elif op == BinaryOps.MUL: ret = self._np * srcs[0]._np
elif op == BinaryOps.DIV: ret = self._np / srcs[0]._np
elif op == BinaryOps.XOR: ret = self._np ^ srcs[0]._np
elif op == BinaryOps.MAX: ret = np.maximum(self._np, srcs[0]._np)
elif op == BinaryOps.CMPLT: ret = self._np < srcs[0]._np
elif op == BinaryOps.CMPEQ: ret = self._np == srcs[0]._np
elif op == TernaryOps.WHERE: ret = np.where(self._np, srcs[0]._np, srcs[1]._np)
else: raise NotImplementedError(op)
return LazyBuffer(ret.astype(self.dtype.np if len(srcs) == 0 else max(self.dtype, *[x.dtype for x in srcs]).np, copy=False))

def r(self, op, new_shape):
if DEBUG >= 1: print(op, self, new_shape)
assert len(self.shape) == len(new_shape), "reduce shapes must have same dimensions"
axis = tuple(i for i,(a,b) in enumerate(zip(self.shape, new_shape)) if a != b)
def r(self, op, axis):
if DEBUG >= 1: print(op, self, axis)
if op == ReduceOps.SUM: return LazyBuffer(self._np.sum(axis, dtype=self._np.dtype, keepdims=True))
elif op == ReduceOps.MAX: return LazyBuffer(self._np.max(axis, keepdims=True))
else: raise NotImplementedError(op)
Expand Down
Loading

0 comments on commit 48b96d5

Please sign in to comment.