Skip to content

Commit

Permalink
Merge pull request #361 from PainterQubits/main
Browse files Browse the repository at this point in the history
Update to SQLAlchemy 2.0
  • Loading branch information
nikolaqm authored Mar 30, 2023
2 parents aee1969 + b1e68de commit 49575cd
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 151 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed
* Updated SQLAlchemy from version 1.4 to 2.0

## [0.15.8]
### Changed
* supporting python versions 3.8, 3.9, 3.10, 3.11
Expand Down
24 changes: 4 additions & 20 deletions entropylab/pipeline/params/persistence/sqlalchemy/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from logging.config import fileConfig

from alembic import context
from sqlalchemy import engine_from_config
from sqlalchemy import pool

from entropylab.pipeline.params.persistence.sqlalchemy.model import Base

Expand Down Expand Up @@ -59,24 +57,10 @@ def run_migrations_online() -> None:
and associate a connection with the context.
"""
connectable = config.attributes.get("connection", None)

if connectable is None:
# only create Engine if we don't have a Connection
# from the outside
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)

# when connectable is already a Connection object, calling
# connect() gives us a *branched connection*.
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)

with context.begin_transaction():
context.run_migrations()
connection = config.attributes["connection"]
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()


if context.is_offline_mode():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import uuid
from uuid import UUID
from pathlib import Path
from typing import Optional, Set, List

Expand All @@ -8,7 +9,7 @@
from alembic.config import Config
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Connection
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import sessionmaker, close_all_sessions

from entropylab.pipeline.api.errors import EntropyError
from entropylab.pipeline.params.persistence.persistence import Persistence, Commit
Expand All @@ -17,7 +18,7 @@
TempTable,
)

TEMP_COMMIT_ID = "00000000-0000-0000-0000-000000000000"
TEMP_COMMIT_ID = UUID("00000000-0000-0000-0000-000000000000")


class SqlAlchemyPersistence(Persistence):
Expand Down Expand Up @@ -57,7 +58,7 @@ def _abs_path_to(rel_path: str) -> str:
return os.path.join(source_dir, rel_path)

def close(self):
self.__session_maker.close_all()
close_all_sessions()

def get_commit(
self, commit_id: Optional[str] = None, commit_num: Optional[int] = None
Expand All @@ -66,7 +67,7 @@ def get_commit(
with self.__session_maker() as session:
commit = (
session.query(CommitTable)
.filter(CommitTable.id == commit_id)
.filter(CommitTable.id == UUID(commit_id))
.one_or_none()
)
if commit:
Expand Down Expand Up @@ -108,15 +109,15 @@ def commit(
# TODO: Perhaps create the timestamp here?
self.stamp_dirty_params_with_commit(commit, dirty_keys)
commit_table = CommitTable()
commit_table.id = commit.id
commit_table.id = UUID(commit.id)
commit_table.timestamp = commit.timestamp
commit_table.label = commit.label
commit_table.params = commit.params
commit_table.tags = commit.tags
with self.__session_maker() as session:
session.add(commit_table)
session.commit()
return commit_table.id
return commit.id

@staticmethod
def __generate_commit_id() -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,29 @@ def target(tmp_path) -> SqlAlchemyPersistence:


def test_ctor_creates_schema(target):
cursor = target.engine.execute("SELECT sql FROM sqlite_master WHERE type = 'table'")
assert len(cursor.fetchall()) == 3
with target.engine.connect() as connection:
cursor = connection.execute(
text("SELECT sql FROM sqlite_master WHERE type = 'table'")
)
assert len(cursor.fetchall()) == 3


def test_ctor_stamps_head(target):
cursor = target.engine.execute("SELECT version_num FROM alembic_version")
assert cursor.first() == ("000c6a88457f",)
with target.engine.connect() as connection:
cursor = connection.execute(text("SELECT version_num FROM alembic_version"))
assert cursor.first() == ("000c6a88457f",)


""" get_commit """


def test_get_commit_when_commit_id_exists_then_commit_is_returned(target):
commit_id = "f74c808e-2388-4b0a-a051-17eb9eb14339"
with target.engine.connect() as connection:
with target.engine.begin() as connection:
connection.execute(
text(
"INSERT INTO 'commit' VALUES "
f"('{commit_id}', '{pd.Timestamp.now()}', 'bar', '0', '0');"
f"('{UUID(commit_id).hex}', '{pd.Timestamp.now()}', 'bar', '0', '0');"
)
)
actual = target.get_commit(commit_id)
Expand All @@ -50,20 +54,20 @@ def test_get_commit_when_commit_id_exists_then_commit_is_returned(target):

def test_get_commit_when_commit_id_does_not_exist_then_error_is_raised(target):
with pytest.raises(EntropyError):
target.get_commit("foo")
target.get_commit("f74c808e-2388-4b0a-a051-17eb9eb14339")


def test_get_commit_when_commit_num_exists_then_commit_is_returned(target):
commit_id1 = "f74c808e-2388-4b0a-a051-17eb9eb11111"
commit_id2 = "f74c808e-2388-4b0a-a051-17eb9eb22222"
commit_id3 = "f74c808e-2388-4b0a-a051-17eb9eb33333"
with target.engine.connect() as connection:
with target.engine.begin() as connection:
connection.execute(
text(
"INSERT INTO 'commit' VALUES "
f"('{commit_id1}', '{pd.Timestamp.now()}', 'bar', '0', '0'),"
f"('{commit_id2}', '{pd.Timestamp.now()}', 'bar', '0', '0'),"
f"('{commit_id3}', '{pd.Timestamp.now()}', 'bar', '0', '0');"
f"('{UUID(commit_id1).hex}', '{pd.Timestamp.now()}', 'bar', '0', '0'),"
f"('{UUID(commit_id2).hex}', '{pd.Timestamp.now()}', 'bar', '0', '0'),"
f"('{UUID(commit_id3).hex}', '{pd.Timestamp.now()}', 'bar', '0', '0');"
)
)
actual = target.get_commit(commit_num=2)
Expand Down
24 changes: 4 additions & 20 deletions entropylab/pipeline/results_backend/sqlalchemy/alembic/env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from logging.config import fileConfig

from alembic import context
from sqlalchemy import engine_from_config
from sqlalchemy import pool

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down Expand Up @@ -61,24 +59,10 @@ def run_migrations_online():
https://alembic.sqlalchemy.org/en/latest/cookbook.html#connection-sharing
"""
connectable = config.attributes.get("connection", None)

