Skip to content

Commit

Permalink
add ignore_list() to node
Browse files Browse the repository at this point in the history
  • Loading branch information
estshorter committed Aug 18, 2021
1 parent b0e8807 commit 50d2635
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 69 deletions.
80 changes: 60 additions & 20 deletions lwpipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import logging
from enum import IntEnum, auto
from inspect import signature
from typing import Callable, Optional
from typing import Callable, Iterable, Optional

from .utils import _assert_same_length

logger = logging.getLogger(__name__)


__version__ = "5.0.0"
__version__ = "5.1.0"


class DumpType(IntEnum):
Expand Down Expand Up @@ -170,18 +170,62 @@ def convert_callables_to_nodes(self, funcs, names):
def get_node_names(self) -> list[str]:
return [node.name for node in self.nodes]

def run(self, from_: int | str = 0, to_: int | str | None = None):
def _handle_ignore_list(
self,
from_: int | str | None,
to_: int | str | None,
ignore_list: Iterable[int | str],
):
if isinstance(from_, int):
raise ValueError("from_ must be string or None")
if isinstance(to_, int):
raise ValueError("to_ must be string or None")

ignore_set = set()
for idx, ignore in enumerate(ignore_list):
ignore = self._get_node_index(ignore, "ignore")
if ignore in ignore_set:
raise ValueError(f"ignore index: {ignore} is duplicated")
ignore_set.add(ignore)

nodes = []
names = []
for idx, node in enumerate(self.nodes):
if idx not in ignore_set:
nodes.append(node)
names.append(node.name)

if len(nodes) == 0:
raise ValueError(
"number of nodes becomes zero after considering ignore_list"
)
if nodes[0].inputs is None:
nodes[0].inputs = self.nodes[0].inputs

return Pipeline(nodes, names).run(from_, to_)

def run(
self,
from_: int | str | None = None,
to_: int | str | None = None,
ignore_list: Iterable[int | str] | None = None,
):
"""pipelineを実行する。戻り値はlist。
Parameters
----------------
start: どのノードからパイプラインを開始するか。インデックスかnameで指定可能。
end: どのノードのまでパイプラインを実行するか。インデックスかnameで指定可能。
start: どのノードからパイプラインを開始するか。インデックスかnameで指定可能だが、ignore_listを指定する際は必ずnameで指定する。
end: どのノードのまでパイプラインを実行するか。インデックスかnameで指定可能だが、ignore_listを指定する際は必ずnameで指定する。
ignore_list: 実行しないノードのリスト。
"""
if ignore_list is not None:
return self._handle_ignore_list(from_, to_, ignore_list)
if from_ is None:
from_ = 0
if to_ is None:
to_ = len(self.nodes) - 1
idx_from = self._get_start_or_end_index(from_, "start")
idx_from = self._get_node_index(from_, "start")
self.idx_from = idx_from
idx_to = self._get_start_or_end_index(to_, "end")
idx_to = self._get_node_index(to_, "end")
if idx_from > idx_to:
raise ValueError(
f"idx_from must satisfy idx_from ({idx_from}) <= idx_to ({idx_to})"
Expand Down Expand Up @@ -224,16 +268,20 @@ def run(self, from_: int | str = 0, to_: int | str | None = None):
)
return outputs

def _get_start_or_end_index(self, start_or_end: int | str, start_or_end_str: str):
def _get_node_index(self, start_or_end: int | str, start_or_end_str: str):
if isinstance(start_or_end, int):
idx = start_or_end
elif isinstance(start_or_end, str):
try:
idx = self.name_to_idx[start_or_end]
except KeyError:
raise ValueError(
f"specified {start_or_end_str} node ({start_or_end}) is not found"
f'specified {start_or_end_str} node "{start_or_end}" is not found'
)
else:
raise ValueError(
f'specified "{start_or_end_str}" node must be int or str, but {type(start_or_end)}',
)

if idx < 0 or idx >= len(self.nodes):
raise ValueError(
Expand Down Expand Up @@ -267,9 +315,7 @@ def _load_interim_output(self):
if idx_ > self.idx_from:
continue
last_output = self._load_last_output(
self.nodes[idx_ - 1],
loaded_files_set,
idx_ - 1,
self.nodes[idx_ - 1], loaded_files_set, idx_ - 1,
)
else: # 前段より前の出力を入力にする場合
self._load_past_output(self.nodes[idx_], loaded_files_set)
Expand All @@ -291,10 +337,7 @@ def _load_last_output(self, node_prev, loaded_files_set, idx):
results = node_prev.outputs_loader(*args)
if node_prev.outputs is not None:
_assert_same_length(
node_prev.outputs,
results,
"node_prev.outputs_path",
"results",
node_prev.outputs, results, "node_prev.outputs_path", "results",
)
for output, result in zip(node_prev.outputs, results):
if output is not None:
Expand Down Expand Up @@ -346,10 +389,7 @@ def _load_past_output(self, node, loaded_files_set):
*node_dep.outputs_path, node_dep.outputs
)
_assert_same_length(
node_dep.outputs,
outputs,
"node_dep.outputs",
"outputs",
node_dep.outputs, outputs, "node_dep.outputs", "outputs",
)
for output_key, output in zip(node_dep.outputs, outputs):
self.results[output_key] = output
Expand Down
119 changes: 70 additions & 49 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,8 @@ def np_array_2d():

def test_simple(np_array_2d):
nodes = [
Node(
func=np.add,
inputs=(np_array_2d, 1),
),
Node(
func=np.sum,
),
Node(func=np.add, inputs=(np_array_2d, 1),),
Node(func=np.sum,),
]

