Skip to content

Commit

Permalink
Remove the views from the load script and add actual table that conta…
Browse files Browse the repository at this point in the history
…ins aggregated count.

Add test for the new aggregated count load
  • Loading branch information
tcezard committed Mar 21, 2024
1 parent cee7cec commit 2713b0a
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from sqlalchemy import select
from sqlalchemy.orm import Session

from gather_clustering_counts.release_count_models import RSCountCategory, RSCount, get_sql_alchemy_engine
from gather_clustering_counts.release_count_models import RSCountCategory, RSCount, get_sql_alchemy_engine, \
RSCountPerTaxonomy, RSCountPerAssembly

logger = logging_config.get_logger(__name__)

Expand Down Expand Up @@ -228,57 +229,152 @@ def group_descriptor(self, count_groups):
if None not in group_descriptor_list:
return '|'.join(sorted(group_descriptor_list)) + f'_release_{self.release_version}'

def _write_exploded_counts(self, session):
for count_groups in self.all_counts_grouped:
# All the part of the group should have the same count
count_set = set((count_dict['count'] for count_dict in count_groups))
assert len(count_set) == 1
count = count_set.pop()
# This is used to uniquely identify the count for this group so that loading the same value twice does
# not result in duplicates
group_description = self.group_descriptor(count_groups)
if not group_description:
# One of the taxonomy annotation is missing, we should not load that count
continue
query = select(RSCount).where(RSCount.group_description == group_description)
result = session.execute(query).fetchone()
if result:
rs_count = result.RSCount
else:
rs_count = RSCount(count=count, group_description=group_description)
session.add(rs_count)
for count_dict in count_groups:
query = select(RSCountCategory).where(
RSCountCategory.assembly_accession == count_dict['assembly'],
RSCountCategory.taxonomy_id == count_dict['taxonomy'],
RSCountCategory.rs_type == count_dict['idtype'],
RSCountCategory.release_version == self.release_version,
RSCountCategory.rs_count_id == rs_count.rs_count_id,
)
result = session.execute(query).fetchone()
if not result:
self.info(
f"Create persistence for {count_dict['assembly']}, {count_dict['taxonomy']}, {count_dict['idtype']}, {count_dict['count']}")
rs_category = RSCountCategory(
assembly_accession=count_dict['assembly'],
taxonomy_id=count_dict['taxonomy'],
rs_type=count_dict['idtype'],
release_version=self.release_version,
rs_count=rs_count
)
session.add(rs_category)
else:
rs_category = result.RSCountCategory
# Check if we were just loading the same value as a previous run
if rs_category.rs_count.count != count_dict['count']:
self.error(f"{self.count_descriptor(count_dict)} already has a count entry in the table "
f"({rs_category.rs_count.count}) different from the one being loaded "
f"{count_dict['count']}")

def _write_per_taxonomy_counts(self, session):
"""Load the aggregated count per taxonomy (assume previous version of the release was loaded already)"""
taxonomy_counts, assembly_lists = self.generate_per_taxonomy_counts()
for taxonomy in taxonomy_counts:
for rs_type in taxonomy_counts.get(taxonomy):
query = select(RSCountPerTaxonomy).where(
RSCountPerTaxonomy.taxonomy_id == taxonomy,
RSCountPerTaxonomy.rs_type == rs_type,
RSCountPerTaxonomy.release_version == self.release_version
)
result = session.execute(query).fetchone()
if not result:
self.info(
f"Create persistence for aggregate per taxonomy {taxonomy}, {rs_type}: "
f"{taxonomy_counts.get(taxonomy).get(rs_type)}"
)
# Get the entry from previous release
query = select(RSCountPerTaxonomy).where(
RSCountPerTaxonomy.taxonomy_id == taxonomy,
RSCountPerTaxonomy.rs_type == rs_type,
RSCountPerTaxonomy.release_version == self.release_version - 1
)
result = session.execute(query).fetchone()
if result:
prev_count_for_taxonomy = result.RSCountPerTaxonomy
count_new = taxonomy_counts.get(taxonomy).get(rs_type) - prev_count_for_taxonomy.count
else:
count_new = 0
taxonomy_row = RSCountPerTaxonomy(
taxonomy_id=taxonomy,
assembly_accessions=assembly_lists.get(taxonomy).get(rs_type),
rs_type=rs_type,
release_version=self.release_version,
count=taxonomy_counts.get(taxonomy).get(rs_type),
new=count_new
)
session.add(taxonomy_row)
else:
taxonomy_row = result.RSCountPerTaxonomy
# Check if we were just loading the same value as a previous run
if taxonomy_row.count != taxonomy_counts.get(taxonomy).get(rs_type):
self.error(f"Count for aggregate per taxonomy {taxonomy}, {rs_type} already has a count entry "
f"in the table ({taxonomy_row.count}) different from the one being loaded "
f"{taxonomy_counts.get(taxonomy).get(rs_type)}")

