Skip to content

Commit

Permalink
add AutoDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
vzhong committed Nov 1, 2024
1 parent f222147 commit 5e739e3
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 45 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ hydra-submitit-launcher
lightning
gitpython
seaborn
joblib
ujson
60 changes: 60 additions & 0 deletions tests/test_autodataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import bz2
import json
import tarfile
import unittest
import tempfile
from wrangl.data.io import AutoDataset


class TestAutoDataset(unittest.TestCase):

def setUp(self):
self.examples = [
dict(x='a', y=1),
dict(x='b', y=2),
dict(x='c', y=3),
]

def test_process_file(self):
with tempfile.TemporaryFile('wt+') as f:
json.dump(self.examples, f)
f.flush()
f.seek(0)
dataset = AutoDataset.process_file(f)
self.assertListEqual(self.examples, dataset)

def test_load_json_from_disk(self):
with tempfile.NamedTemporaryFile('wt+') as f:
json.dump(self.examples, f)
f.flush()
f.seek(0)
dataset = AutoDataset.load_from_disk(f.name)
self.assertListEqual(self.examples, dataset)

def test_load_bz2_from_disk(self):
with tempfile.NamedTemporaryFile('wb+', suffix='.bz2') as f:
s = json.dumps(self.examples)
f.write(bz2.compress(s.encode('utf8')))
f.flush()
f.seek(0)
dataset = AutoDataset.load_from_disk(f.name)
self.assertListEqual(self.examples, dataset)

def test_load_tar_bz2_from_disk(self):
with tempfile.NamedTemporaryFile(suffix='.tar.bz2') as forig:
with tarfile.open(forig.name, 'w:bz2') as tar:
with tempfile.NamedTemporaryFile('wt', suffix='.json') as fa:
json.dump(self.examples, fa)
fa.flush()
with tempfile.NamedTemporaryFile('wt', suffix='.json') as fb:
json.dump(list(reversed(self.examples)), fb)
fb.flush()
tar.add(fa.name)
tar.add(fb.name)
tar.close()
dataset = AutoDataset.load_from_disk(forig.name)
self.assertListEqual(self.examples + list(reversed(self.examples)), dataset)


if __name__ == '__main__':
unittest.main()
13 changes: 0 additions & 13 deletions tests/test_load_jsonl_files.py

This file was deleted.

13 changes: 0 additions & 13 deletions tests/test_load_sql_db.py

This file was deleted.

18 changes: 0 additions & 18 deletions tests/test_repeat_string.py

This file was deleted.

2 changes: 1 addition & 1 deletion wrangl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
.. include:: ../README.md
.. include:: examples/learn/README.md
"""
__all__ = ['learn', 'examples']
__all__ = ['learn', 'data', 'examples']
Empty file added wrangl/data/__init__.py
Empty file.
59 changes: 59 additions & 0 deletions wrangl/data/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Utilities to save/load data to disk.
"""
import bz2
import typing
import tarfile
import ujson as json
from tqdm import auto as tqdm
from torch.utils.data import Dataset


class AutoDataset(list, Dataset):

@classmethod
def process_file(cls, file: typing.TextIO):
"""
Loads a dataset from a single file.
Assumes that the data is a JSON file with a list of dictionaries.
"""
return json.load(file)

@classmethod
def merge(cls, datasets):
"""Merge multiple datasets"""
out = cls()
for d in datasets:
out.extend(d)
return out

@classmethod
def load_from_disk(cls, path: typing.Union[str, typing.Iterable], verbose=False):
"""
Loads a dataset from given `path`, where `path` can be either one file path or an iterator over file paths.
If `path` is a single file, then `cls.process_file` will be applied.
If `path` is a single file ending in `.bz2`, then bzip2 decompression will be applied, followed by `cls.process_file`.
If `path` is a single file ending in `.tar.bz2`, then tar decompression will be applied, followed by bzip2 decompression on every file, followed by `cls.process_file`.
"""
if isinstance(path, str):
if path.endswith('.tar.bz2'):
datasets = []
with tarfile.open(path, 'r:bz2') as tar:
iterator = tar.getmembers()
if verbose:
iterator = tqdm.tqdm(iterator)
for member in iterator:
file = tar.extractfile(member)
datasets.append(cls.process_file(file))
return cls.merge(datasets)
elif path.endswith('.bz2'):
with bz2.open(path, 'rt') as f:
return cls.process_file(f)
else:
with open(path, 'rt') as f:
return cls.process_file(f)
else:
iterator = path
if verbose:
iterator = tqdm.tqdm(path)
return cls.merge(cls.process_file(p) for p in iterator)

0 comments on commit 5e739e3

Please sign in to comment.