diff --git a/comdb2/_ccdb2.pyx b/comdb2/_ccdb2.pyx index 155bacb..bcaca37 100644 --- a/comdb2/_ccdb2.pyx +++ b/comdb2/_ccdb2.pyx @@ -39,6 +39,11 @@ ctypedef fused client_datetime: lib.cdb2_client_datetimeus_t +cdef struct blob_descriptor: + size_t size + char* data + + cdef _string_as_bytes(s): if isinstance(s, unicode): return s.encode('utf-8') @@ -99,13 +104,15 @@ cdef _bind_datetime(obj, client_datetime *val): cdef class _ParameterValue(object): - cdef int type + cdef lib.cdb2_coltype type cdef int size cdef void *data cdef object owner + cdef int list_size def __cinit__(self, obj, param_name): try: + self.list_size = -1 if obj is None: self.type = lib.CDB2_INTEGER self.owner = None @@ -152,6 +159,54 @@ cdef class _ParameterValue(object): self.data = PyMem_Malloc(self.size) _bind_datetime(obj, self.data) return + elif isinstance(obj, (list, tuple)): + self.list_size = len(obj) + if 0 == self.list_size: + raise ValueError(f"empty {type(obj).__name__}s cannot be bound") + + if all(isinstance(ele, int) for ele in obj): + self.type = lib.CDB2_INTEGER + self.size = sizeof(long long) + self.owner = None + self.data = PyMem_Malloc(self.list_size * self.size) + for l_index in range(self.list_size): + (self.data)[l_index] = obj[l_index] + return + elif all(isinstance(ele, float) for ele in obj): + self.type = lib.CDB2_REAL + self.size = sizeof(double) + self.owner = None + self.data = PyMem_Malloc(self.list_size * self.size) + for l_index in range(self.list_size): + (self.data)[l_index] = obj[l_index] + return + elif all(isinstance(ele, bytes) for ele in obj): + self.type = lib.CDB2_BLOB + self.size = sizeof(blob_descriptor) + self.owner = obj + self.data = PyMem_Malloc(self.list_size * self.size) + for l_index in range(self.list_size): + (self.data)[l_index].size = len(obj[l_index]) + (self.data)[l_index].data = obj[l_index] + return + elif all(isinstance(ele, unicode) for ele in obj): + self.type = lib.CDB2_CSTRING + self.size = sizeof(char*) + # Strings need to be converted to bytes + self.owner = [x.encode('utf-8') for x in obj] + self.data = PyMem_Malloc(self.list_size * self.size) + for l_index in range(self.list_size): + (self.data)[l_index] = (self.owner[l_index]) + return + elif not all(isinstance(ele, type(obj[0])) for ele in obj): + raise ValueError( + f"all {type(obj).__name__} elements must be the same type" + ) + else: + raise ValueError( + f"Cannot bind a {type(obj).__name__} of {type(obj[0]).__name__}" + ) + except Exception as e: exc = e else: @@ -177,7 +232,7 @@ cdef class _ParameterValue(object): def __dealloc__(self): - if self.owner is None: + if self.owner is None or self.list_size != -1: PyMem_Free(self.data) @@ -343,8 +398,12 @@ cdef class Handle(object): cval = _ParameterValue(val, key) param_guards.append(ckey) param_guards.append(cval) - rc = lib.cdb2_bind_param(self.hndl, ckey, - cval.type, cval.data, cval.size) + if cval.list_size == -1: + rc = lib.cdb2_bind_param(self.hndl, ckey, + cval.type, cval.data, cval.size) + else: + # Bind Array if cval is an array + rc = lib.cdb2_bind_array(self.hndl, ckey, cval.type, cval.data, cval.list_size, cval.size) _errchk(rc, self.hndl) with nogil: diff --git a/comdb2/_cdb2api.pxd b/comdb2/_cdb2api.pxd index 40c84ab..afae380 100644 --- a/comdb2/_cdb2api.pxd +++ b/comdb2/_cdb2api.pxd @@ -18,7 +18,7 @@ cdef extern from "cdb2api.h" nogil: CDB2ERR_NOTSUPPORTED CDB2ERR_CONV_FAIL - enum: + enum cdb2_coltype: CDB2_INTEGER CDB2_REAL CDB2_CSTRING @@ -75,5 +75,6 @@ cdef extern from "cdb2api.h" nogil: void* cdb2_column_value(cdb2_hndl_tp* hndl, int col); const char* cdb2_errstr(cdb2_hndl_tp* hndl); int cdb2_bind_param(cdb2_hndl_tp *hndl, const char *name, int type, const void *varaddr, int length); + int cdb2_bind_array(cdb2_hndl_tp *hndl, const char *name, cdb2_coltype, const void *varaddr, size_t count, size_t typelen); int cdb2_clearbindings(cdb2_hndl_tp *hndl); int cdb2_clear_ack(cdb2_hndl_tp *hndl); diff --git a/tests/test_dbapi2.py b/tests/test_dbapi2.py index 40e33fa..25409c4 100644 --- a/tests/test_dbapi2.py +++ b/tests/test_dbapi2.py @@ -36,6 +36,7 @@ import pytest import datetime import pytz +import re from functools import partial from unittest.mock import patch @@ -851,3 +852,66 @@ def test_interface_error_reading_result_set_after_commits(): ]) def test_finding_operation(statement, operation): assert _sql_operation(statement) == operation + + +@pytest.mark.parametrize( + "values", + [ + [5, 10], + ("hello", "hi"), + [0.25, 0.35, 0.25], + (b"123", b"456"), + ], +) +def test_parameter_binding_arrays(values): + # GIVEN + conn = connect("mattdb", "dev") + cursor = conn.cursor() + + # WHEN + cursor.execute("select * from carray(%(values)s)", dict(values=values)) + results = cursor.fetchall() + + # THEN + assert results == [[v] for v in values] + conn.close() + + +@pytest.mark.parametrize( + "values,exc_msg", + [ + ( + [], + "Can't bind list value [] for parameter 'values': " + + "ValueError: empty lists cannot be bound", + ), + ( + (), + "Can't bind tuple value () for parameter 'values': " + + "ValueError: empty tuples cannot be bound", + ), + ( + [1, "hello"], + "Can't bind list value [1, 'hello'] for parameter 'values': " + + "ValueError: all list elements must be the same type", + ), + ( + (1j,), + "Can't bind tuple value (1j,) for parameter 'values': " + + "ValueError: Cannot bind a tuple of complex", + ), + ( + ((1, 2, 3),), + "Can't bind tuple value ((1, 2, 3),) for parameter 'values': " + + "ValueError: Cannot bind a tuple of tuple", + ), + ], +) +def test_parameter_binding_invalid_arrays(values, exc_msg): + # GIVEN + conn = connect("mattdb", "dev") + cursor = conn.cursor() + + # WHEN/THEN + with pytest.raises(DataError, match=re.escape(exc_msg)): + cursor.execute("select * from carray(%(values)s)", dict(values=values))