Skip to content

Commit

Permalink
Fix (u)int8 upcasting as per docs and numpy (#650)
Browse files Browse the repository at this point in the history
* fix wrong #if guard in ndarray_inplace_ams

* implement (u)int8 upcasting rules as per documentation

* bump version
  • Loading branch information
s-ol authored Dec 11, 2023
1 parent 4bde4ef commit e329206
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
10 changes: 5 additions & 5 deletions code/ndarray_operators.c
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ mp_obj_t ndarray_binary_add(ndarray_obj_t *lhs, ndarray_obj_t *rhs,

if(lhs->dtype == NDARRAY_UINT8) {
if(rhs->dtype == NDARRAY_UINT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
BINARY_LOOP(results, uint16_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, +);
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
BINARY_LOOP(results, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, +);
} else if(rhs->dtype == NDARRAY_INT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
BINARY_LOOP(results, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides, +);
Expand Down Expand Up @@ -264,8 +264,8 @@ mp_obj_t ndarray_binary_multiply(ndarray_obj_t *lhs, ndarray_obj_t *rhs,

if(lhs->dtype == NDARRAY_UINT8) {
if(rhs->dtype == NDARRAY_UINT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT16);
BINARY_LOOP(results, uint16_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, *);
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_UINT8);
BINARY_LOOP(results, uint8_t, uint8_t, uint8_t, larray, lstrides, rarray, rstrides, *);
} else if(rhs->dtype == NDARRAY_INT8) {
results = ndarray_new_dense_ndarray(ndim, shape, NDARRAY_INT16);
BINARY_LOOP(results, int16_t, uint8_t, int8_t, larray, lstrides, rarray, rstrides, *);
Expand Down Expand Up @@ -1059,7 +1059,7 @@ mp_obj_t ndarray_inplace_ams(ndarray_obj_t *lhs, ndarray_obj_t *rhs, int32_t *rs
UNWRAP_INPLACE_OPERATOR(lhs, larray, rarray, rstrides, +=);
}
#endif
#if NDARRAY_HAS_INPLACE_ADD
#if NDARRAY_HAS_INPLACE_MULTIPLY
if(optype == MP_BINARY_OP_INPLACE_MULTIPLY) {
UNWRAP_INPLACE_OPERATOR(lhs, larray, rarray, rstrides, *=);
}
Expand Down
2 changes: 1 addition & 1 deletion code/ulab.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include "user/user.h"
#include "utils/utils.h"

#define ULAB_VERSION 6.4.1
#define ULAB_VERSION 6.4.2
#define xstr(s) str(s)
#define str(s) #s

Expand Down
6 changes: 6 additions & 0 deletions docs/ulab-change-log.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
Thu, 11 Dec 2023

version 6.4.2

fix upcasting with two uint8 operands (#650)

Thu, 10 Aug 2023

version 6.4.1
Expand Down
2 changes: 1 addition & 1 deletion tests/2d/numpy/operators.py.exp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ array([1.0, 2.0, 3.0], dtype=float64)
array([1.0, 32.0, 729.0], dtype=float64)
array([1.0, 32.0, 729.0], dtype=float64)
array([1.0, 32.0, 729.0], dtype=float64)
array([5, 7, 9], dtype=uint16)
array([5, 7, 9], dtype=uint8)
array([5, 7, 9], dtype=int16)
array([5, 7, 9], dtype=int8)
array([5, 7, 9], dtype=uint16)
Expand Down

0 comments on commit e329206

Please sign in to comment.