if connectable is None:
# only create Engine if we don't have a Connection
# from the outside
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)

# when connectable is already a Connection object, calling
# connect() gives us a *branched connection*.
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)

with context.begin_transaction():
context.run_migrations()
connection = config.attributes["connection"]
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()


if context.is_offline_mode():
Expand Down
13 changes: 7 additions & 6 deletions entropylab/pipeline/results_backend/sqlalchemy/db.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from datetime import datetime
from contextlib import contextmanager
from typing import List, TypeVar, Optional, ContextManager, Iterable, Union, Any
from typing import Set
from warnings import warn

import jsonpickle
import pandas as pd
from pandas import DataFrame
from plotly import graph_objects as go
from sqlalchemy import desc
from sqlalchemy import text, desc
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.sql import Selectable
from sqlalchemy.util.compat import contextmanager

from entropylab.components.instrument_driver import Function, Parameter
from entropylab.components.lab_model import (
Expand Down Expand Up @@ -336,11 +335,12 @@ def __get_last_result_of_experiment_from_sqlalchemy(
def custom_query(self, query: Union[str, Selectable]) -> DataFrame:
with self._session_maker() as sess:
if isinstance(query, str):
selectable = query
selectable = text(query)
else:
selectable = query.statement

return pd.read_sql(selectable, sess.bind)
result = sess.execute(selectable)
return DataFrame(result.all(), columns=result.keys())

def _execute_transaction(self, transaction):
with self._session_maker() as sess:
Expand All @@ -350,7 +350,8 @@ def _execute_transaction(self, transaction):

@staticmethod
def _query_pandas(query):
return pd.read_sql(query.statement, query.session.bind)
result = query.session.execute(query.statement)
return DataFrame(result.all(), columns=result.keys())

@contextmanager
def _session_maker(self) -> ContextManager[Session]:
Expand Down
11 changes: 6 additions & 5 deletions entropylab/pipeline/results_backend/sqlalchemy/db_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TypeVar, Type, Tuple

import sqlalchemy.engine
from sqlalchemy import create_engine
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker

from entropylab.logger import logger
Expand Down Expand Up @@ -117,10 +117,11 @@ def _validate_path(path):
)

def _db_is_empty(self) -> bool:
cursor = self._engine.execute(
"SELECT sql FROM sqlite_master WHERE type = 'table'"
)
return len(cursor.fetchall()) == 0
with self._engine.connect() as connection:
cursor = connection.execute(
text("SELECT sql FROM sqlite_master WHERE type = 'table'")
)
return len(cursor.fetchall()) == 0


class _DbUpgrader:
Expand Down
3 changes: 1 addition & 2 deletions entropylab/pipeline/results_backend/sqlalchemy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
Enum,
Boolean,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import declarative_base, relationship

from entropylab.logger import logger
from entropylab.pipeline.api.data_reader import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import shutil

import pytest
from sqlalchemy import create_engine
from sqlalchemy import create_engine, text

from entropylab import SqlAlchemyDB, RawResultData
from entropylab.conftest import _copy_template
Expand Down Expand Up @@ -76,9 +76,12 @@ def test_upgrade_db_when_initial_db_is_empty(initialized_project_dir_path):
engine = create_engine(
f"sqlite:///{initialized_project_dir_path}/{_ENTROPY_DIRNAME}/{_DB_FILENAME}"
)
cur = engine.execute("SELECT sql FROM sqlite_master WHERE name = 'Results'")
res = cur.fetchone()
cur.close()
with engine.connect() as connection:
cur = connection.execute(
text("SELECT sql FROM sqlite_master WHERE name = 'Results'")
)
res = cur.fetchone()
cur.close()
assert "saved_in_hdf5" in res[0]


Expand All @@ -88,9 +91,12 @@ def test_upgrade_db_when_db_is_in_memory():
# act
target.upgrade_db()
# assert
cur = target._engine.execute("SELECT sql FROM sqlite_master WHERE name = 'Results'")
res = cur.fetchone()
cur.close()
with target._engine.connect() as connection:
cur = connection.execute(
text("SELECT sql FROM sqlite_master WHERE name = 'Results'")
)
res = cur.fetchone()
cur.close()
assert "saved_in_hdf5" in res[0]


Expand All @@ -112,8 +118,9 @@ def test__migrate_results_to_hdf5(initialized_project_dir_path):
)
hdf5_results = storage.get_result_records()
assert len(list(hdf5_results)) == 5
cur = target._engine.execute("SELECT * FROM Results WHERE saved_in_hdf5 = 1")
res = cur.all()
with target._engine.connect() as connection:
cur = connection.execute(text("SELECT * FROM Results WHERE saved_in_hdf5 = 1"))
res = cur.all()
assert len(res) == 5


