From b9e361131e7b1fdd6a59f1a0b71f9dc4a5c6968a Mon Sep 17 00:00:00 2001 From: Parth Nobel Date: Sat, 18 Apr 2020 20:15:10 -0700 Subject: [PATCH] Fixes exponentiation support in auto_diff. --- auto_diff/vecvalder.py | 33 +++++++++++++++++++++++++ auto_diff/vecvalder_funcs_and_ufuncs.py | 26 ++++++++++++------- 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/auto_diff/vecvalder.py b/auto_diff/vecvalder.py index 41c0678..d4230dc 100644 --- a/auto_diff/vecvalder.py +++ b/auto_diff/vecvalder.py @@ -4,6 +4,11 @@ _HANDLED_FUNCS_AND_UFUNCS = {} +def _defer_to_val(f): + def fn(self, *args, **kwargs): + return getattr(self.val, f)(*args, **kwargs) + fn.__name__ = f + return fn class VecValDer(np.lib.mixins.NDArrayOperatorsMixin): __slots__ = 'val', 'der' @@ -25,6 +30,34 @@ def transpose(self, *axes): axes = None return np.transpose(self, axes) + all = _defer_to_val('all') + any = _defer_to_val('any') + argmax = _defer_to_val('argmax') + argmin = _defer_to_val('argmin') + argpartition = _defer_to_val('argpartition') + argsort = _defer_to_val('argsort') + nonzero = _defer_to_val('nonzero') + + def copy(self): + return VecValDer(self.val.copy(), self.der.copy()) + + def fill(self, value): + if isinstance(value, VecValDer): + self.val.fill(value.val) + self.der[:] = value.der + else: + self.val.fill(value) + self.der.fill(0.0) + + def reshape(self, shape): + der_dim_shape = self.der.shape[len(self.val.shape):] + new_der_shape = shape + der_dim_shape + self.val.reshape(shape) + self.der.reshape(new_der_shape) + + def trace(self, *args, **kwargs): + return np.trace(*args, **kwargs) + def __array_ufunc__(self, ufunc, method, *args, **kwargs): if method == '__call__' and ufunc in _HANDLED_FUNCS_AND_UFUNCS: return _HANDLED_FUNCS_AND_UFUNCS[ufunc](*args, **kwargs) diff --git a/auto_diff/vecvalder_funcs_and_ufuncs.py b/auto_diff/vecvalder_funcs_and_ufuncs.py index 7caee51..edac69b 100644 --- a/auto_diff/vecvalder_funcs_and_ufuncs.py +++ b/auto_diff/vecvalder_funcs_and_ufuncs.py @@ -43,7 +43,7 @@ def _chain_rule(f_x, dx, out=None): if out is None: out = np.ndarray(dx.shape) # Uninitialized memory is fine because # we're about to overwrite each element. If we do compression of the for - # loop in the future be sure to swtich to np.zeros. + # loop in the future be sure to switch to np.zeros. for index, y in np.ndenumerate(f_x): out[index] = y * dx[index] return out @@ -284,25 +284,33 @@ def true_divide(x1, x2, /, out): raise RuntimeError("This should not be occuring.") -# Tested @register(np.float_power) @_add_out_support -def float_power(x1, x2): +def float_power(x1, x2, /, out): if isinstance(x1, cls) and isinstance(x2, cls): - return cls(x1.val ** x2.val, x1.val**(x2.val - 1) * ( - x2.val * x1.der + x1.val * np.log(x1.val) * x2.der)) + return cls(np.float_power(x1.val, x2.val, out=out.val), + np.multiply(x1.val**(x2.val - 1), (x2.val * x1.der + x1.val * np.log(x1.val) * x2.der), out=out.der)) elif isinstance(x1, cls): - return cls(x1.val ** x2, x1.val**(x2 - 1) * x2 * x1.der) + return cls(np.float_power(x1.val, x2, out=out.val), np.multiply(x1.val**(x2 - 1) * x2, x1.der, out=out.der)) elif isinstance(x2, cls): - return cls(x1.val ** x2.val, x1**(x2.val) * np.log(x1.val) * x2.der) + return cls(np.float_power(x1, x2.val, out=out.val), np.multiply(x1**(x2.val) * np.log(x1.val), x2.der, out=out.der)) else: raise RuntimeError("This should not be occuring.") @register(np.power) @_add_out_support -def power(x1, x2): - return float_power(x1, x2) +def power(x1, x2, /, out): + if isinstance(x1, cls) and isinstance(x2, cls): + return cls(np.power(x1.val, x2.val, out=out.val), + np.multiply(x1.val**(x2.val - 1), (x2.val * x1.der + x1.val * np.log(x1.val) * x2.der), out=out.der)) + elif isinstance(x1, cls): + return cls(np.power(x1.val, x2, out=out.val), np.multiply(x1.val**(x2 - 1) * x2, x1.der, out=out.der)) + elif isinstance(x2, cls): + return cls(np.power(x1, x2.val, out=out.val), np.multiply(x1**(x2.val) * np.log(x1.val), x2.der, out=out.der)) + else: + raise RuntimeError("This should not be occuring.") + # Partially Tested