Skip to content

Commit

Permalink
Fix SA2.0 usage in managers.workflows (4)
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Oct 19, 2023
1 parent ba15434 commit 104b97b
Showing 1 changed file with 41 additions and 39 deletions.
80 changes: 41 additions & 39 deletions lib/galaxy/managers/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Union,
)

import sqlalchemy
from gxformat2 import (
from_galaxy_native,
ImporterGalaxyInterface,
Expand Down Expand Up @@ -57,7 +58,9 @@
ImplicitCollectionJobsJobAssociation,
Job,
StoredWorkflow,
StoredWorkflowTagAssociation,
StoredWorkflowUserShareAssociation,
User,
Workflow,
WorkflowInvocation,
WorkflowInvocationStep,
Expand Down Expand Up @@ -144,7 +147,7 @@ def __init__(self, app: MinimalManagerApp):

def index_query(
self, trans: ProvidesUserContext, payload: WorkflowIndexQueryPayload, include_total_count: bool = False
) -> Tuple[Query, Optional[int]]:
) -> Tuple[sqlalchemy.engine.Result, Optional[int]]:
show_published = payload.show_published
show_hidden = payload.show_hidden
show_deleted = payload.show_deleted
Expand All @@ -161,99 +164,98 @@ def index_query(
raise exceptions.RequestParameterInvalidException(message)

filters = [
model.StoredWorkflow.user == trans.user,
StoredWorkflow.user == trans.user,
]
user = trans.user
if user and show_shared:
filters.append(model.StoredWorkflowUserShareAssociation.user == user)

if show_published or user is None and show_published is None:
filters.append(model.StoredWorkflow.published == true())
filters.append(StoredWorkflow.published == true())

query = trans.sa_session.query(model.StoredWorkflow)
stmt = select(StoredWorkflow)
if show_shared:
query = query.outerjoin(model.StoredWorkflow.users_shared_with)
query = query.outerjoin(model.StoredWorkflow.tags)
stmt = stmt.outerjoin(StoredWorkflow.users_shared_with)
stmt = stmt.outerjoin(StoredWorkflow.tags)

latest_workflow_load = joinedload(model.StoredWorkflow.latest_workflow)
latest_workflow_load = joinedload(StoredWorkflow.latest_workflow)
if not payload.skip_step_counts:
latest_workflow_load = latest_workflow_load.undefer("step_count")
latest_workflow_load = latest_workflow_load.lazyload(model.Workflow.steps)
latest_workflow_load = latest_workflow_load.lazyload(Workflow.steps)

query = query.options(joinedload(model.StoredWorkflow.annotations))
query = query.options(latest_workflow_load)
query = query.filter(or_(*filters))
query = query.filter(model.StoredWorkflow.table.c.hidden == (true() if show_hidden else false()))
stmt = stmt.options(joinedload(StoredWorkflow.annotations))
stmt = stmt.options(latest_workflow_load)
stmt = stmt.where(or_(*filters))
stmt = stmt.where(StoredWorkflow.hidden == (true() if show_hidden else false()))
if payload.search:
search_query = payload.search
parsed_search = parse_filters_structured(search_query, INDEX_SEARCH_FILTERS)

def w_tag_filter(term_text: str, quoted: bool):
nonlocal query
alias = aliased(model.StoredWorkflowTagAssociation)
query = query.outerjoin(model.StoredWorkflow.tags.of_type(alias))
def w_tag_filter(stmt, term_text: str, quoted: bool):
alias = aliased(StoredWorkflowTagAssociation)
stmt = stmt.outerjoin(StoredWorkflow.tags.of_type(alias))
return tag_filter(alias, term_text, quoted)

def name_filter(term):
return text_column_filter(model.StoredWorkflow.name, term)
return text_column_filter(StoredWorkflow.name, term)

for term in parsed_search.terms:
if isinstance(term, FilteredTerm):
key = term.filter
q = term.text
if key == "tag":
tf = w_tag_filter(term.text, term.quoted)
query = query.filter(tf)
tf = w_tag_filter(stmt, term.text, term.quoted)
stmt = stmt.where(tf)
elif key == "name":
query = query.filter(name_filter(term))
stmt = stmt.where(name_filter(term))
elif key == "user":
query = append_user_filter(query, model.StoredWorkflow, term)
stmt = append_user_filter(stmt, StoredWorkflow, term)
elif key == "is":
if q == "published":
query = query.filter(model.StoredWorkflow.published == true())
stmt = stmt.where(StoredWorkflow.published == true())
elif q == "importable":
query = query.filter(model.StoredWorkflow.importable == true())
stmt = stmt.where(StoredWorkflow.importable == true())
elif q == "deleted":
query = query.filter(model.StoredWorkflow.deleted == true())
stmt = stmt.where(StoredWorkflow.deleted == true())
show_deleted = true
elif q == "shared_with_me":
if not show_shared:
message = "Can only use tag is:shared_with_me if show_shared parameter also true."
raise exceptions.RequestParameterInvalidException(message)
query = query.filter(model.StoredWorkflowUserShareAssociation.user == user)
stmt = stmt.where(StoredWorkflowUserShareAssociation.user == user)
elif isinstance(term, RawTextTerm):
tf = w_tag_filter(term.text, False)
alias = aliased(model.User)
query = query.outerjoin(model.StoredWorkflow.user.of_type(alias))
query = query.filter(
tf = w_tag_filter(stmt, term.text, False)
alias = aliased(User)
stmt = stmt.outerjoin(StoredWorkflow.user.of_type(alias))
stmt = stmt.where(
raw_text_column_filter(
[
model.StoredWorkflow.name,
StoredWorkflow.name,
tf,
alias.username,
],
term,
)
)
query = query.filter(model.StoredWorkflow.table.c.deleted == (true() if show_deleted else false()))
stmt = stmt.where(StoredWorkflow.deleted == (true() if show_deleted else false()))
if include_total_count:
total_matches = query.count()
total_matches = get_count(trans.sa_session, stmt)
else:
total_matches = None
if payload.sort_by is None:
if user:
query = query.order_by(desc(model.StoredWorkflow.user == user))
query = query.order_by(desc(model.StoredWorkflow.table.c.update_time))
stmt = stmt.order_by(desc(StoredWorkflow.user == user))
stmt = stmt.order_by(desc(StoredWorkflow.update_time))
else:
sort_column = getattr(model.StoredWorkflow, payload.sort_by)
sort_column = getattr(StoredWorkflow, payload.sort_by)
if payload.sort_desc:
sort_column = sort_column.desc()
query = query.order_by(sort_column)
stmt = stmt.order_by(sort_column)
if payload.limit is not None:
query = query.limit(payload.limit)
stmt = stmt.limit(payload.limit)
if payload.offset is not None:
query = query.offset(payload.offset)
return query, total_matches
stmt = stmt.offset(payload.offset)
return trans.sa_session.scalars(stmt), total_matches

def get_stored_workflow(self, trans, workflow_id, by_stored_id=True) -> StoredWorkflow:
"""Use a supplied ID (UUID or encoded stored workflow ID) to find
Expand Down

0 comments on commit 104b97b

Please sign in to comment.