Skip to content

Commit

Permalink
fix(db): add missing constraints and relationships in TaskJob table (
Browse files Browse the repository at this point in the history
…#1872)

Merge pull request #1872 from AntaresSimulatorTeam/feature/perf-db-taskjob-table-constraints
  • Loading branch information
laurent-laporte-pro authored Dec 22, 2023
2 parents 34c97e0 + 3e8dd93 commit 939a35b
Show file tree
Hide file tree
Showing 11 changed files with 453 additions and 156 deletions.
113 changes: 113 additions & 0 deletions alembic/versions/782a481f3414_fix_task_job_cascade_delete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""fix-task_job_cascade_delete
Revision ID: 782a481f3414
Revises: d495746853cc
Create Date: 2023-12-16 14:26:30.035324
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "782a481f3414"
down_revision = "d495746853cc"
branch_labels = None
depends_on = None


def upgrade():
# Delete logs of tasks older than one week
op.execute(
"""
DELETE FROM taskjoblog
WHERE task_id IN (SELECT id FROM taskjob WHERE NOW() - creation_date > INTERVAL '1 week');
"""
)

# Delete tasks older than one week
op.execute(""" DELETE FROM taskjob WHERE NOW() - creation_date > INTERVAL '1 week'; """)

# Set the name "Unknown task" to tasks that have no name
op.execute(""" UPDATE taskjob SET name = 'Unknown task' WHERE name IS NULL OR name = ''; """)

# Attach the user "admin" to tasks that have no user
op.execute(""" UPDATE taskjob SET owner_id = 1 WHERE owner_id NOT IN (SELECT id FROM identities); """)

# Delete logs of tasks that reference a study that has been deleted
op.execute(
"""
DELETE FROM taskjoblog tjl
WHERE
tjl.task_id IN (
SELECT
t.id
FROM
taskjob t
WHERE
t.ref_id IS NOT NULL
AND t.ref_id NOT IN (SELECT s.id FROM study s));
"""
)

# Delete tasks that reference a study that has been deleted (long query)
op.execute(
"""
DELETE FROM taskjob t
WHERE
t.ref_id IS NOT NULL
AND t.ref_id NOT IN (SELECT s.id FROM study s);
"""
)

# Delete logs of tasks whose task_id is NULL
op.execute(""" DELETE FROM taskjoblog WHERE task_id IS NULL; """)

# Set the status "CANCELLED" to tasks whose status is not in the list of possible values
op.execute(""" UPDATE taskjob SET status = 6 WHERE status NOT IN (1, 2, 3, 4, 5, 6); """)

# Set the type "VARIANT_GENERATION" to tasks whose type is NULL
op.execute(""" UPDATE taskjob SET type = 'VARIANT_GENERATION' WHERE type IS NULL AND name LIKE '%Generation%'; """)

# Set the type "EXPORT" to tasks whose type is NULL
op.execute(""" UPDATE taskjob SET type = 'EXPORT' WHERE type IS NULL AND name LIKE '%export%'; """)

# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("taskjoblog", schema=None) as batch_op:
batch_op.alter_column("task_id", existing_type=sa.VARCHAR(), nullable=False)
batch_op.drop_constraint("fk_log_taskjob_id", type_="foreignkey")
batch_op.create_foreign_key("fk_log_taskjob_id", "taskjob", ["task_id"], ["id"], ondelete="CASCADE")

with op.batch_alter_table('taskjob', schema=None) as batch_op:
batch_op.alter_column('name', existing_type=sa.VARCHAR(), nullable=False)
batch_op.create_index(batch_op.f('ix_taskjob_creation_date'), ['creation_date'], unique=False)
batch_op.create_index(batch_op.f('ix_taskjob_name'), ['name'], unique=False)
batch_op.create_index(batch_op.f('ix_taskjob_owner_id'), ['owner_id'], unique=False)
batch_op.create_index(batch_op.f('ix_taskjob_ref_id'), ['ref_id'], unique=False)
batch_op.create_index(batch_op.f('ix_taskjob_status'), ['status'], unique=False)
batch_op.create_index(batch_op.f('ix_taskjob_type'), ['type'], unique=False)
batch_op.create_foreign_key('fk_taskjob_identity_id', 'identities', ['owner_id'], ['id'], ondelete='SET NULL')
batch_op.create_foreign_key('fk_taskjob_study_id', 'study', ['ref_id'], ['id'], ondelete='CASCADE')

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('taskjob', schema=None) as batch_op:
batch_op.drop_constraint('fk_taskjob_study_id', type_='foreignkey')
batch_op.drop_constraint('fk_taskjob_identity_id', type_='foreignkey')
batch_op.drop_index(batch_op.f('ix_taskjob_type'))
batch_op.drop_index(batch_op.f('ix_taskjob_status'))
batch_op.drop_index(batch_op.f('ix_taskjob_ref_id'))
batch_op.drop_index(batch_op.f('ix_taskjob_owner_id'))
batch_op.drop_index(batch_op.f('ix_taskjob_name'))
batch_op.drop_index(batch_op.f('ix_taskjob_creation_date'))
batch_op.alter_column('name', existing_type=sa.VARCHAR(), nullable=True)

with op.batch_alter_table("taskjoblog", schema=None) as batch_op:
batch_op.drop_constraint("fk_log_taskjob_id", type_="foreignkey")
batch_op.create_foreign_key("fk_log_taskjob_id", "taskjob", ["task_id"], ["id"])
batch_op.alter_column("task_id", existing_type=sa.VARCHAR(), nullable=True)

# ### end Alembic commands ###
111 changes: 72 additions & 39 deletions antarest/core/tasks/model.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import typing as t
import uuid
from datetime import datetime
from enum import Enum
from typing import Any, List, Mapping, Optional

from pydantic import BaseModel, Extra
from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, Sequence, String # type: ignore
from sqlalchemy.engine.base import Engine # type: ignore
from sqlalchemy.orm import Session, relationship, sessionmaker # type: ignore
from sqlalchemy.orm import relationship, sessionmaker # type: ignore

from antarest.core.persistence import Base

if t.TYPE_CHECKING:
# avoid circular import
from antarest.login.model import Identity
from antarest.study.model import Study


class TaskType(str, Enum):
EXPORT = "EXPORT"
Expand Down Expand Up @@ -43,7 +48,7 @@ class TaskResult(BaseModel, extra=Extra.forbid):
success: bool
message: str
# Can be used to store json serialized result
return_value: Optional[str]
return_value: t.Optional[str]


class TaskLogDTO(BaseModel, extra=Extra.forbid):
Expand All @@ -65,25 +70,25 @@ class TaskEventPayload(BaseModel, extra=Extra.forbid):
class TaskDTO(BaseModel, extra=Extra.forbid):
id: str
name: str
owner: Optional[int]
owner: t.Optional[int]
status: TaskStatus
creation_date_utc: str
completion_date_utc: Optional[str]
result: Optional[TaskResult]
logs: Optional[List[TaskLogDTO]]
type: Optional[str] = None
ref_id: Optional[str] = None
completion_date_utc: t.Optional[str]
result: t.Optional[TaskResult]
logs: t.Optional[t.List[TaskLogDTO]]
type: t.Optional[str] = None
ref_id: t.Optional[str] = None


class TaskListFilter(BaseModel, extra=Extra.forbid):
status: List[TaskStatus] = []
name: Optional[str] = None
type: List[TaskType] = []
ref_id: Optional[str] = None
from_creation_date_utc: Optional[float] = None
to_creation_date_utc: Optional[float] = None
from_completion_date_utc: Optional[float] = None
to_completion_date_utc: Optional[float] = None
status: t.List[TaskStatus] = []
name: t.Optional[str] = None
type: t.List[TaskType] = []
ref_id: t.Optional[str] = None
from_creation_date_utc: t.Optional[float] = None
to_creation_date_utc: t.Optional[float] = None
from_completion_date_utc: t.Optional[float] = None
to_completion_date_utc: t.Optional[float] = None


class TaskJobLog(Base): # type: ignore
Expand All @@ -93,10 +98,15 @@ class TaskJobLog(Base): # type: ignore
message = Column(String, nullable=False)
task_id = Column(
String(),
ForeignKey("taskjob.id", name="fk_log_taskjob_id"),
ForeignKey("taskjob.id", name="fk_log_taskjob_id", ondelete="CASCADE"),
nullable=False,
)

def __eq__(self, other: Any) -> bool:
# Define a many-to-one relationship between `TaskJobLog` and `TaskJob`.
# If the TaskJob is deleted, all attached logs must also be deleted in cascade.
job: "TaskJob" = relationship("TaskJob", back_populates="logs", uselist=False)

def __eq__(self, other: t.Any) -> bool:
if not isinstance(other, TaskJobLog):
return False
return bool(other.id == self.id and other.message == self.message and other.task_id == self.task_id)
Expand All @@ -111,19 +121,41 @@ def to_dto(self) -> TaskLogDTO:
class TaskJob(Base): # type: ignore
__tablename__ = "taskjob"

id = Column(String(), default=lambda: str(uuid.uuid4()), primary_key=True)
name = Column(String())
status = Column(Integer(), default=lambda: TaskStatus.PENDING.value)
creation_date = Column(DateTime, default=datetime.utcnow)
completion_date = Column(DateTime, nullable=True)
result_msg = Column(String(), nullable=True)
result = Column(String(), nullable=True)
result_status = Column(Boolean(), nullable=True)
logs = relationship(TaskJobLog, uselist=True, cascade="all, delete, delete-orphan")
# this is not a foreign key to prevent the need to delete the job history if the user is deleted
owner_id = Column(Integer(), nullable=True)
type = Column(String(), nullable=True)
ref_id = Column(String(), nullable=True)
id: str = Column(String(), default=lambda: str(uuid.uuid4()), primary_key=True)
name: str = Column(String(), nullable=False, index=True)
status: int = Column(Integer(), default=lambda: TaskStatus.PENDING.value, index=True)
creation_date: datetime = Column(DateTime, default=datetime.utcnow, index=True)
completion_date: t.Optional[datetime] = Column(DateTime, nullable=True, default=None)
result_msg: t.Optional[str] = Column(String(), nullable=True, default=None)
result: t.Optional[str] = Column(String(), nullable=True, default=None)
result_status: t.Optional[bool] = Column(Boolean(), nullable=True, default=None)
type: t.Optional[str] = Column(String(), nullable=True, default=None, index=True)
owner_id: int = Column(
Integer(),
ForeignKey("identities.id", name="fk_taskjob_identity_id", ondelete="SET NULL"),
nullable=True,
default=None,
index=True,
)
ref_id: t.Optional[str] = Column(
String(),
ForeignKey("study.id", name="fk_taskjob_study_id", ondelete="CASCADE"),
nullable=True,
default=None,
index=True,
)

# Define a one-to-many relationship between `TaskJob` and `TaskJobLog`.
# If the TaskJob is deleted, all attached logs must also be deleted in cascade.
logs: t.List["TaskJobLog"] = relationship("TaskJobLog", back_populates="job", cascade="all, delete, delete-orphan")

# Define a many-to-one relationship between `TaskJob` and `Identity`.
# If the Identity is deleted, all attached TaskJob must be preserved.
owner: "Identity" = relationship("Identity", back_populates="owned_jobs", uselist=False)

# Define a many-to-one relationship between `TaskJob` and `Study`.
# If the Study is deleted, all attached TaskJob must be deleted in cascade.
study: "Study" = relationship("Study", back_populates="jobs", uselist=False)

def to_dto(self, with_logs: bool = False) -> TaskDTO:
return TaskDTO(
Expand All @@ -140,12 +172,12 @@ def to_dto(self, with_logs: bool = False) -> TaskDTO:
)
if self.completion_date
else None,
logs=sorted([log.to_dto() for log in self.logs], key=lambda l: l.id) if with_logs else None,
logs=sorted([log.to_dto() for log in self.logs], key=lambda log: log.id) if with_logs else None,
type=self.type,
ref_id=self.ref_id,
)

