Skip to content

Commit

Permalink
Fix persistence
Browse files Browse the repository at this point in the history
Add unique constraint
Support loading the same data multiple times
  • Loading branch information
tcezard committed Nov 1, 2023
1 parent b99dde4 commit bfb6959
Showing 1 changed file with 79 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
from collections import defaultdict, Counter
from functools import lru_cache, cached_property
from typing import List
from urllib.parse import urlsplit

from ebi_eva_common_pyutils.command_utils import run_command_with_output
Expand All @@ -13,8 +12,10 @@
from ebi_eva_common_pyutils.logger import logging_config, AppLogger
from ebi_eva_common_pyutils.metadata_utils import get_metadata_connection_handle
from ebi_eva_common_pyutils.pg_utils import get_all_results_for_query
from sqlalchemy import Column, String, Integer, BigInteger, ForeignKey, Engine, create_engine, URL
from sqlalchemy.orm import declarative_base, Mapped, relationship, mapped_column, Session

from sqlalchemy import Column, String, Integer, BigInteger, ForeignKey, create_engine, URL, MetaData, TEXT, \
UniqueConstraint, select
from sqlalchemy.orm import declarative_base, relationship, mapped_column, Session

logger = logging_config.get_logger(__name__)

Expand Down Expand Up @@ -111,7 +112,7 @@ def collect_files_to_count(release_directory, set_of_species):
return all_files


def calculate_all_logs(release_dir, output_dir, species_directories=None):
def run_calculation_script_for_species(release_dir, output_dir, species_directories=None):
all_assemblies_2_species, all_species_2_assemblies = gather_assemblies_and_species_from_directory(release_dir)
all_sets_of_species = set()
# Determine the species that needs to be counted together because they share assemblies
Expand Down Expand Up @@ -145,29 +146,37 @@ def generate_output_tsv(dict_of_counter, output_file, header):
]) + '\n')


Base = declarative_base()
Base = declarative_base(metadata=MetaData(schema="eva_stats"))


class RSCountCategory(Base):
"""
Table that provide the metadata associated with a number of RS id
"""
__tablename__ = 'rscountcatergory'

assembly = Column(String, primary_key=True)
taxonomy = Column(Integer, primary_key=True)
rs_type = Column(String, primary_key=True)
release_version = Column(Integer, primary_key=True)
rs_count_id: Mapped[int] = mapped_column(ForeignKey("rscount.id"))
rs_count = Mapped["RSCount"] = relationship(back_populates="count_categories")
"""
__tablename__ = 'release_rs_count_category'

count_category_id = Column(Integer, primary_key=True, autoincrement=True)
assembly = Column(String)
taxonomy = Column(Integer)
rs_type = Column(String)
release_version = Column(Integer)
rs_count_id = mapped_column(ForeignKey("release_rs_count.id"))
rs_count = relationship("RSCount", back_populates="count_categories")
__table_args__ = (
UniqueConstraint('assembly', 'taxonomy', 'rs_type', 'release_version', 'rs_count_id', name='uix_1'),
)
schema = 'eva_stats'


class RSCount(Base):
__tablename__ = 'eva_stats.rscount'
"""
Table that provide the count associated with each category
"""
__tablename__ = 'release_rs_count'
id = Column(Integer, primary_key=True, autoincrement=True)
count = Column(BigInteger)
count_categories: Mapped[List["RSCountCategory"]] = relationship(back_populates="parent")
group_description = Column(TEXT, unique=True)
count_categories = relationship("RSCountCategory", back_populates="rs_count")
schema = 'eva_stats'


Expand All @@ -178,7 +187,7 @@ def __init__(self, private_config_xml_file, config_profile, release_version, log
self.config_profile = config_profile
self.release_version = release_version
self.all_counts_grouped = []
self.parse_logs(logs)
self.parse_count_script_logs(logs)
self.add_annotations()

@lru_cache
Expand All @@ -201,14 +210,17 @@ def sqlalchemy_engine(self):
pg_url, pg_user, pg_pass = get_metadata_creds_for_profile(self.config_profile, self.private_config_xml_file)
dbtype, host_url, port_and_db = urlsplit(pg_url).path.split(':')
port, db = port_and_db.split('/')
return create_engine(URL.create(
engine = create_engine(URL.create(
dbtype + '+psycopg2',
username=pg_user,
password=pg_pass,
host=host_url.split('/')[-1],
database=db,
port=port
))
RSCount.__table__.create(bind=engine, checkfirst=True)
RSCountCategory.__table__.create(bind=engine, checkfirst=True)
return engine

def add_annotations(self):
for count_groups in self.all_counts_grouped:
Expand All @@ -220,24 +232,52 @@ def add_annotations(self):
def write_to_db(self):
session = Session(self.sqlalchemy_engine)
with session.begin():
try:
for count_groups in self.all_counts_grouped:
for count_dict in count_groups:
session.add(RSCountCategory(
for count_groups in self.all_counts_grouped:
# All the part of the group should have the same count
count_set = set((count_dict['count'] for count_dict in count_groups))
assert len(count_set) == 1
count = count_set.pop()

def count_descriptor(count_dict):
return f"{count_dict['assembly']},{count_dict['taxonomy']},{count_dict['idtype']}"
# This is used to uniquely identify the count for this group so that loading the same value twice does
# not results in duplicates
group_description = '_'.join([
count_descriptor(count_dict)
for count_dict in sorted(count_groups, key=count_descriptor)]) + f'_release_{self.release_version}'
query = select(RSCount).where(RSCount.group_description == group_description)
result = session.execute(query).fetchone()
if result:
rs_count = result.RSCount
else:
rs_count = RSCount(count=count, group_description=group_description)
session.add(rs_count)
for count_dict in count_groups:
query = select(RSCountCategory).where(
RSCountCategory.assembly == count_dict['assembly'],
RSCountCategory.taxonomy == count_dict['taxonomy'],
RSCountCategory.rs_type == count_dict['idtype'],
RSCountCategory.release_version == self.release_version,
RSCountCategory.rs_count_id == rs_count.id,
)
result = session.execute(query).fetchone()
if not result:
self.info(f"Create persistence for {count_dict['assembly']}, {count_dict['taxonomy']}, {count_dict['idtype']}, {count_dict['count']}")
rs_category = RSCountCategory(
assembly=count_dict['assembly'],
taxonomy=count_dict['taxonomy'],
rs_type=count_dict['idtype'],
rs_count=count_dict['count']
))
session.commit()
session.flush()
except:
session.rollback()
finally:
session.close()



release_version=self.release_version,
rs_count=rs_count
)
session.add(rs_category)
else:
rs_category = result.RSCountCategory
# Check if we were just loading the same value as a previous run
if rs_category.rs_count.count != count_dict['count']:
self.warning(f"{count_descriptor(count_dict)} already has a count entry in the table "
f"({rs_category.rs_count.count}) different from the one being loaded "
f"{count_dict['count']}")

def get_assembly_counts_from_database(self):
"""
Expand All @@ -260,7 +300,7 @@ def get_assembly_counts_from_database(self):
results[assembly][metric] = row[index + 1]
return results

