diff --git a/lwpipe/__init__.py b/lwpipe/__init__.py index 953ec27..23fb078 100644 --- a/lwpipe/__init__.py +++ b/lwpipe/__init__.py @@ -3,14 +3,14 @@ import logging from enum import IntEnum, auto from inspect import signature -from typing import Callable, Optional +from typing import Callable, Iterable, Optional from .utils import _assert_same_length logger = logging.getLogger(__name__) -__version__ = "5.0.0" +__version__ = "5.1.0" class DumpType(IntEnum): @@ -170,18 +170,62 @@ def convert_callables_to_nodes(self, funcs, names): def get_node_names(self) -> list[str]: return [node.name for node in self.nodes] - def run(self, from_: int | str = 0, to_: int | str | None = None): + def _handle_ignore_list( + self, + from_: int | str | None, + to_: int | str | None, + ignore_list: Iterable[int | str], + ): + if isinstance(from_, int): + raise ValueError("from_ must be string or None") + if isinstance(to_, int): + raise ValueError("to_ must be string or None") + + ignore_set = set() + for idx, ignore in enumerate(ignore_list): + ignore = self._get_node_index(ignore, "ignore") + if ignore in ignore_set: + raise ValueError(f"ignore index: {ignore} is duplicated") + ignore_set.add(ignore) + + nodes = [] + names = [] + for idx, node in enumerate(self.nodes): + if idx not in ignore_set: + nodes.append(node) + names.append(node.name) + + if len(nodes) == 0: + raise ValueError( + "number of nodes becomes zero after considering ignore_list" + ) + if nodes[0].inputs is None: + nodes[0].inputs = self.nodes[0].inputs + + return Pipeline(nodes, names).run(from_, to_) + + def run( + self, + from_: int | str | None = None, + to_: int | str | None = None, + ignore_list: Iterable[int | str] | None = None, + ): """pipelineを実行する。戻り値はlist。 Parameters ---------------- - start: どのノードからパイプラインを開始するか。インデックスかnameで指定可能。 - end: どのノードのまでパイプラインを実行するか。インデックスかnameで指定可能。 + start: どのノードからパイプラインを開始するか。インデックスかnameで指定可能だが、ignore_listを指定する際は必ずnameで指定する。 + end: どのノードのまでパイプラインを実行するか。インデックスかnameで指定可能だが、ignore_listを指定する際は必ずnameで指定する。 + ignore_list: 実行しないノードのリスト。 """ + if ignore_list is not None: + return self._handle_ignore_list(from_, to_, ignore_list) + if from_ is None: + from_ = 0 if to_ is None: to_ = len(self.nodes) - 1 - idx_from = self._get_start_or_end_index(from_, "start") + idx_from = self._get_node_index(from_, "start") self.idx_from = idx_from - idx_to = self._get_start_or_end_index(to_, "end") + idx_to = self._get_node_index(to_, "end") if idx_from > idx_to: raise ValueError( f"idx_from must satisfy idx_from ({idx_from}) <= idx_to ({idx_to})" @@ -224,7 +268,7 @@ def run(self, from_: int | str = 0, to_: int | str | None = None): ) return outputs - def _get_start_or_end_index(self, start_or_end: int | str, start_or_end_str: str): + def _get_node_index(self, start_or_end: int | str, start_or_end_str: str): if isinstance(start_or_end, int): idx = start_or_end elif isinstance(start_or_end, str): @@ -232,8 +276,12 @@ def _get_start_or_end_index(self, start_or_end: int | str, start_or_end_str: str idx = self.name_to_idx[start_or_end] except KeyError: raise ValueError( - f"specified {start_or_end_str} node ({start_or_end}) is not found" + f'specified {start_or_end_str} node "{start_or_end}" is not found' ) + else: + raise ValueError( + f'specified "{start_or_end_str}" node must be int or str, but {type(start_or_end)}', + ) if idx < 0 or idx >= len(self.nodes): raise ValueError( @@ -267,9 +315,7 @@ def _load_interim_output(self): if idx_ > self.idx_from: continue last_output = self._load_last_output( - self.nodes[idx_ - 1], - loaded_files_set, - idx_ - 1, + self.nodes[idx_ - 1], loaded_files_set, idx_ - 1, ) else: # 前段より前の出力を入力にする場合 self._load_past_output(self.nodes[idx_], loaded_files_set) @@ -291,10 +337,7 @@ def _load_last_output(self, node_prev, loaded_files_set, idx): results = node_prev.outputs_loader(*args) if node_prev.outputs is not None: _assert_same_length( - node_prev.outputs, - results, - "node_prev.outputs_path", - "results", + node_prev.outputs, results, "node_prev.outputs_path", "results", ) for output, result in zip(node_prev.outputs, results): if output is not None: @@ -346,10 +389,7 @@ def _load_past_output(self, node, loaded_files_set): *node_dep.outputs_path, node_dep.outputs ) _assert_same_length( - node_dep.outputs, - outputs, - "node_dep.outputs", - "outputs", + node_dep.outputs, outputs, "node_dep.outputs", "outputs", ) for output_key, output in zip(node_dep.outputs, outputs): self.results[output_key] = output diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index f044661..32ee77d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -35,13 +35,8 @@ def np_array_2d(): def test_simple(np_array_2d): nodes = [ - Node( - func=np.add, - inputs=(np_array_2d, 1), - ), - Node( - func=np.sum, - ), + Node(func=np.add, inputs=(np_array_2d, 1),), + Node(func=np.sum,), ] pipe = Pipeline(nodes) @@ -59,15 +54,10 @@ def test_output(np_array_2d, tmp_path): inputs=np_array_2d, outputs=("mean1", "mean2"), outputs_dumper=dump_pickle, - outputs_path=( - result1, - result2, - ), + outputs_path=(result1, result2,), outputs_loader=load_pickle, ), - Node( - func=lambda x, y: (x.mean(), y.mean()), - ), + Node(func=lambda x, y: (x.mean(), y.mean()),), ] pipe = Pipeline(nodes) @@ -90,15 +80,10 @@ def test_ensure_read(np_array_2d, tmp_path): inputs=np_array_2d, outputs=("mean1", "mean2"), outputs_dumper=dump_pickle, - outputs_path=( - result1, - result2, - ), + outputs_path=(result1, result2,), outputs_loader=load_pickle, ), - Node( - func=lambda x, y: (x.mean(), y.mean()), - ), + Node(func=lambda x, y: (x.mean(), y.mean()),), ] pipe = Pipeline(nodes) @@ -138,10 +123,7 @@ def test_base(np_array_2d, tmp_path): outputs_path=result2, outputs_loader=load_npy, ), - Node( - func=lambda x: x + 10, - inputs="mean", - ), + Node(func=lambda x: x + 10, inputs="mean",), ] pipe = Pipeline(nodes) @@ -168,20 +150,14 @@ def test_tuple_output(np_array_2d, tmp_path): inputs=np_array_2d, outputs=("divide1", "divide2"), outputs_dumper=dump_pickle, - outputs_path=( - mean1, - mean2, - ), + outputs_path=(mean1, mean2,), outputs_loader=load_pickle, ), Node( func=ten_times_two_inputs, inputs=("divide1", "divide2"), outputs_dumper=dump_npy, - outputs_path=( - mul1, - mul2, - ), + outputs_path=(mul1, mul2,), ), ] @@ -196,10 +172,7 @@ def test_tuple_output(np_array_2d, tmp_path): def test_no_input_at_initial_node(): nodes = [ Node(func=lambda: 100), - Node( - func=lambda x: 10 * x, - name="multiply", - ), + Node(func=lambda x: 10 * x, name="multiply",), ] pipe = Pipeline(nodes) @@ -210,11 +183,7 @@ def test_no_input_at_initial_node(): def test_none_outputs(): nodes = [ Node(func=lambda: 100, outputs=[None]), - Node( - func=lambda x: 10 * x, - name="multiply", - inputs=[None], - ), + Node(func=lambda x: 10 * x, name="multiply", inputs=[None],), ] pipe = Pipeline(nodes) @@ -281,9 +250,7 @@ def test_batch(np_array_2d, tmp_path): outputs_path=result3, outputs_loader=load_savez_compressed, ), - Node( - func=lambda x, y: (x, y), - ), + Node(func=lambda x, y: (x, y),), ] pipe = Pipeline(nodes) @@ -402,10 +369,7 @@ def test_kidou(tmp_path): def test_inputs_not_found(): nodes = [ - Node( - func=np.add, - inputs=(1, 2), - ), + Node(func=np.add, inputs=(1, 2),), Node(func=lambda x: x * 10, inputs="hoge"), ] @@ -571,3 +535,60 @@ def test_previous_result_to_results_dict_batch(): outputs = pipe.run(1) assert outputs[0] == "INPUT_proc1_proc3" assert pipe.results["proc2"] == "INPUT_proc1_proc2" + + +def test_ignore(tmp_path): + result1 = tmp_path / "result1.pickle" + result2 = tmp_path / "result2.pickle" + result3 = tmp_path / "result3.pickle" + nodes = [ + Node( + func=lambda x: x + 10, + inputs=0, + name="1", + outputs_dumper=dump_pickle, + outputs_loader=load_pickle, + outputs_path=result1, + ), + Node( + func=lambda x: x + 100, + name="2", + outputs_dumper=dump_pickle, + outputs_loader=load_pickle, + outputs_path=result2, + ), + Node( + func=lambda x: x + 100, + name="3", + outputs_dumper=dump_pickle, + outputs_loader=load_pickle, + outputs_path=result3, + ), + ] + pipe = Pipeline(nodes) + pipe.run(ignore_list=[0]) + pipe.clear() + pipe.run(ignore_list=[1]) + pipe.clear() + pipe.run(ignore_list=["2"]) + pipe.clear() + pipe.run(from_="3", ignore_list=["2"]) + pipe.clear() + with pytest.raises(ValueError): + pipe.run(ignore_list=(0, 1, 2)) + with pytest.raises(ValueError): + pipe.run(from_="1", ignore_list=(0,)) + with pytest.raises(ValueError): + pipe.run(from_=2, ignore_list=(0,)) + + funcs = [no_op for _ in range(3)] + pipe = Pipeline(funcs, ["1", "2", "3"]) + pipe.run() + pipe.run(ignore_list=[0]) + pipe.clear() + pipe.run(ignore_list=[1]) + with pytest.raises(ValueError): + pipe.run(ignore_list=(0, 1, 2)) + with pytest.raises(ValueError): + pipe.run(from_="1", ignore_list=(0,)) +