Skip to content

Commit

Permalink
Refactor database session management and enhance test structure
Browse files Browse the repository at this point in the history
  • Loading branch information
quang-ng committed Dec 23, 2024
1 parent 22b536b commit 948f7b6
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 120 deletions.
8 changes: 7 additions & 1 deletion dsst_etl/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from sqlalchemy import inspect
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy_utils import database_exists

from dsst_etl import logger
Expand All @@ -10,6 +10,12 @@ def get_db_session(engine):
return Session()


def get_db_session_new(engine=None, bind=None):
if bind:
return Session(bind=bind)
return Session(engine)


def init_db(engine):
logger.info("Checking database initialization")

Expand Down
27 changes: 27 additions & 0 deletions tests/base_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import unittest
from dsst_etl import get_db_engine
from dsst_etl.db import get_db_session_new, init_db

class BaseTest(unittest.TestCase):

def setUp(self):
self.engine = get_db_engine(is_test=True)
init_db(self.engine)

# Start a transaction at the connection level
self.connection = self.engine.connect()
self.trans = self.connection.begin()

# Create session bound to this connection
self.session = get_db_session_new(bind=self.connection)

def tearDown(self):
if self.session.is_active:
self.session.close()

# Roll back the transaction
if self.trans.is_active:
self.trans.rollback()

# Close the connection
self.connection.close()
62 changes: 4 additions & 58 deletions tests/test_oddpub.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,21 @@
import logging
import unittest
from unittest.mock import patch, MagicMock
from sqlalchemy.orm import Session
from dsst_etl.oddpub_wrapper import OddpubWrapper
from dsst_etl.models import OddpubMetrics
from dsst_etl import get_db_engine
from dsst_etl.db import get_db_session, init_db
from sqlalchemy import inspect

from tests.base_test import BaseTest # type: ignore
logger = logging.getLogger(__name__)

class TestOddpubWrapper(unittest.TestCase):
class TestOddpubWrapper(BaseTest):

def setUp(self):
# Mock the database session
self.engine = get_db_engine(is_test=True)

init_db(self.engine)
# Create a new session for each test
self.session = get_db_session(self.engine)
super().setUp()

self.wrapper = OddpubWrapper(
db_session=self.session,
oddpub_host_api="http://mock-api"
)

def tearDown(self):
# # Rollback the transaction
# self.session.rollback()

# Check if the Works table exists before attempting to update or delete
inspector = inspect(self.engine)
tables = inspector.get_table_names()
logger.info(f"Tables in the tearDown: {tables}")
if "oddpub_metrics" in tables:
self.session.query(OddpubMetrics).delete()
self.session.commit()

def test_oddpub_wrapper_without_mock_api(self):
self.wrapper.oddpub_host_api = "http://localhost:8071"
self.wrapper.process_pdfs("tests/pdf-test", force_upload=True)
Expand All @@ -46,39 +26,5 @@ def test_oddpub_wrapper_without_mock_api(self):
self.assertIn("test2.txt", articles)



@patch("dsst_etl.oddpub_wrapper.requests.post")
def test_process_pdfs_success(self, mock_post):
# Mock the response from the API
mock_response = MagicMock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {
'article': 'test1.txt',
'is_open_data': False,
'open_data_category': '',
'is_reuse': False,
'is_open_code': False,
'is_open_data_das': False,
'is_open_code_cas': False,
'das': None,
'open_data_statements': '',
'cas': None,
'open_code_statements': ''
}

mock_post.return_value = mock_response

# Mock the PDF files
pdf_folder = "tests/pdf-test"
pdf_paths = [
pdf_folder + "/test1.pdf",
]

# Call the method
self.wrapper.process_pdfs(pdf_folder, force_upload=True)

# Assertions
self.assertEqual(self.session.query(OddpubMetrics).count(), len(pdf_paths))

if __name__ == "__main__":
unittest.main()
47 changes: 4 additions & 43 deletions tests/test_pdf_uploader.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,23 @@
import unittest
from unittest.mock import patch, MagicMock

from sqlalchemy import update, inspect
from dsst_etl import get_db_engine
from dsst_etl.upload_pdfs import PDFUploader
from dsst_etl.db import get_db_session
from dsst_etl.models import Documents, Provenance, Works
from pathlib import Path
from dsst_etl.db import init_db

from tests.base_test import BaseTest # type: ignore

class TestPDFUploader(BaseTest):

class TestPDFUploader(unittest.TestCase):

@patch("dsst_etl.upload_pdfs.boto3.client")
def setUp(self, mock_boto_client):
super().setUp()
self.mock_s3_client = MagicMock()
mock_boto_client.return_value = self.mock_s3_client
self.engine = get_db_engine(is_test=True)

init_db(self.engine)

# Create a new session for each test
self.session = get_db_session(self.engine)

# Initialize PDFUploader with the session
self.uploader = PDFUploader(self.session)

def tearDown(self):
# Rollback the transaction


# Check if the Works table exists before attempting to update or delete
inspector = inspect(self.session.bind)
if "works" in inspector.get_table_names():
# Ensure all data is removed
self.session.execute(update(Works).values(provenance_id=None))
self.session.execute(update(Works).values(initial_document_id=None))
self.session.execute(update(Works).values(primary_document_id=None))
self.session.commit()

# Check if the Documents table exists before attempting to update or delete
if "documents" in inspector.get_table_names():
self.session.execute(update(Documents).values(provenance_id=None))
self.session.commit()

# Check if the Provenance table exists before attempting to delete
if "provenance" in inspector.get_table_names():
self.session.query(Provenance).delete()

if "documents" in inspector.get_table_names():
self.session.query(Documents).delete()

if "works" in inspector.get_table_names():
self.session.query(Works).delete()

self.session.commit()
self.session.close()

def test_upload_pdfs_success(self):
# Mock successful upload
self.mock_s3_client.upload_file.return_value = None
Expand Down
22 changes: 4 additions & 18 deletions tests/test_upload_rtransparent_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from dsst_etl.db import get_db_session, init_db
import logging

from tests.base_test import BaseTest # type: ignore

logger = logging.getLogger(__name__)

class TestRTransparentDataUploader(unittest.TestCase):
class TestRTransparentDataUploader(BaseTest):

def mock_data(self):
logger.info("Creating mock data")
Expand All @@ -28,25 +30,9 @@ def mock_data(self):


def setUp(self):
self.engine = get_db_engine(is_test=True)

init_db(self.engine)
# Create a new session for each test
self.session = get_db_session(self.engine)

super().setUp()
self.uploader = RTransparentDataUploader(self.session)

def tearDown(self):
# # Rollback the transaction
# self.session.rollback()

# Check if the Works table exists before attempting to update or delete
inspector = inspect(self.engine)
tables = inspector.get_table_names()
logger.info(f"Tables in the tearDown: {tables}")
if "rtransparent_publication" in tables:
self.session.query(RTransparentPublication).delete()
self.session.commit()

@patch('pandas.read_feather')
def test_read_file_feather(self, mock_read_feather):
Expand Down

0 comments on commit 948f7b6

Please sign in to comment.