def parse_logs(self, all_logs):
def parse_count_script_logs(self, all_logs):
for log_file in all_logs:
with open(log_file) as open_file:
for line in open_file:
Expand All @@ -271,7 +311,8 @@ def parse_logs(self, all_logs):
for annotation in set_of_annotations:
assembly, release_folder, idtype = annotation.split('-')
all_groups.append(
{'count': count, 'release_folder': release_folder, 'assembly': assembly, 'idtype': idtype}
{'count': count, 'release_folder': release_folder, 'assembly': assembly, 'idtype': idtype,
'annotation': annotation}
)
self.all_counts_grouped.append(all_groups)

Expand Down Expand Up @@ -323,7 +364,7 @@ def compare_assembly_counts_with_db(self, threshold=0, output_csv=None):
header = ('Assembly', 'Metric', 'File', 'DB', 'Diff (file-db)')
rows = []
count_per_assembly_from_files = self.generate_per_assembly_counts()
counts_per_assembly_from_db = self.get_counts_assembly_from_database()
counts_per_assembly_from_db = self.get_assembly_counts_from_database()
all_asms = set(count_per_assembly_from_files.keys()).union(counts_per_assembly_from_db.keys())
for asm in all_asms:
asm_counts_from_files = count_per_assembly_from_files.get(asm, {})
Expand Down Expand Up @@ -363,9 +404,9 @@ def main():
args = parser.parse_args()
logging_config.add_stdout_handler()
logger.info(f'Analyse {args.release_root_path}')
logs = calculate_all_logs(args.release_root_path, args.output_dir, args.species_directories)
log_files = run_calculation_script_for_species(args.release_root_path, args.output_dir, args.species_directories)
counter = ReleaseCounter(args.private_config_xml_file,
config_profile=args.config_profile, release_version=args.release_version, logs=logs)
config_profile=args.config_profile, release_version=args.release_version, logs=log_files)
counter.write_to_db()
counter.detect_inconsistent_types()
generate_output_tsv(counter.generate_per_species_counts(), os.path.join(args.output_dir, 'species_counts.tsv'), 'Taxonomy')
Expand Down

0 comments on commit bfb6959

Please sign in to comment.