Skip to content

Commit

Permalink
Optimize handling of base TextIO
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Aug 6, 2024
1 parent e212c06 commit 000c559
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 35 deletions.
89 changes: 59 additions & 30 deletions src/multicsv/subtextio.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ class allows for convenient and isolated operations within a given
Caveats:
--------
- Ensure that the range specified by `start` and `end` does not
overlap unintended sections of the text.
- Be cautious of buffer size and memory usage when working with
very large files.
- Writing to and reading from the base TextIO when it is used in SubTextIO
can lead to unexpected results.
- SubTextIO loads the subsection into memory. Thus be cautious
of buffer size when working with very large files.
- Always ensure to call `flush` or use context management to
commit changes back to the base TextIO.
"""
Expand All @@ -110,12 +110,34 @@ def __init__(self, base_io: TextIO, start: int, end: int):
self.start = start
self.end = end
self.position = 0 # Position within the SubTextIO

# Load the relevant part of the file into the buffer
self.base_io.seek(self.start)
self._buffer = self.base_io.read(self.end - self.start)
self.length = len(self._buffer)
self._closed = base_io.closed
self._load()
self.initial_length = self.buffer_length

def _load(self) -> None:
"""
Load the relevant part of the base_io into the buffer.
"""

base_initial_position = self.base_io.tell()
try:
if self.end > self.start:
self.base_io.seek(self.start)
self._buffer = self.base_io.read(self.end - self.start)
base_final_position = self.base_io.tell()
self.base_io.seek(0, os.SEEK_END)
self.is_at_end = base_final_position == self.base_io.tell()
else:
self._buffer = ""
base_final_position = base_initial_position
self.base_io.seek(0, os.SEEK_END)
self.is_at_end = base_final_position == self.base_io.tell()
finally:
self.base_io.seek(base_initial_position)

@property
def buffer_length(self) -> int:
return len(self._buffer)

@property
def mode(self) -> str:
Expand All @@ -133,8 +155,8 @@ def read(self, size: int = -1) -> str:
if self._closed:
raise OpOnClosedError("I/O operation on closed file.")

if size < 0 or size > self.length - self.position:
size = self.length - self.position
if size < 0 or size > self.buffer_length - self.position:
size = self.buffer_length - self.position

result = self._buffer[self.position:self.position + size]
self.position += len(result)
Expand All @@ -144,12 +166,12 @@ def readline(self, limit: int = -1) -> str:
if self._closed:
raise OpOnClosedError("I/O operation on closed file.")

if self.position >= self.length:
if self.position >= self.buffer_length:
return ''

newline_pos = self._buffer.find('\n', self.position)
if newline_pos == -1 or newline_pos >= self.length:
newline_pos = self.length
if newline_pos == -1 or newline_pos >= self.buffer_length:
newline_pos = self.buffer_length

if limit < 0 or limit > newline_pos - self.position:
limit = newline_pos - self.position + 1
Expand Down Expand Up @@ -196,7 +218,6 @@ def write(self, s: str) -> int:
written = len(s)
self.position += written

self.length = max(self.position, self.length)
return written

def writelines(self, lines: Iterable[str]) -> None:
Expand All @@ -213,13 +234,14 @@ def truncate(self, size: Optional[int] = None) -> int:
end = size

self._buffer = self._buffer[:end]
self.length = len(self._buffer)
return self.length
return self.buffer_length

def close(self) -> None:
if not self._closed:
self.flush()
self._closed = True
try:
self.flush()
finally:
self._closed = True

def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
if self._closed:
Expand All @@ -230,12 +252,12 @@ def seek(self, offset: int, whence: int = os.SEEK_SET) -> int:
elif whence == os.SEEK_CUR: # Relative to current position
target = self.position + offset
elif whence == os.SEEK_END: # Relative to the end
target = self.length + offset
target = self.buffer_length + offset
else:
raise InvalidWhenceError(
f"Invalid value for whence: {repr(whence)}")

self.position = max(0, min(target, self.length))
self.position = max(0, min(target, self.buffer_length))
return self.position

def tell(self) -> int:
Expand All @@ -246,14 +268,21 @@ def tell(self) -> int:

def flush(self) -> None:
if not self._closed:
self.base_io.seek(0)
content_before = self.base_io.read(self.start)
self.base_io.seek(self.end)
content_after = self.base_io.read()

self.base_io.seek(0)
self.base_io.write(content_before + self._buffer + content_after)
self.base_io.flush()
base_initial_position = self.base_io.tell()
try:
if self.buffer_length == self.initial_length or self.is_at_end:
self.base_io.seek(self.start)
self.base_io.write(self._buffer)
self.base_io.flush()
else:
self.base_io.seek(self.end)
content_after = self.base_io.read()

self.base_io.seek(self.start)
self.base_io.write(self._buffer + content_after)
self.base_io.flush()
finally:
self.base_io.seek(base_initial_position)

def isatty(self) -> bool:
return False
Expand All @@ -274,7 +303,7 @@ def __iter__(self) -> 'SubTextIO':
return self

def __next__(self) -> str:
if self.position < self.length:
if self.position < self.buffer_length:
return self.readline()
else:
raise StopIteration
Expand Down
10 changes: 5 additions & 5 deletions tests/test_subtextio.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def test_truncate_with_arg(base_textio):
"""

def test_truncate_past_end(base_textio):
initial_content = base_textio.read()
sub_text = SubTextIO(base_textio, start=6, end=21)
sub_text.truncate(999)
assert base_textio.read() == """\
a
test
"""
assert sub_text.length == 15 == 21 - 6
assert sub_text.buffer_length == 15 == 21 - 6

base_textio.seek(0)
assert base_textio.read() == initial_content

def test_seek_tell(base_textio):
sub_text = SubTextIO(base_textio, start=6, end=21)
Expand Down

0 comments on commit 000c559

Please sign in to comment.