-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
122 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,5 @@ hydra-submitit-launcher | |
lightning | ||
gitpython | ||
seaborn | ||
joblib | ||
ujson |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import bz2 | ||
import json | ||
import tarfile | ||
import unittest | ||
import tempfile | ||
from wrangl.data.io import AutoDataset | ||
|
||
|
||
class TestAutoDataset(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self.examples = [ | ||
dict(x='a', y=1), | ||
dict(x='b', y=2), | ||
dict(x='c', y=3), | ||
] | ||
|
||
def test_process_file(self): | ||
with tempfile.TemporaryFile('wt+') as f: | ||
json.dump(self.examples, f) | ||
f.flush() | ||
f.seek(0) | ||
dataset = AutoDataset.process_file(f) | ||
self.assertListEqual(self.examples, dataset) | ||
|
||
def test_load_json_from_disk(self): | ||
with tempfile.NamedTemporaryFile('wt+') as f: | ||
json.dump(self.examples, f) | ||
f.flush() | ||
f.seek(0) | ||
dataset = AutoDataset.load_from_disk(f.name) | ||
self.assertListEqual(self.examples, dataset) | ||
|
||
def test_load_bz2_from_disk(self): | ||
with tempfile.NamedTemporaryFile('wb+', suffix='.bz2') as f: | ||
s = json.dumps(self.examples) | ||
f.write(bz2.compress(s.encode('utf8'))) | ||
f.flush() | ||
f.seek(0) | ||
dataset = AutoDataset.load_from_disk(f.name) | ||
self.assertListEqual(self.examples, dataset) | ||
|
||
def test_load_tar_bz2_from_disk(self): | ||
with tempfile.NamedTemporaryFile(suffix='.tar.bz2') as forig: | ||
with tarfile.open(forig.name, 'w:bz2') as tar: | ||
with tempfile.NamedTemporaryFile('wt', suffix='.json') as fa: | ||
json.dump(self.examples, fa) | ||
fa.flush() | ||
with tempfile.NamedTemporaryFile('wt', suffix='.json') as fb: | ||
json.dump(list(reversed(self.examples)), fb) | ||
fb.flush() | ||
tar.add(fa.name) | ||
tar.add(fb.name) | ||
tar.close() | ||
dataset = AutoDataset.load_from_disk(forig.name) | ||
self.assertListEqual(self.examples + list(reversed(self.examples)), dataset) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
""" | ||
Utilities to save/load data to disk. | ||
""" | ||
import bz2 | ||
import typing | ||
import tarfile | ||
import ujson as json | ||
from tqdm import auto as tqdm | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class AutoDataset(list, Dataset): | ||
|
||
@classmethod | ||
def process_file(cls, file: typing.TextIO): | ||
""" | ||
Loads a dataset from a single file. | ||
Assumes that the data is a JSON file with a list of dictionaries. | ||
""" | ||
return json.load(file) | ||
|
||
@classmethod | ||
def merge(cls, datasets): | ||
"""Merge multiple datasets""" | ||
out = cls() | ||
for d in datasets: | ||
out.extend(d) | ||
return out | ||
|
||
@classmethod | ||
def load_from_disk(cls, path: typing.Union[str, typing.Iterable], verbose=False): | ||
""" | ||
Loads a dataset from given `path`, where `path` can be either one file path or an iterator over file paths. | ||
If `path` is a single file, then `cls.process_file` will be applied. | ||
If `path` is a single file ending in `.bz2`, then bzip2 decompression will be applied, followed by `cls.process_file`. | ||
If `path` is a single file ending in `.tar.bz2`, then tar decompression will be applied, followed by bzip2 decompression on every file, followed by `cls.process_file`. | ||
""" | ||
if isinstance(path, str): | ||
if path.endswith('.tar.bz2'): | ||
datasets = [] | ||
with tarfile.open(path, 'r:bz2') as tar: | ||
iterator = tar.getmembers() | ||
if verbose: | ||
iterator = tqdm.tqdm(iterator) | ||
for member in iterator: | ||
file = tar.extractfile(member) | ||
datasets.append(cls.process_file(file)) | ||
return cls.merge(datasets) | ||
elif path.endswith('.bz2'): | ||
with bz2.open(path, 'rt') as f: | ||
return cls.process_file(f) | ||
else: | ||
with open(path, 'rt') as f: | ||
return cls.process_file(f) | ||
else: | ||
iterator = path | ||
if verbose: | ||
iterator = tqdm.tqdm(path) | ||
return cls.merge(cls.process_file(p) for p in iterator) |