Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
arash77 committed Aug 20, 2024
1 parent 1c908d6 commit c2a1014
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
3 changes: 2 additions & 1 deletion lib/galaxy/schema/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Dict,
List,
Optional,
Union,
)

from pydantic import (
Expand Down Expand Up @@ -357,7 +358,7 @@ class VisualizationUpdatePayload(Model):
title="Deleted",
description="Whether this Visualization has been deleted.",
)
config: Optional[dict] = Field(
config: Optional[Union[dict, bytes]] = Field(
{},
title="Config",
description="The config of the visualization.",
Expand Down
44 changes: 24 additions & 20 deletions lib/galaxy/webapps/galaxy/services/visualizations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
from typing import (
cast,
Optional,
Tuple,
Union,
Expand Down Expand Up @@ -120,7 +121,11 @@ def show(
"dbkey": visualization.dbkey,
"slug": visualization.slug,
# to_dict only the latest revision (allow older to be fetched elsewhere)
"latest_revision": self._get_visualization_revision_dict(visualization.latest_revision),
"latest_revision": (
self._get_visualization_revision_dict(visualization.latest_revision)
if visualization.latest_revision
else None
),
# need to encode ids in revisions as well
# NOTE: does not encode ids inside the configs
"revisions": [r.id for r in visualization.revisions],
Expand All @@ -136,10 +141,10 @@ def show(
dictionary["email_hash"] = md5_hash_str(visualization.user.email)
dictionary["tags"] = visualization.make_tag_string_list()
dictionary["annotation"] = get_item_annotation_str(trans.sa_session, trans.user, visualization)
app: StructuredApp = trans.app
app = cast(StructuredApp, trans.app)
if app.visualizations_registry:
visualizations_registry: VisualizationsRegistry = app.visualizations_registry
visualization_plugin: VisualizationPlugin = visualizations_registry.get_plugin(dictionary["type"])
visualizations_registry = cast(VisualizationsRegistry, app.visualizations_registry)
visualization_plugin = cast(VisualizationPlugin, visualizations_registry.get_plugin(dictionary["type"]))
dictionary["plugin"] = visualization_plugin.to_dict()
return VisualizationShowResponse(**dictionary)

Expand All @@ -157,7 +162,7 @@ def create(
if payload.import_id:
visualization = self._import_visualization(trans, payload.import_id)
else:
payload = self._validate_and_parse_payload(payload)
payload = cast(VisualizationCreatePayload, self._validate_and_parse_payload(payload))
# must have a type (I've taken this to be the visualization name)
if not payload.type:
raise exceptions.RequestParameterMissingException("key/value 'type' is required")
Expand All @@ -184,9 +189,7 @@ def create(
with transaction(session):
session.commit()

rval = {"id": visualization.id}

return VisualizationCreateResponse(**rval)
return VisualizationCreateResponse(id=str(visualization.id))

def update(
self,
Expand All @@ -201,7 +204,7 @@ def update(
:returns: dictionary containing Visualization details
"""
rval = None
payload = self._validate_and_parse_payload(payload)
payload = cast(VisualizationUpdatePayload, self._validate_and_parse_payload(payload))

# there's a differentiation here between updating the visualization and creating a new revision
# that needs to be handled clearly here or alternately, using a different controller
Expand All @@ -212,19 +215,20 @@ def update(

# only update owned visualizations
visualization = self._get_visualization(trans, visualization_id, check_ownership=True)
title = payload.title or visualization.latest_revision.title
dbkey = payload.dbkey or visualization.latest_revision.dbkey
latest_revision = cast(VisualizationRevision, visualization.latest_revision)
title = payload.title or latest_revision.title
dbkey = payload.dbkey or latest_revision.dbkey
deleted = payload.deleted or visualization.deleted
config = payload.config or visualization.latest_revision.config
config = payload.config or latest_revision.config

latest_config = visualization.latest_revision.config
latest_config = latest_revision.config
if (
(title != visualization.latest_revision.title)
or (dbkey != visualization.latest_revision.dbkey)
(title != latest_revision.title)
or (dbkey != latest_revision.dbkey)
or (json.dumps(config) != json.dumps(latest_config))
):
revision = self._add_visualization_revision(trans, visualization, config, title, dbkey)
rval = {"id": visualization_id, "revision": revision.id}
rval = {"id": str(visualization_id), "revision": str(revision.id)}

# allow updating vis title
visualization.title = title
Expand Down Expand Up @@ -346,9 +350,9 @@ def _add_visualization_revision(
self,
trans: ProvidesUserContext,
visualization: Visualization,
config: dict,
title: str,
dbkey: str,
config: Optional[Union[dict, bytes]],
title: Optional[str],
dbkey: Optional[str],
) -> VisualizationRevision:
"""
Adds a new `VisualizationRevision` to the given `visualization` with
Expand Down Expand Up @@ -376,7 +380,7 @@ def _create_visualization(
dbkey: Optional[str] = None,
slug: Optional[str] = None,
annotation: Optional[str] = None,
save: bool = True,
save: Optional[bool] = True,
) -> Visualization:
"""Create visualization but not first revision. Returns Visualization object."""
user = trans.get_user()
Expand Down

0 comments on commit c2a1014

Please sign in to comment.