def _write_per_assembly_counts(self, session):
"""Load the aggregated count per assembly (assume previous version of the release was loaded already)"""
assembly_counts, taxonomy_lists = self.generate_per_assembly_counts()
for assembly in assembly_counts:
for rs_type in assembly_counts.get(assembly):
query = select(RSCountPerAssembly).where(
RSCountPerAssembly.assembly_accession == assembly,
RSCountPerAssembly.rs_type == rs_type,
RSCountPerAssembly.release_version == self.release_version
)
result = session.execute(query).fetchone()
if not result:
self.info(
f"Create persistence for aggregate per assembly {assembly}, {rs_type}: "
f"{assembly_counts.get(assembly).get(rs_type)}"
)
# Retrieve the count for the previous release
query = select(RSCountPerAssembly).where(
RSCountPerAssembly.assembly_accession == assembly,
RSCountPerAssembly.rs_type == rs_type,
RSCountPerAssembly.release_version == self.release_version - 1
)
result = session.execute(query).fetchone()
if result:
prev_count_for_assembly = result.RSCountPerAssembly
count_new = assembly_counts.get(assembly).get(rs_type) - prev_count_for_assembly.count
else:
count_new = 0
assembly_row = RSCountPerAssembly(
assembly_accession=assembly,
taxonomy_ids=taxonomy_lists.get(assembly).get(rs_type),
rs_type=rs_type,
release_version=self.release_version,
count=assembly_counts.get(assembly).get(rs_type),
new=count_new
)
session.add(assembly_row)
else:
assembly_row = result.RSCountPerAssembly
# Check if we were just loading the same value as a previous run
if assembly_row.count != assembly_counts.get(assembly).get(rs_type):
self.error(f"Count for aggregate per assembly {assembly}, {rs_type} already has a count entry "
f"in the table ({assembly_row.count}) different from the one being loaded "
f"{assembly_counts.get(assembly).get(rs_type)}")
def write_counts_to_db(self):
"""
For all the counts gathered in this self.all_counts_grouped, write them to the db if they do not exist already.
Warn if the count already exists and are different.
"""
session = Session(self.sqlalchemy_engine)
with session.begin():
for count_groups in self.all_counts_grouped:
# All the part of the group should have the same count
count_set = set((count_dict['count'] for count_dict in count_groups))
assert len(count_set) == 1
count = count_set.pop()
# This is used to uniquely identify the count for this group so that loading the same value twice does
# not result in duplicates
group_description = self.group_descriptor(count_groups)
if not group_description:
# One of the taxonomy annotation is missing, we should not load that count
continue
query = select(RSCount).where(RSCount.group_description == group_description)
result = session.execute(query).fetchone()
if result:
rs_count = result.RSCount
else:
rs_count = RSCount(count=count, group_description=group_description)
session.add(rs_count)
for count_dict in count_groups:
query = select(RSCountCategory).where(
RSCountCategory.assembly_accession == count_dict['assembly'],
RSCountCategory.taxonomy_id == count_dict['taxonomy'],
RSCountCategory.rs_type == count_dict['idtype'],
RSCountCategory.release_version == self.release_version,
RSCountCategory.rs_count_id == rs_count.rs_count_id,
)
result = session.execute(query).fetchone()
if not result:
self.info(f"Create persistence for {count_dict['assembly']}, {count_dict['taxonomy']}, {count_dict['idtype']}, {count_dict['count']}")
rs_category = RSCountCategory(
assembly_accession=count_dict['assembly'],
taxonomy_id=count_dict['taxonomy'],
rs_type=count_dict['idtype'],
release_version=self.release_version,
rs_count=rs_count
)
session.add(rs_category)
else:
rs_category = result.RSCountCategory
# Check if we were just loading the same value as a previous run
if rs_category.rs_count.count != count_dict['count']:
self.error(f"{self.count_descriptor(count_dict)} already has a count entry in the table "
f"({rs_category.rs_count.count}) different from the one being loaded "
f"{count_dict['count']}")
self._write_exploded_counts(session)
self._write_per_taxonomy_counts(session)
self._write_per_assembly_counts(session)

def get_assembly_counts_from_database(self):
"""
Expand Down Expand Up @@ -317,19 +413,27 @@ def parse_count_script_logs(self, all_logs):
)
self.all_counts_grouped.append(all_groups)

def generate_per_species_counts(self):
def generate_per_taxonomy_counts(self):
species_counts = defaultdict(Counter)
assembly_lists = defaultdict(dict)
for count_groups in self.all_counts_grouped:
for count_dict in count_groups:
species_counts[count_dict['taxonomy']][count_dict['idtype'] + '_rs'] += count_dict['count']
return species_counts
species_counts[count_dict['taxonomy']][count_dict['idtype']] += count_dict['count']
if count_dict['idtype'] not in assembly_lists.get(count_dict['taxonomy'], {}):
assembly_lists[count_dict['taxonomy']][count_dict['idtype']] = set()
assembly_lists[count_dict['taxonomy']][count_dict['idtype']].add(count_dict['assembly'])
return species_counts, assembly_lists

def generate_per_assembly_counts(self):
assembly_counts = defaultdict(Counter)
taxonomy_lists = defaultdict(dict)
for count_groups in self.all_counts_grouped:
for count_dict in count_groups:
assembly_counts[count_dict['assembly']][count_dict['idtype'] + '_rs'] += count_dict['count']
return assembly_counts
assembly_counts[count_dict['assembly']][count_dict['idtype']] += count_dict['count']
if count_dict['idtype'] not in taxonomy_lists.get(count_dict['assembly'], {}):
taxonomy_lists[count_dict['assembly']][count_dict['idtype']] = set()
taxonomy_lists[count_dict['assembly']][count_dict['idtype']].add(count_dict['taxonomy'])
return assembly_counts, taxonomy_lists

def generate_per_species_assembly_counts(self):
species_assembly_counts = defaultdict(Counter)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,62 +1,13 @@
import sqlalchemy
from sqlalchemy import MetaData, Column, Integer, String, ForeignKey, UniqueConstraint, BigInteger, TEXT, create_engine, \
URL, text, schema
URL, text, schema, ARRAY
from sqlalchemy.orm import declarative_base, mapped_column, relationship


metadata = MetaData(schema="eva_stats")
Base = declarative_base(metadata=metadata)


create_view_taxonomy = """CREATE OR REPLACE VIEW eva_stats.release_rs_count_per_taxonomy AS
WITH count_per_taxonomy AS (
SELECT taxonomy_id, rs_type, release_version, ARRAY_AGG(DISTINCT assembly_accession) AS assembly_accessions, SUM(count) AS count
FROM eva_stats.release_rs_count_category cc
JOIN eva_stats.release_rs_count c ON c.rs_count_id=cc.rs_count_id
GROUP BY taxonomy_id, release_version, rs_type
)
SELECT current.taxonomy_id AS taxonomy_id,
t.scientific_name AS scientific_name,
t.common_name AS common_name,
current.rs_type AS rs_Type,
current.release_version AS release_version,
current.assembly_accessions as assembly_accessions,
current.count AS count,
coalesce(current.count-previous.count, 0) as new
FROM count_per_taxonomy current
LEFT JOIN count_per_taxonomy previous
ON current.release_version=previous.release_version+1 AND current.taxonomy_id=previous.taxonomy_id AND current.rs_type=previous.rs_type
LEFT JOIN evapro.taxonomy t
ON current.taxonomy_id=t.taxonomy_id;
"""

create_view_assembly = """CREATE OR REPLACE VIEW eva_stats.release_rs_count_per_assembly AS
WITH count_per_assembly AS (
SELECT assembly_accession, rs_type, release_version, ARRAY_AGG(DISTINCT taxonomy_id) AS taxonomy_ids, SUM(count) AS count
FROM eva_stats.release_rs_count_category cc
JOIN eva_stats.release_rs_count c ON c.rs_count_id=cc.rs_count_id
GROUP BY assembly_accession, release_version, rs_type
)
SELECT current.assembly_accession AS assembly_accession,
current.rs_type AS rs_Type,
current.release_version AS release_version,
current.taxonomy_ids as taxonomy_ids,
current.count AS count,
coalesce(current.count-previous.count, 0) as new
FROM count_per_assembly current
LEFT JOIN count_per_assembly previous
ON current.release_version=previous.release_version+1 AND current.assembly_accession=previous.assembly_accession AND current.rs_type=previous.rs_type;
"""


def create_views_from_sql(engine):
with engine.begin() as conn:
conn.execute(text(create_view_taxonomy))
with engine.begin() as conn:
conn.execute(text(create_view_assembly))


class RSCountCategory(Base):
"""
Table that provide the metadata associated with a number of RS id
Expand Down Expand Up @@ -88,6 +39,36 @@ class RSCount(Base):
schema = 'eva_stats'


