diff --git a/rl/utils/core.py b/rl/utils/core.py index a225339..e89d34a 100644 --- a/rl/utils/core.py +++ b/rl/utils/core.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Iterable, TypeVar import regex +from tqdm.contrib.concurrent import process_map, thread_map K = TypeVar("K") T = TypeVar("T") @@ -30,3 +31,18 @@ def safe_extract(pattern: regex.Pattern, text: str, key: str) -> str | None: return match.group(key) else: return None + + +def tqdm_map( + fn: Callable[[T], Any], + *iterables: Iterable[T], + mode: str = None, + **tqdm_kwargs: Any, +) -> list[Any]: + """Map a function over iterables with a tqdm progress bar.""" + if mode == "thread": + return thread_map(fn, *iterables, **tqdm_kwargs) + elif mode == "process": + return process_map(fn, *iterables, **tqdm_kwargs) + else: + return list(map(fn, *iterables))