Skip to content

Commit

Permalink
Refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ostafen committed May 20, 2024
1 parent 172e062 commit 42a4d24
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 62 deletions.
10 changes: 8 additions & 2 deletions immudb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def openSession(self, username, password, database=b"defaultdb"):
session_response = self._stub.OpenSession(
req)
self._stub = self._set_session_id_interceptor(session_response)
return transaction.Tx(self._stub, session_response, self.channel)
return transaction.Tx(self._stub, database, session_response, self.channel)

def closeSession(self):
"""Closes unmanaged session
Expand Down Expand Up @@ -1624,7 +1624,12 @@ def sqlQuery(self, query, params={}, columnNameMode=constants.COLUMN_NAME_MODE_N
['table1', 'table2']
"""
return sqlquery.call(self._stub, self._rs, query, params, columnNameMode, self._currentdb, acceptStream)
it = sqlquery.call(self._stub, self._rs, query,
params, columnNameMode, self._currentdb)
if acceptStream:
return it

return list(it)

def listTables(self):
"""List all tables in the current database
Expand Down Expand Up @@ -1700,6 +1705,7 @@ def verifiableSQLGet(self, table: str, primaryKeys: List[datatypesv2.PrimaryKey]

# immudb-py only


def getAllValues(self, keys: list): # immudb-py only
resp = batchGet.call(self._stub, self._rs, keys)
return resp
Expand Down
15 changes: 5 additions & 10 deletions immudb/handler/sqlquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,23 @@
from immudb.exceptions import ErrPySDKInvalidColumnMode


def call(service: schema_pb2_grpc.ImmuServiceStub, rs: RootService, query, params, columnNameMode, dbname, acceptStream):
return _call_with_executor(query, params, columnNameMode, dbname, acceptStream, service.SQLQuery)
def call(service: schema_pb2_grpc.ImmuServiceStub, rs: RootService, query, params, columnNameMode, dbname):
return _call_with_executor(query, params, columnNameMode, dbname, service.SQLQuery)


def _call_with_executor(query, params, columnNameMode, dbname, acceptStream, executor):
def _call_with_executor(query, params, columnNameMode, dbname, executor):
paramsObj = []
for key, value in params.items():
paramsObj.append(schema_pb2.NamedParam(
name=key, value=py_to_sqlvalue(value)))

request = schema_pb2.SQLQueryRequest(
sql=query,
acceptStream=acceptStream,
acceptStream=True,
params=paramsObj)

resp = executor(request)
if acceptStream:
return RowIterator(resp, columnNameMode, dbname)

res = next(resp)
columnNames = getColumnNames(res, dbname, columnNameMode)
return unpack_rows(res, columnNameMode, columnNames)
return RowIterator(resp, columnNameMode, dbname)


def fix_colnames(cols, dbname, columnNameMode):
Expand Down
9 changes: 7 additions & 2 deletions immudb/handler/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@


class Tx:
def __init__(self, stub, session, channel):
def __init__(self, stub, dbname, session, channel):
self.stub = stub
self.dbname = dbname
self.session = session
self.channel = channel
self.txStub = None
Expand Down Expand Up @@ -53,7 +54,11 @@ def rollback(self):
return resp

def sqlQuery(self, query, params=dict(), columnNameMode=constants.COLUMN_NAME_MODE_NONE, acceptStream=False):
return executeSQLQuery(query, params, columnNameMode,'', acceptStream, self.txStub.TxSQLQuery)
it = executeSQLQuery(query, params, columnNameMode,
self.dbname, self.txStub.TxSQLQuery)
if acceptStream:
return it
return list(it)

def sqlExec(self, stmt, params=dict(), noWait=False):
return executeSQLExec(stmt, params, noWait, self.txStub.TxSQLExec)
132 changes: 84 additions & 48 deletions tests/immu/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,41 @@
class TestSessionTransaction:

def test_simple_unmanaged_session(self, wrappedClient: ImmuTestClient):
if(not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")):
if (not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")):
pytest.skip("Version too low")
txInterface = wrappedClient.client.openSession("immudb", "immudb", b"defaultdb")
txInterface = wrappedClient.client.openSession(
"immudb", "immudb", b"defaultdb")
try:
newTx = txInterface.newTx()
table = wrappedClient._generateTableName()
newTx.sqlExec(f"CREATE TABLE {table} (id INTEGER AUTO_INCREMENT, tester VARCHAR[10], PRIMARY KEY id)")
newTx.sqlExec(
f"CREATE TABLE {table} (id INTEGER AUTO_INCREMENT, tester VARCHAR[10], PRIMARY KEY id)")
commit = newTx.commit()
assert commit.header.id != None

newTx = txInterface.newTx()
newTx.sqlExec(f"INSERT INTO {table} (tester) VALUES(@testParam)", params = {"testParam": "123"})
what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD)
if wrappedClient.serverVersionEqual("1.9DOM.0") and what==[]:
newTx.sqlExec(
f"INSERT INTO {table} (tester) VALUES(@testParam)", params={"testParam": "123"})
what = newTx.sqlQuery(
f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD)
if wrappedClient.serverVersionEqual("1.9DOM.0") and what == []:
pytest.xfail("Known bug #1854")
assert what == [{"id": 1, "tester": '123'}]
commit = newTx.commit()
assert commit.header.id != None

newTx = txInterface.newTx()
newTx.sqlExec(f"INSERT INTO {table} (tester) VALUES(@testParam)", params = {"testParam": "321"})
what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD)
assert what == [{"id": 1, "tester": '123'}, {"id": 2, "tester": '321'}]
newTx.sqlExec(
f"INSERT INTO {table} (tester) VALUES(@testParam)", params={"testParam": "321"})
what = newTx.sqlQuery(
f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD)
assert what == [{"id": 1, "tester": '123'},
{"id": 2, "tester": '321'}]
commit = newTx.rollback()

newTx = txInterface.newTx()
what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD)
what = newTx.sqlQuery(
f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD)
assert what == [{"id": 1, "tester": '123'}]
commit = newTx.commit()
wrappedClient.closeSession()
Expand All @@ -56,40 +64,48 @@ def test_simple_unmanaged_session(self, wrappedClient: ImmuTestClient):
pass

def test_simple_managed_session(self, wrappedClient: ImmuTestClient):
if(not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")):
if (not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")):
pytest.skip("Version too low")
with wrappedClient.client.openManagedSession("immudb", "immudb", b"defaultdb") as session:
newTx = session.newTx()
table = wrappedClient._generateTableName()
newTx.sqlExec(f"CREATE TABLE {table} (id INTEGER AUTO_INCREMENT, tester VARCHAR[10], PRIMARY KEY id)")
newTx.sqlExec(
f"CREATE TABLE {table} (id INTEGER AUTO_INCREMENT, tester VARCHAR[10], PRIMARY KEY id)")
commit = newTx.commit()
assert commit.header.id != None

newTx = session.newTx()
newTx.sqlExec(f"INSERT INTO {table} (tester) VALUES(@testParam)", params = {"testParam": "123"})
what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD)
if wrappedClient.serverVersionEqual("1.9DOM.0") and what==[]:
newTx.sqlExec(
f"INSERT INTO {table} (tester) VALUES(@testParam)", params={"testParam": "123"})
what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(
), columnNameMode=constants.COLUMN_NAME_MODE_FIELD, acceptStream=True)
if wrappedClient.serverVersionEqual("1.9DOM.0") and what == []:
pytest.xfail("Known bug #1854")
assert what == [{"id": 1, "tester": '123'}]
assert list(what) == [{"id": 1, "tester": '123'}]
commit = newTx.commit()
assert commit.header.id != None

with wrappedClient.client.openManagedSession("immudb", "immudb", b"defaultdb") as session:
newTx = session.newTx()
newTx.sqlExec(f"INSERT INTO {table} (tester) VALUES(@testParam)", params = {"testParam": "321"})
what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD)
assert what == [{"id": 1, "tester": '123'}, {"id": 2, "tester": '321'}]
newTx.sqlExec(
f"INSERT INTO {table} (tester) VALUES(@testParam)", params={"testParam": "321"})
what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(
), columnNameMode=constants.COLUMN_NAME_MODE_FIELD, acceptStream=True)
assert list(what) == [{"id": 1, "tester": '123'},
{"id": 2, "tester": '321'}]
commit = newTx.rollback()

newTx = session.newTx()
what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(), columnNameMode=constants.COLUMN_NAME_MODE_FIELD)
assert what == [{"id": 1, "tester": '123'}]
what = newTx.sqlQuery(f"SELECT * FROM {table}", dict(
), columnNameMode=constants.COLUMN_NAME_MODE_FIELD, acceptStream=True)
assert list(what) == [{"id": 1, "tester": '123'}]
commit = newTx.commit()

def test_unmanaged_session(self, wrappedClient: ImmuTestClient):
if(not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")):
if (not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")):
pytest.skip("Version too low")
currentTxInterface = wrappedClient.openSession("immudb", "immudb", b"defaultdb")
currentTxInterface = wrappedClient.openSession(
"immudb", "immudb", b"defaultdb")
try:
wrappedClient.currentTx = currentTxInterface
key = wrappedClient.generateKeyName().encode("utf-8")
Expand All @@ -101,13 +117,18 @@ def test_unmanaged_session(self, wrappedClient: ImmuTestClient):
a = wrappedClient.get(key)
assert a.value == b'1'
interface = wrappedClient.newTx()
table = wrappedClient.createTestTable("id INTEGER AUTO_INCREMENT", "tester VARCHAR[10]", "PRIMARY KEY id")
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "3"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "4"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "5"})
table = wrappedClient.createTestTable(
"id INTEGER AUTO_INCREMENT", "tester VARCHAR[10]", "PRIMARY KEY id")
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "3"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "4"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "5"})
interface.commit()
wrappedClient.closeSession()
currentTxInterface = wrappedClient.openSession("immudb", "immudb", b"defaultdb")
currentTxInterface = wrappedClient.openSession(
"immudb", "immudb", b"defaultdb")
wrappedClient.currentTx = currentTxInterface
interface = wrappedClient.newTx()

Expand All @@ -117,16 +138,20 @@ def test_unmanaged_session(self, wrappedClient: ImmuTestClient):

interface.commit()
interface = wrappedClient.newTx()
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "6"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "7"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "8"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "6"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "7"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "8"})
interface.rollback()
finally:
try:
wrappedClient.closeSession()
except:
pass
wrappedClient.currentTx = wrappedClient.openSession("immudb", "immudb", b"defaultdb")
wrappedClient.currentTx = wrappedClient.openSession(
"immudb", "immudb", b"defaultdb")
interface = wrappedClient.newTx()
what = wrappedClient.simpleSelect(table, ["tester"], dict())
concatenated = [item[0] for item in what]
Expand All @@ -141,7 +166,7 @@ def test_unmanaged_session(self, wrappedClient: ImmuTestClient):
assert a.value == b'1'

def test_managed_session(self, wrappedClient: ImmuTestClient):
if(not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")):
if (not wrappedClient.serverHigherOrEqualsToVersion("1.2.0")):
pytest.skip("Version too low")

with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session:
Expand All @@ -156,10 +181,14 @@ def test_managed_session(self, wrappedClient: ImmuTestClient):
with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session:
wrappedClient.currentTx = session
interface = wrappedClient.newTx()
table = wrappedClient.createTestTable("id INTEGER AUTO_INCREMENT", "tester VARCHAR[10]", "PRIMARY KEY id")
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "3"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "4"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "5"})
table = wrappedClient.createTestTable(
"id INTEGER AUTO_INCREMENT", "tester VARCHAR[10]", "PRIMARY KEY id")
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "3"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "4"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "5"})
interface.commit()

with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session:
Expand All @@ -172,9 +201,12 @@ def test_managed_session(self, wrappedClient: ImmuTestClient):
with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session:
wrappedClient.currentTx = session
interface = wrappedClient.newTx()
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "6"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "7"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "8"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "6"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "7"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "8"})
interface.rollback()

with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session:
Expand All @@ -187,14 +219,20 @@ def test_managed_session(self, wrappedClient: ImmuTestClient):
with wrappedClient.openManagedSession("immudb", "immudb", b"defaultdb") as session:
wrappedClient.currentTx = session
interface = wrappedClient.newTx()
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "6"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "7"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "8"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "6"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "7"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "8"})
rollbackAs = interface.rollback()
interface = wrappedClient.newTx()
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "6"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "7"})
wrappedClient.insertToTable(table, ["tester"], ["@blabla"], {"blabla": "8"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "6"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "7"})
wrappedClient.insertToTable(
table, ["tester"], ["@blabla"], {"blabla": "8"})
commitAs = interface.commit()
assert commitAs.header.id != None
interface = wrappedClient.newTx()
Expand All @@ -206,5 +244,3 @@ def test_managed_session(self, wrappedClient: ImmuTestClient):
concatenated = [item[0] for item in what]
assert concatenated == ["3", "4", "5", "6", "7", "8"]
what = wrappedClient.commit()


0 comments on commit 42a4d24

Please sign in to comment.