From e32920645c8c5e6153f29e4e33ca7e3b7087aef1 Mon Sep 17 00:00:00 2001 From: sol <1731279+s-ol@users.noreply.github.com> Date: Mon, 11 Dec 2023 22:34:21 +0100 Subject: [PATCH] Fix (u)int8 upcasting as per docs and numpy (#650) * fix wrong #if guard in ndarray_inplace_ams * implement (u)int8 upcasting rules as per documentation * bump version --- code/ndarray_operators.c | 10 +++++----- code/ulab.c | 2 +- docs/ulab-change-log.md | 6 ++++++ tests/2d/numpy/operators.py.exp | 2 +- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/code/ndarray_operators.c b/code/ndarray_operators.c index b7d61f26..e8de4d48 100644 --- a/code/ndarray_operators.c +++ b/code/ndarray_operators.c @@ -181,8 +181,8 @@ mp_obj_t ndarray_binary_add(ndarray_obj_t *lhs, ndarray_obj_t *rhs, if(lhs->dtype == NDARRAY_UINT8) { if(rhs->dtype == NDARRAY_UINT8) { - results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16); - BINARY_LOOP(results, uint16_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, +); + results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8); + BINARY_LOOP(results, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, +); } else if(rhs->dtype == NDARRAY_INT8) { results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16); BINARY_LOOP(results, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides, +); @@ -264,8 +264,8 @@ mp_obj_t ndarray_binary_multiply(ndarray_obj_t *lhs, ndarray_obj_t *rhs, if(lhs->dtype == NDARRAY_UINT8) { if(rhs->dtype == NDARRAY_UINT8) { - results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16); - BINARY_LOOP(results, uint16_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, *); + results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8); + BINARY_LOOP(results, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, *); } else if(rhs->dtype == NDARRAY_INT8) { results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16); BINARY_LOOP(results, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides, *); @@ -1059,7 +1059,7 @@ mp_obj_t ndarray_inplace_ams(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t *rs UNWRAP_INPLACE_OPERATOR(lhs, larray, rarray, rstrides, +=); } #endif - #if NDARRAY_HAS_INPLACE_ADD + #if NDARRAY_HAS_INPLACE_MULTIPLY if(optype == MP_BINARY_OP_INPLACE_MULTIPLY) { UNWRAP_INPLACE_OPERATOR(lhs, larray, rarray, rstrides, *=); } diff --git a/code/ulab.c b/code/ulab.c index af7ba0ba..5c5067bd 100644 --- a/code/ulab.c +++ b/code/ulab.c @@ -33,7 +33,7 @@ #include "user/user.h" #include "utils/utils.h" -#define ULAB_VERSION 6.4.1 +#define ULAB_VERSION 6.4.2 #define xstr(s) str(s) #define str(s) #s diff --git a/docs/ulab-change-log.md b/docs/ulab-change-log.md index ff921249..4d389f9a 100644 --- a/docs/ulab-change-log.md +++ b/docs/ulab-change-log.md @@ -1,3 +1,9 @@ +Thu, 11 Dec 2023 + +version 6.4.2 + + fix upcasting with two uint8 operands (#650) + Thu, 10 Aug 2023 version 6.4.1 diff --git a/tests/2d/numpy/operators.py.exp b/tests/2d/numpy/operators.py.exp index 319c2073..7517a390 100644 --- a/tests/2d/numpy/operators.py.exp +++ b/tests/2d/numpy/operators.py.exp @@ -94,7 +94,7 @@ array([1.0, 2.0, 3.0], dtype=float64) array([1.0, 32.0, 729.0], dtype=float64) array([1.0, 32.0, 729.0], dtype=float64) array([1.0, 32.0, 729.0], dtype=float64) -array([5, 7, 9], dtype=uint16) +array([5, 7, 9], dtype=uint8) array([5, 7, 9], dtype=int16) array([5, 7, 9], dtype=int8) array([5, 7, 9], dtype=uint16)