Skip to content

Commit

Permalink
Remove spark; add pydantic mode to io readers
Browse files Browse the repository at this point in the history
  • Loading branch information
ProbablyFaiz committed Oct 3, 2024
1 parent c96c201 commit 75c8aad
Showing 1 changed file with 18 additions and 29 deletions.
47 changes: 18 additions & 29 deletions rl/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import csv
import json
import os
import shutil
from collections.abc import Iterable
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -86,14 +85,21 @@ def getenv(name: str, default=None) -> str:
return os.getenv(name, default)


def read_jsonl(filename: str | Path) -> Iterable[Any]:
def read_jsonl(
filename: str | Path, *, pydantic_cls: type[BaseModel] | None = None
) -> Iterable[Any]:
filename = Path(filename)
with filename.open() as f:
for line in f:
yield json.loads(line)
if pydantic_cls:
yield pydantic_cls.model_validate_json(line)
else:
yield json.loads(line)


def write_jsonl(filename: str | Path, records: Iterable[Any], overwrite=False) -> None:
def write_jsonl(
filename: str | Path, records: Iterable[Any], *, overwrite=False
) -> None:
filename = Path(filename)
if filename.exists() and not overwrite:
raise ValueError(f"{filename} already exists and overwrite is not set.")
Expand All @@ -107,34 +113,17 @@ def write_jsonl(filename: str | Path, records: Iterable[Any], overwrite=False) -
f.write(json_record + "\n")


def write_jsonl_spark(filename: str | Path, df, overwrite=False) -> None:
filename = Path(filename)
if filename.exists() and not overwrite:
raise ValueError(f"{filename} already exists and overwrite is not set.")
output_path_dir_name = filename.parent / f"{filename.stem}_dir"
df.coalesce(1).write.json(str(output_path_dir_name), lineSep="\n", mode="overwrite")
output_path_dir = list(output_path_dir_name.glob("*.json"))[0]
shutil.move(output_path_dir, filename)
shutil.rmtree(output_path_dir_name)


def write_parquet_spark(filename: str | Path, df, overwrite=False) -> None:
filename = Path(filename)
if filename.exists() and not overwrite:
raise ValueError(f"{filename} already exists and overwrite is not set.")

output_path_dir_name = filename.parent / f"{filename.stem}_dir"
df.coalesce(1).write.parquet(str(output_path_dir_name), mode="overwrite")
parquet_file = list(output_path_dir_name.glob("*.parquet"))[0]
shutil.move(parquet_file, filename)
shutil.rmtree(output_path_dir_name)


def read_csv(filename: str | Path) -> Iterable[dict[str, Any]]:
def read_csv(
filename: str | Path, *, pydantic_cls: type[BaseModel] | None = None
) -> Iterable[dict[str, Any]]:
filename = Path(filename)
with filename.open() as f:
reader = csv.DictReader(f)
yield from reader
for row in reader:
if pydantic_cls:
yield pydantic_cls.model_validate(row)
else:
yield row


def write_csv(
Expand Down

0 comments on commit 75c8aad

Please sign in to comment.