diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index 9103009c9..7ed045bc3 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -64,11 +64,13 @@ def unpack(value: _NestedJobSequenceType) -> t.Generator[Job, None, None]: """Unpack any iterable input in order to obtain a - single sequence of values + single sequence of values. :param value: Sequence containing elements of type Job or other - sequences that are also of type _NestedJobSequenceType - :return: flattened list of Jobs""" + sequences that are also of type `_NestedJobSequenceType`. + :raises TypeError: If the value is not a nested sequence of jobs. + :return: A flattened list of `Jobs`. + """ from smartsim.launchable.job import Job for item in value: diff --git a/smartsim/experiment.py b/smartsim/experiment.py index 2af726959..d5aafc3b6 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -42,7 +42,7 @@ from smartsim._core.control.launch_history import LaunchHistory as _LaunchHistory from smartsim._core.utils import helpers as _helpers from smartsim.error import errors -from smartsim.launchable.job import Job +from smartsim.launchable.job import Job, Record from smartsim.status import TERMINAL_STATUSES, InvalidJobStatus, JobStatus from ._core import Generator, Manifest @@ -52,7 +52,6 @@ from .log import ctx_exp_path, get_logger, method_contextualizer if t.TYPE_CHECKING: - from smartsim.launchable.job import Job from smartsim.types import LaunchedJobID logger = get_logger(__name__) @@ -158,26 +157,24 @@ def __init__(self, name: str, exp_path: str | None = None): experiment """ - def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[LaunchedJobID, ...]: + def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[Record, ...]: """Execute a collection of `Job` instances. :param jobs: A collection of other job instances to start - :raises TypeError: If jobs provided are not the correct type :raises ValueError: No Jobs were provided. - :returns: A sequence of ids with order corresponding to the sequence of - jobs that can be used to query or alter the status of that - particular execution of the job. + :returns: A sequence of records with order corresponding to the + sequence of jobs that can be used to query or alter the status of + that particular execution of the job. """ if not jobs: raise ValueError("No jobs provided to start") - # Create the run id jobs_ = list(_helpers.unpack(jobs)) - run_id = datetime.datetime.now().replace(microsecond=0).isoformat() root = pathlib.Path(self.exp_path, run_id) - return self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs_) + ids = self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs_) + return tuple(Record(id_, job) for id_, job in zip(ids, jobs_)) def _dispatch( self, @@ -233,15 +230,13 @@ def execute_dispatch(generator: Generator, job: Job, idx: int) -> LaunchedJobID: execute_dispatch(generator, job, idx) for idx, job in enumerate(jobs, 1) ) - def get_status( - self, *ids: LaunchedJobID - ) -> tuple[JobStatus | InvalidJobStatus, ...]: - """Get the status of jobs launched through the `Experiment` from their - launched job id returned when calling `Experiment.start`. + def get_status(self, *records: Record) -> tuple[JobStatus | InvalidJobStatus, ...]: + """Get the status of jobs launched through the `Experiment` from the + record returned when calling `Experiment.start`. - The `Experiment` will map the launched ID back to the launcher that - started the job and request a status update. The order of the returned - statuses exactly matches the order of the launched job ids. + The `Experiment` will map the launched id of the record back to the + launcher that started the job and request a status update. The order of + the returned statuses exactly matches the order of the records. If the `Experiment` cannot find any launcher that started the job associated with the launched job id, then a @@ -252,16 +247,17 @@ def get_status( launched job ids issued by user defined launcher are not sufficiently unique. - :param ids: A sequence of launched job ids issued by the experiment. + :param records: A sequence of records issued by the experiment. :raises TypeError: If ids provided are not the correct type :raises ValueError: No IDs were provided. :returns: A tuple of statuses with order respective of the order of the calling arguments. """ - if not ids: - raise ValueError("No job ids provided to get status") - if not all(isinstance(id, str) for id in ids): - raise TypeError("ids argument was not of type LaunchedJobID") + if not records: + raise ValueError("No records provided to get status") + if not all(isinstance(record, Record) for record in records): + raise TypeError("record argument was not of type Record") + ids = tuple(record.launched_id for record in records) to_query = self._launch_history.group_by_launcher( set(ids), unknown_ok=True @@ -272,39 +268,38 @@ def get_status( return tuple(stats) def wait( - self, *ids: LaunchedJobID, timeout: float | None = None, verbose: bool = True + self, *records: Record, timeout: float | None = None, verbose: bool = True ) -> None: """Block execution until all of the provided launched jobs, represented by an ID, have entered a terminal status. - :param ids: The ids of the launched jobs to wait for. + :param records: The records of the launched jobs to wait for. :param timeout: The max time to wait for all of the launched jobs to end. :param verbose: Whether found statuses should be displayed in the console. :raises TypeError: If IDs provided are not the correct type :raises ValueError: No IDs were provided. """ - if ids: - if not all(isinstance(id, str) for id in ids): - raise TypeError("ids argument was not of type LaunchedJobID") - else: - raise ValueError("No job ids to wait on provided") + if not records: + raise ValueError("No records to wait on provided") + if not all(isinstance(record, Record) for record in records): + raise TypeError("record argument was not of type Record") self._poll_for_statuses( - ids, TERMINAL_STATUSES, timeout=timeout, verbose=verbose + records, TERMINAL_STATUSES, timeout=timeout, verbose=verbose ) def _poll_for_statuses( self, - ids: t.Sequence[LaunchedJobID], + records: t.Sequence[Record], statuses: t.Collection[JobStatus], timeout: float | None = None, interval: float = 5.0, verbose: bool = True, - ) -> dict[LaunchedJobID, JobStatus | InvalidJobStatus]: + ) -> dict[Record, JobStatus | InvalidJobStatus]: """Poll the experiment's launchers for the statuses of the launched jobs with the provided ids, until the status of the changes to one of the provided statuses. - :param ids: The ids of the launched jobs to wait for. + :param records: The records of the launched jobs to wait for. :param statuses: A collection of statuses to poll for. :param timeout: The minimum amount of time to spend polling all jobs to reach one of the supplied statuses. If not supplied or `None`, the @@ -320,12 +315,10 @@ def _poll_for_statuses( log = logger.info if verbose else lambda *_, **__: None method_timeout = _interval.SynchronousTimeInterval(timeout) iter_timeout = _interval.SynchronousTimeInterval(interval) - final: dict[LaunchedJobID, JobStatus | InvalidJobStatus] = {} + final: dict[Record, JobStatus | InvalidJobStatus] = {} - def is_finished( - id_: LaunchedJobID, status: JobStatus | InvalidJobStatus - ) -> bool: - job_title = f"Job({id_}): " + def is_finished(record: Record, status: JobStatus | InvalidJobStatus) -> bool: + job_title = f"Job({record.launched_id}, {record.job.name}): " if done := status in terminal: log(f"{job_title}Finished with status '{status.value}'") else: @@ -334,22 +327,22 @@ def is_finished( if iter_timeout.infinite: raise ValueError("Polling interval cannot be infinite") - while ids and not method_timeout.expired: + while records and not method_timeout.expired: iter_timeout = iter_timeout.new_interval() - stats = zip(ids, self.get_status(*ids)) + stats = zip(records, self.get_status(*records)) is_done = _helpers.group_by(_helpers.pack_params(is_finished), stats) final |= dict(is_done.get(True, ())) - ids = tuple(id_ for id_, _ in is_done.get(False, ())) - if ids: + records = tuple(rec for rec, _ in is_done.get(False, ())) + if records: ( iter_timeout if iter_timeout.remaining < method_timeout.remaining else method_timeout ).block() - if ids: + if records: raise TimeoutError( - f"Job ID(s) {', '.join(map(str, ids))} failed to reach " - "terminal status before timeout" + f"Job ID(s) {', '.join(rec.launched_id for rec in records)} " + "failed to reach terminal status before timeout" ) return final @@ -445,20 +438,20 @@ def summary(self, style: str = "github") -> str: disable_numparse=True, ) - def stop(self, *ids: LaunchedJobID) -> tuple[JobStatus | InvalidJobStatus, ...]: + def stop(self, *records: Record) -> tuple[JobStatus | InvalidJobStatus, ...]: """Cancel the execution of a previously launched job. - :param ids: The ids of the launched jobs to stop. + :param records: The records of the launched jobs to stop. :raises TypeError: If ids provided are not the correct type :raises ValueError: No job ids were provided. :returns: A tuple of job statuses upon cancellation with order respective of the order of the calling arguments. """ - if ids: - if not all(isinstance(id, str) for id in ids): - raise TypeError("ids argument was not of type LaunchedJobID") - else: - raise ValueError("No job ids provided") + if not records: + raise ValueError("No records provided") + if not all(isinstance(record, Record) for record in records): + raise TypeError("record argument was not of type Record") + ids = tuple(record.launched_id for record in records) by_launcher = self._launch_history.group_by_launcher(set(ids), unknown_ok=True) id_to_stop_stat = ( launcher.stop_jobs(*launched).items() diff --git a/smartsim/launchable/job.py b/smartsim/launchable/job.py index 6082ba61d..9ce7dc87e 100644 --- a/smartsim/launchable/job.py +++ b/smartsim/launchable/job.py @@ -26,6 +26,7 @@ from __future__ import annotations +import textwrap import typing as t from copy import deepcopy @@ -39,6 +40,7 @@ if t.TYPE_CHECKING: from smartsim.entity.entity import SmartSimEntity + from smartsim.types import LaunchedJobID @t.final @@ -158,3 +160,44 @@ def __str__(self) -> str: # pragma: no cover string = f"SmartSim Entity: {self.entity}\n" string += f"Launch Settings: {self.launch_settings}" return string + + +@t.final +class Record: + """A Record object to track a launched job along with its assigned + launch ID. + """ + + def __init__(self, launch_id: LaunchedJobID, job: Job) -> None: + """Initialize a new record of a launched job + + :param launch_id: A unique identifier for the launch of the job. + :param job: The job that was launched. + """ + self._id = launch_id + self._job = deepcopy(job) + + @property + def launched_id(self) -> LaunchedJobID: + """The unique identifier for the launched job. + + :returns: A unique identifier for the launched job. + """ + return self._id + + @property + def job(self) -> Job: + """A deep copy of the job that was launched. + + :returns: A deep copy of the launched job. + """ + return deepcopy(self._job) + + def __str__(self) -> str: + return textwrap.dedent(f"""\ + Launch Record: + Launched Job ID: + {self.launched_id} + Laucnehd Job: + {textwrap.indent(str(self._job), " ")} + """) diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 45f3ecf8e..f3494de61 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -111,6 +111,14 @@ def iter_jobs(): yield lambda: next(jobs) +@pytest.fixture +def record_maker(job_maker): + def impl(id_): + return job.Record(id_, job_maker()) + + yield impl + + JobMakerType: t.TypeAlias = t.Callable[[], job.Job] @@ -247,7 +255,7 @@ def test_start_can_launch_jobs( assert ( len(list(experiment._launch_history.iter_past_launchers())) == 0 ), "Initialized w/ launchers" - launched_ids = experiment.start(*jobs) + records = experiment.start(*jobs) assert ( len(list(experiment._launch_history.iter_past_launchers())) == 1 ), "Unexpected number of launchers" @@ -257,20 +265,26 @@ def test_start_can_launch_jobs( assert isinstance(launcher, NoOpRecordLauncher), "Unexpected launcher type" assert launcher.created_by_experiment is experiment, "Not created by experiment" assert ( - len(jobs) == len(launcher.launched_order) == len(launched_ids) == num_jobs + len(jobs) == len(launcher.launched_order) == len(records) == num_jobs ), "Inconsistent number of jobs/launched jobs/launched ids/expected number of jobs" expected_launched = [LaunchRecord.from_job(job) for job in jobs] # Check that `job_a, job_b, job_c, ...` are started in that order when # calling `experiemnt.start(job_a, job_b, job_c, ...)` assert expected_launched == list(launcher.launched_order), "Unexpected launch order" - assert sorted(launched_ids) == sorted(exp_cached_ids), "Exp did not cache ids" + assert sorted(rec.launched_id for rec in records) == sorted( + exp_cached_ids + ), "Exp did not cache ids" - # Similarly, check that `id_a, id_b, id_c, ...` corresponds to + # Similarly, check that `rec_a, rec_b, rec_c, ...` corresponds to # `job_a, job_b, job_c, ...` when calling - # `id_a, id_b, id_c, ... = experiemnt.start(job_a, job_b, job_c, ...)` - expected_id_map = dict(zip(launched_ids, expected_launched)) - assert expected_id_map == launcher.ids_to_launched, "IDs returned in wrong order" + # `rec_a, rec_b, rec_c, ... = experiemnt.start(job_a, job_b, job_c, ...)` + expected_record_map = dict( + zip((rec.launched_id for rec in records), expected_launched) + ) + assert ( + expected_record_map == launcher.ids_to_launched + ), "records returned in wrong order" @pytest.mark.parametrize( @@ -285,7 +299,8 @@ def test_start_can_start_a_job_multiple_times_accross_multiple_calls( ), "Initialized w/ launchers" job = job_maker() ids_to_launches = { - experiment.start(job)[0]: LaunchRecord.from_job(job) for _ in range(num_starts) + experiment.start(job)[0].launched_id: LaunchRecord.from_job(job) + for _ in range(num_starts) } assert ( len(list(experiment._launch_history.iter_past_launchers())) == 1 @@ -297,12 +312,12 @@ def test_start_can_start_a_job_multiple_times_accross_multiple_calls( assert len(launcher.launched_order) == num_starts, "Unexpected number launches" # Check that a single `job` instance can be launched and re-launched and - # that `id_a, id_b, id_c, ...` corresponds to + # that `rec_a, rec_b, rec_c, ...` corresponds to # `"start_a", "start_b", "start_c", ...` when calling # ```py - # id_a = experiment.start(job) # "start_a" - # id_b = experiment.start(job) # "start_b" - # id_c = experiment.start(job) # "start_c" + # rec_a = experiment.start(job) # "start_a" + # rec_b = experiment.start(job) # "start_b" + # rec_c = experiment.start(job) # "start_c" # ... # ``` assert ids_to_launches == launcher.ids_to_launched, "Job was not re-launched" @@ -356,11 +371,11 @@ def impl(num_active_launchers): yield impl -def test_experiment_can_get_statuses(make_populated_experiment): +def test_experiment_can_get_statuses(make_populated_experiment, record_maker): exp = make_populated_experiment(num_active_launchers=1) (launcher,) = exp._launch_history.iter_past_launchers() ids = tuple(launcher.known_ids) - recieved_stats = exp.get_status(*ids) + recieved_stats = exp.get_status(*map(record_maker, ids)) assert len(recieved_stats) == len(ids), "Unexpected number of statuses" assert ( dict(zip(ids, recieved_stats)) == launcher.id_to_status @@ -372,7 +387,7 @@ def test_experiment_can_get_statuses(make_populated_experiment): [pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)], ) def test_experiment_can_get_statuses_from_many_launchers( - make_populated_experiment, num_launchers + make_populated_experiment, num_launchers, record_maker ): exp = make_populated_experiment(num_active_launchers=num_launchers) launcher_and_rand_ids = ( @@ -383,13 +398,13 @@ def test_experiment_can_get_statuses_from_many_launchers( id_: launcher.id_to_status[id_] for launcher, id_ in launcher_and_rand_ids } query_ids = tuple(expected_id_to_stat) - stats = exp.get_status(*query_ids) + stats = exp.get_status(*map(record_maker, query_ids)) assert len(stats) == len(expected_id_to_stat), "Unexpected number of statuses" assert dict(zip(query_ids, stats)) == expected_id_to_stat, "Statuses in wrong order" def test_get_status_returns_not_started_for_unrecognized_ids( - monkeypatch, make_populated_experiment + monkeypatch, make_populated_experiment, record_maker ): exp = make_populated_experiment(num_active_launchers=1) brand_new_id = create_job_id() @@ -399,12 +414,14 @@ def test_get_status_returns_not_started_for_unrecognized_ids( new_history = LaunchHistory({id_: launcher for id_ in rest}) monkeypatch.setattr(exp, "_launch_history", new_history) expected_stats = (InvalidJobStatus.NEVER_STARTED,) * 2 - actual_stats = exp.get_status(brand_new_id, id_not_known_by_exp) + actual_stats = exp.get_status( + record_maker(brand_new_id), record_maker(id_not_known_by_exp) + ) assert expected_stats == actual_stats def test_get_status_de_dups_ids_passed_to_launchers( - monkeypatch, make_populated_experiment + monkeypatch, make_populated_experiment, record_maker ): def track_calls(fn): calls = [] @@ -419,7 +436,8 @@ def impl(*a, **kw): ((launcher, (id_, *_)),) = exp._launch_history.group_by_launcher().items() calls, tracked_get_status = track_calls(launcher.get_status) monkeypatch.setattr(launcher, "get_status", tracked_get_status) - stats = exp.get_status(id_, id_, id_) + record = record_maker(id_) + stats = exp.get_status(record, record, record) assert len(stats) == 3, "Unexpected number of statuses" assert all(stat == stats[0] for stat in stats), "Statuses are not eq" assert len(calls) == 1, "Launcher's `get_status` was called more than once" @@ -429,20 +447,20 @@ def impl(*a, **kw): def test_wait_handles_empty_call_args(experiment): """An exception is raised when there are no jobs to complete""" - with pytest.raises(ValueError, match="No job ids"): + with pytest.raises(ValueError, match="No records"): experiment.wait() -def test_wait_does_not_block_unknown_id(experiment): +def test_wait_does_not_block_unknown_id(experiment, record_maker): """If an experiment does not recognize a job id, it should not wait for its completion """ now = time.perf_counter() - experiment.wait(create_job_id()) + experiment.wait(record_maker(create_job_id())) assert time.perf_counter() - now < 1 -def test_wait_calls_prefered_impl(make_populated_experiment, monkeypatch): +def test_wait_calls_prefered_impl(make_populated_experiment, record_maker, monkeypatch): """Make wait is calling the expected method for checking job statuses. Right now we only have the "polling" impl, but in future this might change to an event based system. @@ -456,7 +474,7 @@ def mocked_impl(*args, **kwargs): was_called = True monkeypatch.setattr(exp, "_poll_for_statuses", mocked_impl) - exp.wait(id_) + exp.wait(record_maker(id_)) assert was_called @@ -469,7 +487,7 @@ def mocked_impl(*args, **kwargs): ) @pytest.mark.parametrize("verbose", [True, False]) def test_poll_status_blocks_until_job_is_completed( - monkeypatch, make_populated_experiment, num_polls, verbose + monkeypatch, make_populated_experiment, record_maker, num_polls, verbose ): """Make sure that the polling based implementation blocks the calling thread. Use varying number of polls to simulate varying lengths of job time @@ -501,17 +519,25 @@ def __call__(self, *args, **kwargs): monkeypatch.setattr( "smartsim.experiment.logger.info", lambda s: mock_log.write(f"{s}\n") ) + record = record_maker(id_) final_statuses = exp._poll_for_statuses( - [id_], different_statuses, timeout=10, interval=0, verbose=verbose + [record], + different_statuses, + timeout=10, + interval=0, + verbose=verbose, ) - assert final_statuses == {id_: new_status} + assert final_statuses == {record: new_status} expected_log = io.StringIO() + name = record.job.name expected_log.writelines( - f"Job({id_}): Running with status '{current_status.value}'\n" + f"Job({id_}, {name}): Running with status '{current_status.value}'\n" for _ in range(num_polls - 1) ) - expected_log.write(f"Job({id_}): Finished with status '{new_status.value}'\n") + expected_log.write( + f"Job({id_}, {name}): Finished with status '{new_status.value}'\n" + ) assert mock_get_status.num_calls == num_polls assert mock_log.getvalue() == (expected_log.getvalue() if verbose else "") @@ -534,7 +560,7 @@ def test_poll_status_raises_when_called_with_infinite_iter_wait( def test_poll_for_status_raises_if_ids_not_found_within_timeout( - make_populated_experiment, + make_populated_experiment, record_maker ): """If there is a timeout, a timeout error should be raised when it is exceeded""" exp = make_populated_experiment(1) @@ -548,7 +574,7 @@ def test_poll_for_status_raises_if_ids_not_found_within_timeout( ), ): exp._poll_for_statuses( - [id_], + [record_maker(id_)], different_statuses, timeout=1, interval=0, @@ -584,14 +610,18 @@ def test_poll_for_status_raises_if_ids_not_found_within_timeout( ), ], ) -def test_experiment_can_stop_jobs(make_populated_experiment, num_launchers, select_ids): +def test_experiment_can_stop_jobs( + make_populated_experiment, record_maker, num_launchers, select_ids +): exp = make_populated_experiment(num_launchers) ids = (launcher.known_ids for launcher in exp._launch_history.iter_past_launchers()) ids = tuple(itertools.chain.from_iterable(ids)) - before_stop_stats = exp.get_status(*ids) + records = tuple(map(record_maker, ids)) + ids_to_records = dict(zip(ids, records)) + before_stop_stats = exp.get_status(*records) to_cancel = tuple(select_ids(exp._launch_history)) - stats = exp.stop(*to_cancel) - after_stop_stats = exp.get_status(*ids) + stats = exp.stop(*(ids_to_records[id_] for id_ in to_cancel)) + after_stop_stats = exp.get_status(*records) assert stats == (JobStatus.CANCELLED,) * len(to_cancel) assert dict(zip(ids, before_stop_stats)) | dict(zip(to_cancel, stats)) == dict( zip(ids, after_stop_stats) @@ -599,7 +629,7 @@ def test_experiment_can_stop_jobs(make_populated_experiment, num_launchers, sele def test_experiment_raises_if_asked_to_stop_no_jobs(experiment): - with pytest.raises(ValueError, match="No job ids provided"): + with pytest.raises(ValueError, match="No records provided"): experiment.stop() @@ -607,16 +637,17 @@ def test_experiment_raises_if_asked_to_stop_no_jobs(experiment): "num_launchers", [pytest.param(i, id=f"{i} launcher(s)") for i in (2, 3, 5, 10, 20, 100)], ) -def test_experiment_stop_does_not_raise_on_unknown_job_id( - make_populated_experiment, num_launchers +def test_experiment_stop_does_not_raise_on_unknown_record( + make_populated_experiment, num_launchers, record_maker ): exp = make_populated_experiment(num_launchers) - new_id = create_job_id() all_known_ids = tuple(exp._launch_history._id_to_issuer) - before_cancel = exp.get_status(*all_known_ids) - (stat,) = exp.stop(new_id) + all_known_records = tuple(map(record_maker, all_known_ids)) + unknown_record = record_maker(create_job_id()) + before_cancel = exp.get_status(*all_known_records) + (stat,) = exp.stop(unknown_record) assert stat == InvalidJobStatus.NEVER_STARTED - after_cancel = exp.get_status(*all_known_ids) + after_cancel = exp.get_status(*all_known_records) assert before_cancel == after_cancel @@ -628,27 +659,28 @@ def test_start_raises_if_no_args_supplied(test_dir): def test_stop_raises_if_no_args_supplied(test_dir): exp = Experiment(name="exp_name", exp_path=test_dir) - with pytest.raises(ValueError, match="No job ids provided"): + with pytest.raises(ValueError, match="No records provided"): exp.stop() def test_get_status_raises_if_no_args_supplied(test_dir): exp = Experiment(name="exp_name", exp_path=test_dir) - with pytest.raises(ValueError, match="No job ids provided"): + with pytest.raises(ValueError, match="No records provided"): exp.get_status() def test_poll_raises_if_no_args_supplied(test_dir): exp = Experiment(name="exp_name", exp_path=test_dir) with pytest.raises( - TypeError, match="missing 2 required positional arguments: 'ids' and 'statuses'" + TypeError, + match="missing 2 required positional arguments: 'records' and 'statuses'", ): exp._poll_for_statuses() def test_wait_raises_if_no_args_supplied(test_dir): exp = Experiment(name="exp_name", exp_path=test_dir) - with pytest.raises(ValueError, match="No job ids to wait on provided"): + with pytest.raises(ValueError, match="No records to wait on provided"): exp.wait() @@ -665,19 +697,19 @@ def test_type_start_parameters(test_dir): def test_type_get_status_parameters(test_dir): exp = Experiment(name="exp_name", exp_path=test_dir) - with pytest.raises(TypeError, match="ids argument was not of type LaunchedJobID"): + with pytest.raises(TypeError, match="record argument was not of type Record"): exp.get_status(2) def test_type_wait_parameter(test_dir): exp = Experiment(name="exp_name", exp_path=test_dir) - with pytest.raises(TypeError, match="ids argument was not of type LaunchedJobID"): + with pytest.raises(TypeError, match="record argument was not of type Record"): exp.wait(2) def test_type_stop_parameter(test_dir): exp = Experiment(name="exp_name", exp_path=test_dir) - with pytest.raises(TypeError, match="ids argument was not of type LaunchedJobID"): + with pytest.raises(TypeError, match="record argument was not of type Record"): exp.stop(2) diff --git a/tests/test_record.py b/tests/test_record.py new file mode 100644 index 000000000..3197de7c0 --- /dev/null +++ b/tests/test_record.py @@ -0,0 +1,64 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import itertools + +import pytest + +from smartsim._core.utils.helpers import expand_exe_path +from smartsim._core.utils.launcher import create_job_id +from smartsim.entity.application import Application +from smartsim.launchable.job import Job, Record +from smartsim.settings.launch_settings import LaunchSettings + +pytestmark = pytest.mark.group_a + + +def test_cannot_mutate_record_job(): + app = Application("my-test-app", "echo", ["spam", "eggs"]) + settings = LaunchSettings("local") + job = Job(app, settings) + + id_ = create_job_id() + record = Record(id_, job) + assert record.launched_id == id_ + assert all( + x is not y for x, y in itertools.combinations([job, record.job, record._job], 2) + ) + + app.name = "Modified orignal app name" + job.name = "Modified original job name" + record.job.name = "Modified reference to job off record name" + record.job.name = "Modified reference to app to job off record name" + assert record.job.name == record.job.entity.name == "my-test-app" + + record.job.entity.exe = "sleep" + app.exe_args = ["120"] + assert [record.job.entity.exe] + record.job.entity.exe_args == [ + expand_exe_path("echo"), + "spam", + "eggs", + ]