Skip to content

Commit

Permalink
Add tqdm_map util
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Jul 7, 2024
1 parent 9f59d41 commit a3c6a03
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions rl/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Callable, Iterable, TypeVar

import regex
from tqdm.contrib.concurrent import process_map, thread_map

K = TypeVar("K")
T = TypeVar("T")
Expand Down Expand Up @@ -30,3 +31,18 @@ def safe_extract(pattern: regex.Pattern, text: str, key: str) -> str | None:
return match.group(key)
else:
return None


def tqdm_map(
fn: Callable[[T], Any],
*iterables: Iterable[T],
mode: str = None,
**tqdm_kwargs: Any,
) -> list[Any]:
"""Map a function over iterables with a tqdm progress bar."""
if mode == "thread":
return thread_map(fn, *iterables, **tqdm_kwargs)
elif mode == "process":
return process_map(fn, *iterables, **tqdm_kwargs)
else:
return list(map(fn, *iterables))

0 comments on commit a3c6a03

Please sign in to comment.