pipe = Pipeline(nodes)
Expand All @@ -59,15 +54,10 @@ def test_output(np_array_2d, tmp_path):
inputs=np_array_2d,
outputs=("mean1", "mean2"),
outputs_dumper=dump_pickle,
outputs_path=(
result1,
result2,
),
outputs_path=(result1, result2,),
outputs_loader=load_pickle,
),
Node(
func=lambda x, y: (x.mean(), y.mean()),
),
Node(func=lambda x, y: (x.mean(), y.mean()),),
]

pipe = Pipeline(nodes)
Expand All @@ -90,15 +80,10 @@ def test_ensure_read(np_array_2d, tmp_path):
inputs=np_array_2d,
outputs=("mean1", "mean2"),
outputs_dumper=dump_pickle,
outputs_path=(
result1,
result2,
),
outputs_path=(result1, result2,),
outputs_loader=load_pickle,
),
Node(
func=lambda x, y: (x.mean(), y.mean()),
),
Node(func=lambda x, y: (x.mean(), y.mean()),),
]

pipe = Pipeline(nodes)
Expand Down Expand Up @@ -138,10 +123,7 @@ def test_base(np_array_2d, tmp_path):
outputs_path=result2,
outputs_loader=load_npy,
),
Node(
func=lambda x: x + 10,
inputs="mean",
),
Node(func=lambda x: x + 10, inputs="mean",),
]

pipe = Pipeline(nodes)
Expand All @@ -168,20 +150,14 @@ def test_tuple_output(np_array_2d, tmp_path):
inputs=np_array_2d,
outputs=("divide1", "divide2"),
outputs_dumper=dump_pickle,
outputs_path=(
mean1,
mean2,
),
outputs_path=(mean1, mean2,),
outputs_loader=load_pickle,
),
Node(
func=ten_times_two_inputs,
inputs=("divide1", "divide2"),
outputs_dumper=dump_npy,
outputs_path=(
mul1,
mul2,
),
outputs_path=(mul1, mul2,),
),
]

Expand All @@ -196,10 +172,7 @@ def test_tuple_output(np_array_2d, tmp_path):
def test_no_input_at_initial_node():
nodes = [
Node(func=lambda: 100),
Node(
func=lambda x: 10 * x,
name="multiply",
),
Node(func=lambda x: 10 * x, name="multiply",),
]

pipe = Pipeline(nodes)
Expand All @@ -210,11 +183,7 @@ def test_no_input_at_initial_node():
def test_none_outputs():
nodes = [
Node(func=lambda: 100, outputs=[None]),
Node(
func=lambda x: 10 * x,
name="multiply",
inputs=[None],
),
Node(func=lambda x: 10 * x, name="multiply", inputs=[None],),
]

pipe = Pipeline(nodes)
Expand Down Expand Up @@ -281,9 +250,7 @@ def test_batch(np_array_2d, tmp_path):
outputs_path=result3,
outputs_loader=load_savez_compressed,
),
Node(
func=lambda x, y: (x, y),
),
Node(func=lambda x, y: (x, y),),
]

pipe = Pipeline(nodes)
Expand Down Expand Up @@ -402,10 +369,7 @@ def test_kidou(tmp_path):

def test_inputs_not_found():
nodes = [
Node(
func=np.add,
inputs=(1, 2),
),
Node(func=np.add, inputs=(1, 2),),
Node(func=lambda x: x * 10, inputs="hoge"),
]

Expand Down Expand Up @@ -571,3 +535,60 @@ def test_previous_result_to_results_dict_batch():
outputs = pipe.run(1)
assert outputs[0] == "INPUT_proc1_proc3"
assert pipe.results["proc2"] == "INPUT_proc1_proc2"


def test_ignore(tmp_path):
result1 = tmp_path / "result1.pickle"
result2 = tmp_path / "result2.pickle"
result3 = tmp_path / "result3.pickle"
nodes = [
Node(
func=lambda x: x + 10,
inputs=0,
name="1",
outputs_dumper=dump_pickle,
outputs_loader=load_pickle,
outputs_path=result1,
),
Node(
func=lambda x: x + 100,
name="2",
outputs_dumper=dump_pickle,
outputs_loader=load_pickle,
outputs_path=result2,
),
Node(
func=lambda x: x + 100,
name="3",
outputs_dumper=dump_pickle,
outputs_loader=load_pickle,
outputs_path=result3,
),
]
pipe = Pipeline(nodes)
pipe.run(ignore_list=[0])
pipe.clear()
pipe.run(ignore_list=[1])
pipe.clear()
pipe.run(ignore_list=["2"])
pipe.clear()
pipe.run(from_="3", ignore_list=["2"])
pipe.clear()
with pytest.raises(ValueError):
pipe.run(ignore_list=(0, 1, 2))
with pytest.raises(ValueError):
pipe.run(from_="1", ignore_list=(0,))
with pytest.raises(ValueError):
pipe.run(from_=2, ignore_list=(0,))

funcs = [no_op for _ in range(3)]
pipe = Pipeline(funcs, ["1", "2", "3"])
pipe.run()
pipe.run(ignore_list=[0])
pipe.clear()
pipe.run(ignore_list=[1])
with pytest.raises(ValueError):
pipe.run(ignore_list=(0, 1, 2))
with pytest.raises(ValueError):
pipe.run(from_="1", ignore_list=(0,))

0 comments on commit 50d2635

Please sign in to comment.