diff --git a/src/multicsv/file.py b/src/multicsv/file.py index 26ffac3..fb1ef0d 100644 --- a/src/multicsv/file.py +++ b/src/multicsv/file.py @@ -115,6 +115,8 @@ def __init__(self, file: TextIO): self._initialized = True def __getitem__(self, key: str) -> TextIO: + self._check_closed() + for item in self._sections: if item.name == key: return item.descriptor @@ -123,6 +125,8 @@ def __getitem__(self, key: str) -> TextIO: f"have section named {key!r}.") def __setitem__(self, key: str, value: TextIO) -> None: + self._check_closed() + def make_section() -> MultiCSVSection: return MultiCSVSection(name=key, descriptor=value) @@ -134,6 +138,8 @@ def make_section() -> MultiCSVSection: self._sections.append(make_section()) def __delitem__(self, key: str) -> None: + self._check_closed() + found = None for i, item in enumerate(self._sections): if item.name == key: @@ -147,12 +153,16 @@ def __delitem__(self, key: str) -> None: del self._sections[i] def __iter__(self) -> Iterator[str]: + self._check_closed() + return iter(map(lambda x: x.name, self._sections)) def __len__(self) -> int: return len(self._sections) def __contains__(self, key: object) -> bool: + self._check_closed() + for item in self._sections: if item.name == key: return True diff --git a/tests/test_file.py b/tests/test_file.py index a6f15e4..330ec13 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -2,7 +2,7 @@ import pytest from typing import TextIO from multicsv.file import MultiCSVFile -from multicsv.exceptions import SectionNotFound, CSVFileBaseIOClosed +from multicsv.exceptions import SectionNotFound, CSVFileBaseIOClosed, OpOnClosedCSVFileError @pytest.fixture @@ -189,3 +189,26 @@ def test_various_initial_contents(initial_content, expected_sections): file = io.StringIO(initial_content) csv_file = MultiCSVFile(file) assert list(iter(csv_file)) == expected_sections + + +def test_op_on_closed(simple_csv): + csv_file = MultiCSVFile(simple_csv) + + assert csv_file["section1"] + csv_file.close() + + with pytest.raises(OpOnClosedCSVFileError): + csv_file["section1"] + + csv_file.close() + + with pytest.raises(OpOnClosedCSVFileError): + csv_file["section1"] + + +def test_op_on_closed_via_context(simple_csv): + with MultiCSVFile(simple_csv) as csv_file: + assert csv_file["section1"] + + with pytest.raises(OpOnClosedCSVFileError): + csv_file["section1"]