Skip to content

Commit

Permalink
Addressing the comments
Browse files Browse the repository at this point in the history
  • Loading branch information
arash77 committed Aug 21, 2024
1 parent b73333a commit b441f97
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 123 deletions.
28 changes: 14 additions & 14 deletions client/src/api/schema/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16696,11 +16696,6 @@ export interface components {
* @description The database key of the visualization.
*/
dbkey?: string | null;
/**
* Import ID
* @description The ID of the imported visualization.
*/
import_id?: string | null;
/**
* Save
* @description Whether to save the visualization.
Expand Down Expand Up @@ -16815,11 +16810,12 @@ export interface components {
*/
id: string;
/**
* Model Class
* @description The model class name for this object.
* @default VisualizationRevision
* Model class
* @description The name of the database model class.
* @constant
* @enum {string}
*/
model_class: string;
model_class: "VisualizationRevision";
/**
* Title
* @description The name of the visualization revision.
Expand Down Expand Up @@ -16861,11 +16857,12 @@ export interface components {
*/
latest_revision: components["schemas"]["VisualizationRevisionResponse"];
/**
* Model Class
* @description The model class name for this object.
* @default Visualization
* Model class
* @description The name of the database model class.
* @constant
* @enum {string}
*/
model_class: string;
model_class: "Visualization";
/**
* Plugin
* @description The plugin of this Visualization.
Expand Down Expand Up @@ -32441,7 +32438,10 @@ export interface operations {
};
create_api_visualizations_post: {
parameters: {
query?: never;
query?: {
/** @description The encoded database identifier of the Visualization to import. */
import_id?: string | null;
};
header?: {
/** @description The user ID that will be used to effectively make this API call. Only admins and designated users can make API calls on behalf of other users. */
"run-as"?: string | null;
Expand Down
42 changes: 25 additions & 17 deletions lib/galaxy/schema/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,30 @@
from pydantic import (
ConfigDict,
Field,
field_validator,
RootModel,
)
from typing_extensions import Literal

from galaxy.schema.fields import (
DecodedDatabaseIdField,
EncodedDatabaseIdField,
ModelClassField,
)
from galaxy.schema.schema import (
CreateTimeField,
Model,
TagCollection,
UpdateTimeField,
WithModelClass,
)
from galaxy.util.sanitize_html import sanitize_html

VisualizationSortByEnum = Literal["create_time", "title", "update_time", "username"]

VISUALIZATION_MODEL_CLASS = Literal["Visualization"]
VISUALIZATION_REVISION_MODEL_CLASS = Literal["VisualizationRevision"]


class VisualizationIndexQueryPayload(Model):
deleted: bool = False
Expand Down Expand Up @@ -104,12 +111,8 @@ class VisualizationSummaryList(RootModel):
)


class VisualizationRevisionResponse(Model):
model_class: str = Field(
"VisualizationRevision",
title="Model Class",
description="The model class name for this object.",
)
class VisualizationRevisionResponse(Model, WithModelClass):
model_class: VISUALIZATION_REVISION_MODEL_CLASS = ModelClassField(VISUALIZATION_REVISION_MODEL_CLASS)
id: EncodedDatabaseIdField = Field(
...,
title="ID",
Expand Down Expand Up @@ -200,12 +203,8 @@ class VisualizationPluginResponse(Model):
)


