Skip to content

Commit

Permalink
Prohibit ending SubTextIO past length of base io
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Aug 7, 2024
1 parent 1043e84 commit b23cb89
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/multicsv/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BaseMustBeReadable(SubTextIOErrror, ValueError):
pass


class StartsBeyondBaseContent(SubTextIOErrror, ValueError):
class EndsBeyondBaseContent(SubTextIOErrror, ValueError):
pass


Expand Down
14 changes: 6 additions & 8 deletions src/multicsv/subtextio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .exceptions import OpOnClosedError, \
InvalidWhenceError, InvalidSubtextCoordinates, \
BaseMustBeReadable, BaseMustBeSeakable, \
StartsBeyondBaseContent, BaseIOClosed
EndsBeyondBaseContent, BaseIOClosed


class SubTextIO(TextIO):
Expand Down Expand Up @@ -142,18 +142,16 @@ def _load(self) -> None:
self._base_io.seek(0, os.SEEK_END)
base_last_position = self._base_io.tell()

if self.start > base_last_position:
raise StartsBeyondBaseContent(
"Start position is greater than base TextIO length.")
if self.end > base_last_position:
raise EndsBeyondBaseContent(
"End position is greater than base TextIO length.")

if self.end > self.start:
self._base_io.seek(self.end)
base_final_position = self._base_io.tell()
self.is_at_end = base_final_position == base_last_position

if self.end <= base_last_position:
self._base_io.seek(self.start)
self._buffer = self._base_io.read(self.end - self.start)
self._base_io.seek(self.start)
self._buffer = self._base_io.read(self.end - self.start)
else:
base_final_position = self.start
self._base_io.seek(0, os.SEEK_END)
Expand Down
11 changes: 7 additions & 4 deletions tests/test_subtextio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import os
from multicsv.subtextio import SubTextIO
from multicsv.exceptions import OpOnClosedError, InvalidWhenceError, InvalidSubtextCoordinates, StartsBeyondBaseContent
from multicsv.exceptions import OpOnClosedError, InvalidWhenceError, InvalidSubtextCoordinates, EndsBeyondBaseContent


@pytest.fixture
Expand Down Expand Up @@ -394,7 +394,10 @@ def test_mode(mode):
import tempfile

with tempfile.NamedTemporaryFile(mode=mode) as tmp:
if tmp.readable():
if tmp.writable():
tmp.truncate(30)

if tmp.readable() and tmp.writable():
sub_text = SubTextIO(tmp, start=0, end=21)
assert sub_text.mode == mode

Expand All @@ -407,8 +410,8 @@ def test_invalid_range(base_textio):
SubTextIO(base_textio, start=15, end=10)

def test_invalid_range_past_initial(base_textio):
with pytest.raises(StartsBeyondBaseContent):
SubTextIO(base_textio, start=30, end=40)
with pytest.raises(EndsBeyondBaseContent):
SubTextIO(base_textio, start=5, end=40)

def test_no_readable_requirement():
import tempfile
Expand Down

0 comments on commit b23cb89

Please sign in to comment.