-
Notifications
You must be signed in to change notification settings - Fork 86
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
1,972 additions
and
930 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from typing import Optional, Any | ||
from teenygrad.dtype import DType | ||
import numpy as np | ||
|
||
class Device: | ||
DEFAULT = "CPU" | ||
_devices = ["CPU"] | ||
@staticmethod | ||
def canonicalize(device:Optional[str]) -> str: return "CPU" | ||
|
||
class Buffer: | ||
def __init__(self, device:str, size:int, dtype:DType, opaque:Any=None, options=None): | ||
self.device, self.size, self.dtype, self._buf = device, size, dtype, opaque[1] if isinstance(opaque, tuple) else opaque | ||
def copyin(self, buf): self._buf = np.frombuffer(buf, dtype=self.dtype.np) | ||
def as_buffer(self): return self._buf.data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union | ||
from dataclasses import dataclass | ||
import numpy as np # TODO: remove numpy | ||
import functools | ||
|
||
Scalar = Union[float, int, bool] | ||
|
||
@dataclass(frozen=True, order=True) | ||
class DType: | ||
priority: int # this determines when things get upcasted | ||
itemsize: int | ||
name: str | ||
fmt: Optional[str] | ||
count: int | ||
def __repr__(self): return f"dtypes.{'_'*(c:=self.count!=1)}{INVERSE_DTYPES_DICT[self.name if not c else self.scalar().name]}{str(self.count)*c}" | ||
def vec(self, sz:int): | ||
assert sz > 1 and self.count == 1, f"can't vectorize {self} with size {sz}" | ||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz) | ||
def scalar(self): return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self | ||
# TODO: someday this will be removed with the "remove numpy" project | ||
@property | ||
def np(self) -> Optional[type]: return np.dtype(self.fmt).type if self.fmt is not None else None | ||
|
||
# dependent typing? | ||
@dataclass(frozen=True, repr=False) | ||
class ImageDType(DType): | ||
shape: Tuple[int, ...] # arbitrary arg for the dtype, used in image for the shape | ||
base: DType | ||
def scalar(self): return self.base | ||
def vec(self, sz:int): return self.base.vec(sz) | ||
def __repr__(self): return f"dtypes.{self.name}({self.shape})" | ||
|
||
# @dataclass(frozen=True, init=False, repr=False, eq=False) | ||
class PtrDType(DType): | ||
def __init__(self, dt:DType): super().__init__(dt.priority, dt.itemsize, dt.name, dt.fmt, dt.count) | ||
def __repr__(self): return f"ptr.{super().__repr__()}" | ||
def __hash__(self): return super().__hash__() | ||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count | ||
def __ne__(self, dt): return not (self == dt) | ||
|
||
def cast_scalar(scalar: Scalar, dtype:DType): | ||
return int(scalar) if dtypes.is_int(dtype) else float(scalar) if dtypes.is_float(dtype) else bool(scalar) | ||
|
||
class dtypes: | ||
@staticmethod | ||
def is_float(x: DType) -> bool: return x.scalar() in (dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64) | ||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool | ||
def is_int(x: DType) -> bool: return x.scalar() in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) or dtypes.is_unsigned(x) | ||
@staticmethod | ||
def is_unsigned(x: DType) -> bool: return x.scalar() in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) | ||
@staticmethod | ||
def from_np(x: type) -> DType: return DTYPES_DICT[np.dtype(x).name] | ||
@staticmethod # NOTE: isinstance(True, int) is True in python | ||
def from_py(x) -> DType: return dtypes.default_float if isinstance(x, float) else dtypes.bool if isinstance(x, bool) else dtypes.default_int | ||
@staticmethod | ||
def fields() -> Dict[str, DType]: return DTYPES_DICT | ||
bool: Final[DType] = DType(0, 1, "bool", '?', 1) | ||
int8: Final[DType] = DType(1, 1, "char", 'b', 1) | ||
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1) | ||
int16: Final[DType] = DType(3, 2, "short", 'h', 1) | ||
uint16: Final[DType] = DType(4, 2, "unsigned short", 'H', 1) | ||
int32: Final[DType] = DType(5, 4, "int", 'i', 1) | ||
uint32: Final[DType] = DType(6, 4, "unsigned int", 'I', 1) | ||
int64: Final[DType] = DType(7, 8, "long", 'l', 1) | ||
uint64: Final[DType] = DType(8, 8, "unsigned long", 'L', 1) | ||
float16: Final[DType] = DType(9, 2, "half", 'e', 1) | ||
# bfloat16 has higher priority than float16, so least_upper_dtype(dtypes.int64, dtypes.uint64) = dtypes.float16 | ||
bfloat16: Final[DType] = DType(10, 2, "__bf16", None, 1) | ||
float32: Final[DType] = DType(11, 4, "float", 'f', 1) | ||
float64: Final[DType] = DType(12, 8, "double", 'd', 1) | ||
|
||
# dtype aliases | ||
half = float16; float = float32; double = float64 # noqa: E702 | ||
uchar = uint8; ushort = uint16; uint = uint32; ulong = uint64 # noqa: E702 | ||
char = int8; short = int16; int = int32; long = int64 # noqa: E702 | ||
|
||
# NOTE: these are image dtypes | ||
@staticmethod | ||
def imageh(shp): return ImageDType(100, 2, "imageh", 'e', 1, shape=shp, base=dtypes.float32) | ||
@staticmethod | ||
def imagef(shp): return ImageDType(100, 4, "imagef", 'f', 1, shape=shp, base=dtypes.float32) | ||
|
||
default_float: ClassVar[DType] = float32 | ||
default_int: ClassVar[DType] = int32 | ||
|
||
# https://jax.readthedocs.io/en/latest/jep/9407-type-promotion.html | ||
# we don't support weak type and complex type | ||
promo_lattice = { dtypes.bool: [dtypes.int8, dtypes.uint8], dtypes.int8: [dtypes.int16], dtypes.int16: [dtypes.int32], dtypes.int32: [dtypes.int64], | ||
dtypes.int64: [dtypes.float16, dtypes.bfloat16], dtypes.uint8: [dtypes.int16, dtypes.uint16], dtypes.uint16: [dtypes.int32, dtypes.uint32], | ||
dtypes.uint32: [dtypes.int64, dtypes.uint64], dtypes.uint64: [dtypes.float16, dtypes.bfloat16], | ||
dtypes.float16: [dtypes.float32], dtypes.bfloat16: [dtypes.float32], dtypes.float32: [dtypes.float64], } | ||
|
||
@functools.lru_cache(None) | ||
def _get_recursive_parents(dtype:DType) -> Set[DType]: | ||
return set.union(*[_get_recursive_parents(d) for d in promo_lattice[dtype]], {dtype}) if dtype != dtypes.float64 else {dtypes.float64} | ||
@functools.lru_cache(None) | ||
def least_upper_dtype(*ds:DType) -> DType: | ||
return min(set.intersection(*[_get_recursive_parents(d) for d in ds])) if not (images:=[d for d in ds if isinstance(d, ImageDType)]) else images[0] | ||
def least_upper_float(dt:DType) -> DType: return dt if dtypes.is_float(dt) else least_upper_dtype(dt, dtypes.float32) | ||
|
||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class | ||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not (k.startswith(('__', 'default')) or v.__class__ is staticmethod)} | ||
INVERSE_DTYPES_DICT = {v.name:k for k,v in DTYPES_DICT.items()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,25 @@ | ||
from typing import Union, Tuple, Iterator, Optional, Final, Any | ||
import os, functools, platform | ||
import numpy as np | ||
from math import prod # noqa: F401 # pylint:disable=unused-import | ||
from dataclasses import dataclass | ||
from typing import Union, Tuple, Sequence, Any, Iterable, Dict, TypeVar | ||
import os, functools, platform, operator | ||
|
||
T = TypeVar("T") | ||
U = TypeVar("U") | ||
OSX = platform.system() == "Darwin" | ||
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.mul, x, 1) | ||
def dedup(x): return list(dict.fromkeys(x)) # retains list orderi | ||
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x | ||
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x | ||
def flatten(l:Iterator): return [item for sublist in l for item in sublist] | ||
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist] | ||
def fully_flatten(l): return [item for sublist in l for item in (fully_flatten(sublist) if isinstance(sublist, (tuple, list)) else [sublist])] | ||
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python | ||
def all_int(t: Tuple[Any, ...]) -> bool: return all(isinstance(s, int) for s in t) | ||
def all_int(t: Sequence[Any]) -> bool: return all(isinstance(s, int) for s in t) | ||
def round_up(num, amt:int): return (num+amt-1)//amt * amt | ||
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]: | ||
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501 | ||
return {k:v for d in ds for k,v in d.items()} | ||
def flat_mv(mv:memoryview): return mv if len(mv) == 0 else mv.cast("B", shape=(mv.nbytes,)) | ||
|
||
@functools.lru_cache(maxsize=None) | ||
def getenv(key, default=0): return type(default)(os.getenv(key, default)) | ||
|
||
DEBUG = getenv("DEBUG") | ||
DEBUG, WINO, IMAGE = getenv("DEBUG"), getenv("WINO"), 0 | ||
CI = os.getenv("CI", "") != "" | ||
|
||
@dataclass(frozen=True, order=True) | ||
class DType: | ||
priority: int # this determines when things get upcasted | ||
itemsize: int | ||
name: str | ||
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project | ||
sz: int = 1 | ||
def __repr__(self): return f"dtypes.{self.name}" | ||
|
||
class dtypes: | ||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool | ||
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) | ||
@staticmethod | ||
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64) | ||
@staticmethod | ||
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64) | ||
@staticmethod | ||
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name] | ||
bool: Final[DType] = DType(0, 1, "bool", np.bool_) | ||
float16: Final[DType] = DType(9, 2, "half", np.float16) | ||
half = float16 | ||
float32: Final[DType] = DType(10, 4, "float", np.float32) | ||
float = float32 | ||
float64: Final[DType] = DType(11, 8, "double", np.float64) | ||
double = float64 | ||
int8: Final[DType] = DType(1, 1, "char", np.int8) | ||
int16: Final[DType] = DType(3, 2, "short", np.int16) | ||
int32: Final[DType] = DType(5, 4, "int", np.int32) | ||
int64: Final[DType] = DType(7, 8, "long", np.int64) | ||
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8) | ||
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16) | ||
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32) | ||
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64) | ||
|
||
# NOTE: bfloat16 isn't supported in numpy | ||
bfloat16: Final[DType] = DType(9, 2, "__bf16", None) | ||
|
||
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod} | ||
|
||
PtrDType, ImageDType, IMAGE = None, None, 0 # junk to remove |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.