-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial implementation of MultiCSVFile
- Loading branch information
Showing
4 changed files
with
384 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
|
||
import csv | ||
from typing import TextIO, Optional, Type, List, MutableMapping, Iterator | ||
import shutil | ||
from .subtextio import SubTextIO | ||
from .exceptions import OpOnClosedCSVFileError, CSVFileBaseIOClosed, \ | ||
SectionNotFound | ||
from .section import MultiCSVSection | ||
|
||
|
||
class MultiCSVFile(MutableMapping[str, TextIO]): | ||
def __init__(self, file: TextIO): | ||
self._initialized = False | ||
self._file = file | ||
self._closed = self._file.closed | ||
self._sections: List[MultiCSVSection] = [] | ||
self._initialize_sections() | ||
self._initialized = True | ||
|
||
def __getitem__(self, key: str) -> TextIO: | ||
for item in self._sections: | ||
if item.name == key: | ||
return item.descriptor | ||
|
||
raise SectionNotFound("MultiCSVFile does not " | ||
f"have section named {key!r}.") | ||
|
||
def __setitem__(self, key: str, value: TextIO) -> None: | ||
def make_section() -> MultiCSVSection: | ||
return MultiCSVSection(name=key, descriptor=value) | ||
|
||
for i, item in enumerate(self._sections): | ||
if item.name == key: | ||
self._sections[i] = make_section() | ||
return | ||
|
||
self._sections.append(make_section()) | ||
|
||
def __delitem__(self, key: str) -> None: | ||
found = None | ||
for i, item in enumerate(self._sections): | ||
if item.name == key: | ||
found = i | ||
break | ||
|
||
if found is None: | ||
raise SectionNotFound("MultiCSVFile does not " | ||
f"have section named {key!r}.") | ||
else: | ||
del self._sections[i] | ||
|
||
def __iter__(self) -> Iterator[str]: | ||
return iter(map(lambda x: x.name, self._sections)) | ||
|
||
def __len__(self) -> int: | ||
return len(self._sections) | ||
|
||
def __contains__(self, key: object) -> bool: | ||
for item in self._sections: | ||
if item.name == key: | ||
return True | ||
return False | ||
|
||
def close(self) -> None: | ||
if not self._closed: | ||
try: | ||
self.flush() | ||
finally: | ||
self._closed = True | ||
|
||
def _write_section(self, section: MultiCSVSection) -> None: | ||
self._file.write(f"[{section.name}]\n") | ||
|
||
initial_section_pos = section.descriptor.tell() | ||
try: | ||
section.descriptor.seek(0) | ||
shutil.copyfileobj(section.descriptor, self._file) | ||
finally: | ||
section.descriptor.seek(initial_section_pos) | ||
|
||
def _write_file(self) -> None: | ||
self._file.seek(0) | ||
self._file.truncate() | ||
|
||
for section in self._sections: | ||
self._write_section(section) | ||
|
||
def flush(self) -> None: | ||
if self._file.closed: | ||
raise CSVFileBaseIOClosed("Base file is closed in flush.") | ||
|
||
if self._closed: | ||
return | ||
|
||
initial_file_pos = self._file.tell() | ||
try: | ||
self._write_file() | ||
finally: | ||
self._file.seek(initial_file_pos) | ||
|
||
def __enter__(self) -> 'MultiCSVFile': | ||
return self | ||
|
||
def __exit__(self, | ||
exc_type: Optional[Type[BaseException]], | ||
exc_val: Optional[BaseException], | ||
exc_tb: Optional[object]) -> None: | ||
self.close() | ||
|
||
def _initialize_sections_wrapped(self) -> None: | ||
reader = csv.reader(self._file) | ||
|
||
current_section: Optional[str] = None | ||
section_start = 0 | ||
previous_position = 0 | ||
|
||
self._file.seek(0) | ||
|
||
def end_section() -> None: | ||
if current_section is not None: | ||
descriptor = SubTextIO(self._file, | ||
start=section_start, | ||
end=previous_position) | ||
section = MultiCSVSection(name=current_section, | ||
descriptor=descriptor) | ||
self._sections.append(section) | ||
|
||
for row in reader: | ||
current_position = self._file.tell() | ||
if row: | ||
|
||
first = row[0] | ||
rest = row[1:] | ||
|
||
if first.startswith("[") and \ | ||
first.endswith("]") and \ | ||
all(not x for x in rest): | ||
|
||
end_section() | ||
current_section = first[1:-1] | ||
section_start = current_position | ||
|
||
previous_position = current_position | ||
|
||
end_section() | ||
|
||
def _initialize_sections(self) -> None: | ||
initial_file_pos = self._file.tell() | ||
try: | ||
self._initialize_sections_wrapped() | ||
finally: | ||
self._file.seek(initial_file_pos) | ||
|
||
def _check_closed(self) -> None: | ||
""" | ||
Helper method to verify if the IO object is closed. | ||
""" | ||
|
||
if self._closed: | ||
raise OpOnClosedCSVFileError("I/O operation on closed file.") | ||
|
||
def __del__(self) -> None: | ||
if self._initialized: | ||
try: | ||
self.close() | ||
except CSVFileBaseIOClosed: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
|
||
from dataclasses import dataclass | ||
from typing import TextIO | ||
|
||
|
||
@dataclass(frozen=True) | ||
class MultiCSVSection: | ||
name: str | ||
descriptor: TextIO |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import io | ||
import pytest | ||
from typing import TextIO | ||
from multicsv.file import MultiCSVFile | ||
from multicsv.exceptions import SectionNotFound, CSVFileBaseIOClosed | ||
|
||
|
||
@pytest.fixture | ||
def simple_csv() -> TextIO: | ||
content = """\ | ||
[section1] | ||
a,b,c | ||
1,2,3 | ||
[section2] | ||
d,e,f | ||
4,5,6 | ||
""" | ||
return io.StringIO(content) | ||
|
||
|
||
@pytest.fixture | ||
def empty_csv() -> TextIO: | ||
return io.StringIO("") | ||
|
||
|
||
@pytest.fixture | ||
def no_section_csv() -> TextIO: | ||
content = """\ | ||
a,b,c | ||
1,2,3 | ||
d,e,f | ||
4,5,6 | ||
""" | ||
return io.StringIO(content) | ||
|
||
|
||
def test_read_section(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
section1 = csv_file["section1"] | ||
assert section1.read() == "a,b,c\n1,2,3\n" | ||
|
||
section2 = csv_file["section2"] | ||
assert section2.read() == "d,e,f\n4,5,6\n" | ||
|
||
|
||
def test_write_section(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
section3 = io.StringIO("g,h,i\n7,8,9\n") | ||
csv_file["section3"] = section3 | ||
|
||
csv_file.flush() | ||
simple_csv.seek(0) | ||
assert simple_csv.read() == "[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6\n[section3]\ng,h,i\n7,8,9\n" | ||
|
||
|
||
def test_delete_section(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
del csv_file["section1"] | ||
|
||
csv_file.flush() | ||
simple_csv.seek(0) | ||
assert simple_csv.read() == "[section2]\nd,e,f\n4,5,6\n" | ||
|
||
|
||
def test_iterate_sections(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
sections = list(iter(csv_file)) | ||
assert sections == ["section1", "section2"] | ||
|
||
|
||
def test_getitem_not_found(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
with pytest.raises(SectionNotFound): | ||
_ = csv_file["section3"] | ||
|
||
|
||
def test_contains(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
assert "section1" in csv_file | ||
assert "section2" in csv_file | ||
assert "section3" not in csv_file | ||
|
||
|
||
def test_len(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
assert len(csv_file) == 2 | ||
|
||
|
||
def test_close(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
|
||
simple_csv.close() | ||
assert simple_csv.closed is True | ||
|
||
with pytest.raises(CSVFileBaseIOClosed): | ||
csv_file.flush() | ||
|
||
|
||
def test_context_manager(simple_csv): | ||
with MultiCSVFile(simple_csv) as csv_file: | ||
assert len(csv_file) == 2 | ||
|
||
assert csv_file._closed is True | ||
|
||
|
||
def test_flush_on_closed_file(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
csv_file.close() | ||
csv_file.flush() | ||
|
||
|
||
def test_initialize_sections_on_empty_file(empty_csv): | ||
csv_file = MultiCSVFile(empty_csv) | ||
assert len(csv_file) == 0 | ||
|
||
|
||
def test_initialize_sections_on_no_section_file(no_section_csv): | ||
csv_file = MultiCSVFile(no_section_csv) | ||
assert len(csv_file) == 0 # No sections should be found | ||
|
||
|
||
def test_update_existing_section(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
new_section = io.StringIO("new,data\n") | ||
csv_file["section1"] = new_section | ||
|
||
csv_file.flush() | ||
simple_csv.seek(0) | ||
expected_content = "[section1]\nnew,data\n[section2]\nd,e,f\n4,5,6\n" | ||
assert simple_csv.read() == expected_content | ||
|
||
|
||
def test_del_non_existent_section(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
with pytest.raises(SectionNotFound): | ||
del csv_file["section9"] | ||
|
||
|
||
def test_multiple_writes_with_flush(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
|
||
section3 = io.StringIO("x,y,z\n10,11,12\n") | ||
csv_file["section3"] = section3 | ||
csv_file.flush() | ||
|
||
section4 = io.StringIO("p,q,r\n13,14,15\n") | ||
csv_file["section4"] = section4 | ||
csv_file.flush() | ||
|
||
simple_csv.seek(0) | ||
expected_content = "[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6\n[section3]\nx,y,z\n10,11,12\n[section4]\np,q,r\n13,14,15\n" | ||
assert simple_csv.read() == expected_content | ||
|
||
|
||
def test_reading_after_writing(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
|
||
section3 = io.StringIO("x,y,z\n10,11,12\n") | ||
csv_file["section3"] = section3 | ||
csv_file.flush() | ||
|
||
section = csv_file["section3"] | ||
section.seek(0) | ||
assert section.read() == "x,y,z\n10,11,12\n" | ||
|
||
|
||
def test_iterates_zero_length_multicsvfile(empty_csv): | ||
csv_file = MultiCSVFile(empty_csv) | ||
assert list(iter(csv_file)) == [] | ||
|
||
|
||
def test_section_not_found_for_deleted_section(simple_csv): | ||
csv_file = MultiCSVFile(simple_csv) | ||
|
||
assert csv_file["section1"] | ||
del csv_file["section1"] | ||
|
||
with pytest.raises(SectionNotFound): | ||
csv_file["section1"] | ||
|
||
|
||
@pytest.mark.parametrize("initial_content, expected_sections", [ | ||
("[first_section]\na,b,c\n1,2,3\n[second_section]\nd,e,f\n4,5,6\n", | ||
["first_section", "second_section"]), | ||
("", []), | ||
("[lonely_section]\ng,h,i\n7,8,9\n", ["lonely_section"]), | ||
]) | ||
def test_various_initial_contents(initial_content, expected_sections): | ||
file = io.StringIO(initial_content) | ||
csv_file = MultiCSVFile(file) | ||
assert list(iter(csv_file)) == expected_sections |