From 42a4d24bf66c0692553f29a9e10b75610420f921 Mon Sep 17 00:00:00 2001 From: Stefano Scafiti Date: Mon, 20 May 2024 14:02:33 +0200 Subject: [PATCH] Refactoring --- immudb/client.py | 10 ++- immudb/handler/sqlquery.py | 15 ++-- immudb/handler/transaction.py | 9 ++- tests/immu/test_session.py | 132 +++++++++++++++++++++------------- 4 files changed, 104 insertions(+), 62 deletions(-) diff --git a/immudb/client.py b/immudb/client.py index a0eb870..4fbcecb 100644 --- a/immudb/client.py +++ b/immudb/client.py @@ -326,7 +326,7 @@ def openSession(self, username, password, database=b"defaultdb"): session_response = self._stub.OpenSession( req) self._stub = self._set_session_id_interceptor(session_response) - return transaction.Tx(self._stub, session_response, self.channel) + return transaction.Tx(self._stub, database, session_response, self.channel) def closeSession(self): """Closes unmanaged session @@ -1624,7 +1624,12 @@ def sqlQuery(self, query, params={}, columnNameMode=constants.COLUMN_NAME_MODE_N ['table1', 'table2'] """ - return sqlquery.call(self._stub, self._rs, query, params, columnNameMode, self._currentdb, acceptStream) + it = sqlquery.call(self._stub, self._rs, query, + params, columnNameMode, self._currentdb) + if acceptStream: + return it + + return list(it) def listTables(self): """List all tables in the current database @@ -1700,6 +1705,7 @@ def verifiableSQLGet(self, table: str, primaryKeys: List[datatypesv2.PrimaryKey] # immudb-py only + def getAllValues(self, keys: list): # immudb-py only resp = batchGet.call(self._stub, self._rs, keys) return resp diff --git a/immudb/handler/sqlquery.py b/immudb/handler/sqlquery.py index 6593ec6..9e24904 100644 --- a/immudb/handler/sqlquery.py +++ b/immudb/handler/sqlquery.py @@ -19,11 +19,11 @@ from immudb.exceptions import ErrPySDKInvalidColumnMode -def call(service: schema_pb2_grpc.ImmuServiceStub, rs: RootService, query, params, columnNameMode, dbname, acceptStream): - return _call_with_executor(query, params, columnNameMode, dbname, acceptStream, service.SQLQuery) +def call(service: schema_pb2_grpc.ImmuServiceStub, rs: RootService, query, params, columnNameMode, dbname): + return _call_with_executor(query, params, columnNameMode, dbname, service.SQLQuery) -def _call_with_executor(query, params, columnNameMode, dbname, acceptStream, executor): +def _call_with_executor(query, params, columnNameMode, dbname, executor): paramsObj = [] for key, value in params.items(): paramsObj.append(schema_pb2.NamedParam( @@ -31,16 +31,11 @@ def _call_with_executor(query, params, columnNameMode, dbname, acceptStream, exe request = schema_pb2.SQLQueryRequest( sql=query, - acceptStream=acceptStream, + acceptStream=True, params=paramsObj) resp = executor(request) - if acceptStream: - return RowIterator(resp, columnNameMode, dbname) - - res = next(resp) - columnNames = getColumnNames(res, dbname, columnNameMode) - return unpack_rows(res, columnNameMode, columnNames) + return RowIterator(resp, columnNameMode, dbname) def fix_colnames(cols, dbname, columnNameMode): diff --git a/immudb/handler/transaction.py b/immudb/handler/transaction.py index ddc7c68..87f740c 100644 --- a/immudb/handler/transaction.py +++ b/immudb/handler/transaction.py @@ -21,8 +21,9 @@ class Tx: - def __init__(self, stub, session, channel): + def __init__(self, stub, dbname, session, channel): self.stub = stub + self.dbname = dbname self.session = session self.channel = channel self.txStub = None @@ -53,7 +54,11 @@ def rollback(self): return resp def sqlQuery(self, query, params=dict(), columnNameMode=constants.COLUMN_NAME_MODE_NONE, acceptStream=False): - return executeSQLQuery(query, params, columnNameMode,'', acceptStream, self.txStub.TxSQLQuery) + it = executeSQLQuery(query, params, columnNameMode, + self.dbname, self.txStub.TxSQLQuery) + if acceptStream: + return it + return list(it) def sqlExec(self, stmt, params=dict(), noWait=False): return executeSQLExec(stmt, params, noWait, self.txStub.TxSQLExec) diff --git a/tests/immu/test_session.py b/tests/immu/test_session.py index 63c7323..8dd6a63 100644 --- a/tests/immu/test_session.py +++ b/tests/immu/test_session.py @@ -19,33 +19,41 @@ class TestSessionTransaction: def test_simple_unmanaged_session(self, wrappedClient: ImmuTestClient): - if(not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")): + if (not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")): pytest.skip("Version too low") - txInterface = wrappedClient.client.openSession("immudb", "immudb", b"defaultdb") + txInterface = wrappedClient.client.openSession( + "immudb", "immudb", b"defaultdb") try: newTx = txInterface.newTx() table = wrappedClient._generateTableName() - newTx.sqlExec(f"CREATE TABLE {table} (id INTEGER AUTO_INCREMENT, tester VARCHAR[10], PRIMARY KEY id)") + newTx.sqlExec( + f"CREATE TABLE {table} (id INTEGER AUTO_INCREMENT, tester VARCHAR[10], PRIMARY KEY id)") commit = newTx.commit() assert commit.header.id != None newTx = txInterface.newTx() - newTx.sqlExec(f"INSERT INTO {table} (tester) VALUES(@testParam)", params = {"testParam": "123"}) - what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD) - if wrappedClient.serverVersionEqual("1.9DOM.0") and what==[]: + newTx.sqlExec( + f"INSERT INTO {table} (tester) VALUES(@testParam)", params={"testParam": "123"}) + what = newTx.sqlQuery( + f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD) + if wrappedClient.serverVersionEqual("1.9DOM.0") and what == []: pytest.xfail("Known bug #1854") assert what == [{"id": 1, "tester": '123'}] commit = newTx.commit() assert commit.header.id != None newTx = txInterface.newTx() - newTx.sqlExec(f"INSERT INTO {table} (tester) VALUES(@testParam)", params = {"testParam": "321"}) - what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD) - assert what == [{"id": 1, "tester": '123'}, {"id": 2, "tester": '321'}] + newTx.sqlExec( + f"INSERT INTO {table} (tester) VALUES(@testParam)", params={"testParam": "321"}) + what = newTx.sqlQuery( + f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD) + assert what == [{"id": 1, "tester": '123'}, + {"id": 2, "tester": '321'}] commit = newTx.rollback() newTx = txInterface.newTx() - what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD) + what = newTx.sqlQuery( + f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD) assert what == [{"id": 1, "tester": '123'}] commit = newTx.commit() wrappedClient.closeSession() @@ -56,40 +64,48 @@ def test_simple_unmanaged_session(self, wrappedClient: ImmuTestClient): pass def test_simple_managed_session(self, wrappedClient: ImmuTestClient): - if(not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")): + if (not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")): pytest.skip("Version too low") with wrappedClient.client.openManagedSession("immudb", "immudb", b"defaultdb") as session: newTx = session.newTx() table = wrappedClient._generateTableName() - newTx.sqlExec(f"CREATE TABLE {table} (id INTEGER AUTO_INCREMENT, tester VARCHAR[10], PRIMARY KEY id)") + newTx.sqlExec( + f"CREATE TABLE {table} (id INTEGER AUTO_INCREMENT, tester VARCHAR[10], PRIMARY KEY id)") commit = newTx.commit() assert commit.header.id != None newTx = session.newTx() - newTx.sqlExec(f"INSERT INTO {table} (tester) VALUES(@testParam)", params = {"testParam": "123"}) - what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD) - if wrappedClient.serverVersionEqual("1.9DOM.0") and what==[]: + newTx.sqlExec( + f"INSERT INTO {table} (tester) VALUES(@testParam)", params={"testParam": "123"}) + what = newTx.sqlQuery(f"SELECT * FROM {table}", dict( + ), columnNameMode=constants.COLUMN_NAME_MODE_FIELD, acceptStream=True) + if wrappedClient.serverVersionEqual("1.9DOM.0") and what == []: pytest.xfail("Known bug #1854") - assert what == [{"id": 1, "tester": '123'}] + assert list(what) == [{"id": 1, "tester": '123'}] commit = newTx.commit() assert commit.header.id != None with wrappedClient.client.openManagedSession("immudb", "immudb", b"defaultdb") as session: newTx = session.newTx() - newTx.sqlExec(f"INSERT INTO {table} (tester) VALUES(@testParam)", params = {"testParam": "321"}) - what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD) - assert what == [{"id": 1, "tester": '123'}, {"id": 2, "tester": '321'}] + newTx.sqlExec( + f"INSERT INTO {table} (tester) VALUES(@testParam)", params={"testParam": "321"}) + what = newTx.sqlQuery(f"SELECT * FROM {table}", dict( + ), columnNameMode=constants.COLUMN_NAME_MODE_FIELD, acceptStream=True) + assert list(what) == [{"id": 1, "tester": '123'}, + {"id": 2, "tester": '321'}] commit = newTx.rollback() newTx = session.newTx() - what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD) - assert what == [{"id": 1, "tester": '123'}] + what = newTx.sqlQuery(f"SELECT * FROM {table}", dict( + ), columnNameMode=constants.COLUMN_NAME_MODE_FIELD, acceptStream=True) + assert list(what) == [{"id": 1, "tester": '123'}] commit = newTx.commit() def test_unmanaged_session(self, wrappedClient: ImmuTestClient): - if(not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")): + if (not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")): pytest.skip("Version too low") - currentTxInterface = wrappedClient.openSession("immudb", "immudb", b"defaultdb") + currentTxInterface = wrappedClient.openSession( + "immudb", "immudb", b"defaultdb") try: wrappedClient.currentTx = currentTxInterface key = wrappedClient.generateKeyName().encode("utf-8") @@ -101,13 +117,18 @@ def test_unmanaged_session(self, wrappedClient: ImmuTestClient): a = wrappedClient.get(key) assert a.value == b'1' interface = wrappedClient.newTx() - table = wrappedClient.createTestTable("id INTEGER AUTO_INCREMENT", "tester VARCHAR[10]", "PRIMARY KEY id") - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "3"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "4"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "5"}) + table = wrappedClient.createTestTable( + "id INTEGER AUTO_INCREMENT", "tester VARCHAR[10]", "PRIMARY KEY id") + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "3"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "4"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "5"}) interface.commit() wrappedClient.closeSession() - currentTxInterface = wrappedClient.openSession("immudb", "immudb", b"defaultdb") + currentTxInterface = wrappedClient.openSession( + "immudb", "immudb", b"defaultdb") wrappedClient.currentTx = currentTxInterface interface = wrappedClient.newTx() @@ -117,16 +138,20 @@ def test_unmanaged_session(self, wrappedClient: ImmuTestClient): interface.commit() interface = wrappedClient.newTx() - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "6"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "7"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "8"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "6"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "7"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "8"}) interface.rollback() finally: try: wrappedClient.closeSession() except: pass - wrappedClient.currentTx = wrappedClient.openSession("immudb", "immudb", b"defaultdb") + wrappedClient.currentTx = wrappedClient.openSession( + "immudb", "immudb", b"defaultdb") interface = wrappedClient.newTx() what = wrappedClient.simpleSelect(table, ["tester"], dict()) concatenated = [item[0] for item in what] @@ -141,7 +166,7 @@ def test_unmanaged_session(self, wrappedClient: ImmuTestClient): assert a.value == b'1' def test_managed_session(self, wrappedClient: ImmuTestClient): - if(not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")): + if (not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")): pytest.skip("Version too low") with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session: @@ -156,10 +181,14 @@ def test_managed_session(self, wrappedClient: ImmuTestClient): with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session: wrappedClient.currentTx = session interface = wrappedClient.newTx() - table = wrappedClient.createTestTable("id INTEGER AUTO_INCREMENT", "tester VARCHAR[10]", "PRIMARY KEY id") - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "3"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "4"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "5"}) + table = wrappedClient.createTestTable( + "id INTEGER AUTO_INCREMENT", "tester VARCHAR[10]", "PRIMARY KEY id") + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "3"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "4"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "5"}) interface.commit() with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session: @@ -172,9 +201,12 @@ def test_managed_session(self, wrappedClient: ImmuTestClient): with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session: wrappedClient.currentTx = session interface = wrappedClient.newTx() - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "6"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "7"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "8"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "6"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "7"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "8"}) interface.rollback() with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session: @@ -187,14 +219,20 @@ def test_managed_session(self, wrappedClient: ImmuTestClient): with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session: wrappedClient.currentTx = session interface = wrappedClient.newTx() - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "6"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "7"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "8"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "6"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "7"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "8"}) rollbackAs = interface.rollback() interface = wrappedClient.newTx() - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "6"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "7"}) - wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "8"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "6"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "7"}) + wrappedClient.insertToTable( + table, ["tester"], ["@blabla"], {"blabla": "8"}) commitAs = interface.commit() assert commitAs.header.id != None interface = wrappedClient.newTx() @@ -206,5 +244,3 @@ def test_managed_session(self, wrappedClient: ImmuTestClient): concatenated = [item[0] for item in what] assert concatenated == ["3", "4", "5", "6", "7", "8"] what = wrappedClient.commit() - -