Skip to content

Commit

Permalink
Add initial implementation of MultiCSVFile
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Aug 7, 2024
1 parent b141847 commit 9709660
Show file tree
Hide file tree
Showing 4 changed files with 384 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/multicsv/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,20 @@ class StartsBeyondBaseContent(SubTextIOErrror, ValueError):

class BaseIOClosed(SubTextIOErrror, ValueError):
pass


class MultiCSVFileError(Exception):
"""Base class for all MultiCSVFile custom exceptions."""
pass


class OpOnClosedCSVFileError(MultiCSVFileError, ValueError):
pass


class CSVFileBaseIOClosed(MultiCSVFileError, ValueError):
pass


class SectionNotFound(MultiCSVFileError, KeyError):
pass
167 changes: 167 additions & 0 deletions src/multicsv/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@

import csv
from typing import TextIO, Optional, Type, List, MutableMapping, Iterator
import shutil
from .subtextio import SubTextIO
from .exceptions import OpOnClosedCSVFileError, CSVFileBaseIOClosed, \
SectionNotFound
from .section import MultiCSVSection


class MultiCSVFile(MutableMapping[str, TextIO]):
def __init__(self, file: TextIO):
self._initialized = False
self._file = file
self._closed = self._file.closed
self._sections: List[MultiCSVSection] = []
self._initialize_sections()
self._initialized = True

def __getitem__(self, key: str) -> TextIO:
for item in self._sections:
if item.name == key:
return item.descriptor

raise SectionNotFound("MultiCSVFile does not "
f"have section named {key!r}.")

def __setitem__(self, key: str, value: TextIO) -> None:
def make_section() -> MultiCSVSection:
return MultiCSVSection(name=key, descriptor=value)

for i, item in enumerate(self._sections):
if item.name == key:
self._sections[i] = make_section()
return

self._sections.append(make_section())

def __delitem__(self, key: str) -> None:
found = None
for i, item in enumerate(self._sections):
if item.name == key:
found = i
break

if found is None:
raise SectionNotFound("MultiCSVFile does not "
f"have section named {key!r}.")
else:
del self._sections[i]

def __iter__(self) -> Iterator[str]:
return iter(map(lambda x: x.name, self._sections))

def __len__(self) -> int:
return len(self._sections)

def __contains__(self, key: object) -> bool:
for item in self._sections:
if item.name == key:
return True
return False

def close(self) -> None:
if not self._closed:
try:
self.flush()
finally:
self._closed = True

def _write_section(self, section: MultiCSVSection) -> None:
self._file.write(f"[{section.name}]\n")

initial_section_pos = section.descriptor.tell()
try:
section.descriptor.seek(0)
shutil.copyfileobj(section.descriptor, self._file)
finally:
section.descriptor.seek(initial_section_pos)

def _write_file(self) -> None:
self._file.seek(0)
self._file.truncate()

for section in self._sections:
self._write_section(section)

def flush(self) -> None:
if self._file.closed:
raise CSVFileBaseIOClosed("Base file is closed in flush.")

if self._closed:
return

initial_file_pos = self._file.tell()
try:
self._write_file()
finally:
self._file.seek(initial_file_pos)

def __enter__(self) -> 'MultiCSVFile':
return self

def __exit__(self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[object]) -> None:
self.close()

def _initialize_sections_wrapped(self) -> None:
reader = csv.reader(self._file)

current_section: Optional[str] = None
section_start = 0
previous_position = 0

self._file.seek(0)

def end_section() -> None:
if current_section is not None:
descriptor = SubTextIO(self._file,
start=section_start,
end=previous_position)
section = MultiCSVSection(name=current_section,
descriptor=descriptor)
self._sections.append(section)

for row in reader:
current_position = self._file.tell()
if row:

first = row[0]
rest = row[1:]

if first.startswith("[") and \
first.endswith("]") and \
all(not x for x in rest):

end_section()
current_section = first[1:-1]
section_start = current_position

previous_position = current_position

end_section()

def _initialize_sections(self) -> None:
initial_file_pos = self._file.tell()
try:
self._initialize_sections_wrapped()
finally:
self._file.seek(initial_file_pos)

def _check_closed(self) -> None:
"""
Helper method to verify if the IO object is closed.
"""

if self._closed:
raise OpOnClosedCSVFileError("I/O operation on closed file.")

def __del__(self) -> None:
if self._initialized:
try:
self.close()
except CSVFileBaseIOClosed:
pass
9 changes: 9 additions & 0 deletions src/multicsv/section.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

from dataclasses import dataclass
from typing import TextIO


@dataclass(frozen=True)
class MultiCSVSection:
name: str
descriptor: TextIO
191 changes: 191 additions & 0 deletions tests/test_file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import io
import pytest
from typing import TextIO
from multicsv.file import MultiCSVFile
from multicsv.exceptions import SectionNotFound, CSVFileBaseIOClosed


@pytest.fixture
def simple_csv() -> TextIO:
content = """\
[section1]
a,b,c
1,2,3
[section2]
d,e,f
4,5,6
"""
return io.StringIO(content)


