From 000c559b68d44c173f8b04a81f2a19072394362c Mon Sep 17 00:00:00 2001 From: Vitaliy Mysak Date: Tue, 6 Aug 2024 15:31:21 -0700 Subject: [PATCH] Optimize handling of base TextIO --- src/multicsv/subtextio.py | 89 ++++++++++++++++++++++++++------------- tests/test_subtextio.py | 10 ++--- 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/src/multicsv/subtextio.py b/src/multicsv/subtextio.py index f9e0851..2eecd53 100644 --- a/src/multicsv/subtextio.py +++ b/src/multicsv/subtextio.py @@ -93,10 +93,10 @@ class allows for convenient and isolated operations within a given Caveats: -------- - - Ensure that the range specified by `start` and `end` does not - overlap unintended sections of the text. - - Be cautious of buffer size and memory usage when working with - very large files. + - Writing to and reading from the base TextIO when it is used in SubTextIO + can lead to unexpected results. + - SubTextIO loads the subsection into memory. Thus be cautious + of buffer size when working with very large files. - Always ensure to call `flush` or use context management to commit changes back to the base TextIO. """ @@ -110,12 +110,34 @@ def __init__(self, base_io: TextIO, start: int, end: int): self.start = start self.end = end self.position = 0 # Position within the SubTextIO - - # Load the relevant part of the file into the buffer - self.base_io.seek(self.start) - self._buffer = self.base_io.read(self.end - self.start) - self.length = len(self._buffer) self._closed = base_io.closed + self._load() + self.initial_length = self.buffer_length + + def _load(self) -> None: + """ + Load the relevant part of the base_io into the buffer. + """ + + base_initial_position = self.base_io.tell() + try: + if self.end > self.start: + self.base_io.seek(self.start) + self._buffer = self.base_io.read(self.end - self.start) + base_final_position = self.base_io.tell() + self.base_io.seek(0, os.SEEK_END) + self.is_at_end = base_final_position == self.base_io.tell() + else: + self._buffer = "" + base_final_position = base_initial_position + self.base_io.seek(0, os.SEEK_END) + self.is_at_end = base_final_position == self.base_io.tell() + finally: + self.base_io.seek(base_initial_position) + + @property + def buffer_length(self) -> int: + return len(self._buffer) @property def mode(self) -> str: @@ -133,8 +155,8 @@ def read(self, size: int = -1) -> str: if self._closed: raise OpOnClosedError("I/O operation on closed file.") - if size < 0 or size > self.length - self.position: - size = self.length - self.position + if size < 0 or size > self.buffer_length - self.position: + size = self.buffer_length - self.position result = self._buffer[self.position:self.position + size] self.position += len(result) @@ -144,12 +166,12 @@ def readline(self, limit: int = -1) -> str: if self._closed: raise OpOnClosedError("I/O operation on closed file.") - if self.position >= self.length: + if self.position >= self.buffer_length: return '' newline_pos = self._buffer.find('\n', self.position) - if newline_pos == -1 or newline_pos >= self.length: - newline_pos = self.length + if newline_pos == -1 or newline_pos >= self.buffer_length: + newline_pos = self.buffer_length if limit < 0 or limit > newline_pos - self.position: limit = newline_pos - self.position + 1 @@ -196,7 +218,6 @@ def write(self, s: str) -> int: written = len(s) self.position += written - self.length = max(self.position, self.length) return written def writelines(self, lines: Iterable[str]) -> None: @@ -213,13 +234,14 @@ def truncate(self, size: Optional[int] = None) -> int: end = size self._buffer = self._buffer[:end] - self.length = len(self._buffer) - return self.length + return self.buffer_length def close(self) -> None: if not self._closed: - self.flush() - self._closed = True + try: + self.flush() + finally: + self._closed = True def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: if self._closed: @@ -230,12 +252,12 @@ def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: elif whence == os.SEEK_CUR: # Relative to current position target = self.position + offset elif whence == os.SEEK_END: # Relative to the end - target = self.length + offset + target = self.buffer_length + offset else: raise InvalidWhenceError( f"Invalid value for whence: {repr(whence)}") - self.position = max(0, min(target, self.length)) + self.position = max(0, min(target, self.buffer_length)) return self.position def tell(self) -> int: @@ -246,14 +268,21 @@ def tell(self) -> int: def flush(self) -> None: if not self._closed: - self.base_io.seek(0) - content_before = self.base_io.read(self.start) - self.base_io.seek(self.end) - content_after = self.base_io.read() - - self.base_io.seek(0) - self.base_io.write(content_before + self._buffer + content_after) - self.base_io.flush() + base_initial_position = self.base_io.tell() + try: + if self.buffer_length == self.initial_length or self.is_at_end: + self.base_io.seek(self.start) + self.base_io.write(self._buffer) + self.base_io.flush() + else: + self.base_io.seek(self.end) + content_after = self.base_io.read() + + self.base_io.seek(self.start) + self.base_io.write(self._buffer + content_after) + self.base_io.flush() + finally: + self.base_io.seek(base_initial_position) def isatty(self) -> bool: return False @@ -274,7 +303,7 @@ def __iter__(self) -> 'SubTextIO': return self def __next__(self) -> str: - if self.position < self.length: + if self.position < self.buffer_length: return self.readline() else: raise StopIteration diff --git a/tests/test_subtextio.py b/tests/test_subtextio.py index 1aebe5e..70eb95d 100644 --- a/tests/test_subtextio.py +++ b/tests/test_subtextio.py @@ -115,13 +115,13 @@ def test_truncate_with_arg(base_textio): """ def test_truncate_past_end(base_textio): + initial_content = base_textio.read() sub_text = SubTextIO(base_textio, start=6, end=21) sub_text.truncate(999) - assert base_textio.read() == """\ -a -test -""" - assert sub_text.length == 15 == 21 - 6 + assert sub_text.buffer_length == 15 == 21 - 6 + + base_textio.seek(0) + assert base_textio.read() == initial_content def test_seek_tell(base_textio): sub_text = SubTextIO(base_textio, start=6, end=21)