Skip to content

Commit

Permalink
Merge pull request tensor-compiler#350 from RawnH/pytaco_negation
Browse files Browse the repository at this point in the history
Adds negation to pytaco tensor interface
  • Loading branch information
stephenchouca authored Dec 22, 2020
2 parents 81ebb64 + e48e4ec commit fb2ed72
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
38 changes: 37 additions & 1 deletion python_bindings/pytaco/pytensor/taco_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,10 @@ def __pow__(self, power, modulo=None):
return tensor_pow(self, power, default_mode)

def __abs__(self):
return tensor_abs(self, default_mode)
return tensor_abs(self, self.format)

def __neg__(self):
return tensor_neg(self, self.format)

def __array__(self):
if not _cm.is_dense(self.format):
Expand Down Expand Up @@ -1482,6 +1485,39 @@ def tensor_logical_not(t1, out_format, dtype=None):
"""
return _compute_unary_elt_eise_op(_cm.logical_not, t1, out_format, dtype)

def tensor_neg(t1, out_format, dtype=None):
"""
Negates every value in the tensor.
The tensor class implements ``__neg__`` using this method.
Parameters
------------
t1: tensor, array_like
input tensor or array_like object
out_format: format, mode_format, optional
* If a :class:`format` is specified, the result tensor is stored in the format out_format.
* If a :class:`mode_format` is specified, the result the result tensor has a with all of the dimensions
stored in the :class:`mode_format` passed in.
dtype: Datatype
The datatype of the output tensor.
Examples
----------
>>> import pytaco as pt
>>> pt.tensor_neg([1, -2, 0], out_format=pt.dense).toarray()
array([-1, 2, 0], dtype=int64)
Returns
--------
neg: tensor
The element wise negation of the input tensor.
"""
return _compute_unary_elt_eise_op(_cm.neg, t1, out_format, dtype)

def tensor_abs(t1, out_format, dtype=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions python_bindings/unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,10 @@ def test_mod(self):
t1[i, j] = pt.remainder(t[i, j], 2)
self.assertEqual(t1, arr % 2)

def test_neg(self):
arr = np.arange(1, 5).reshape([2, 2])
t = pt.from_array(arr)
self.assertEqual(-t, -arr)

class testParsers(unittest.TestCase):

Expand Down

0 comments on commit fb2ed72

Please sign in to comment.