diff --git a/src/db.rs b/src/db.rs index 2ad75184..7c12dc3b 100644 --- a/src/db.rs +++ b/src/db.rs @@ -10,7 +10,6 @@ use crate::{ReadTransaction, Result, WriteTransaction}; use std::fmt::{Debug, Display, Formatter}; use std::fs::{File, OpenOptions}; -use std::io::ErrorKind; use std::marker::PhantomData; use std::ops::RangeFull; use std::path::Path; @@ -692,6 +691,7 @@ impl Database { fn new( file: Box, + allow_initialize: bool, page_size: usize, region_size: Option, read_cache_size_bytes: usize, @@ -704,6 +704,7 @@ impl Database { info!("Opening database {:?}", &file_path); let mem = TransactionalMemory::new( file, + allow_initialize, page_size, region_size, read_cache_size_bytes, @@ -1009,6 +1010,7 @@ impl Builder { Database::new( Box::new(FileBackend::new(file)?), + true, self.page_size, self.region_size, self.read_cache_size_bytes, @@ -1021,12 +1023,9 @@ impl Builder { pub fn open(&self, path: impl AsRef) -> Result { let file = OpenOptions::new().read(true).write(true).open(path)?; - if file.metadata()?.len() == 0 { - return Err(StorageError::Io(ErrorKind::InvalidData.into()).into()); - } - Database::new( Box::new(FileBackend::new(file)?), + false, self.page_size, None, self.read_cache_size_bytes, @@ -1041,6 +1040,7 @@ impl Builder { pub fn create_file(&self, file: File) -> Result { Database::new( Box::new(FileBackend::new(file)?), + true, self.page_size, self.region_size, self.read_cache_size_bytes, @@ -1056,6 +1056,7 @@ impl Builder { ) -> Result { Database::new( Box::new(backend), + true, self.page_size, self.region_size, self.read_cache_size_bytes, diff --git a/src/tree_store/page_store/header.rs b/src/tree_store/page_store/header.rs index 26b96489..891f75e8 100644 --- a/src/tree_store/page_store/header.rs +++ b/src/tree_store/page_store/header.rs @@ -508,6 +508,7 @@ mod test { assert!(TransactionalMemory::new( Box::new(FileBackend::new(file).unwrap()), + false, PAGE_SIZE, None, 0, @@ -597,6 +598,7 @@ mod test { assert!(TransactionalMemory::new( Box::new(FileBackend::new(file).unwrap()), + false, PAGE_SIZE, None, 0, @@ -631,6 +633,7 @@ mod test { assert!(TransactionalMemory::new( Box::new(FileBackend::new(file).unwrap()), + false, PAGE_SIZE, None, 0, @@ -687,6 +690,7 @@ mod test { assert!(TransactionalMemory::new( Box::new(FileBackend::new(file).unwrap()), + false, PAGE_SIZE, None, 0, diff --git a/src/tree_store/page_store/page_manager.rs b/src/tree_store/page_store/page_manager.rs index 4ce0ad94..abb02f70 100644 --- a/src/tree_store/page_store/page_manager.rs +++ b/src/tree_store/page_store/page_manager.rs @@ -18,6 +18,7 @@ use std::cmp::{max, min}; use std::collections::HashMap; use std::collections::HashSet; use std::convert::TryInto; +use std::io::ErrorKind; use std::sync::atomic::{AtomicBool, Ordering}; #[cfg(debug_assertions)] use std::sync::Arc; @@ -107,6 +108,8 @@ impl TransactionalMemory { #[allow(clippy::too_many_arguments)] pub(crate) fn new( file: Box, + // Allow initializing a new database in an empty file + allow_initialize: bool, page_size: usize, requested_region_size: Option, read_cache_size_bytes: usize, @@ -128,8 +131,10 @@ impl TransactionalMemory { write_cache_size_bytes, )?; + let initial_storage_len = storage.raw_file_len()?; + let magic_number: [u8; MAGICNUMBER.len()] = - if storage.raw_file_len()? >= MAGICNUMBER.len() as u64 { + if initial_storage_len >= MAGICNUMBER.len() as u64 { storage .read_direct(0, MAGICNUMBER.len())? .try_into() @@ -138,6 +143,18 @@ impl TransactionalMemory { [0; MAGICNUMBER.len()] }; + if initial_storage_len > 0 { + // File already exists check that the magic number matches + if magic_number != MAGICNUMBER { + return Err(StorageError::Io(ErrorKind::InvalidData.into()).into()); + } + } else { + // File is empty, check that we're allowed to initialize a new database (i.e. the caller is Database::create() and not open()) + if !allow_initialize { + return Err(StorageError::Io(ErrorKind::InvalidData.into()).into()); + } + } + if magic_number != MAGICNUMBER { let region_tracker_required_bytes = RegionTracker::new(INITIAL_REGIONS, MAX_MAX_PAGE_ORDER + 1) diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index be811f05..ad24994e 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -9,7 +9,7 @@ use redb::{ use redb::{DatabaseError, ReadableMultimapTable, SavepointError, StorageError, TableError}; use std::borrow::Borrow; use std::fs; -use std::io::ErrorKind; +use std::io::{ErrorKind, Write}; use std::marker::PhantomData; use std::ops::RangeBounds; use std::sync::atomic::{AtomicBool, Ordering}; @@ -1525,6 +1525,25 @@ fn does_not_exist() { } } +#[test] +fn invalid_database_file() { + let mut tmpfile = create_tempfile(); + tmpfile.write_all(b"hi").unwrap(); + let result = Database::open(tmpfile.path()); + if let Err(DatabaseError::Storage(StorageError::Io(e))) = result { + assert!(matches!(e.kind(), ErrorKind::InvalidData)); + } else { + panic!(); + } + + let result = Database::create(tmpfile.path()); + if let Err(DatabaseError::Storage(StorageError::Io(e))) = result { + assert!(matches!(e.kind(), ErrorKind::InvalidData)); + } else { + panic!(); + } +} + #[test] fn wrong_types() { let tmpfile = create_tempfile();