@pytest.fixture
def empty_csv() -> TextIO:
return io.StringIO("")


@pytest.fixture
def no_section_csv() -> TextIO:
content = """\
a,b,c
1,2,3
d,e,f
4,5,6
"""
return io.StringIO(content)


def test_read_section(simple_csv):
csv_file = MultiCSVFile(simple_csv)
section1 = csv_file["section1"]
assert section1.read() == "a,b,c\n1,2,3\n"

section2 = csv_file["section2"]
assert section2.read() == "d,e,f\n4,5,6\n"


def test_write_section(simple_csv):
csv_file = MultiCSVFile(simple_csv)
section3 = io.StringIO("g,h,i\n7,8,9\n")
csv_file["section3"] = section3

csv_file.flush()
simple_csv.seek(0)
assert simple_csv.read() == "[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6\n[section3]\ng,h,i\n7,8,9\n"


def test_delete_section(simple_csv):
csv_file = MultiCSVFile(simple_csv)
del csv_file["section1"]

csv_file.flush()
simple_csv.seek(0)
assert simple_csv.read() == "[section2]\nd,e,f\n4,5,6\n"


def test_iterate_sections(simple_csv):
csv_file = MultiCSVFile(simple_csv)
sections = list(iter(csv_file))
assert sections == ["section1", "section2"]


def test_getitem_not_found(simple_csv):
csv_file = MultiCSVFile(simple_csv)
with pytest.raises(SectionNotFound):
_ = csv_file["section3"]


def test_contains(simple_csv):
csv_file = MultiCSVFile(simple_csv)
assert "section1" in csv_file
assert "section2" in csv_file
assert "section3" not in csv_file


def test_len(simple_csv):
csv_file = MultiCSVFile(simple_csv)
assert len(csv_file) == 2


def test_close(simple_csv):
csv_file = MultiCSVFile(simple_csv)

simple_csv.close()
assert simple_csv.closed is True

with pytest.raises(CSVFileBaseIOClosed):
csv_file.flush()


def test_context_manager(simple_csv):
with MultiCSVFile(simple_csv) as csv_file:
assert len(csv_file) == 2

assert csv_file._closed is True


def test_flush_on_closed_file(simple_csv):
csv_file = MultiCSVFile(simple_csv)
csv_file.close()
csv_file.flush()


def test_initialize_sections_on_empty_file(empty_csv):
csv_file = MultiCSVFile(empty_csv)
assert len(csv_file) == 0


def test_initialize_sections_on_no_section_file(no_section_csv):
csv_file = MultiCSVFile(no_section_csv)
assert len(csv_file) == 0 # No sections should be found


def test_update_existing_section(simple_csv):
csv_file = MultiCSVFile(simple_csv)
new_section = io.StringIO("new,data\n")
csv_file["section1"] = new_section

csv_file.flush()
simple_csv.seek(0)
expected_content = "[section1]\nnew,data\n[section2]\nd,e,f\n4,5,6\n"
assert simple_csv.read() == expected_content


def test_del_non_existent_section(simple_csv):
csv_file = MultiCSVFile(simple_csv)
with pytest.raises(SectionNotFound):
del csv_file["section9"]


def test_multiple_writes_with_flush(simple_csv):
csv_file = MultiCSVFile(simple_csv)

section3 = io.StringIO("x,y,z\n10,11,12\n")
csv_file["section3"] = section3
csv_file.flush()

section4 = io.StringIO("p,q,r\n13,14,15\n")
csv_file["section4"] = section4
csv_file.flush()

simple_csv.seek(0)
expected_content = "[section1]\na,b,c\n1,2,3\n[section2]\nd,e,f\n4,5,6\n[section3]\nx,y,z\n10,11,12\n[section4]\np,q,r\n13,14,15\n"
assert simple_csv.read() == expected_content


def test_reading_after_writing(simple_csv):
csv_file = MultiCSVFile(simple_csv)

section3 = io.StringIO("x,y,z\n10,11,12\n")
csv_file["section3"] = section3
csv_file.flush()

section = csv_file["section3"]
section.seek(0)
assert section.read() == "x,y,z\n10,11,12\n"


def test_iterates_zero_length_multicsvfile(empty_csv):
csv_file = MultiCSVFile(empty_csv)
assert list(iter(csv_file)) == []


def test_section_not_found_for_deleted_section(simple_csv):
csv_file = MultiCSVFile(simple_csv)

assert csv_file["section1"]
del csv_file["section1"]

with pytest.raises(SectionNotFound):
csv_file["section1"]


@pytest.mark.parametrize("initial_content, expected_sections", [
("[first_section]\na,b,c\n1,2,3\n[second_section]\nd,e,f\n4,5,6\n",
["first_section", "second_section"]),
("", []),
("[lonely_section]\ng,h,i\n7,8,9\n", ["lonely_section"]),
])
def test_various_initial_contents(initial_content, expected_sections):
file = io.StringIO(initial_content)
csv_file = MultiCSVFile(file)
assert list(iter(csv_file)) == expected_sections

0 comments on commit 9709660

Please sign in to comment.