From c2a1014725535158b25ad49983ba436669044557 Mon Sep 17 00:00:00 2001 From: Arash Date: Tue, 20 Aug 2024 12:23:39 +0200 Subject: [PATCH] Fix mypy errors --- lib/galaxy/schema/visualization.py | 3 +- .../webapps/galaxy/services/visualizations.py | 44 ++++++++++--------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/lib/galaxy/schema/visualization.py b/lib/galaxy/schema/visualization.py index 9efd344eb4f0..fe859a34b204 100644 --- a/lib/galaxy/schema/visualization.py +++ b/lib/galaxy/schema/visualization.py @@ -3,6 +3,7 @@ Dict, List, Optional, + Union, ) from pydantic import ( @@ -357,7 +358,7 @@ class VisualizationUpdatePayload(Model): title="Deleted", description="Whether this Visualization has been deleted.", ) - config: Optional[dict] = Field( + config: Optional[Union[dict, bytes]] = Field( {}, title="Config", description="The config of the visualization.", diff --git a/lib/galaxy/webapps/galaxy/services/visualizations.py b/lib/galaxy/webapps/galaxy/services/visualizations.py index 562b9ed289de..07f782629184 100644 --- a/lib/galaxy/webapps/galaxy/services/visualizations.py +++ b/lib/galaxy/webapps/galaxy/services/visualizations.py @@ -1,6 +1,7 @@ import json import logging from typing import ( + cast, Optional, Tuple, Union, @@ -120,7 +121,11 @@ def show( "dbkey": visualization.dbkey, "slug": visualization.slug, # to_dict only the latest revision (allow older to be fetched elsewhere) - "latest_revision": self._get_visualization_revision_dict(visualization.latest_revision), + "latest_revision": ( + self._get_visualization_revision_dict(visualization.latest_revision) + if visualization.latest_revision + else None + ), # need to encode ids in revisions as well # NOTE: does not encode ids inside the configs "revisions": [r.id for r in visualization.revisions], @@ -136,10 +141,10 @@ def show( dictionary["email_hash"] = md5_hash_str(visualization.user.email) dictionary["tags"] = visualization.make_tag_string_list() dictionary["annotation"] = get_item_annotation_str(trans.sa_session, trans.user, visualization) - app: StructuredApp = trans.app + app = cast(StructuredApp, trans.app) if app.visualizations_registry: - visualizations_registry: VisualizationsRegistry = app.visualizations_registry - visualization_plugin: VisualizationPlugin = visualizations_registry.get_plugin(dictionary["type"]) + visualizations_registry = cast(VisualizationsRegistry, app.visualizations_registry) + visualization_plugin = cast(VisualizationPlugin, visualizations_registry.get_plugin(dictionary["type"])) dictionary["plugin"] = visualization_plugin.to_dict() return VisualizationShowResponse(**dictionary) @@ -157,7 +162,7 @@ def create( if payload.import_id: visualization = self._import_visualization(trans, payload.import_id) else: - payload = self._validate_and_parse_payload(payload) + payload = cast(VisualizationCreatePayload, self._validate_and_parse_payload(payload)) # must have a type (I've taken this to be the visualization name) if not payload.type: raise exceptions.RequestParameterMissingException("key/value 'type' is required") @@ -184,9 +189,7 @@ def create( with transaction(session): session.commit() - rval = {"id": visualization.id} - - return VisualizationCreateResponse(**rval) + return VisualizationCreateResponse(id=str(visualization.id)) def update( self, @@ -201,7 +204,7 @@ def update( :returns: dictionary containing Visualization details """ rval = None - payload = self._validate_and_parse_payload(payload) + payload = cast(VisualizationUpdatePayload, self._validate_and_parse_payload(payload)) # there's a differentiation here between updating the visualization and creating a new revision # that needs to be handled clearly here or alternately, using a different controller @@ -212,19 +215,20 @@ def update( # only update owned visualizations visualization = self._get_visualization(trans, visualization_id, check_ownership=True) - title = payload.title or visualization.latest_revision.title - dbkey = payload.dbkey or visualization.latest_revision.dbkey + latest_revision = cast(VisualizationRevision, visualization.latest_revision) + title = payload.title or latest_revision.title + dbkey = payload.dbkey or latest_revision.dbkey deleted = payload.deleted or visualization.deleted - config = payload.config or visualization.latest_revision.config + config = payload.config or latest_revision.config - latest_config = visualization.latest_revision.config + latest_config = latest_revision.config if ( - (title != visualization.latest_revision.title) - or (dbkey != visualization.latest_revision.dbkey) + (title != latest_revision.title) + or (dbkey != latest_revision.dbkey) or (json.dumps(config) != json.dumps(latest_config)) ): revision = self._add_visualization_revision(trans, visualization, config, title, dbkey) - rval = {"id": visualization_id, "revision": revision.id} + rval = {"id": str(visualization_id), "revision": str(revision.id)} # allow updating vis title visualization.title = title @@ -346,9 +350,9 @@ def _add_visualization_revision( self, trans: ProvidesUserContext, visualization: Visualization, - config: dict, - title: str, - dbkey: str, + config: Optional[Union[dict, bytes]], + title: Optional[str], + dbkey: Optional[str], ) -> VisualizationRevision: """ Adds a new `VisualizationRevision` to the given `visualization` with @@ -376,7 +380,7 @@ def _create_visualization( dbkey: Optional[str] = None, slug: Optional[str] = None, annotation: Optional[str] = None, - save: bool = True, + save: Optional[bool] = True, ) -> Visualization: """Create visualization but not first revision. Returns Visualization object.""" user = trans.get_user()