Skip to content

Commit

Permalink
Allow parameter binding Arrays and Tuples
Browse files Browse the repository at this point in the history
Signed-off-by: asardesai2 <[email protected]>
  • Loading branch information
asardesai2 committed Nov 28, 2023
1 parent 7b81430 commit 203f0f2
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 5 deletions.
67 changes: 63 additions & 4 deletions comdb2/_ccdb2.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ ctypedef fused client_datetime:
lib.cdb2_client_datetimeus_t


cdef struct blob_descriptor:
size_t size
char* data


cdef _string_as_bytes(s):
if isinstance(s, unicode):
return s.encode('utf-8')
Expand Down Expand Up @@ -99,13 +104,15 @@ cdef _bind_datetime(obj, client_datetime *val):


cdef class _ParameterValue(object):
cdef int type
cdef lib.cdb2_coltype type
cdef int size
cdef void *data
cdef object owner
cdef int list_size

def __cinit__(self, obj, param_name):
try:
self.list_size = -1
if obj is None:
self.type = lib.CDB2_INTEGER
self.owner = None
Expand Down Expand Up @@ -152,6 +159,54 @@ cdef class _ParameterValue(object):
self.data = PyMem_Malloc(self.size)
_bind_datetime(obj, <lib.cdb2_client_datetime_t*>self.data)
return
elif isinstance(obj, (list, tuple)):
self.list_size = len(obj)
if 0 == self.list_size:
raise ValueError(f"empty {type(obj).__name__}s cannot be bound")

if all(isinstance(ele, int) for ele in obj):
self.type = lib.CDB2_INTEGER
self.size = sizeof(long long)
self.owner = None
self.data = PyMem_Malloc(self.list_size * self.size)
for l_index in range(self.list_size):
(<long long*>self.data)[l_index] = obj[l_index]
return
elif all(isinstance(ele, float) for ele in obj):
self.type = lib.CDB2_REAL
self.size = sizeof(double)
self.owner = None
self.data = PyMem_Malloc(self.list_size * self.size)
for l_index in range(self.list_size):
(<double*>self.data)[l_index] = obj[l_index]
return
elif all(isinstance(ele, bytes) for ele in obj):
self.type = lib.CDB2_BLOB
self.size = sizeof(blob_descriptor)
self.owner = obj
self.data = PyMem_Malloc(self.list_size * self.size)
for l_index in range(self.list_size):
(<blob_descriptor*>self.data)[l_index].size = len(obj[l_index])
(<blob_descriptor*>self.data)[l_index].data = obj[l_index]
return
elif all(isinstance(ele, unicode) for ele in obj):
self.type = lib.CDB2_CSTRING
self.size = sizeof(char*)
# Strings need to be converted to bytes
self.owner = [x.encode('utf-8') for x in obj]
self.data = PyMem_Malloc(self.list_size * self.size)
for l_index in range(self.list_size):
(<char**>self.data)[l_index] = <char*>(self.owner[l_index])
return
elif not all(isinstance(ele, type(obj[0])) for ele in obj):
raise ValueError(
f"all {type(obj).__name__} elements must be the same type"
)
else:
raise ValueError(
f"Cannot bind a {type(obj).__name__} of {type(obj[0]).__name__}"
)

except Exception as e:
exc = e
else:
Expand All @@ -177,7 +232,7 @@ cdef class _ParameterValue(object):


def __dealloc__(self):
if self.owner is None:
if self.owner is None or self.list_size != -1:
PyMem_Free(self.data)


Expand Down Expand Up @@ -343,8 +398,12 @@ cdef class Handle(object):
cval = _ParameterValue(val, key)
param_guards.append(ckey)
param_guards.append(cval)
rc = lib.cdb2_bind_param(self.hndl, <char*>ckey,
cval.type, cval.data, cval.size)
if cval.list_size == -1:
rc = lib.cdb2_bind_param(self.hndl, <char*>ckey,
cval.type, cval.data, cval.size)
else:
# Bind Array if cval is an array
rc = lib.cdb2_bind_array(self.hndl, <char*>ckey, cval.type, cval.data, cval.list_size, cval.size)
_errchk(rc, self.hndl)

with nogil:
Expand Down
3 changes: 2 additions & 1 deletion comdb2/_cdb2api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ cdef extern from "cdb2api.h" nogil:
CDB2ERR_NOTSUPPORTED
CDB2ERR_CONV_FAIL

enum:
enum cdb2_coltype:
CDB2_INTEGER
CDB2_REAL
CDB2_CSTRING
Expand Down Expand Up @@ -75,5 +75,6 @@ cdef extern from "cdb2api.h" nogil:
void* cdb2_column_value(cdb2_hndl_tp* hndl, int col);
const char* cdb2_errstr(cdb2_hndl_tp* hndl);
int cdb2_bind_param(cdb2_hndl_tp *hndl, const char *name, int type, const void *varaddr, int length);
int cdb2_bind_array(cdb2_hndl_tp *hndl, const char *name, cdb2_coltype, const void *varaddr, size_t count, size_t typelen);
int cdb2_clearbindings(cdb2_hndl_tp *hndl);
int cdb2_clear_ack(cdb2_hndl_tp *hndl);
64 changes: 64 additions & 0 deletions tests/test_dbapi2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import pytest
import datetime
import pytz
import re
from functools import partial

from unittest.mock import patch
Expand Down Expand Up @@ -851,3 +852,66 @@ def test_interface_error_reading_result_set_after_commits():
])
def test_finding_operation(statement, operation):
assert _sql_operation(statement) == operation


@pytest.mark.parametrize(
"values",
[
[5, 10],
("hello", "hi"),
[0.25, 0.35, 0.25],
(b"123", b"456"),
],
)
def test_parameter_binding_arrays(values):
# GIVEN
conn = connect("mattdb", "dev")
cursor = conn.cursor()

# WHEN
cursor.execute("select * from carray(%(values)s)", dict(values=values))
results = cursor.fetchall()

# THEN
assert results == [[v] for v in values]
conn.close()


@pytest.mark.parametrize(
"values,exc_msg",
[
(
[],
"Can't bind list value [] for parameter 'values': "
+ "ValueError: empty lists cannot be bound",
),
(
(),
"Can't bind tuple value () for parameter 'values': "
+ "ValueError: empty tuples cannot be bound",
),
(
[1, "hello"],
"Can't bind list value [1, 'hello'] for parameter 'values': "
+ "ValueError: all list elements must be the same type",
),
(
(1j,),
"Can't bind tuple value (1j,) for parameter 'values': "
+ "ValueError: Cannot bind a tuple of complex",
),
(
((1, 2, 3),),
"Can't bind tuple value ((1, 2, 3),) for parameter 'values': "
+ "ValueError: Cannot bind a tuple of tuple",
),
],
)
def test_parameter_binding_invalid_arrays(values, exc_msg):
# GIVEN
conn = connect("mattdb", "dev")
cursor = conn.cursor()

# WHEN/THEN
with pytest.raises(DataError, match=re.escape(exc_msg)):
cursor.execute("select * from carray(%(values)s)", dict(values=values))

0 comments on commit 203f0f2

Please sign in to comment.