From 172e062a936ad21ab22cc6bdd5f0e1c0d46a572d Mon Sep 17 00:00:00 2001 From: Stefano Scafiti Date: Sat, 18 May 2024 13:45:57 +0200 Subject: [PATCH] Fix colname mode --- immudb/handler/sqlquery.py | 91 ++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 49 deletions(-) diff --git a/immudb/handler/sqlquery.py b/immudb/handler/sqlquery.py index c378d90..6593ec6 100644 --- a/immudb/handler/sqlquery.py +++ b/immudb/handler/sqlquery.py @@ -39,23 +39,15 @@ def _call_with_executor(query, params, columnNameMode, dbname, acceptStream, exe return RowIterator(resp, columnNameMode, dbname) res = next(resp) - columnNames = getColumnNames(res, columnNameMode) - rows = unpack_rows(res, columnNameMode, columnNames) - return fix_colnames(rows, dbname, columnNameMode) + columnNames = getColumnNames(res, dbname, columnNameMode) + return unpack_rows(res, columnNameMode, columnNames) -def fix_colnames(ret, dbname, columnNameMode): +def fix_colnames(cols, dbname, columnNameMode): if columnNameMode not in [constants.COLUMN_NAME_MODE_DATABASE, constants.COLUMN_NAME_MODE_FULL]: - return ret + return cols - # newer DB version don't insert database name anymore, we need to - # process it manually - for i, t in enumerate(ret): - newkeys = [ - x.replace("[@DB]", dbname.decode("utf-8")) for x in t.keys()] - k = dict(zip(newkeys, list(t.values()))) - ret[i] = k - return ret + return [x.replace("[@DB]", dbname.decode("utf-8")) for x in cols] def unpack_rows(resp, columnNameMode, colNames): @@ -69,34 +61,36 @@ def unpack_rows(resp, columnNameMode, colNames): return result -def getColumnNames(resp, columnNameMode): - columnNames = [] - if columnNameMode != constants.COLUMN_NAME_MODE_NONE: - for column in resp.columns: - # note that depending on the version parts can be - # '(dbname.tablename.fieldname)' *or* - # '(tablename.fieldname)' without dbnname. - # In that case we mimic the old behavior by using [@DB] as placeholder - # that will be replaced at higher level. - parts = column.name.strip("()").split(".") - if columnNameMode == constants.COLUMN_NAME_MODE_FIELD: - columnNames.append(parts[-1]) - continue - if columnNameMode == constants.COLUMN_NAME_MODE_TABLE: - columnNames.append(".".join(parts[-2:])) - continue - print( - "Use of COLUMN_NAME_MODE_DATABASE and COLUMN_NAME_MODE_FULL is deprecated") - if len(parts) == 2: - parts.insert(0, "[@DB]") - if columnNameMode == constants.COLUMN_NAME_MODE_DATABASE: - columnNames.append(".".join(parts)) - continue - if columnNameMode == constants.COLUMN_NAME_MODE_FULL: - columnNames.append("("+".".join(parts)+")") - continue - raise ErrPySDKInvalidColumnMode - return columnNames +def getColumnNames(resp, dbname, columnNameMode): + cols = [] + if columnNameMode == constants.COLUMN_NAME_MODE_NONE: + return cols + + for column in resp.columns: + # note that depending on the version parts can be + # '(dbname.tablename.fieldname)' *or* + # '(tablename.fieldname)' without dbnname. + # In that case we mimic the old behavior by using [@DB] as placeholder + # that will be replaced at higher level. + parts = column.name.strip("()").split(".") + if columnNameMode == constants.COLUMN_NAME_MODE_FIELD: + cols.append(parts[-1]) + continue + if columnNameMode == constants.COLUMN_NAME_MODE_TABLE: + cols.append(".".join(parts[-2:])) + continue + print( + "Use of COLUMN_NAME_MODE_DATABASE and COLUMN_NAME_MODE_FULL is deprecated") + if len(parts) == 2: + parts.insert(0, "[@DB]") + if columnNameMode == constants.COLUMN_NAME_MODE_DATABASE: + cols.append(".".join(parts)) + continue + if columnNameMode == constants.COLUMN_NAME_MODE_FULL: + cols.append("("+".".join(parts)+")") + continue + raise ErrPySDKInvalidColumnMode + return fix_colnames(cols, dbname, columnNameMode) class ClosedIterator(BaseException): @@ -109,7 +103,7 @@ def __init__(self, grpcIt, colNameMode, dbname) -> None: self._nextRow = 0 self._rows = [] self._columns = None - self._colNameMode = colNameMode if colNameMode != constants.COLUMN_NAME_MODE_NONE else constants.COLUMN_NAME_MODE_FIELD + self._colNameMode = colNameMode self._dbname = dbname self._closed = False @@ -132,22 +126,21 @@ def _fetch_next(self): res = next(self._grpcIt) if self._columns == None: - self._columns = getColumnNames(res, self._colNameMode) + self._columns = getColumnNames(res, self._dbname, self._colsMode()) self._rows = unpack_rows( - res, constants.COLUMN_NAME_MODE_NONE, self._columns) + res, self._colNameMode, self._columns) self._nextRow = 0 if len(self._rows) == 0: raise StopIteration + def _colsMode(self): + return self._colNameMode if self._colNameMode != constants.COLUMN_NAME_MODE_NONE else constants.COLUMN_NAME_MODE_FIELD + def columns(self): self._fetch_next() - - if self._colNameMode not in [constants.COLUMN_NAME_MODE_DATABASE, constants.COLUMN_NAME_MODE_FULL]: - return self._columns - - return [x.replace("[@DB]", self._dbname.decode("utf-8")) for x in self._columns] + return self._columns def close(self): if self._closed: