Skip to content

Commit

Permalink
Merge pull request #83 from simonsobs/dev
Browse files Browse the repository at this point in the history
Update the `on_start_run()` and `on_end_run()` hooks with the new `event` argument
  • Loading branch information
TaiSakuma authored Jun 14, 2024
2 parents e0a2c07 + af174bd commit b7ecec6
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 52 deletions.
46 changes: 20 additions & 26 deletions src/nextline_rdb/write/write_run_table.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from datetime import timezone
from logging import getLogger

from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import RunArg
from nextline.events import OnEndRun, OnStartRun
from nextline.plugin.spec import hookimpl
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
Expand All @@ -18,24 +18,23 @@ def __init__(self, db: DB) -> None:
self._logger = getLogger(__name__)

@hookimpl
async def on_start_run(self, context: Context) -> None:
assert (run_arg := context.run_arg)
assert (running_process := context.running_process)
started_at = running_process.process_created_at
assert started_at.tzinfo is timezone.utc
started_at = started_at.replace(tzinfo=None)
run_no = run_arg.run_no
async def on_start_run(self, event: OnStartRun) -> None:
assert event.started_at.tzinfo is timezone.utc
started_at = event.started_at.replace(tzinfo=None)
async with self._db.session.begin() as session:
script = await self._find_script(run_arg, session)
script = await self._find_script(event, session)
run = Run(
run_no=run_no, state='running', started_at=started_at, script=script
run_no=event.run_no,
state='running',
started_at=started_at,
script=script,
)
session.add(run)

async def _find_script(
self, run_arg: RunArg, session: AsyncSession
self, event: OnStartRun, session: AsyncSession
) -> Script | None:
statement = self._str_statement_or_none(run_arg)
statement = self._str_statement_or_none(event)
current_script = await self._load_current_script(session)
match statement, current_script:
case None, None:
Expand All @@ -60,27 +59,22 @@ async def _find_script(
return cs.script
return None

def _str_statement_or_none(self, run_arg: RunArg) -> str | None:
if isinstance(run_arg.statement, str):
return run_arg.statement
def _str_statement_or_none(self, event: OnStartRun) -> str | None:
if isinstance(event.statement, str):
return event.statement
return None

async def _load_current_script(self, session: AsyncSession) -> CurrentScript | None:
stmt = select(CurrentScript).options(selectinload(CurrentScript.script))
return (await session.execute(stmt)).scalar_one_or_none()

@hookimpl
async def on_end_run(self, context: Context) -> None:
assert (run_arg := context.run_arg)
assert (exited_process := context.exited_process)
assert (returned := exited_process.returned)
ended_at = exited_process.process_exited_at
assert ended_at.tzinfo is timezone.utc
ended_at = ended_at.replace(tzinfo=None)
run_no = run_arg.run_no
async def on_end_run(self, event: OnEndRun) -> None:
assert event.ended_at.tzinfo is timezone.utc
ended_at = event.ended_at.replace(tzinfo=None)
async with self._db.session.begin() as session:
stmt = select(Run).filter_by(run_no=run_no)
stmt = select(Run).filter_by(run_no=event.run_no)
run = await until_scalar_one(session, stmt)
run.state = 'finished'
run.ended_at = ended_at
run.exception = returned.fmt_exc
run.exception = event.raised
15 changes: 6 additions & 9 deletions src/nextline_rdb/write/write_trace_table.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import timezone

from nextline.events import OnEndTrace, OnStartTrace
from nextline.plugin.spec import Context, hookimpl
from nextline.events import OnEndRun, OnEndTrace, OnStartTrace
from nextline.plugin.spec import hookimpl
from nextline.types import TraceNo
from sqlalchemy import select

Expand Down Expand Up @@ -49,13 +49,10 @@ async def on_start_run(self) -> None:
self._running_trace_nos.clear()

@hookimpl
async def on_end_run(self, context: Context) -> None:
assert (run_arg := context.run_arg)
assert (exited_process := context.exited_process)
ended_at = exited_process.process_exited_at
assert ended_at.tzinfo is timezone.utc
ended_at = ended_at.replace(tzinfo=None)
run_no = run_arg.run_no
async def on_end_run(self, event: OnEndRun) -> None:
assert event.ended_at.tzinfo is timezone.utc
ended_at = event.ended_at.replace(tzinfo=None)
run_no = event.run_no
async with self._db.session.begin() as session:
stmt = (
select(Trace)
Expand Down
35 changes: 18 additions & 17 deletions tests/write/test_write.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from datetime import timezone
from pathlib import Path
from unittest.mock import Mock
Expand All @@ -7,15 +8,17 @@
from nextline import Nextline
from nextline.events import (
OnEndPrompt,
OnEndRun,
OnEndTrace,
OnEndTraceCall,
OnStartPrompt,
OnStartRun,
OnStartTrace,
OnStartTraceCall,
OnWriteStdout,
)
from nextline.plugin import spec
from nextline.spawned import RunArg, RunResult
from nextline.spawned import RunArg
from nextline.types import (
PromptNo,
RunNo,
Expand All @@ -25,7 +28,6 @@
TraceCallNo,
TraceNo,
)
from nextline.utils import ExitedProcess, RunningProcess

from nextline_rdb.db import DB
from nextline_rdb.models import (
Expand Down Expand Up @@ -126,15 +128,16 @@ async def _handle_run(context: spec.Context, run: Run) -> None:
if run_started_at is None:
return

# running_process = Mock(spec=RunningProcess[RunResult])
running_process = Mock(spec=RunningProcess)
running_process.process_created_at = run_started_at.replace(tzinfo=timezone.utc)
context.running_process = running_process
on_start_run = OnStartRun(
started_at=run_started_at.replace(tzinfo=timezone.utc),
run_no=RunNo(run.run_no),
statement=statement,
)

# ic(run_started_at)
# ic(run.started_at == run_started_at.replace(tzinfo=None))

await context.hook.ahook.on_start_run(context=context)
await context.hook.ahook.on_start_run(context=context, event=on_start_run)
run.state = 'running'

for trace in run.traces:
Expand All @@ -148,18 +151,16 @@ async def _handle_run(context: spec.Context, run: Run) -> None:
run.exception = None
return

returned = Mock(spec=RunResult)
returned.fmt_exc = run.exception = run.exception or ''
raised = run.exception = run.exception or ''

exited_process = ExitedProcess[RunResult](
returned=returned,
raised=None,
process=Mock(),
process_created_at=running_process.process_created_at,
process_exited_at=run_ended_at.replace(tzinfo=timezone.utc),
on_end_run = OnEndRun(
ended_at=run_ended_at.replace(tzinfo=timezone.utc),
run_no=RunNo(run.run_no),
returned=json.dumps(None),
raised=raised,
)
context.exited_process = exited_process
await context.hook.ahook.on_end_run(context=context)

await context.hook.ahook.on_end_run(context=context, event=on_end_run)
run.state = 'finished'


Expand Down

0 comments on commit b7ecec6

Please sign in to comment.