From 75c8aad39ae5373a1dc8108dd697dae5379762a1 Mon Sep 17 00:00:00 2001 From: Faiz Surani Date: Thu, 3 Oct 2024 10:52:16 -0700 Subject: [PATCH] Remove spark; add pydantic mode to io readers --- rl/utils/io.py | 47 ++++++++++++++++++----------------------------- 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/rl/utils/io.py b/rl/utils/io.py index dcc8ac6..ed830fe 100644 --- a/rl/utils/io.py +++ b/rl/utils/io.py @@ -1,7 +1,6 @@ import csv import json import os -import shutil from collections.abc import Iterable from pathlib import Path from typing import Any @@ -86,14 +85,21 @@ def getenv(name: str, default=None) -> str: return os.getenv(name, default) -def read_jsonl(filename: str | Path) -> Iterable[Any]: +def read_jsonl( + filename: str | Path, *, pydantic_cls: type[BaseModel] | None = None +) -> Iterable[Any]: filename = Path(filename) with filename.open() as f: for line in f: - yield json.loads(line) + if pydantic_cls: + yield pydantic_cls.model_validate_json(line) + else: + yield json.loads(line) -def write_jsonl(filename: str | Path, records: Iterable[Any], overwrite=False) -> None: +def write_jsonl( + filename: str | Path, records: Iterable[Any], *, overwrite=False +) -> None: filename = Path(filename) if filename.exists() and not overwrite: raise ValueError(f"{filename} already exists and overwrite is not set.") @@ -107,34 +113,17 @@ def write_jsonl(filename: str | Path, records: Iterable[Any], overwrite=False) - f.write(json_record + "\n") -def write_jsonl_spark(filename: str | Path, df, overwrite=False) -> None: - filename = Path(filename) - if filename.exists() and not overwrite: - raise ValueError(f"{filename} already exists and overwrite is not set.") - output_path_dir_name = filename.parent / f"{filename.stem}_dir" - df.coalesce(1).write.json(str(output_path_dir_name), lineSep="\n", mode="overwrite") - output_path_dir = list(output_path_dir_name.glob("*.json"))[0] - shutil.move(output_path_dir, filename) - shutil.rmtree(output_path_dir_name) - - -def write_parquet_spark(filename: str | Path, df, overwrite=False) -> None: - filename = Path(filename) - if filename.exists() and not overwrite: - raise ValueError(f"{filename} already exists and overwrite is not set.") - - output_path_dir_name = filename.parent / f"{filename.stem}_dir" - df.coalesce(1).write.parquet(str(output_path_dir_name), mode="overwrite") - parquet_file = list(output_path_dir_name.glob("*.parquet"))[0] - shutil.move(parquet_file, filename) - shutil.rmtree(output_path_dir_name) - - -def read_csv(filename: str | Path) -> Iterable[dict[str, Any]]: +def read_csv( + filename: str | Path, *, pydantic_cls: type[BaseModel] | None = None +) -> Iterable[dict[str, Any]]: filename = Path(filename) with filename.open() as f: reader = csv.DictReader(f) - yield from reader + for row in reader: + if pydantic_cls: + yield pydantic_cls.model_validate(row) + else: + yield row def write_csv(