diff --git a/basic_test.go b/basic_test.go index a2cb88b..505afaf 100644 --- a/basic_test.go +++ b/basic_test.go @@ -3,6 +3,7 @@ package diskv import ( "bytes" "errors" + "io" "math/rand" "regexp" "strings" @@ -428,3 +429,60 @@ func TestHybridStore(t *testing.T) { } } + +// Make sure that temporary files used for atomic writes never +// show up in the key listing +func TestIgnoreAtomicTempFiles(t *testing.T) { + var ( + basePath = "test-data" + ) + // Simplest transform function: put all the data files into the base dir. + flatTransform := func(s string) []string { return []string{} } + + // Initialize a new diskv store, rooted at "my-data-dir", + // with no cache. + d := New(Options{ + BasePath: basePath, + Transform: flatTransform, + CacheSizeMax: 0, + }) + + // Write something in so everything is set up + d.Write("foo", []byte("bar")) + + // Start to write an entry using a stream, but do not + // put anything into it yet! + key := "key1" + data := make([]byte, 1024*1024) + rand.Read(data) + + // Get a pipe + rdr, wtr := io.Pipe() + + // Start the write + go d.WriteStream(key, rdr, true) + + // Now list keys: there should be 1 key. + keys := d.Keys(nil) + var count int + for _ = range keys { + count++ + } + if count != 1 { + t.Fatalf("Expected 1 key, got %d", count) + } + + // Now complete the write + wtr.Write(data) + wtr.Close() + + // And make sure we see exactly two keys + keys = d.Keys(nil) + for _ = range keys { + count++ + } + if count != 2 { + t.Fatalf("Expected 2 keys, got %d", count) + } + d.EraseAll() +} \ No newline at end of file diff --git a/diskv.go b/diskv.go index 9f07b85..46241a6 100644 --- a/diskv.go +++ b/diskv.go @@ -9,17 +9,21 @@ import ( "fmt" "io" "io/ioutil" + "math/rand" "os" "path/filepath" "strings" "sync" "syscall" + "time" ) const ( defaultBasePath = "diskv" defaultFilePerm os.FileMode = 0666 defaultPathPerm os.FileMode = 0777 + + DefaultAtomicPrefix = ".diskv_atomic_temp" ) // PathKey represents a string key that has been transformed to @@ -76,12 +80,12 @@ type Options struct { CacheSizeMax uint64 // bytes PathPerm os.FileMode FilePerm os.FileMode - // If TempDir is set, it will enable filesystem atomic writes by - // writing temporary files to that location before being moved - // to BasePath. - // Note that TempDir MUST be on the same device/partition as - // BasePath. + // Note: TempDir is deprecated, all writes are now atomic. TempDir string + // AtomicPrefix sets the name of a directory which will be created + // within BasePath to store temporary files for atomic writes. + // It defaults to DefaultAtomicPrefix; you probably don't need to change it. + AtomicPrefix string Index Index IndexLess LessFunction @@ -96,6 +100,7 @@ type Diskv struct { mu sync.RWMutex cache map[string][]byte cacheSize uint64 + rnd *rand.Rand } // New returns an initialized Diskv structure, ready to use. @@ -105,6 +110,9 @@ func New(o Options) *Diskv { if o.BasePath == "" { o.BasePath = defaultBasePath } + if o.AtomicPrefix == "" { + o.AtomicPrefix = DefaultAtomicPrefix + } if o.AdvancedTransform == nil { if o.Transform == nil { @@ -132,12 +140,18 @@ func New(o Options) *Diskv { Options: o, cache: map[string][]byte{}, cacheSize: 0, + rnd: rand.New(rand.NewSource(time.Now().UnixNano())), } if d.Index != nil && d.IndexLess != nil { d.Index.Initialize(d.IndexLess, d.Keys(nil)) } + // Just in case there were any failures during writes previously, we + // remove the atomic write directory (and any temp files within it). + // The directory will be created the first time we do a Write. + os.RemoveAll(d.atomicTempPath()) + return d } @@ -196,41 +210,19 @@ func (d *Diskv) WriteStream(key string, r io.Reader, sync bool) error { return d.writeStreamWithLock(pathKey, r, sync) } -// createKeyFileWithLock either creates the key file directly, or -// creates a temporary file in TempDir if it is set. -func (d *Diskv) createKeyFileWithLock(pathKey *PathKey) (*os.File, error) { - if d.TempDir != "" { - if err := os.MkdirAll(d.TempDir, d.PathPerm); err != nil { - return nil, fmt.Errorf("temp mkdir: %s", err) - } - f, err := ioutil.TempFile(d.TempDir, "") - if err != nil { - return nil, fmt.Errorf("temp file: %s", err) - } - - if err := os.Chmod(f.Name(), d.FilePerm); err != nil { - f.Close() // error deliberately ignored - os.Remove(f.Name()) // error deliberately ignored - return nil, fmt.Errorf("chmod: %s", err) - } - return f, nil - } - - mode := os.O_WRONLY | os.O_CREATE | os.O_TRUNC // overwrite if exists - f, err := os.OpenFile(d.completeFilename(pathKey), mode, d.FilePerm) - if err != nil { - return nil, fmt.Errorf("open file: %s", err) - } - return f, nil -} - // writeStream does no input validation checking. func (d *Diskv) writeStreamWithLock(pathKey *PathKey, r io.Reader, sync bool) error { + // fullPath is the on-disk location of the key + fullPath := d.completeFilename(pathKey) + if err := d.ensurePathWithLock(pathKey); err != nil { return fmt.Errorf("ensure path: %s", err) } - f, err := d.createKeyFileWithLock(pathKey) + // Get a temporary file we can write to. + // We'll move it when we're all done. + d.ensureAtomicTempDir() + f, err := ioutil.TempFile(d.atomicTempPath(), pathKey.FileName) if err != nil { return fmt.Errorf("create key file: %s", err) } @@ -269,12 +261,10 @@ func (d *Diskv) writeStreamWithLock(pathKey *PathKey, r io.Reader, sync bool) er return fmt.Errorf("file close: %s", err) } - fullPath := d.completeFilename(pathKey) - if f.Name() != fullPath { - if err := os.Rename(f.Name(), fullPath); err != nil { - os.Remove(f.Name()) // error deliberately ignored - return fmt.Errorf("rename: %s", err) - } + // Move the temporary file to the final location. + if err := os.Rename(f.Name(), fullPath); err != nil { + os.Remove(f.Name()) // error deliberately ignored + return fmt.Errorf("rename: %s", err) } if d.Index != nil { @@ -596,7 +586,7 @@ func (d *Diskv) walker(c chan<- string, prefix string, cancel <-chan struct{}) f key := d.InverseTransform(pathKey) - if info.IsDir() || !strings.HasPrefix(key, prefix) { + if info.IsDir() || !strings.HasPrefix(key, prefix) || strings.HasPrefix(dir, d.AtomicPrefix) { return nil // "pass" } @@ -627,6 +617,14 @@ func (d *Diskv) completeFilename(pathKey *PathKey) string { return filepath.Join(d.pathFor(pathKey), pathKey.FileName) } +func (d *Diskv) ensureAtomicTempDir() error { + return os.MkdirAll(d.atomicTempPath(), d.PathPerm) +} + +func (d *Diskv) atomicTempPath() string { + return filepath.Join(d.BasePath, d.AtomicPrefix) +} + // cacheWithLock attempts to cache the given key-value pair in the store's // cache. It can fail if the value is larger than the cache's maximum size. func (d *Diskv) cacheWithLock(key string, val []byte) error { diff --git a/issues_test.go b/issues_test.go index 5bf9e35..db5fe7a 100644 --- a/issues_test.go +++ b/issues_test.go @@ -189,3 +189,55 @@ func TestIssue40(t *testing.T) { // is no room in the cache for this entry and it panics. d.Read(k2) } + +// Test issue #63, where a reader obtained from ReadStream will start +// to return invalid data if WriteStream is called before you finish +// reading. +func TestIssue63(t *testing.T) { + var ( + basePath = "test-data" + ) + // Simplest transform function: put all the data files into the base dir. + flatTransform := func(s string) []string { return []string{} } + + // Initialize a new diskv store, rooted at "my-data-dir", + // with no cache. + d := New(Options{ + BasePath: basePath, + Transform: flatTransform, + CacheSizeMax: 0, + }) + + defer d.EraseAll() + + // Write a big entry + k1 := "key1" + d1 := make([]byte, 1024*1024) + rand.Read(d1) + d.Write(k1, d1) + + // Open a reader. We set the direct flag to be sure we're going straight to disk. + s1, err := d.ReadStream(k1, true) + if err != nil { + t.Fatal(err) + } + + // Now generate a second big entry and put it in the *same* key + d2 := make([]byte, 1024*1024) + rand.Read(d2) + d.Write(k1, d2) + + // Now read from that stream we opened + out, err := ioutil.ReadAll(s1) + if err != nil { + t.Fatal(err) + } + if len(out) != len(d1) { + t.Fatalf("Invalid read: got %v bytes expected %v\n", len(out), len(d1)) + } + for i := range out { + if out[i] != d1[i] { + t.Fatalf("Output differs from expected at byte %v", i) + } + } +}