diff --git a/src/nextline_rdb/write/write_run_table.py b/src/nextline_rdb/write/write_run_table.py index 41b3430..c765321 100644 --- a/src/nextline_rdb/write/write_run_table.py +++ b/src/nextline_rdb/write/write_run_table.py @@ -1,8 +1,8 @@ from datetime import timezone from logging import getLogger -from nextline.plugin.spec import Context, hookimpl -from nextline.spawned import RunArg +from nextline.events import OnEndRun, OnStartRun +from nextline.plugin.spec import hookimpl from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -18,24 +18,23 @@ def __init__(self, db: DB) -> None: self._logger = getLogger(__name__) @hookimpl - async def on_start_run(self, context: Context) -> None: - assert (run_arg := context.run_arg) - assert (running_process := context.running_process) - started_at = running_process.process_created_at - assert started_at.tzinfo is timezone.utc - started_at = started_at.replace(tzinfo=None) - run_no = run_arg.run_no + async def on_start_run(self, event: OnStartRun) -> None: + assert event.started_at.tzinfo is timezone.utc + started_at = event.started_at.replace(tzinfo=None) async with self._db.session.begin() as session: - script = await self._find_script(run_arg, session) + script = await self._find_script(event, session) run = Run( - run_no=run_no, state='running', started_at=started_at, script=script + run_no=event.run_no, + state='running', + started_at=started_at, + script=script, ) session.add(run) async def _find_script( - self, run_arg: RunArg, session: AsyncSession + self, event: OnStartRun, session: AsyncSession ) -> Script | None: - statement = self._str_statement_or_none(run_arg) + statement = self._str_statement_or_none(event) current_script = await self._load_current_script(session) match statement, current_script: case None, None: @@ -60,9 +59,9 @@ async def _find_script( return cs.script return None - def _str_statement_or_none(self, run_arg: RunArg) -> str | None: - if isinstance(run_arg.statement, str): - return run_arg.statement + def _str_statement_or_none(self, event: OnStartRun) -> str | None: + if isinstance(event.statement, str): + return event.statement return None async def _load_current_script(self, session: AsyncSession) -> CurrentScript | None: @@ -70,17 +69,12 @@ async def _load_current_script(self, session: AsyncSession) -> CurrentScript | N return (await session.execute(stmt)).scalar_one_or_none() @hookimpl - async def on_end_run(self, context: Context) -> None: - assert (run_arg := context.run_arg) - assert (exited_process := context.exited_process) - assert (returned := exited_process.returned) - ended_at = exited_process.process_exited_at - assert ended_at.tzinfo is timezone.utc - ended_at = ended_at.replace(tzinfo=None) - run_no = run_arg.run_no + async def on_end_run(self, event: OnEndRun) -> None: + assert event.ended_at.tzinfo is timezone.utc + ended_at = event.ended_at.replace(tzinfo=None) async with self._db.session.begin() as session: - stmt = select(Run).filter_by(run_no=run_no) + stmt = select(Run).filter_by(run_no=event.run_no) run = await until_scalar_one(session, stmt) run.state = 'finished' run.ended_at = ended_at - run.exception = returned.fmt_exc + run.exception = event.raised diff --git a/src/nextline_rdb/write/write_trace_table.py b/src/nextline_rdb/write/write_trace_table.py index bf3522c..407ac6e 100644 --- a/src/nextline_rdb/write/write_trace_table.py +++ b/src/nextline_rdb/write/write_trace_table.py @@ -1,7 +1,7 @@ from datetime import timezone -from nextline.events import OnEndTrace, OnStartTrace -from nextline.plugin.spec import Context, hookimpl +from nextline.events import OnEndRun, OnEndTrace, OnStartTrace +from nextline.plugin.spec import hookimpl from nextline.types import TraceNo from sqlalchemy import select @@ -49,13 +49,10 @@ async def on_start_run(self) -> None: self._running_trace_nos.clear() @hookimpl - async def on_end_run(self, context: Context) -> None: - assert (run_arg := context.run_arg) - assert (exited_process := context.exited_process) - ended_at = exited_process.process_exited_at - assert ended_at.tzinfo is timezone.utc - ended_at = ended_at.replace(tzinfo=None) - run_no = run_arg.run_no + async def on_end_run(self, event: OnEndRun) -> None: + assert event.ended_at.tzinfo is timezone.utc + ended_at = event.ended_at.replace(tzinfo=None) + run_no = event.run_no async with self._db.session.begin() as session: stmt = ( select(Trace) diff --git a/tests/write/test_write.py b/tests/write/test_write.py index b4d71a7..db3fd89 100644 --- a/tests/write/test_write.py +++ b/tests/write/test_write.py @@ -1,3 +1,4 @@ +import json from datetime import timezone from pathlib import Path from unittest.mock import Mock @@ -7,15 +8,17 @@ from nextline import Nextline from nextline.events import ( OnEndPrompt, + OnEndRun, OnEndTrace, OnEndTraceCall, OnStartPrompt, + OnStartRun, OnStartTrace, OnStartTraceCall, OnWriteStdout, ) from nextline.plugin import spec -from nextline.spawned import RunArg, RunResult +from nextline.spawned import RunArg from nextline.types import ( PromptNo, RunNo, @@ -25,7 +28,6 @@ TraceCallNo, TraceNo, ) -from nextline.utils import ExitedProcess, RunningProcess from nextline_rdb.db import DB from nextline_rdb.models import ( @@ -126,15 +128,16 @@ async def _handle_run(context: spec.Context, run: Run) -> None: if run_started_at is None: return - # running_process = Mock(spec=RunningProcess[RunResult]) - running_process = Mock(spec=RunningProcess) - running_process.process_created_at = run_started_at.replace(tzinfo=timezone.utc) - context.running_process = running_process + on_start_run = OnStartRun( + started_at=run_started_at.replace(tzinfo=timezone.utc), + run_no=RunNo(run.run_no), + statement=statement, + ) # ic(run_started_at) # ic(run.started_at == run_started_at.replace(tzinfo=None)) - await context.hook.ahook.on_start_run(context=context) + await context.hook.ahook.on_start_run(context=context, event=on_start_run) run.state = 'running' for trace in run.traces: @@ -148,18 +151,16 @@ async def _handle_run(context: spec.Context, run: Run) -> None: run.exception = None return - returned = Mock(spec=RunResult) - returned.fmt_exc = run.exception = run.exception or '' + raised = run.exception = run.exception or '' - exited_process = ExitedProcess[RunResult]( - returned=returned, - raised=None, - process=Mock(), - process_created_at=running_process.process_created_at, - process_exited_at=run_ended_at.replace(tzinfo=timezone.utc), + on_end_run = OnEndRun( + ended_at=run_ended_at.replace(tzinfo=timezone.utc), + run_no=RunNo(run.run_no), + returned=json.dumps(None), + raised=raised, ) - context.exited_process = exited_process - await context.hook.ahook.on_end_run(context=context) + + await context.hook.ahook.on_end_run(context=context, event=on_end_run) run.state = 'finished'