diff --git a/src/multicsv/exceptions.py b/src/multicsv/exceptions.py index 9e10148..2aada06 100644 --- a/src/multicsv/exceptions.py +++ b/src/multicsv/exceptions.py @@ -25,7 +25,7 @@ class BaseMustBeReadable(SubTextIOErrror, ValueError): pass -class StartsBeyondBaseContent(SubTextIOErrror, ValueError): +class EndsBeyondBaseContent(SubTextIOErrror, ValueError): pass diff --git a/src/multicsv/subtextio.py b/src/multicsv/subtextio.py index 4a25968..8846340 100644 --- a/src/multicsv/subtextio.py +++ b/src/multicsv/subtextio.py @@ -4,7 +4,7 @@ from .exceptions import OpOnClosedError, \ InvalidWhenceError, InvalidSubtextCoordinates, \ BaseMustBeReadable, BaseMustBeSeakable, \ - StartsBeyondBaseContent, BaseIOClosed + EndsBeyondBaseContent, BaseIOClosed class SubTextIO(TextIO): @@ -142,18 +142,16 @@ def _load(self) -> None: self._base_io.seek(0, os.SEEK_END) base_last_position = self._base_io.tell() - if self.start > base_last_position: - raise StartsBeyondBaseContent( - "Start position is greater than base TextIO length.") + if self.end > base_last_position: + raise EndsBeyondBaseContent( + "End position is greater than base TextIO length.") if self.end > self.start: self._base_io.seek(self.end) base_final_position = self._base_io.tell() self.is_at_end = base_final_position == base_last_position - - if self.end <= base_last_position: - self._base_io.seek(self.start) - self._buffer = self._base_io.read(self.end - self.start) + self._base_io.seek(self.start) + self._buffer = self._base_io.read(self.end - self.start) else: base_final_position = self.start self._base_io.seek(0, os.SEEK_END) diff --git a/tests/test_subtextio.py b/tests/test_subtextio.py index bd93da2..3ba1347 100644 --- a/tests/test_subtextio.py +++ b/tests/test_subtextio.py @@ -4,7 +4,7 @@ import pytest import os from multicsv.subtextio import SubTextIO -from multicsv.exceptions import OpOnClosedError, InvalidWhenceError, InvalidSubtextCoordinates, StartsBeyondBaseContent +from multicsv.exceptions import OpOnClosedError, InvalidWhenceError, InvalidSubtextCoordinates, EndsBeyondBaseContent @pytest.fixture @@ -394,7 +394,10 @@ def test_mode(mode): import tempfile with tempfile.NamedTemporaryFile(mode=mode) as tmp: - if tmp.readable(): + if tmp.writable(): + tmp.truncate(30) + + if tmp.readable() and tmp.writable(): sub_text = SubTextIO(tmp, start=0, end=21) assert sub_text.mode == mode @@ -407,8 +410,8 @@ def test_invalid_range(base_textio): SubTextIO(base_textio, start=15, end=10) def test_invalid_range_past_initial(base_textio): - with pytest.raises(StartsBeyondBaseContent): - SubTextIO(base_textio, start=30, end=40) + with pytest.raises(EndsBeyondBaseContent): + SubTextIO(base_textio, start=5, end=40) def test_no_readable_requirement(): import tempfile