Skip to content

Commit

Permalink
Merge pull request #369 from rsokl/mirror-dtype
Browse files Browse the repository at this point in the history
mirror numpy dtypes that are valid for tensors
  • Loading branch information
rsokl authored Mar 8, 2021
2 parents 82bd144 + c7553bb commit bd3f7ee
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/mygrad/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
tensor,
Tensor,
)
from mygrad._dtype_mirrors import *
from mygrad._utils.graph_tracking import no_autodiff
from mygrad._utils.lock_management import (
mem_guard_active,
Expand Down
41 changes: 41 additions & 0 deletions src/mygrad/_dtype_mirrors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import numpy

__all__ = [
"bool8",
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"intp",
"uintp",
"float16",
"float32",
"float64",
"half",
"single",
"double",
"longdouble",
]

bool8 = numpy.bool8
int8 = numpy.int8
int16 = numpy.int16
int32 = numpy.int32
int64 = numpy.int64
uint8 = numpy.uint8
uint16 = numpy.uint16
uint32 = numpy.uint32
uint64 = numpy.uint64
intp = numpy.intp
uintp = numpy.uintp
float16 = numpy.float16
float32 = numpy.float32
float64 = numpy.float64
half = numpy.half
single = numpy.single
double = numpy.double
longdouble = numpy.longdouble
9 changes: 9 additions & 0 deletions tests/test_dtype_mirrors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

import mygrad as mg
from mygrad._dtype_mirrors import __all__ as all_mirrored_dtyped


@pytest.mark.parametrize("dtype_str", all_mirrored_dtyped)
def test_mirrored_dtype_is_valid(dtype_str):
mg.tensor(1, dtype=getattr(mg, dtype_str))

0 comments on commit bd3f7ee

Please sign in to comment.