def __eq__(self, other: Any) -> bool:
def __eq__(self, other: t.Any) -> bool:
if not isinstance(other, TaskJob):
return False
return bool(
Expand Down Expand Up @@ -174,7 +206,7 @@ def __repr__(self) -> str:
)


def cancel_orphan_tasks(engine: Engine, session_args: Mapping[str, bool]) -> None:
def cancel_orphan_tasks(engine: Engine, session_args: t.Mapping[str, bool]) -> None:
"""
Cancel all tasks that are currently running or pending.
Expand All @@ -193,8 +225,9 @@ def cancel_orphan_tasks(engine: Engine, session_args: Mapping[str, bool]) -> Non
TaskJob.result_msg: "Task was interrupted due to server restart",
TaskJob.completion_date: datetime.utcnow(),
}
with sessionmaker(bind=engine, **session_args)() as session:
session.query(TaskJob).filter(TaskJob.status.in_([TaskStatus.RUNNING.value, TaskStatus.PENDING.value])).update(
updated_values, synchronize_session=False
)
orphan_status = [TaskStatus.RUNNING.value, TaskStatus.PENDING.value]
make_session = sessionmaker(bind=engine, **session_args)
with make_session() as session:
q = session.query(TaskJob).filter(TaskJob.status.in_(orphan_status)) # type: ignore
q.update(updated_values, synchronize_session=False)
session.commit()
69 changes: 21 additions & 48 deletions antarest/core/tasks/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from fastapi import HTTPException
from sqlalchemy.orm import Session # type: ignore

from antarest.core.tasks.model import TaskJob, TaskListFilter, TaskStatus
from antarest.core.tasks.model import TaskJob, TaskListFilter
from antarest.core.utils.fastapi_sqlalchemy import db
from antarest.core.utils.utils import assert_this


class TaskJobRepository:
Expand Down Expand Up @@ -59,52 +58,35 @@ def get_or_raise(self, id: str) -> TaskJob:
raise HTTPException(HTTPStatus.NOT_FOUND, f"Task {id} not found")
return task

@staticmethod
def _combine_clauses(where_clauses: t.List[t.Any]) -> t.Any:
assert_this(len(where_clauses) > 0)
if len(where_clauses) > 1:
return and_(
where_clauses[0],
TaskJobRepository._combine_clauses(where_clauses[1:]),
)
else:
return where_clauses[0]

def list(self, filter: TaskListFilter, user: t.Optional[int] = None) -> t.List[TaskJob]:
query = self.session.query(TaskJob)
where_clauses: t.List[t.Any] = []
q = self.session.query(TaskJob)
if user:
where_clauses.append(TaskJob.owner_id == user)
q = q.filter(TaskJob.owner_id == user)
if len(filter.status) > 0:
where_clauses.append(TaskJob.status.in_([status.value for status in filter.status]))
_values = [status.value for status in filter.status]
q = q.filter(TaskJob.status.in_(_values)) # type: ignore
if filter.name:
where_clauses.append(TaskJob.name.ilike(f"%{filter.name}%"))
q = q.filter(TaskJob.name.ilike(f"%{filter.name}%")) # type: ignore
if filter.to_creation_date_utc:
where_clauses.append(
TaskJob.creation_date.__le__(datetime.datetime.fromtimestamp(filter.to_creation_date_utc))
)
_date = datetime.datetime.fromtimestamp(filter.to_creation_date_utc)
q = q.filter(TaskJob.creation_date <= _date)
if filter.from_creation_date_utc:
where_clauses.append(
TaskJob.creation_date.__ge__(datetime.datetime.fromtimestamp(filter.from_creation_date_utc))
)
_date = datetime.datetime.fromtimestamp(filter.from_creation_date_utc)
q = q.filter(TaskJob.creation_date >= _date)
if filter.to_completion_date_utc:
where_clauses.append(
TaskJob.completion_date.__le__(datetime.datetime.fromtimestamp(filter.to_completion_date_utc))
)
_date = datetime.datetime.fromtimestamp(filter.to_completion_date_utc)
_clause = and_(TaskJob.completion_date.isnot(None), TaskJob.completion_date <= _date) # type: ignore
q = q.filter(_clause)
if filter.from_completion_date_utc:
where_clauses.append(
TaskJob.completion_date.__ge__(datetime.datetime.fromtimestamp(filter.from_completion_date_utc))
)
_date = datetime.datetime.fromtimestamp(filter.from_completion_date_utc)
_clause = and_(TaskJob.completion_date.isnot(None), TaskJob.completion_date >= _date) # type: ignore
q = q.filter(_clause)
if filter.ref_id is not None:
where_clauses.append(TaskJob.ref_id.__eq__(filter.ref_id))
if len(filter.type) > 0:
where_clauses.append(TaskJob.type.in_([task_type.value for task_type in filter.type]))
if len(where_clauses) > 1:
query = query.where(TaskJobRepository._combine_clauses(where_clauses))
elif len(where_clauses) == 1:
query = query.where(*where_clauses)

tasks: t.List[TaskJob] = query.all()
q = q.filter(TaskJob.ref_id == filter.ref_id)
if filter.type:
_types = [task_type.value for task_type in filter.type]
q = q.filter(TaskJob.type.in_(_types)) # type: ignore
tasks: t.List[TaskJob] = q.all()
return tasks

def delete(self, tid: str) -> None:
Expand All @@ -113,12 +95,3 @@ def delete(self, tid: str) -> None:
if task:
session.delete(task)
session.commit()

def update_timeout(self, task_id: str, timeout: int) -> None:
"""Update task status to TIMEOUT."""
session = self.session
task: TaskJob = session.get(TaskJob, task_id)
task.status = TaskStatus.TIMEOUT
task.result_msg = f"Task '{task_id}' timeout after {timeout} seconds"
task.result_status = False
session.commit()
Loading

0 comments on commit 939a35b

Please sign in to comment.