diff --git a/mpyc/runtime.py b/mpyc/runtime.py index 97afee22..fe4d246b 100644 --- a/mpyc/runtime.py +++ b/mpyc/runtime.py @@ -2909,9 +2909,16 @@ async def np_sum(self, a, axis=None, keepdims=False, initial=0): else: shape = tuple(s for i, s in enumerate(a.shape) if i not in axes) if shape == (): - await self.returnType(sectype) + if isinstance(a, self.SecureFixedPointArray): + rettype = (sectype, a.integral) + else: + rettype = sectype else: - await self.returnType((type(a), shape)) + if isinstance(a, self.SecureFixedPointArray): + rettype = (type(a), a.integral, shape) + else: + rettype = (type(a), shape) + await self.returnType(rettype) a, initial = await self.gather(a, initial) return np.sum(a, axis=axis, keepdims=keepdims, initial=initial.value) # TODO: handle switch from initial (field elt) to initial.value inside finfields.py diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 61770ecb..b1ef47af 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -325,6 +325,41 @@ def test_secfxp_array(self): self.assertEqual(np.vstack((c1, c1)).integral, False) self.assertEqual(np.vstack((c2, c2)).integral, True) + self.assertEqual(np.sum(c2).integral, True) + self.assertEqual(np.sum(c2, axis=0).integral, True) + self.assertEqual(np.sum(c1).integral, False) + self.assertEqual(np.sum(c1, axis=0).integral, False) + + self.assertEqual(mpc.np_sgn(c1).integral, True) + self.assertEqual(mpc.np_sgn(c2).integral, True) + self.assertEqual(np.absolute(c1).integral, False) + self.assertEqual(np.absolute(c2).integral, True) + + self.assertEqual(np.minimum(c1, c2).integral, False) + self.assertEqual(np.minimum(c1, c1).integral, False) + self.assertEqual(np.minimum(c2, c2).integral, True) + self.assertEqual(np.maximum(c1, c2).integral, False) + self.assertEqual(np.maximum(c1, c1).integral, False) + self.assertEqual(np.maximum(c2, c2).integral, True) + + self.assertEqual(np.amin(c2).integral, True) + self.assertEqual(np.amin(c2, axis=0).integral, True) + self.assertEqual(np.amin(c1).integral, False) + self.assertEqual(np.amin(c1, axis=0).integral, False) + self.assertEqual(np.amax(c2).integral, True) + self.assertEqual(np.amax(c2, axis=0).integral, True) + self.assertEqual(np.amax(c1).integral, False) + self.assertEqual(np.amax(c1, axis=0).integral, False) + + self.assertEqual(np.argmin(c2).integral, True) + self.assertEqual(np.argmin(c2, axis=0).integral, True) + self.assertEqual(np.argmin(c1).integral, True) + self.assertEqual(np.argmin(c1, axis=0).integral, True) + self.assertEqual(np.argmax(c2).integral, True) + self.assertEqual(np.argmax(c2, axis=0).integral, True) + self.assertEqual(np.argmax(c1).integral, True) + self.assertEqual(np.argmax(c1, axis=0).integral, True) + @unittest.skipIf(not np, 'NumPy not available or inside MPyC disabled') def test_secfld_array(self): np.assertEqual = np.testing.assert_array_equal