Expand All @@ -135,10 +142,11 @@ def test__migrate_metadata_to_hdf5(initialized_project_dir_path):
)
hdf5_metadata = storage.get_metadata_records()
assert len(list(hdf5_metadata)) == 5
cur = target._engine.execute(
"SELECT * FROM ExperimentMetadata WHERE saved_in_hdf5 = 1"
)
res = cur.all()
with target._engine.connect() as connection:
cur = connection.execute(
text("SELECT * FROM ExperimentMetadata WHERE saved_in_hdf5 = 1")
)
res = cur.all()
assert len(res) == 5


Expand Down Expand Up @@ -205,14 +213,16 @@ def test_upgrade_db_deletes_results_and_metadata_from_sqlite(
# act
target.upgrade_db()
# assert for results
cur = target._engine.execute("SELECT * FROM Results WHERE saved_in_hdf5 = 1")
res = cur.all()
with target._engine.connect() as connection:
cur = connection.execute(text("SELECT * FROM Results WHERE saved_in_hdf5 = 1"))
res = cur.all()
assert len(res) == 0
# assert for metadata
cur = target._engine.execute(
"SELECT * FROM ExperimentMetadata WHERE saved_in_hdf5 = 1"
)
res = cur.all()
with target._engine.connect() as connection:
cur = connection.execute(
text("SELECT * FROM ExperimentMetadata WHERE saved_in_hdf5 = 1")
)
res = cur.all()
assert len(res) == 0


Expand All @@ -230,8 +240,11 @@ def test_upgrade_db_adds_favorite_column_to_experiments_table(
target = _DbUpgrader(initialized_project_dir_path)
# act
target.upgrade_db()
cur = target._engine.execute(
"SELECT COUNT(*) FROM pragma_table_info('Experiments') WHERE name='favorite'; "
)
res = cur.all()
with target._engine.connect() as connection:
cur = connection.execute(
text(
"SELECT COUNT(*) FROM pragma_table_info('Experiments') WHERE name='favorite'; "
)
)
res = cur.all()
assert res[0][0] == 1
Loading

0 comments on commit 49575cd

Please sign in to comment.