class RSCountPerTaxonomy(Base):
"""
Table that provide the aggregated count per taxonomy
"""
__tablename__ = 'release_rs_count_per_taxonomy'

taxonomy_id = Column(Integer, primary_key=True)
rs_type = Column(String, primary_key=True)
release_version = Column(Integer, primary_key=True)
assembly_accessions = Column(ARRAY(String))
count = Column(BigInteger)
new = Column(BigInteger)
schema = 'eva_stats'


class RSCountPerAssembly(Base):
"""
Table that provide the aggregated count per assembly
"""
__tablename__ = 'release_rs_count_per_assembly'

assembly_accession = Column(String, primary_key=True)
rs_type = Column(String, primary_key=True)
release_version = Column(Integer, primary_key=True)
taxonomy_ids = Column(ARRAY(Integer))
count = Column(BigInteger)
new = Column(BigInteger)
schema = 'eva_stats'


def get_sql_alchemy_engine(dbtype, username, password, host_url, database, port):
engine = create_engine(URL.create(
dbtype + '+psycopg2',
Expand All @@ -101,7 +82,8 @@ def get_sql_alchemy_engine(dbtype, username, password, host_url, database, port)
conn.execute(schema.CreateSchema(RSCount.schema, if_not_exists=True))
RSCount.__table__.create(bind=engine, checkfirst=True)
RSCountCategory.__table__.create(bind=engine, checkfirst=True)
create_views_from_sql(engine)
RSCountPerAssembly.__table__.create(bind=engine, checkfirst=True)
RSCountPerTaxonomy.__table__.create(bind=engine, checkfirst=True)
return engine


Expand Down

0 comments on commit 2713b0a

Please sign in to comment.