From b441f979a72c6d320f4891a9f435f8bd66ca9209 Mon Sep 17 00:00:00 2001 From: Arash Date: Wed, 21 Aug 2024 14:33:48 +0200 Subject: [PATCH] Addressing the comments --- client/src/api/schema/schema.ts | 28 +++--- lib/galaxy/schema/visualization.py | 42 ++++---- .../webapps/galaxy/api/visualizations.py | 12 +-- .../webapps/galaxy/services/visualizations.py | 98 +++---------------- 4 files changed, 57 insertions(+), 123 deletions(-) diff --git a/client/src/api/schema/schema.ts b/client/src/api/schema/schema.ts index d0061ecb292b..71172c780780 100644 --- a/client/src/api/schema/schema.ts +++ b/client/src/api/schema/schema.ts @@ -16696,11 +16696,6 @@ export interface components { * @description The database key of the visualization. */ dbkey?: string | null; - /** - * Import ID - * @description The ID of the imported visualization. - */ - import_id?: string | null; /** * Save * @description Whether to save the visualization. @@ -16815,11 +16810,12 @@ export interface components { */ id: string; /** - * Model Class - * @description The model class name for this object. - * @default VisualizationRevision + * Model class + * @description The name of the database model class. + * @constant + * @enum {string} */ - model_class: string; + model_class: "VisualizationRevision"; /** * Title * @description The name of the visualization revision. @@ -16861,11 +16857,12 @@ export interface components { */ latest_revision: components["schemas"]["VisualizationRevisionResponse"]; /** - * Model Class - * @description The model class name for this object. - * @default Visualization + * Model class + * @description The name of the database model class. + * @constant + * @enum {string} */ - model_class: string; + model_class: "Visualization"; /** * Plugin * @description The plugin of this Visualization. @@ -32441,7 +32438,10 @@ export interface operations { }; create_api_visualizations_post: { parameters: { - query?: never; + query?: { + /** @description The encoded database identifier of the Visualization to import. */ + import_id?: string | null; + }; header?: { /** @description The user ID that will be used to effectively make this API call. Only admins and designated users can make API calls on behalf of other users. */ "run-as"?: string | null; diff --git a/lib/galaxy/schema/visualization.py b/lib/galaxy/schema/visualization.py index 67f06f2c884c..8b628a4843e8 100644 --- a/lib/galaxy/schema/visualization.py +++ b/lib/galaxy/schema/visualization.py @@ -9,6 +9,7 @@ from pydantic import ( ConfigDict, Field, + field_validator, RootModel, ) from typing_extensions import Literal @@ -16,16 +17,22 @@ from galaxy.schema.fields import ( DecodedDatabaseIdField, EncodedDatabaseIdField, + ModelClassField, ) from galaxy.schema.schema import ( CreateTimeField, Model, TagCollection, UpdateTimeField, + WithModelClass, ) +from galaxy.util.sanitize_html import sanitize_html VisualizationSortByEnum = Literal["create_time", "title", "update_time", "username"] +VISUALIZATION_MODEL_CLASS = Literal["Visualization"] +VISUALIZATION_REVISION_MODEL_CLASS = Literal["VisualizationRevision"] + class VisualizationIndexQueryPayload(Model): deleted: bool = False @@ -104,12 +111,8 @@ class VisualizationSummaryList(RootModel): ) -class VisualizationRevisionResponse(Model): - model_class: str = Field( - "VisualizationRevision", - title="Model Class", - description="The model class name for this object.", - ) +class VisualizationRevisionResponse(Model, WithModelClass): + model_class: VISUALIZATION_REVISION_MODEL_CLASS = ModelClassField(VISUALIZATION_REVISION_MODEL_CLASS) id: EncodedDatabaseIdField = Field( ..., title="ID", @@ -200,12 +203,8 @@ class VisualizationPluginResponse(Model): ) -class VisualizationShowResponse(Model): - model_class: str = Field( - "Visualization", - title="Model Class", - description="The model class name for this object.", - ) +class VisualizationShowResponse(Model, WithModelClass): + model_class: VISUALIZATION_MODEL_CLASS = ModelClassField(VISUALIZATION_MODEL_CLASS) id: EncodedDatabaseIdField = Field( ..., title="ID", @@ -300,11 +299,6 @@ class VisualizationUpdateResponse(Model): class VisualizationCreatePayload(Model): - import_id: Optional[DecodedDatabaseIdField] = Field( - None, - title="Import ID", - description="The ID of the imported visualization.", - ) type: Optional[str] = Field( None, title="Type", @@ -341,6 +335,13 @@ class VisualizationCreatePayload(Model): description="Whether to save the visualization.", ) + @field_validator("type", "title", "dbkey", "slug", "annotation", mode="before") + @classmethod + def sanitize_html_fields(cls, v): + if isinstance(v, str): + return sanitize_html(v) + return v + class VisualizationUpdatePayload(Model): title: Optional[str] = Field( @@ -363,3 +364,10 @@ class VisualizationUpdatePayload(Model): title="Config", description="The config of the visualization.", ) + + @field_validator("title", "dbkey", mode="before") + @classmethod + def sanitize_html_fields(cls, v): + if isinstance(v, str): + return sanitize_html(v) + return v diff --git a/lib/galaxy/webapps/galaxy/api/visualizations.py b/lib/galaxy/webapps/galaxy/api/visualizations.py index e76ed8a94e4b..39c53b2a6c8b 100644 --- a/lib/galaxy/webapps/galaxy/api/visualizations.py +++ b/lib/galaxy/webapps/galaxy/api/visualizations.py @@ -6,10 +6,7 @@ """ import logging -from typing import ( - Optional, - Union, -) +from typing import Optional from fastapi import ( Body, @@ -253,6 +250,9 @@ def show( def create( self, payload: VisualizationCreatePayload = Body(...), + import_id: Optional[DecodedDatabaseIdField] = Query( + None, title="Import ID", description="The encoded database identifier of the Visualization to import." + ), trans: ProvidesUserContext = DependsOnTrans, ) -> VisualizationCreateResponse: """ @@ -262,7 +262,7 @@ def create( POST /api/visualizations?import_id={encoded_visualization_id} imports a copy of an existing visualization into the user's workspace and does not require the rest of the payload """ - return self.service.create(trans, payload) + return self.service.create(trans, import_id, payload) @router.put( "/api/visualizations/{id}", @@ -273,5 +273,5 @@ def update( id: VisualizationIdPathParam, payload: VisualizationUpdatePayload = Body(...), trans: ProvidesUserContext = DependsOnTrans, - ) -> Union[VisualizationUpdateResponse, None]: + ) -> Optional[VisualizationUpdateResponse]: return self.service.update(trans, id, payload) diff --git a/lib/galaxy/webapps/galaxy/services/visualizations.py b/lib/galaxy/webapps/galaxy/services/visualizations.py index 07f782629184..c7536dd5323b 100644 --- a/lib/galaxy/webapps/galaxy/services/visualizations.py +++ b/lib/galaxy/webapps/galaxy/services/visualizations.py @@ -122,7 +122,7 @@ def show( "slug": visualization.slug, # to_dict only the latest revision (allow older to be fetched elsewhere) "latest_revision": ( - self._get_visualization_revision_dict(visualization.latest_revision) + self._get_visualization_revision(visualization.latest_revision) if visualization.latest_revision else None ), @@ -151,6 +151,7 @@ def show( def create( self, trans: ProvidesUserContext, + import_id: Optional[DecodedDatabaseIdField], payload: VisualizationCreatePayload, ) -> VisualizationCreateResponse: """Returns a dictionary of the created visualization @@ -159,10 +160,11 @@ def create( :returns: dictionary containing Visualization details """ - if payload.import_id: - visualization = self._import_visualization(trans, payload.import_id) + if import_id: + visualization = self._import_visualization(trans, import_id) else: - payload = cast(VisualizationCreatePayload, self._validate_and_parse_payload(payload)) + # custom validator to sanitize the HTML, and assign that type to those fields that require it: type, annotation, title, slug, dbkey + # must have a type (I've taken this to be the visualization name) if not payload.type: raise exceptions.RequestParameterMissingException("key/value 'type' is required") @@ -178,9 +180,7 @@ def create( visualization = self._create_visualization(trans, type, title, dbkey, slug, annotation, save) # Create and save first visualization revision - revision = trans.model.VisualizationRevision( - visualization=visualization, title=title, config=config, dbkey=dbkey - ) + revision = VisualizationRevision(visualization=visualization, title=title, config=config, dbkey=dbkey) visualization.latest_revision = revision if save: @@ -196,7 +196,7 @@ def update( trans: ProvidesUserContext, visualization_id: DecodedDatabaseIdField, payload: VisualizationUpdatePayload, - ) -> Union[VisualizationUpdateResponse, None]: + ) -> Optional[VisualizationUpdateResponse]: """ Update a visualization @@ -204,7 +204,6 @@ def update( :returns: dictionary containing Visualization details """ rval = None - payload = cast(VisualizationUpdatePayload, self._validate_and_parse_payload(payload)) # there's a differentiation here between updating the visualization and creating a new revision # that needs to be handled clearly here or alternately, using a different controller @@ -238,76 +237,6 @@ def update( return VisualizationUpdateResponse(**rval) if rval else None - def _validate_and_parse_payload( - self, - payload: Union[VisualizationCreatePayload, VisualizationUpdatePayload], - ) -> Union[VisualizationCreatePayload, VisualizationUpdatePayload]: - """ - Validate and parse incomming data payload for a visualization. - """ - # This layer handles (most of the stricter idiot proofing): - # - unknown/unallowed keys - # - changing data keys from api key to attribute name - # - protection against bad data form/type - # - protection against malicious data content - # all other conversions and processing (such as permissions, etc.) should happen down the line - - # keys listed here don't error when attempting to set, but fail silently - # this allows PUT'ing an entire model back to the server without attribute errors on uneditable attrs - valid_but_uneditable_keys = ( - "id", - "model_class", - # TODO: fill out when we create to_dict, get_dict, whatevs - ) - # TODO: importable - ValidationError = exceptions.RequestParameterInvalidException - - validated_payload = {} - for key, val in payload.model_dump().items(): - # By adding the pydatnic model there will be some variables that are not set and should be ignored in the validation - if val is None: - continue - # TODO: validate types in VALID_TYPES/registry names at the mixin/model level? - if key == "type": - if not isinstance(val, str): - raise ValidationError(f"{key} must be a string or unicode: {str(type(val))}") - val = sanitize_html(val) - elif key == "config": - if not isinstance(val, dict): - raise ValidationError(f"{key} must be a dictionary: {str(type(val))}") - elif key == "annotation": - if not isinstance(val, str): - raise ValidationError(f"{key} must be a string or unicode: {str(type(val))}") - val = sanitize_html(val) - elif key == "deleted": - if not isinstance(val, bool): - raise ValidationError(f"{key} must be a bool: {str(type(val))}") - - # these are keys that actually only be *updated* at the revision level and not here - # (they are still valid for create, tho) - elif key == "title": - if not isinstance(val, str): - raise ValidationError(f"{key} must be a string or unicode: {str(type(val))}") - val = sanitize_html(val) - elif key == "slug": - if not isinstance(val, str): - raise ValidationError(f"{key} must be a string: {str(type(val))}") - val = sanitize_html(val) - elif key == "dbkey": - if not isinstance(val, str): - raise ValidationError(f"{key} must be a string or unicode: {str(type(val))}") - val = sanitize_html(val) - - elif key not in valid_but_uneditable_keys: - continue - # raise AttributeError( 'unknown key: %s' %( str( key ) ) ) - - validated_payload[key] = val - if isinstance(payload, VisualizationCreatePayload): - return VisualizationCreatePayload(**validated_payload) - elif isinstance(payload, VisualizationUpdatePayload): - return VisualizationUpdatePayload(**validated_payload) - def _get_visualization( self, trans: ProvidesUserContext, @@ -318,7 +247,6 @@ def _get_visualization( """ Get a Visualization from the database by id, verifying ownership. """ - # Load workflow from database try: visualization = trans.sa_session.get(Visualization, visualization_id) except TypeError: @@ -328,7 +256,7 @@ def _get_visualization( else: return security_check(trans, visualization, check_ownership, check_accessible) - def _get_visualization_revision_dict( + def _get_visualization_revision( self, revision: VisualizationRevision, ) -> VisualizationRevisionResponse: @@ -361,9 +289,7 @@ def _add_visualization_revision( """ # precondition: only add new revision on owned vis's # TODO:?? should we default title, dbkey, config? to which: visualization or latest_revision? - revision = trans.model.VisualizationRevision( - visualization=visualization, title=title, dbkey=dbkey, config=config - ) + revision = VisualizationRevision(visualization=visualization, title=title, dbkey=dbkey, config=config) visualization.latest_revision = revision # TODO:?? does this automatically add revision to visualzation.revisions? @@ -391,7 +317,7 @@ def _create_visualization( title_err = "visualization name is required" elif slug and not is_valid_slug(slug): slug_err = "visualization identifier must consist of only lowercase letters, numbers, and the '-' character" - elif slug and slug_exists(trans.sa_session, trans.model.Visualization, user, slug, ignore_deleted=True): + elif slug and slug_exists(trans.sa_session, Visualization, user, slug, ignore_deleted=True): slug_err = "visualization identifier must be unique" if title_err or slug_err: @@ -400,7 +326,7 @@ def _create_visualization( raise exceptions.RequestParameterMissingException(val_err) # Create visualization - visualization = trans.model.Visualization(user=user, title=title, dbkey=dbkey, type=type) + visualization = Visualization(user=user, title=title, dbkey=dbkey, type=type) if slug: visualization.slug = slug else: