Skip to content

Commit

Permalink
PeekReader, for slightly more tar safety
Browse files Browse the repository at this point in the history
  • Loading branch information
Fallen-Breath committed Dec 8, 2023
1 parent ebfdcba commit 7e05521
Showing 1 changed file with 45 additions and 3 deletions.
48 changes: 45 additions & 3 deletions prime_backup/action/export_backup_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import abstractmethod, ABC
from io import BytesIO
from pathlib import Path
from typing import ContextManager, Optional, List, Tuple
from typing import ContextManager, Optional, List, Tuple, IO, Any

from prime_backup.action import Action
from prime_backup.compressors import Compressor, CompressMethod
Expand Down Expand Up @@ -192,6 +192,41 @@ def _export_backup(self, session, backup: schema.Backup) -> ExportFailures:
return failures


class PeekReader:
def __init__(self, file_obj: IO[bytes], peek_size: int):
self.file_obj = file_obj
self.peek_size = peek_size
self.peek_buf: Optional[bytes] = None
self.peek_buf_idx = 0

def peek(self):
if self.peek_buf is not None:
raise RuntimeError('double peek')
self.peek_buf = self.file_obj.read(self.peek_size)

def read(self, n: int = -1) -> bytes:
if self.peek_buf is None:
raise RuntimeError('read before peek')

if self.peek_buf_idx == len(self.peek_buf):
return self.file_obj.read(n)

if n == -1:
data = self.peek_buf[self.peek_buf_idx:] + self.file_obj.read(n)
self.peek_buf_idx = len(self.peek_buf)
return data
else:
remaining = len(self.peek_buf) - self.peek_buf_idx
if n <= remaining:
data = self.peek_buf[self.peek_buf_idx:self.peek_buf_idx + n]
self.peek_buf_idx += n
return data
else:
data = self.peek_buf[self.peek_buf_idx:] + self.file_obj.read(n - remaining)
self.peek_buf_idx = len(self.peek_buf)
return data


class ExportBackupToTarAction(_ExportBackupActionBase):
def __init__(
self, backup_id: int, output_path: Path, tar_format: TarFormat, *,
Expand Down Expand Up @@ -225,12 +260,19 @@ def __export_file(self, tar: tarfile.TarFile, file: schema.File):
blob_path = blob_utils.get_blob_path(file.blob_hash)

with Compressor.create(file.blob_compress).open_decompressed(blob_path) as stream:
# Exception raised in TarFile.addfile might nuke the whole remaining tar file, which is bad
# We read a few bytes from the stream, to *hopefully* trigger potential decompress exception in advanced,
# make it fail before affecting the actual tar file
peek_reader = PeekReader(stream, 4096)
peek_reader.peek()

if self.verify_blob:
reader = BypassReader(stream, calc_hash=True)
reader = BypassReader(peek_reader, calc_hash=True)
tar.addfile(tarinfo=info, fileobj=reader)
else:
reader = None
tar.addfile(tarinfo=info, fileobj=stream)
peek_reader: Any
tar.addfile(tarinfo=info, fileobj=peek_reader)
if reader is not None:
# notes: the read len is always <= info.size
self._verify_exported_blob(file, reader.get_read_len(), reader.get_hash())
Expand Down

0 comments on commit 7e05521

Please sign in to comment.