Skip to content

Commit

Permalink
Fix colname mode
Browse files Browse the repository at this point in the history
  • Loading branch information
ostafen committed May 18, 2024
1 parent fbc9fb1 commit 172e062
Showing 1 changed file with 42 additions and 49 deletions.
91 changes: 42 additions & 49 deletions immudb/handler/sqlquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,15 @@ def _call_with_executor(query, params, columnNameMode, dbname, acceptStream, exe
return RowIterator(resp, columnNameMode, dbname)

res = next(resp)
columnNames = getColumnNames(res, columnNameMode)
rows = unpack_rows(res, columnNameMode, columnNames)
return fix_colnames(rows, dbname, columnNameMode)
columnNames = getColumnNames(res, dbname, columnNameMode)
return unpack_rows(res, columnNameMode, columnNames)


def fix_colnames(ret, dbname, columnNameMode):
def fix_colnames(cols, dbname, columnNameMode):
if columnNameMode not in [constants.COLUMN_NAME_MODE_DATABASE, constants.COLUMN_NAME_MODE_FULL]:
return ret
return cols

# newer DB version don't insert database name anymore, we need to
# process it manually
for i, t in enumerate(ret):
newkeys = [
x.replace("[@DB]", dbname.decode("utf-8")) for x in t.keys()]
k = dict(zip(newkeys, list(t.values())))
ret[i] = k
return ret
return [x.replace("[@DB]", dbname.decode("utf-8")) for x in cols]


def unpack_rows(resp, columnNameMode, colNames):
Expand All @@ -69,34 +61,36 @@ def unpack_rows(resp, columnNameMode, colNames):
return result


def getColumnNames(resp, columnNameMode):
columnNames = []
if columnNameMode != constants.COLUMN_NAME_MODE_NONE:
for column in resp.columns:
# note that depending on the version parts can be
# '(dbname.tablename.fieldname)' *or*
# '(tablename.fieldname)' without dbnname.
# In that case we mimic the old behavior by using [@DB] as placeholder
# that will be replaced at higher level.
parts = column.name.strip("()").split(".")
if columnNameMode == constants.COLUMN_NAME_MODE_FIELD:
columnNames.append(parts[-1])
continue
if columnNameMode == constants.COLUMN_NAME_MODE_TABLE:
columnNames.append(".".join(parts[-2:]))
continue
print(
"Use of COLUMN_NAME_MODE_DATABASE and COLUMN_NAME_MODE_FULL is deprecated")
if len(parts) == 2:
parts.insert(0, "[@DB]")
if columnNameMode == constants.COLUMN_NAME_MODE_DATABASE:
columnNames.append(".".join(parts))
continue
if columnNameMode == constants.COLUMN_NAME_MODE_FULL:
columnNames.append("("+".".join(parts)+")")
continue
raise ErrPySDKInvalidColumnMode
return columnNames
def getColumnNames(resp, dbname, columnNameMode):
cols = []
if columnNameMode == constants.COLUMN_NAME_MODE_NONE:
return cols

for column in resp.columns:
# note that depending on the version parts can be
# '(dbname.tablename.fieldname)' *or*
# '(tablename.fieldname)' without dbnname.
# In that case we mimic the old behavior by using [@DB] as placeholder
# that will be replaced at higher level.
parts = column.name.strip("()").split(".")
if columnNameMode == constants.COLUMN_NAME_MODE_FIELD:
cols.append(parts[-1])
continue
if columnNameMode == constants.COLUMN_NAME_MODE_TABLE:
cols.append(".".join(parts[-2:]))
continue
print(
"Use of COLUMN_NAME_MODE_DATABASE and COLUMN_NAME_MODE_FULL is deprecated")
if len(parts) == 2:
parts.insert(0, "[@DB]")
if columnNameMode == constants.COLUMN_NAME_MODE_DATABASE:
cols.append(".".join(parts))
continue
if columnNameMode == constants.COLUMN_NAME_MODE_FULL:
cols.append("("+".".join(parts)+")")
continue
raise ErrPySDKInvalidColumnMode
return fix_colnames(cols, dbname, columnNameMode)


class ClosedIterator(BaseException):
Expand All @@ -109,7 +103,7 @@ def __init__(self, grpcIt, colNameMode, dbname) -> None:
self._nextRow = 0
self._rows = []
self._columns = None
self._colNameMode = colNameMode if colNameMode != constants.COLUMN_NAME_MODE_NONE else constants.COLUMN_NAME_MODE_FIELD
self._colNameMode = colNameMode
self._dbname = dbname
self._closed = False

Expand All @@ -132,22 +126,21 @@ def _fetch_next(self):

res = next(self._grpcIt)
if self._columns == None:
self._columns = getColumnNames(res, self._colNameMode)
self._columns = getColumnNames(res, self._dbname, self._colsMode())

self._rows = unpack_rows(
res, constants.COLUMN_NAME_MODE_NONE, self._columns)
res, self._colNameMode, self._columns)
self._nextRow = 0

if len(self._rows) == 0:
raise StopIteration

def _colsMode(self):
return self._colNameMode if self._colNameMode != constants.COLUMN_NAME_MODE_NONE else constants.COLUMN_NAME_MODE_FIELD

def columns(self):
self._fetch_next()

if self._colNameMode not in [constants.COLUMN_NAME_MODE_DATABASE, constants.COLUMN_NAME_MODE_FULL]:
return self._columns

return [x.replace("[@DB]", self._dbname.decode("utf-8")) for x in self._columns]
return self._columns

def close(self):
if self._closed:
Expand Down

0 comments on commit 172e062

Please sign in to comment.