Skip to content

Commit

Permalink
Finish upgrading to SQLAlchemy 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Hadley committed Mar 29, 2023
1 parent e84ccc6 commit c3720cb
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 70 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import uuid
from uuid import UUID
from pathlib import Path
from typing import Optional, Set, List

Expand All @@ -17,7 +18,7 @@
TempTable,
)

TEMP_COMMIT_ID = "00000000-0000-0000-0000-000000000000"
TEMP_COMMIT_ID = UUID("00000000-0000-0000-0000-000000000000")


class SqlAlchemyPersistence(Persistence):
Expand Down Expand Up @@ -61,12 +62,12 @@ def close(self):

def get_commit(
self, commit_id: Optional[str] = None, commit_num: Optional[int] = None
):
):
if commit_id:
with self.__session_maker() as session:
commit = (
session.query(CommitTable)
.filter(CommitTable.id == commit_id)
.filter(CommitTable.id == UUID(commit_id))
.one_or_none()
)
if commit:
Expand Down Expand Up @@ -108,15 +109,15 @@ def commit(
# TODO: Perhaps create the timestamp here?
self.stamp_dirty_params_with_commit(commit, dirty_keys)
commit_table = CommitTable()
commit_table.id = commit.id
commit_table.id = UUID(commit.id)
commit_table.timestamp = commit.timestamp
commit_table.label = commit.label
commit_table.params = commit.params
commit_table.tags = commit.tags
with self.__session_maker() as session:
session.add(commit_table)
session.commit()
return commit_table.id
return commit.id

@staticmethod
def __generate_commit_id() -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_get_commit_when_commit_id_exists_then_commit_is_returned(target):
connection.execute(
text(
"INSERT INTO 'commit' VALUES "
f"('{commit_id}', '{pd.Timestamp.now()}', 'bar', '0', '0');"
f"('{UUID(commit_id).hex}', '{pd.Timestamp.now()}', 'bar', '0', '0');"
)
)
actual = target.get_commit(commit_id)
Expand All @@ -54,7 +54,7 @@ def test_get_commit_when_commit_id_exists_then_commit_is_returned(target):

def test_get_commit_when_commit_id_does_not_exist_then_error_is_raised(target):
with pytest.raises(EntropyError):
target.get_commit("foo")
target.get_commit("f74c808e-2388-4b0a-a051-17eb9eb14339")


def test_get_commit_when_commit_num_exists_then_commit_is_returned(target):
Expand All @@ -65,9 +65,9 @@ def test_get_commit_when_commit_num_exists_then_commit_is_returned(target):
connection.execute(
text(
"INSERT INTO 'commit' VALUES "
f"('{commit_id1}', '{pd.Timestamp.now()}', 'bar', '0', '0'),"
f"('{commit_id2}', '{pd.Timestamp.now()}', 'bar', '0', '0'),"
f"('{commit_id3}', '{pd.Timestamp.now()}', 'bar', '0', '0');"
f"('{UUID(commit_id1).hex}', '{pd.Timestamp.now()}', 'bar', '0', '0'),"
f"('{UUID(commit_id2).hex}', '{pd.Timestamp.now()}', 'bar', '0', '0'),"
f"('{UUID(commit_id3).hex}', '{pd.Timestamp.now()}', 'bar', '0', '0');"
)
)
actual = target.get_commit(commit_num=2)
Expand Down
11 changes: 6 additions & 5 deletions entropylab/pipeline/results_backend/sqlalchemy/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from warnings import warn

import jsonpickle
import pandas as pd
from pandas import DataFrame
from plotly import graph_objects as go
from sqlalchemy import desc
from sqlalchemy import text, desc
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.sql import Selectable
Expand Down Expand Up @@ -336,11 +335,12 @@ def __get_last_result_of_experiment_from_sqlalchemy(
def custom_query(self, query: Union[str, Selectable]) -> DataFrame:
with self._session_maker() as sess:
if isinstance(query, str):
selectable = query
selectable = text(query)
else:
selectable = query.statement

return pd.read_sql(selectable, sess.bind)
result = sess.execute(selectable)
return DataFrame(result.all(), columns=result.keys())

def _execute_transaction(self, transaction):
with self._session_maker() as sess:
Expand All @@ -350,7 +350,8 @@ def _execute_transaction(self, transaction):

@staticmethod
def _query_pandas(query):
return pd.read_sql(query.statement, query.session.bind)
result = query.session.execute(query.statement)
return DataFrame(result.all(), columns=result.keys())

@contextmanager
def _session_maker(self) -> ContextManager[Session]:
Expand Down
Loading

0 comments on commit c3720cb

Please sign in to comment.