class VisualizationShowResponse(Model):
model_class: str = Field(
"Visualization",
title="Model Class",
description="The model class name for this object.",
)
class VisualizationShowResponse(Model, WithModelClass):
model_class: VISUALIZATION_MODEL_CLASS = ModelClassField(VISUALIZATION_MODEL_CLASS)
id: EncodedDatabaseIdField = Field(
...,
title="ID",
Expand Down Expand Up @@ -300,11 +299,6 @@ class VisualizationUpdateResponse(Model):


class VisualizationCreatePayload(Model):
import_id: Optional[DecodedDatabaseIdField] = Field(
None,
title="Import ID",
description="The ID of the imported visualization.",
)
type: Optional[str] = Field(
None,
title="Type",
Expand Down Expand Up @@ -341,6 +335,13 @@ class VisualizationCreatePayload(Model):
description="Whether to save the visualization.",
)

@field_validator("type", "title", "dbkey", "slug", "annotation", mode="before")
@classmethod
def sanitize_html_fields(cls, v):
if isinstance(v, str):
return sanitize_html(v)
return v


class VisualizationUpdatePayload(Model):
title: Optional[str] = Field(
Expand All @@ -363,3 +364,10 @@ class VisualizationUpdatePayload(Model):
title="Config",
description="The config of the visualization.",
)

@field_validator("title", "dbkey", mode="before")
@classmethod
def sanitize_html_fields(cls, v):
if isinstance(v, str):
return sanitize_html(v)
return v
12 changes: 6 additions & 6 deletions lib/galaxy/webapps/galaxy/api/visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
"""

import logging
from typing import (
Optional,
Union,
)
from typing import Optional

from fastapi import (
Body,
Expand Down Expand Up @@ -253,6 +250,9 @@ def show(
def create(
self,
payload: VisualizationCreatePayload = Body(...),
import_id: Optional[DecodedDatabaseIdField] = Query(
None, title="Import ID", description="The encoded database identifier of the Visualization to import."
),
trans: ProvidesUserContext = DependsOnTrans,
) -> VisualizationCreateResponse:
"""
Expand All @@ -262,7 +262,7 @@ def create(
POST /api/visualizations?import_id={encoded_visualization_id}
imports a copy of an existing visualization into the user's workspace and does not require the rest of the payload
"""
return self.service.create(trans, payload)
return self.service.create(trans, import_id, payload)

@router.put(
"/api/visualizations/{id}",
Expand All @@ -273,5 +273,5 @@ def update(
id: VisualizationIdPathParam,
payload: VisualizationUpdatePayload = Body(...),
trans: ProvidesUserContext = DependsOnTrans,
) -> Union[VisualizationUpdateResponse, None]:
) -> Optional[VisualizationUpdateResponse]:
return self.service.update(trans, id, payload)
98 changes: 12 additions & 86 deletions lib/galaxy/webapps/galaxy/services/visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def show(
"slug": visualization.slug,
# to_dict only the latest revision (allow older to be fetched elsewhere)
"latest_revision": (
self._get_visualization_revision_dict(visualization.latest_revision)
self._get_visualization_revision(visualization.latest_revision)
if visualization.latest_revision
else None
),
Expand Down Expand Up @@ -151,6 +151,7 @@ def show(
def create(
self,
trans: ProvidesUserContext,
import_id: Optional[DecodedDatabaseIdField],
payload: VisualizationCreatePayload,
) -> VisualizationCreateResponse:
"""Returns a dictionary of the created visualization
Expand All @@ -159,10 +160,11 @@ def create(
:returns: dictionary containing Visualization details
"""

if payload.import_id:
visualization = self._import_visualization(trans, payload.import_id)
if import_id:
visualization = self._import_visualization(trans, import_id)
else:
payload = cast(VisualizationCreatePayload, self._validate_and_parse_payload(payload))
# custom validator to sanitize the HTML, and assign that type to those fields that require it: type, annotation, title, slug, dbkey

# must have a type (I've taken this to be the visualization name)
if not payload.type:
raise exceptions.RequestParameterMissingException("key/value 'type' is required")
Expand All @@ -178,9 +180,7 @@ def create(
visualization = self._create_visualization(trans, type, title, dbkey, slug, annotation, save)

# Create and save first visualization revision
revision = trans.model.VisualizationRevision(
visualization=visualization, title=title, config=config, dbkey=dbkey
)
revision = VisualizationRevision(visualization=visualization, title=title, config=config, dbkey=dbkey)
visualization.latest_revision = revision

if save:
Expand All @@ -196,15 +196,14 @@ def update(
trans: ProvidesUserContext,
visualization_id: DecodedDatabaseIdField,
payload: VisualizationUpdatePayload,
) -> Union[VisualizationUpdateResponse, None]:
) -> Optional[VisualizationUpdateResponse]:
"""
Update a visualization
:rtype: dictionary
:returns: dictionary containing Visualization details
"""
rval = None
payload = cast(VisualizationUpdatePayload, self._validate_and_parse_payload(payload))

# there's a differentiation here between updating the visualization and creating a new revision
# that needs to be handled clearly here or alternately, using a different controller
Expand Down Expand Up @@ -238,76 +237,6 @@ def update(

return VisualizationUpdateResponse(**rval) if rval else None

def _validate_and_parse_payload(
self,
payload: Union[VisualizationCreatePayload, VisualizationUpdatePayload],
) -> Union[VisualizationCreatePayload, VisualizationUpdatePayload]:
"""
Validate and parse incomming data payload for a visualization.
"""
# This layer handles (most of the stricter idiot proofing):
# - unknown/unallowed keys
# - changing data keys from api key to attribute name
# - protection against bad data form/type
# - protection against malicious data content
# all other conversions and processing (such as permissions, etc.) should happen down the line

# keys listed here don't error when attempting to set, but fail silently
# this allows PUT'ing an entire model back to the server without attribute errors on uneditable attrs
valid_but_uneditable_keys = (
"id",
"model_class",
# TODO: fill out when we create to_dict, get_dict, whatevs
)
# TODO: importable
ValidationError = exceptions.RequestParameterInvalidException

validated_payload = {}
for key, val in payload.model_dump().items():
# By adding the pydatnic model there will be some variables that are not set and should be ignored in the validation
if val is None:
continue
# TODO: validate types in VALID_TYPES/registry names at the mixin/model level?
if key == "type":
if not isinstance(val, str):
raise ValidationError(f"{key} must be a string or unicode: {str(type(val))}")
val = sanitize_html(val)
elif key == "config":
if not isinstance(val, dict):
raise ValidationError(f"{key} must be a dictionary: {str(type(val))}")
elif key == "annotation":
if not isinstance(val, str):
raise ValidationError(f"{key} must be a string or unicode: {str(type(val))}")
val = sanitize_html(val)
elif key == "deleted":
if not isinstance(val, bool):
raise ValidationError(f"{key} must be a bool: {str(type(val))}")

# these are keys that actually only be *updated* at the revision level and not here
# (they are still valid for create, tho)
elif key == "title":
if not isinstance(val, str):
raise ValidationError(f"{key} must be a string or unicode: {str(type(val))}")
val = sanitize_html(val)
elif key == "slug":
if not isinstance(val, str):
raise ValidationError(f"{key} must be a string: {str(type(val))}")
val = sanitize_html(val)
elif key == "dbkey":
if not isinstance(val, str):
raise ValidationError(f"{key} must be a string or unicode: {str(type(val))}")
val = sanitize_html(val)

elif key not in valid_but_uneditable_keys:
continue
# raise AttributeError( 'unknown key: %s' %( str( key ) ) )

validated_payload[key] = val
if isinstance(payload, VisualizationCreatePayload):
return VisualizationCreatePayload(**validated_payload)
elif isinstance(payload, VisualizationUpdatePayload):
return VisualizationUpdatePayload(**validated_payload)

def _get_visualization(
self,
trans: ProvidesUserContext,
Expand All @@ -318,7 +247,6 @@ def _get_visualization(
"""
Get a Visualization from the database by id, verifying ownership.
"""
# Load workflow from database
try:
visualization = trans.sa_session.get(Visualization, visualization_id)
except TypeError:
Expand All @@ -328,7 +256,7 @@ def _get_visualization(
else:
return security_check(trans, visualization, check_ownership, check_accessible)

def _get_visualization_revision_dict(
def _get_visualization_revision(
self,
revision: VisualizationRevision,
) -> VisualizationRevisionResponse:
Expand Down Expand Up @@ -361,9 +289,7 @@ def _add_visualization_revision(
"""
# precondition: only add new revision on owned vis's
# TODO:?? should we default title, dbkey, config? to which: visualization or latest_revision?
revision = trans.model.VisualizationRevision(
visualization=visualization, title=title, dbkey=dbkey, config=config
)
revision = VisualizationRevision(visualization=visualization, title=title, dbkey=dbkey, config=config)

visualization.latest_revision = revision
# TODO:?? does this automatically add revision to visualzation.revisions?
Expand Down Expand Up @@ -391,7 +317,7 @@ def _create_visualization(
title_err = "visualization name is required"
elif slug and not is_valid_slug(slug):
slug_err = "visualization identifier must consist of only lowercase letters, numbers, and the '-' character"
elif slug and slug_exists(trans.sa_session, trans.model.Visualization, user, slug, ignore_deleted=True):
elif slug and slug_exists(trans.sa_session, Visualization, user, slug, ignore_deleted=True):
slug_err = "visualization identifier must be unique"

if title_err or slug_err:
Expand All @@ -400,7 +326,7 @@ def _create_visualization(
raise exceptions.RequestParameterMissingException(val_err)

# Create visualization
visualization = trans.model.Visualization(user=user, title=title, dbkey=dbkey, type=type)
visualization = Visualization(user=user, title=title, dbkey=dbkey, type=type)
if slug:
visualization.slug = slug
else:
Expand Down

0 comments on commit b441f97

Please sign in to comment.