diff --git a/src/zenml/client.py b/src/zenml/client.py index 3b50dc757ac..0441db7b973 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -1702,6 +1702,7 @@ def list_services( updated: Optional[datetime] = None, type: Optional[str] = None, flavor: Optional[str] = None, + user: Optional[Union[UUID, str]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, hydrate: bool = False, @@ -1727,6 +1728,7 @@ def list_services( flavor: Use the service flavor for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. running: Use the running status for filtering @@ -1753,6 +1755,7 @@ def list_services( flavor=flavor, workspace_id=workspace_id, user_id=user_id, + user=user, running=running, name=service_name, pipeline_name=pipeline_name, @@ -2249,6 +2252,7 @@ def list_flavors( type: Optional[str] = None, integration: Optional[str] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[FlavorResponse]: """Fetches all the flavor models. @@ -2262,6 +2266,7 @@ def list_flavors( created: Use to flavors by time of creation updated: Use the last updated date for filtering user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the flavor to filter by. type: The type of the flavor to filter by. integration: The integration of the flavor to filter by. @@ -2277,6 +2282,7 @@ def list_flavors( sort_by=sort_by, logical_operator=logical_operator, user_id=user_id, + user=user, name=name, type=type, integration=integration, @@ -2661,6 +2667,7 @@ def list_builds( updated: Optional[Union[datetime, str]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, pipeline_id: Optional[Union[str, UUID]] = None, stack_id: Optional[Union[str, UUID]] = None, container_registry_id: Optional[Union[UUID, str]] = None, @@ -2684,6 +2691,7 @@ def list_builds( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_id: The id of the pipeline to filter by. stack_id: The id of the stack to filter by. container_registry_id: The id of the container registry to @@ -2710,6 +2718,7 @@ def list_builds( updated=updated, workspace_id=workspace_id, user_id=user_id, + user=user, pipeline_id=pipeline_id, stack_id=stack_id, container_registry_id=container_registry_id, @@ -2778,7 +2787,7 @@ def get_event_source( allow_name_prefix_match: bool = True, hydrate: bool = True, ) -> EventSourceResponse: - """Get a event source by name, ID or prefix. + """Get an event source by name, ID or prefix. Args: name_id_or_prefix: The name, ID or prefix of the stack. @@ -2811,6 +2820,7 @@ def list_event_sources( event_source_type: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[EventSourceResponse]: """Lists all event_sources. @@ -2825,6 +2835,7 @@ def list_event_sources( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the event_source to filter by. flavor: The flavor of the event_source to filter by. event_source_type: The subtype of the event_source to filter by. @@ -2841,6 +2852,7 @@ def list_event_sources( logical_operator=logical_operator, workspace_id=workspace_id, user_id=user_id, + user=user, name=name, flavor=flavor, plugin_subtype=event_source_type, @@ -3008,6 +3020,7 @@ def list_actions( action_type: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[ActionResponse]: """List actions. @@ -3022,6 +3035,7 @@ def list_actions( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the action to filter by. flavor: The flavor of the action to filter by. action_type: The type of the action to filter by. @@ -3038,6 +3052,7 @@ def list_actions( logical_operator=logical_operator, workspace_id=workspace_id, user_id=user_id, + user=user, name=name, id=id, flavor=flavor, @@ -3186,6 +3201,7 @@ def list_triggers( action_subtype: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[TriggerResponse]: """Lists all triggers. @@ -3200,6 +3216,7 @@ def list_triggers( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the trigger to filter by. event_source_id: The event source associated with the trigger. action_id: The action associated with the trigger. @@ -3222,6 +3239,7 @@ def list_triggers( logical_operator=logical_operator, workspace_id=workspace_id, user_id=user_id, + user=user, name=name, event_source_id=event_source_id, action_id=action_id, @@ -3372,6 +3390,7 @@ def list_deployments( updated: Optional[Union[datetime, str]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, pipeline_id: Optional[Union[str, UUID]] = None, stack_id: Optional[Union[str, UUID]] = None, build_id: Optional[Union[str, UUID]] = None, @@ -3390,6 +3409,7 @@ def list_deployments( updated: Use the last updated date for filtering workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_id: The id of the pipeline to filter by. stack_id: The id of the stack to filter by. build_id: The id of the build to filter by. @@ -3410,6 +3430,7 @@ def list_deployments( updated=updated, workspace_id=workspace_id, user_id=user_id, + user=user, pipeline_id=pipeline_id, stack_id=stack_id, build_id=build_id, @@ -3660,6 +3681,7 @@ def list_schedules( name: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, pipeline_id: Optional[Union[str, UUID]] = None, orchestrator_id: Optional[Union[str, UUID]] = None, active: Optional[Union[str, bool]] = None, @@ -3684,6 +3706,7 @@ def list_schedules( name: The name of the stack to filter by. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_id: The id of the pipeline to filter by. orchestrator_id: The id of the orchestrator to filter by. active: Use to filter by active status. @@ -3710,6 +3733,7 @@ def list_schedules( name=name, workspace_id=workspace_id, user_id=user_id, + user=user, pipeline_id=pipeline_id, orchestrator_id=orchestrator_id, active=active, @@ -3950,6 +3974,7 @@ def list_run_steps( original_step_run_id: Optional[Union[str, UUID]] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, model_version_id: Optional[Union[str, UUID]] = None, model: Optional[Union[UUID, str]] = None, hydrate: bool = False, @@ -3968,6 +3993,7 @@ def list_run_steps( end_time: Use to filter by the time when the step finished running workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. pipeline_run_id: The id of the pipeline run to filter by. deployment_id: The id of the deployment to filter by. original_step_run_id: The id of the original step run to filter by. @@ -4002,6 +4028,7 @@ def list_run_steps( name=name, workspace_id=workspace_id, user_id=user_id, + user=user, model_version_id=model_version_id, model=model, ) @@ -4674,6 +4701,7 @@ def list_secrets( scope: Optional[SecretScope] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[SecretResponse]: """Fetches all the secret models. @@ -4693,6 +4721,7 @@ def list_secrets( scope: The scope of the secret to filter by. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -4709,6 +4738,7 @@ def list_secrets( sort_by=sort_by, logical_operator=logical_operator, user_id=user_id, + user=user, workspace_id=workspace_id, name=name, scope=scope, @@ -5023,6 +5053,7 @@ def list_code_repositories( name: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[CodeRepositoryResponse]: """List all code repositories. @@ -5038,6 +5069,7 @@ def list_code_repositories( name: The name of the code repository to filter by. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -5055,6 +5087,7 @@ def list_code_repositories( name=name, workspace_id=workspace_id, user_id=user_id, + user=user, ) filter_model.set_scope_workspace(self.active_workspace.id) return self.zen_store.list_code_repositories( @@ -5415,6 +5448,7 @@ def list_service_connectors( resource_id: Optional[str] = None, workspace_id: Optional[Union[str, UUID]] = None, user_id: Optional[Union[str, UUID]] = None, + user: Optional[Union[UUID, str]] = None, labels: Optional[Dict[str, Optional[str]]] = None, secret_id: Optional[Union[str, UUID]] = None, hydrate: bool = False, @@ -5437,6 +5471,7 @@ def list_service_connectors( they can give access to. workspace_id: The id of the workspace to filter by. user_id: The id of the user to filter by. + user: Filter by user name/ID. name: The name of the service connector to filter by. labels: The labels of the service connector to filter by. secret_id: Filter by the id of the secret that is referenced by the @@ -5454,6 +5489,7 @@ def list_service_connectors( logical_operator=logical_operator, workspace_id=workspace_id or self.active_workspace.id, user_id=user_id, + user=user, name=name, connector_type=connector_type, auth_method=auth_method, @@ -6606,6 +6642,7 @@ def list_authorized_devices( client_id: Union[UUID, str, None] = None, status: Union[OAuthDeviceStatus, str, None] = None, trusted_device: Union[bool, str, None] = None, + user: Optional[Union[UUID, str]] = None, failed_auth_attempts: Union[int, str, None] = None, last_login: Optional[Union[datetime, str, None]] = None, hydrate: bool = False, @@ -6623,6 +6660,7 @@ def list_authorized_devices( expires: Use the expiration date for filtering. client_id: Use the client id for filtering. status: Use the status for filtering. + user: Filter by user name/ID. trusted_device: Use the trusted device flag for filtering. failed_auth_attempts: Use the failed auth attempts for filtering. last_login: Use the last login date for filtering. @@ -6642,6 +6680,7 @@ def list_authorized_devices( updated=updated, expires=expires, client_id=client_id, + user=user, status=status, trusted_device=trusted_device, failed_auth_attempts=failed_auth_attempts, @@ -6740,7 +6779,7 @@ def get_trigger_execution( trigger_execution_id: UUID, hydrate: bool = True, ) -> TriggerExecutionResponse: - """Get an trigger execution by ID. + """Get a trigger execution by ID. Args: trigger_execution_id: The ID of the trigger execution to get. @@ -6761,6 +6800,7 @@ def list_trigger_executions( size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, trigger_id: Optional[UUID] = None, + user: Optional[Union[UUID, str]] = None, hydrate: bool = False, ) -> Page[TriggerExecutionResponse]: """List all trigger executions matching the given filter criteria. @@ -6771,6 +6811,7 @@ def list_trigger_executions( size: The maximum size of all pages. logical_operator: Which logical operator to use [and, or]. trigger_id: ID of the trigger to filter by. + user: Filter by user name/ID. hydrate: Flag deciding whether to hydrate the output model(s) by including metadata fields in the response. @@ -6782,6 +6823,7 @@ def list_trigger_executions( sort_by=sort_by, page=page, size=size, + user=user, logical_operator=logical_operator, ) filter_model.set_scope_workspace(self.active_workspace.id) diff --git a/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py b/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py index 0b7b01b546d..52b19af2afe 100644 --- a/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +++ b/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py @@ -33,7 +33,6 @@ from zenml.step_operators import BaseStepOperator if TYPE_CHECKING: - from zenml.config.base_settings import BaseSettings from zenml.config.step_run_info import StepRunInfo from zenml.models import PipelineDeploymentBase diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 1c4d2cccfb5..1b79696134a 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -436,7 +436,6 @@ class BaseFilter(BaseModel): le=PAGE_SIZE_MAXIMUM, description="Page size", ) - id: Optional[Union[UUID, str]] = Field( default=None, description="Id for this resource", @@ -491,13 +490,13 @@ def validate_sort_by(cls, value: Any) -> Any: ) value = column - if column in cls.FILTER_EXCLUDE_FIELDS: + if column in cls.CUSTOM_SORTING_OPTIONS: + return value + elif column in cls.FILTER_EXCLUDE_FIELDS: raise ValueError( f"This resource can not be sorted by this field: '{value}'" ) - elif column in cls.model_fields: - return value - elif column in cls.CUSTOM_SORTING_OPTIONS: + if column in cls.model_fields: return value else: raise ValueError( @@ -759,7 +758,7 @@ def offset(self) -> int: return self.size * (self.page - 1) def generate_filter( - self, table: Type[SQLModel] + self, table: Type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. @@ -779,7 +778,7 @@ def generate_filter( filters.append( column_filter.generate_query_conditions(table=table) ) - for custom_filter in self.get_custom_filters(): + for custom_filter in self.get_custom_filters(table): filters.append(custom_filter) if self.logical_operator == LogicalOperators.OR: return or_(False, *filters) @@ -788,12 +787,17 @@ def generate_filter( else: raise RuntimeError("No valid logical operator was supplied.") - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom filters. This can be overridden by subclasses to define custom filters that are not based on the columns of the underlying table. + Args: + table: The query table. + Returns: A list of custom filters. """ diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index f563b6dc81c..f5267f4840d 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -23,6 +23,7 @@ Optional, Type, TypeVar, + Union, ) from uuid import UUID @@ -151,16 +152,32 @@ class UserScopedFilter(BaseFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *BaseFilter.FILTER_EXCLUDE_FIELDS, + "user", "scope_user", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *BaseFilter.CLI_EXCLUDE_FIELDS, + "user_id", "scope_user", ] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *BaseFilter.CUSTOM_SORTING_OPTIONS, + "user", + ] + scope_user: Optional[UUID] = Field( default=None, description="The user to scope this query to.", ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, + description="UUID of the user that created the entity.", + union_mode="left_to_right", + ) + user: Optional[Union[UUID, str]] = Field( + default=None, + description="Name/ID of the user that created the entity.", + ) def set_scope_user(self, user_id: UUID) -> None: """Set the user that is performing the filtering to scope the response. @@ -170,6 +187,73 @@ def set_scope_user(self, user_id: UUID) -> None: """ self.scope_user = user_id + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: + """Get custom filters. + + Args: + table: The query table. + + Returns: + A list of custom filters. + """ + custom_filters = super().get_custom_filters(table) + + from sqlmodel import and_ + + from zenml.zen_stores.schemas import UserSchema + + if self.user: + user_filter = and_( + getattr(table, "user_id") == UserSchema.id, + self.generate_name_or_id_query_conditions( + value=self.user, + table=UserSchema, + additional_columns=["full_name"], + ), + ) + custom_filters.append(user_filter) + + return custom_filters + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, desc + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import UserSchema + + sort_by, operand = self.sorting_params + + if sort_by == "user": + column = UserSchema.name + + query = query.join( + UserSchema, getattr(table, "user_id") == UserSchema.id + ) + + if operand == SorterOps.ASCENDING: + query = query.order_by(asc(column)) + else: + query = query.order_by(desc(column)) + + return query + + return super().apply_sorting(query=query, table=table) + def apply_filter( self, query: AnyQuery, @@ -240,21 +324,37 @@ def workspace(self) -> "WorkspaceResponse": return self.get_metadata().workspace -class WorkspaceScopedFilter(BaseFilter): +class WorkspaceScopedFilter(UserScopedFilter): """Model to enable advanced scoping with workspace.""" FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *BaseFilter.FILTER_EXCLUDE_FIELDS, + *UserScopedFilter.FILTER_EXCLUDE_FIELDS, + "workspace", "scope_workspace", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *BaseFilter.CLI_EXCLUDE_FIELDS, + *UserScopedFilter.CLI_EXCLUDE_FIELDS, + "workspace_id", + "workspace", "scope_workspace", ] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *UserScopedFilter.CUSTOM_SORTING_OPTIONS, + "workspace", + ] scope_workspace: Optional[UUID] = Field( default=None, description="The workspace to scope this query to.", ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, + description="UUID of the workspace that this entity belongs to.", + union_mode="left_to_right", + ) + workspace: Optional[Union[UUID, str]] = Field( + default=None, + description="Name/ID of the workspace that this entity belongs to.", + ) def set_scope_workspace(self, workspace_id: UUID) -> None: """Set the workspace to scope this response. @@ -264,6 +364,35 @@ def set_scope_workspace(self, workspace_id: UUID) -> None: """ self.scope_workspace = workspace_id + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: + """Get custom filters. + + Args: + table: The query table. + + Returns: + A list of custom filters. + """ + custom_filters = super().get_custom_filters(table) + + from sqlmodel import and_ + + from zenml.zen_stores.schemas import WorkspaceSchema + + if self.workspace: + workspace_filter = and_( + getattr(table, "workspace_id") == WorkspaceSchema.id, + self.generate_name_or_id_query_conditions( + value=self.workspace, + table=WorkspaceSchema, + ), + ) + custom_filters.append(workspace_filter) + + return custom_filters + def apply_filter( self, query: AnyQuery, @@ -291,6 +420,44 @@ def apply_filter( return query + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, desc + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import WorkspaceSchema + + sort_by, operand = self.sorting_params + + if sort_by == "workspace": + column = WorkspaceSchema.name + + query = query.join( + WorkspaceSchema, + getattr(table, "workspace_id") == WorkspaceSchema.id, + ) + + if operand == SorterOps.ASCENDING: + query = query.order_by(asc(column)) + else: + query = query.order_by(desc(column)) + + return query + + return super().apply_sorting(query=query, table=table) + class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): """Model to enable advanced scoping with workspace and tagging.""" @@ -304,6 +471,11 @@ class WorkspaceScopedTaggableFilter(WorkspaceScopedFilter): "tag", ] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedFilter.CUSTOM_SORTING_OPTIONS, + "tag", + ] + def apply_filter( self, query: AnyQuery, @@ -330,15 +502,20 @@ def apply_filter( return query - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom tag filters. + Args: + table: The query table. + Returns: A list of custom filters. """ from zenml.zen_stores.schemas import TagSchema - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) if self.tag: custom_filters.append( self.generate_custom_query_conditions_for_column( @@ -347,3 +524,79 @@ def get_custom_filters(self) -> List["ColumnElement[bool]"]: ) return custom_filters + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + sort_by, operand = self.sorting_params + + if sort_by == "tag": + from sqlmodel import and_, asc, desc, func + + from zenml.enums import SorterOps, TaggableResourceTypes + from zenml.zen_stores.schemas import ( + ArtifactSchema, + ArtifactVersionSchema, + ModelSchema, + ModelVersionSchema, + PipelineRunSchema, + PipelineSchema, + RunTemplateSchema, + TagResourceSchema, + TagSchema, + ) + + resource_type_mapping = { + ArtifactSchema: TaggableResourceTypes.ARTIFACT, + ArtifactVersionSchema: TaggableResourceTypes.ARTIFACT_VERSION, + ModelSchema: TaggableResourceTypes.MODEL, + ModelVersionSchema: TaggableResourceTypes.MODEL_VERSION, + PipelineSchema: TaggableResourceTypes.PIPELINE, + PipelineRunSchema: TaggableResourceTypes.PIPELINE_RUN, + RunTemplateSchema: TaggableResourceTypes.RUN_TEMPLATE, + } + + query = ( + query.outerjoin( + TagResourceSchema, + and_( + table.id == TagResourceSchema.resource_id, + TagResourceSchema.resource_type + == resource_type_mapping[table], + ), + ) + .outerjoin(TagSchema, TagResourceSchema.tag_id == TagSchema.id) + .group_by(table.id) + ) + + if operand == SorterOps.ASCENDING: + query = query.order_by( + asc( + func.group_concat(TagSchema.name, ",").label( + "tags_list" + ) + ) + ) + else: + query = query.order_by( + desc( + func.group_concat(TagSchema.name, ",").label( + "tags_list" + ) + ) + ) + + return query + + return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index cd5089a3db4..a6998b92b3c 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -20,6 +20,8 @@ Dict, List, Optional, + Type, + TypeVar, Union, ) from uuid import UUID @@ -58,6 +60,10 @@ ) from zenml.models.v2.core.pipeline_run import PipelineRunResponse from zenml.models.v2.core.step_run import StepRunResponse + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + logger = get_logger(__name__) @@ -471,7 +477,6 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): "name", "only_unused", "has_custom_name", - "user", "model", "pipeline_run", "model_version_id", @@ -516,19 +521,10 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): description="Artifact store for this artifact", union_mode="left_to_right", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace for this artifact", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that produced this artifact", - union_mode="left_to_right", - ) model_version_id: Optional[Union[UUID, str]] = Field( default=None, - description="ID of the model version that is associated with this artifact version.", + description="ID of the model version that is associated with this " + "artifact version.", union_mode="left_to_right", ) only_unused: Optional[bool] = Field( @@ -559,13 +555,18 @@ class ArtifactVersionFilter(WorkspaceScopedTaggableFilter): model_config = ConfigDict(protected_namespaces=()) - def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List[Union["ColumnElement[bool]"]]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_, or_, select @@ -581,7 +582,6 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: StepRunInputArtifactSchema, StepRunOutputArtifactSchema, StepRunSchema, - UserSchema, ) if self.name: @@ -629,17 +629,6 @@ def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: ) custom_filters.append(custom_name_filter) - if self.user: - user_filter = and_( - ArtifactVersionSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.model: model_filter = and_( ArtifactVersionSchema.id diff --git a/src/zenml/models/v2/core/code_repository.py b/src/zenml/models/v2/core/code_repository.py index c0a5430468b..485f710b7de 100644 --- a/src/zenml/models/v2/core/code_repository.py +++ b/src/zenml/models/v2/core/code_repository.py @@ -13,8 +13,7 @@ # permissions and limitations under the License. """Models representing code repositories.""" -from typing import Any, Dict, Optional, Union -from uuid import UUID +from typing import Any, Dict, Optional from pydantic import Field @@ -189,13 +188,3 @@ class CodeRepositoryFilter(WorkspaceScopedFilter): description="Name of the code repository.", default=None, ) - workspace_id: Optional[Union[UUID, str]] = Field( - description="Workspace of the code repository.", - default=None, - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - description="User that created the code repository.", - default=None, - union_mode="left_to_right", - ) diff --git a/src/zenml/models/v2/core/component.py b/src/zenml/models/v2/core/component.py index a4f52be884c..98418589222 100644 --- a/src/zenml/models/v2/core/component.py +++ b/src/zenml/models/v2/core/component.py @@ -21,6 +21,7 @@ List, Optional, Type, + TypeVar, Union, ) from uuid import UUID @@ -42,9 +43,11 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import ColumnElement - from sqlmodel import SQLModel from zenml.models import FlavorResponse, ServiceConnectorResponse + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Base Model ------------------ @@ -356,7 +359,6 @@ class ComponentFilter(WorkspaceScopedFilter): *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "scope_type", "stack_id", - "user", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS, @@ -366,7 +368,6 @@ class ComponentFilter(WorkspaceScopedFilter): default=None, description="The type to scope this query to.", ) - name: Optional[str] = Field( default=None, description="Name of the stack component", @@ -379,16 +380,6 @@ class ComponentFilter(WorkspaceScopedFilter): default=None, description="Type of the stack component", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the stack component", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the stack component", - union_mode="left_to_right", - ) connector_id: Optional[Union[UUID, str]] = Field( default=None, description="Connector linked to the stack component", @@ -399,10 +390,6 @@ class ComponentFilter(WorkspaceScopedFilter): description="Stack of the stack component", union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the component.", - ) def set_scope_type(self, component_type: str) -> None: """Set the type of component on which to perform the filtering to scope the response. @@ -413,7 +400,7 @@ def set_scope_type(self, component_type: str) -> None: self.scope_type = component_type def generate_filter( - self, table: Type["SQLModel"] + self, table: Type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. @@ -449,31 +436,3 @@ def generate_filter( base_filter = operator(base_filter, stack_filter) return base_filter - - def get_custom_filters(self) -> List["ColumnElement[bool]"]: - """Get custom filters. - - Returns: - A list of custom filters. - """ - from sqlmodel import and_ - - from zenml.zen_stores.schemas import ( - StackComponentSchema, - UserSchema, - ) - - custom_filters = super().get_custom_filters() - - if self.user: - user_filter = and_( - StackComponentSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - - return custom_filters diff --git a/src/zenml/models/v2/core/flavor.py b/src/zenml/models/v2/core/flavor.py index fd4110300c3..77fe774c073 100644 --- a/src/zenml/models/v2/core/flavor.py +++ b/src/zenml/models/v2/core/flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing flavors.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional from uuid import UUID from pydantic import Field @@ -428,13 +428,3 @@ class FlavorFilter(WorkspaceScopedFilter): default=None, description="Integration associated with the flavor", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the stack", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the stack", - union_mode="left_to_right", - ) diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index 0eb3b749c88..0b5272ab7e6 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing models.""" -from typing import TYPE_CHECKING, ClassVar, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional from uuid import UUID from pydantic import BaseModel, Field @@ -30,8 +30,6 @@ from zenml.utils.pagination_utils import depaginate if TYPE_CHECKING: - from sqlalchemy.sql.elements import ColumnElement - from zenml.model.model import Model from zenml.models.v2.core.tag import TagResponse @@ -318,61 +316,7 @@ def versions(self) -> List["Model"]: class ModelFilter(WorkspaceScopedTaggableFilter): """Model to enable advanced filtering of all Workspaces.""" - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS, - "workspace_id", - "user_id", - ] - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, - "user", - ] - name: Optional[str] = Field( default=None, description="Name of the Model", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Model", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the Model", - union_mode="left_to_right", - ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the model.", - ) - - def get_custom_filters( - self, - ) -> List["ColumnElement[bool]"]: - """Get custom filters. - - Returns: - A list of custom filters. - """ - custom_filters = super().get_custom_filters() - - from sqlmodel import and_ - - from zenml.zen_stores.schemas import ( - ModelSchema, - UserSchema, - ) - - if self.user: - user_filter = and_( - ModelSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - - return custom_filters diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index d1a7a951978..949d9ce1d15 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -585,7 +585,6 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, - "user", "run_metadata", ] @@ -597,25 +596,11 @@ class ModelVersionFilter(WorkspaceScopedTaggableFilter): default=None, description="The number of the Model Version", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="The workspace of the Model Version", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="The user of the Model Version", - union_mode="left_to_right", - ) stage: Optional[Union[str, ModelStages]] = Field( description="The model version stage", default=None, union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the model version.", - ) run_metadata: Optional[Dict[str, str]] = Field( default=None, description="The run_metadata to filter the model versions by.", @@ -639,14 +624,17 @@ def set_scope_model(self, model_name_or_id: Union[str, UUID]) -> None: self._model_id = model_id def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ @@ -654,20 +642,8 @@ def get_custom_filters( ModelVersionSchema, RunMetadataResourceSchema, RunMetadataSchema, - UserSchema, ) - if self.user: - user_filter = and_( - ModelVersionSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.run_metadata is not None: from zenml.enums import MetadataResourceTypes diff --git a/src/zenml/models/v2/core/model_version_artifact.py b/src/zenml/models/v2/core/model_version_artifact.py index f3a677a86e9..6c9514b9735 100644 --- a/src/zenml/models/v2/core/model_version_artifact.py +++ b/src/zenml/models/v2/core/model_version_artifact.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing the link between model versions and artifacts.""" -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Union from uuid import UUID from pydantic import ConfigDict, Field @@ -32,6 +32,9 @@ from sqlalchemy.sql.elements import ColumnElement from zenml.models.v2.core.artifact_version import ArtifactVersionResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -164,13 +167,18 @@ class ModelVersionArtifactFilter(BaseFilter): # careful we might overwrite some fields protected by pydantic. model_config = ConfigDict(protected_namespaces=()) - def get_custom_filters(self) -> List[Union["ColumnElement[bool]"]]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List[Union["ColumnElement[bool]"]]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_, col diff --git a/src/zenml/models/v2/core/model_version_pipeline_run.py b/src/zenml/models/v2/core/model_version_pipeline_run.py index 6181c2ffbb1..40e7f823d9c 100644 --- a/src/zenml/models/v2/core/model_version_pipeline_run.py +++ b/src/zenml/models/v2/core/model_version_pipeline_run.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing the link between model versions and pipeline runs.""" -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Union from uuid import UUID from pydantic import ConfigDict, Field @@ -30,6 +30,12 @@ from zenml.models.v2.base.filter import BaseFilter, StrFilter from zenml.models.v2.core.pipeline_run import PipelineRunResponse +if TYPE_CHECKING: + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + + # ------------------ Request Model ------------------ @@ -147,13 +153,18 @@ class ModelVersionPipelineRunFilter(BaseFilter): # careful we might overwrite some fields protected by pydantic. model_config = ConfigDict(protected_namespaces=()) - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 5166e0abb9c..03a81fbb23c 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -21,7 +21,6 @@ Optional, Type, TypeVar, - Union, ) from uuid import UUID @@ -45,8 +44,6 @@ from zenml.models.v2.core.tag import TagResponse if TYPE_CHECKING: - from sqlalchemy.sql.elements import ColumnElement - from zenml.models.v2.core.pipeline_run import PipelineRunResponse from zenml.zen_stores.schemas import BaseSchema @@ -258,10 +255,12 @@ def tags(self) -> List[TagResponse]: class PipelineFilter(WorkspaceScopedTaggableFilter): """Pipeline filter model.""" - CUSTOM_SORTING_OPTIONS = [SORT_PIPELINES_BY_LATEST_RUN_KEY] + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + SORT_PIPELINES_BY_LATEST_RUN_KEY, + ] FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, - "user", "latest_run_status", ] @@ -274,20 +273,6 @@ class PipelineFilter(WorkspaceScopedTaggableFilter): description="Filter by the status of the latest run of a pipeline. " "This will always be applied as an `AND` filter for now.", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Pipeline", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the Pipeline", - union_mode="left_to_right", - ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the pipeline.", - ) def apply_filter( self, query: AnyQuery, table: Type["AnySchema"] @@ -343,36 +328,6 @@ def apply_filter( return query - def get_custom_filters( - self, - ) -> List["ColumnElement[bool]"]: - """Get custom filters. - - Returns: - A list of custom filters. - """ - custom_filters = super().get_custom_filters() - - from sqlmodel import and_ - - from zenml.zen_stores.schemas import ( - PipelineSchema, - UserSchema, - ) - - if self.user: - user_filter = and_( - PipelineSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - - return custom_filters - def apply_sorting( self, query: AnyQuery, @@ -387,12 +342,45 @@ def apply_sorting( Returns: The query with sorting applied. """ - column, _ = self.sorting_params + from sqlmodel import asc, case, col, desc, func, select + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import PipelineRunSchema, PipelineSchema + + sort_by, operand = self.sorting_params + + if sort_by == SORT_PIPELINES_BY_LATEST_RUN_KEY: + # Subquery to find the latest run per pipeline + latest_run_subquery = ( + select( + PipelineRunSchema.pipeline_id, + case( + ( + func.max(PipelineRunSchema.created).is_(None), + PipelineSchema.created, + ), + else_=func.max(PipelineRunSchema.created), + ).label("latest_run"), + ) + .group_by(col(PipelineRunSchema.pipeline_id)) + .subquery() + ) + + # Join the subquery with the pipelines + query = query.outerjoin( + latest_run_subquery, + PipelineSchema.id == latest_run_subquery.c.pipeline_id, + ) + + if operand == SorterOps.ASCENDING: + query = query.order_by( + asc(latest_run_subquery.c.latest_run) + ).order_by(col(PipelineSchema.id)) + else: + query = query.order_by( + desc(latest_run_subquery.c.latest_run) + ).order_by(col(PipelineSchema.id)) - if column == SORT_PIPELINES_BY_LATEST_RUN_KEY: - # If sorting by the latest run, the sorting is already done in the - # base query in `SqlZenStore.list_pipelines(...)` and we don't need - # to to anything here return query else: return super().apply_sorting(query=query, table=table) diff --git a/src/zenml/models/v2/core/pipeline_build.py b/src/zenml/models/v2/core/pipeline_build.py index 93c0ff63a8c..19dc89ccbf0 100644 --- a/src/zenml/models/v2/core/pipeline_build.py +++ b/src/zenml/models/v2/core/pipeline_build.py @@ -14,7 +14,17 @@ """Models representing pipeline builds.""" import json -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field @@ -35,6 +45,9 @@ from zenml.models.v2.core.pipeline import PipelineResponse from zenml.models.v2.core.stack import StackResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -453,16 +466,6 @@ class PipelineBuildFilter(WorkspaceScopedFilter): "container_registry_id", ] - workspace_id: Optional[Union[UUID, str]] = Field( - description="Workspace for this pipeline build.", - default=None, - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - description="User that produced this pipeline build.", - default=None, - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( description="Pipeline associated with the pipeline build.", default=None, @@ -502,13 +505,17 @@ class PipelineBuildFilter(WorkspaceScopedFilter): def get_custom_filters( self, + table: Type["AnySchema"], ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ diff --git a/src/zenml/models/v2/core/pipeline_deployment.py b/src/zenml/models/v2/core/pipeline_deployment.py index 760f65f1a35..94dbc431507 100644 --- a/src/zenml/models/v2/core/pipeline_deployment.py +++ b/src/zenml/models/v2/core/pipeline_deployment.py @@ -358,16 +358,6 @@ def template_id(self) -> Optional[UUID]: class PipelineDeploymentFilter(WorkspaceScopedFilter): """Model to enable advanced filtering of all pipeline deployments.""" - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace for this deployment.", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created this deployment.", - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( default=None, description="Pipeline associated with the deployment.", diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 958d662a515..3a22f642953 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -16,10 +16,13 @@ from datetime import datetime from typing import ( TYPE_CHECKING, + Any, ClassVar, Dict, List, Optional, + Type, + TypeVar, Union, cast, ) @@ -55,6 +58,11 @@ from zenml.models.v2.core.schedule import ScheduleResponse from zenml.models.v2.core.stack import StackResponse from zenml.models.v2.core.step_run import StepRunResponse + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + +AnyQuery = TypeVar("AnyQuery", bound=Any) # ------------------ Request Model ------------------ @@ -584,6 +592,15 @@ def tags(self) -> List[TagResponse]: class PipelineRunFilter(WorkspaceScopedTaggableFilter): """Model to enable advanced filtering of all Workspaces.""" + CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CUSTOM_SORTING_OPTIONS, + "tag", + "stack", + "pipeline", + "model", + "model_version", + ] + FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedTaggableFilter.FILTER_EXCLUDE_FIELDS, "unlisted", @@ -592,7 +609,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): "schedule_id", "stack_id", "template_id", - "user", "pipeline", "stack", "code_repository", @@ -615,16 +631,6 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): description="Pipeline associated with the Pipeline Run", union_mode="left_to_right", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Pipeline Run", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the Pipeline Run", - union_mode="left_to_right", - ) stack_id: Optional[Union[UUID, str]] = Field( default=None, description="Stack used for the Pipeline Run", @@ -675,16 +681,12 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): union_mode="left_to_right", ) unlisted: Optional[bool] = None - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the run.", - ) run_metadata: Optional[Dict[str, str]] = Field( default=None, description="The run_metadata to filter the pipeline runs by.", ) # TODO: Remove once frontend is ready for it. This is replaced by the more - # generic `pipeline` filter below. + # generic `pipeline` filter below. pipeline_name: Optional[str] = Field( default=None, description="Name of the pipeline associated with the run", @@ -716,13 +718,17 @@ class PipelineRunFilter(WorkspaceScopedTaggableFilter): def get_custom_filters( self, + table: Type["AnySchema"], ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_, col, or_ @@ -741,7 +747,6 @@ def get_custom_filters( StackComponentSchema, StackCompositionSchema, StackSchema, - UserSchema, ) if self.unlisted is not None: @@ -792,17 +797,6 @@ def get_custom_filters( ) custom_filters.append(run_template_filter) - if self.user: - user_filter = and_( - PipelineRunSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.pipeline: pipeline_filter = and_( PipelineRunSchema.pipeline_id == PipelineSchema.id, @@ -926,3 +920,71 @@ def get_custom_filters( custom_filters.append(additional_filter) return custom_filters + + def apply_sorting( + self, + query: AnyQuery, + table: Type["AnySchema"], + ) -> AnyQuery: + """Apply sorting to the query. + + Args: + query: The query to which to apply the sorting. + table: The query table. + + Returns: + The query with sorting applied. + """ + from sqlmodel import asc, desc + + from zenml.enums import SorterOps + from zenml.zen_stores.schemas import ( + ModelSchema, + ModelVersionSchema, + PipelineDeploymentSchema, + PipelineRunSchema, + PipelineSchema, + StackSchema, + ) + + sort_by, operand = self.sorting_params + + if sort_by == "pipeline": + query = query.join( + PipelineSchema, + PipelineRunSchema.pipeline_id == PipelineSchema.id, + ) + column = PipelineSchema.name + elif sort_by == "stack": + query = query.join( + PipelineDeploymentSchema, + PipelineRunSchema.deployment_id == PipelineDeploymentSchema.id, + ).join( + StackSchema, + PipelineDeploymentSchema.stack_id == StackSchema.id, + ) + column = StackSchema.name + elif sort_by == "model": + query = query.join( + ModelVersionSchema, + PipelineRunSchema.model_version_id == ModelVersionSchema.id, + ).join( + ModelSchema, + ModelVersionSchema.model_id == ModelSchema.id, + ) + column = ModelSchema.name + elif sort_by == "model_version": + query = query.join( + ModelVersionSchema, + PipelineRunSchema.model_version_id == ModelVersionSchema.id, + ) + column = ModelVersionSchema.name + else: + return super().apply_sorting(query=query, table=table) + + if operand == SorterOps.ASCENDING: + query = query.order_by(asc(column)) + else: + query = query.order_by(desc(column)) + + return query diff --git a/src/zenml/models/v2/core/run_template.py b/src/zenml/models/v2/core/run_template.py index b1aae8a325a..2bc177c043e 100644 --- a/src/zenml/models/v2/core/run_template.py +++ b/src/zenml/models/v2/core/run_template.py @@ -13,7 +13,17 @@ # permissions and limitations under the License. """Models representing pipeline templates.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field @@ -45,6 +55,11 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import ColumnElement + from zenml.zen_stores.schemas.base_schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + + # ------------------ Request Model ------------------ @@ -310,16 +325,6 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter): default=None, description="Name of the run template.", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace associated with the template.", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the template.", - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( default=None, description="Pipeline associated with the template.", @@ -340,10 +345,6 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter): description="Code repository associated with the template.", union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the template.", - ) pipeline: Optional[Union[UUID, str]] = Field( default=None, description="Name/ID of the pipeline associated with the template.", @@ -354,14 +355,17 @@ class RunTemplateFilter(WorkspaceScopedTaggableFilter): ) def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ @@ -371,7 +375,6 @@ def get_custom_filters( PipelineSchema, RunTemplateSchema, StackSchema, - UserSchema, ) if self.code_repository_id: @@ -409,17 +412,6 @@ def get_custom_filters( ) custom_filters.append(pipeline_filter) - if self.user: - user_filter = and_( - RunTemplateSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.pipeline: pipeline_filter = and_( RunTemplateSchema.source_deployment_id diff --git a/src/zenml/models/v2/core/schedule.py b/src/zenml/models/v2/core/schedule.py index af838f17ccc..0e7dc01c421 100644 --- a/src/zenml/models/v2/core/schedule.py +++ b/src/zenml/models/v2/core/schedule.py @@ -279,16 +279,6 @@ def pipeline_id(self) -> Optional[UUID]: class ScheduleFilter(WorkspaceScopedFilter): """Model to enable advanced filtering of all Users.""" - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace scope of the schedule.", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the schedule", - union_mode="left_to_right", - ) pipeline_id: Optional[Union[UUID, str]] = Field( default=None, description="Pipeline that the schedule is attached to.", diff --git a/src/zenml/models/v2/core/secret.py b/src/zenml/models/v2/core/secret.py index 79e50cd1841..3f29b57de22 100644 --- a/src/zenml/models/v2/core/secret.py +++ b/src/zenml/models/v2/core/secret.py @@ -15,7 +15,6 @@ from datetime import datetime from typing import Any, ClassVar, Dict, List, Optional, Union -from uuid import UUID from pydantic import Field, SecretStr @@ -253,25 +252,12 @@ class SecretFilter(WorkspaceScopedFilter): default=None, description="Name of the secret", ) - scope: Optional[Union[SecretScope, str]] = Field( default=None, description="Scope in which to filter secrets", union_mode="left_to_right", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the Secret", - union_mode="left_to_right", - ) - - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that created the Secret", - union_mode="left_to_right", - ) - @staticmethod def _get_filtering_value(value: Optional[Any]) -> str: """Convert the value to a string that can be used for lexicographical filtering and sorting. diff --git a/src/zenml/models/v2/core/service.py b/src/zenml/models/v2/core/service.py index c3dcbd7cfc8..2ad9724b20a 100644 --- a/src/zenml/models/v2/core/service.py +++ b/src/zenml/models/v2/core/service.py @@ -15,19 +15,20 @@ from datetime import datetime from typing import ( + TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Type, + TypeVar, Union, ) from uuid import UUID from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.sql.elements import ColumnElement -from sqlmodel import SQLModel from zenml.constants import STR_FIELD_MAX_LENGTH from zenml.models.v2.base.scoped import ( @@ -37,11 +38,15 @@ WorkspaceScopedResponseBody, WorkspaceScopedResponseMetadata, WorkspaceScopedResponseResources, - WorkspaceScopedTaggableFilter, ) from zenml.services.service_status import ServiceState from zenml.services.service_type import ServiceType +if TYPE_CHECKING: + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) + # ------------------ Request Model ------------------ @@ -376,16 +381,6 @@ class ServiceFilter(WorkspaceScopedFilter): description="Name of the service. Use this to filter services by " "their name.", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the service", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the service", - union_mode="left_to_right", - ) type: Optional[str] = Field( default=None, description="Type of the service. Filter services by their type.", @@ -457,9 +452,7 @@ def set_flavor(self, flavor: str) -> None: "config", ] CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ - *WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS, - "workspace_id", - "user_id", + *WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS, "flavor", "type", "pipeline_step_name", @@ -468,7 +461,7 @@ def set_flavor(self, flavor: str) -> None: ] def generate_filter( - self, table: Type["SQLModel"] + self, table: Type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. diff --git a/src/zenml/models/v2/core/service_connector.py b/src/zenml/models/v2/core/service_connector.py index 806e6100072..8c71106ae22 100644 --- a/src/zenml/models/v2/core/service_connector.py +++ b/src/zenml/models/v2/core/service_connector.py @@ -801,7 +801,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter): default=None, description="The type to scope this query to.", ) - name: Optional[str] = Field( default=None, description="The name to filter by", @@ -810,16 +809,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter): default=None, description="The type of service connector to filter by", ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace to filter by", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User to filter by", - union_mode="left_to_right", - ) auth_method: Optional[str] = Field( default=None, title="Filter by the authentication method configured for the " diff --git a/src/zenml/models/v2/core/stack.py b/src/zenml/models/v2/core/stack.py index 3d8ad20a2c1..1e49eb1544b 100644 --- a/src/zenml/models/v2/core/stack.py +++ b/src/zenml/models/v2/core/stack.py @@ -14,7 +14,17 @@ """Models representing stacks.""" import json -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field, model_validator @@ -39,6 +49,9 @@ from sqlalchemy.sql.elements import ColumnElement from zenml.models.v2.core.component import ComponentResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -323,7 +336,6 @@ class StackFilter(WorkspaceScopedFilter): FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, "component_id", - "user", "component", ] @@ -334,42 +346,32 @@ class StackFilter(WorkspaceScopedFilter): description: Optional[str] = Field( default=None, description="Description of the stack" ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of the stack", - union_mode="left_to_right", - ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User of the stack", - union_mode="left_to_right", - ) component_id: Optional[Union[UUID, str]] = Field( default=None, description="Component in the stack", union_mode="left_to_right", ) - user: Optional[Union[UUID, str]] = Field( - default=None, - description="Name/ID of the user that created the stack.", - ) component: Optional[Union[UUID, str]] = Field( default=None, description="Name/ID of a component in the stack." ) - def get_custom_filters(self) -> List["ColumnElement[bool]"]: + def get_custom_filters( + self, table: Type["AnySchema"] + ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from zenml.zen_stores.schemas import ( StackComponentSchema, StackCompositionSchema, StackSchema, - UserSchema, ) if self.component_id: @@ -379,17 +381,6 @@ def get_custom_filters(self) -> List["ColumnElement[bool]"]: ) custom_filters.append(component_id_filter) - if self.user: - user_filter = and_( - StackSchema.user_id == UserSchema.id, - self.generate_name_or_id_query_conditions( - value=self.user, - table=UserSchema, - additional_columns=["full_name"], - ), - ) - custom_filters.append(user_filter) - if self.component: component_filter = and_( StackCompositionSchema.stack_id == StackSchema.id, diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index d9ac5e0354a..0a505539d07 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -14,7 +14,16 @@ """Models representing steps runs.""" from datetime import datetime -from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import BaseModel, ConfigDict, Field @@ -41,6 +50,9 @@ LogsRequest, LogsResponse, ) + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) class StepRunInputResponse(ArtifactVersionResponse): @@ -553,16 +565,6 @@ class StepRunFilter(WorkspaceScopedFilter): description="Original id for this step run", union_mode="left_to_right", ) - user_id: Optional[Union[UUID, str]] = Field( - default=None, - description="User that produced this step run", - union_mode="left_to_right", - ) - workspace_id: Optional[Union[UUID, str]] = Field( - default=None, - description="Workspace of this step run", - union_mode="left_to_right", - ) model_version_id: Optional[Union[UUID, str]] = Field( default=None, description="Model version associated with the step run.", @@ -576,18 +578,20 @@ class StepRunFilter(WorkspaceScopedFilter): default=None, description="The run_metadata to filter the step runs by.", ) - model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) from sqlmodel import and_ diff --git a/src/zenml/models/v2/core/trigger.py b/src/zenml/models/v2/core/trigger.py index daef211ed7b..45fc23a501c 100644 --- a/src/zenml/models/v2/core/trigger.py +++ b/src/zenml/models/v2/core/trigger.py @@ -13,7 +13,17 @@ # permissions and limitations under the License. """Collection of all models concerning triggers.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + TypeVar, + Union, +) from uuid import UUID from pydantic import Field, model_validator @@ -39,6 +49,9 @@ ActionResponse, ) from zenml.models.v2.core.event_source import EventSourceResponse + from zenml.zen_stores.schemas import BaseSchema + + AnySchema = TypeVar("AnySchema", bound=BaseSchema) # ------------------ Request Model ------------------ @@ -358,10 +371,13 @@ class TriggerFilter(WorkspaceScopedFilter): ) def get_custom_filters( - self, + self, table: Type["AnySchema"] ) -> List["ColumnElement[bool]"]: """Get custom filters. + Args: + table: The query table. + Returns: A list of custom filters. """ @@ -373,7 +389,7 @@ def get_custom_filters( TriggerSchema, ) - custom_filters = super().get_custom_filters() + custom_filters = super().get_custom_filters(table) if self.event_source_flavor: event_source_flavor_filter = and_( diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 464293515b3..ce20d6687f6 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -55,7 +55,7 @@ field_validator, model_validator, ) -from sqlalchemy import asc, case, desc, func +from sqlalchemy import func from sqlalchemy.engine import URL, Engine, make_url from sqlalchemy.exc import ( ArgumentError, @@ -100,7 +100,6 @@ ENV_ZENML_SERVER, FINISHED_ONBOARDING_SURVEY_KEY, MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION, - SORT_PIPELINES_BY_LATEST_RUN_KEY, SQL_STORE_BACKUP_DIRECTORY_NAME, TEXT_FIELD_MAX_LENGTH, handle_bool_env_var, @@ -117,7 +116,6 @@ OnboardingStep, SecretScope, SecretsStoreType, - SorterOps, StackComponentType, StackDeploymentProvider, StepRunInputArtifactType, @@ -4358,69 +4356,14 @@ def list_pipelines( Returns: A list of all pipelines matching the filter criteria. """ - query: Union[Select[Any], SelectOfScalar[Any]] = select(PipelineSchema) - _custom_conversion: Optional[Callable[[Any], PipelineResponse]] = None - - column, operand = pipeline_filter_model.sorting_params - if column == SORT_PIPELINES_BY_LATEST_RUN_KEY: - with Session(self.engine) as session: - max_date_subquery = ( - # If no run exists for the pipeline yet, we use the pipeline - # creation date as a fallback, otherwise newly created - # pipeline would always be at the top/bottom - select( - PipelineSchema.id, - case( - ( - func.max(PipelineRunSchema.created).is_(None), - PipelineSchema.created, - ), - else_=func.max(PipelineRunSchema.created), - ).label("run_or_created"), - ) - .outerjoin( - PipelineRunSchema, - PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type] - ) - .group_by(col(PipelineSchema.id)) - .subquery() - ) - - if operand == SorterOps.DESCENDING: - sort_clause = desc - else: - sort_clause = asc - - query = ( - # We need to include the subquery in the select here to - # make this query work with the distinct statement. This - # result will be removed in the custom conversion function - # applied later - select(PipelineSchema, max_date_subquery.c.run_or_created) - .where(PipelineSchema.id == max_date_subquery.c.id) - .order_by(sort_clause(max_date_subquery.c.run_or_created)) - # We always add the `id` column as a tiebreaker to ensure a - # stable, repeatable order of items, otherwise subsequent - # pages might contain the same items. - .order_by(col(PipelineSchema.id)) - ) - - def _custom_conversion(row: Any) -> PipelineResponse: - return cast( - PipelineResponse, - row[0].to_model( - include_metadata=hydrate, include_resources=True - ), - ) - with Session(self.engine) as session: + query = select(PipelineSchema) return self.filter_and_paginate( session=session, query=query, table=PipelineSchema, filter_model=pipeline_filter_model, hydrate=hydrate, - custom_schema_to_model_conversion=_custom_conversion, ) def count_pipelines(self, filter_model: Optional[PipelineFilter]) -> int: