diff --git a/progress/readcloser.go b/progress/readcloser.go new file mode 100644 index 0000000..20fc37e --- /dev/null +++ b/progress/readcloser.go @@ -0,0 +1,52 @@ +package progress + +import "io" + +// ReadCloser is an [io.ReadCloser] that reports the number of bytes read from +// it via a callback. The callback is called at most every "updateInterval" +// bytes. The updateInterval can be set using the +// [ReadCloser.WithUpdateInterval] method. +// +// The following is an example of how to use [ReadCloser] to report the progress of +// reading from a file: +// +// file, _ := os.Open("file.txt") +// progressReader := NewReadCloser(f, func(readBytes int) { +// fmt.Printf("Read %d bytes\n", readBytes) +// }) +// io.ReadAll(progressReader) +type ReadCloser struct { + inner io.ReadCloser + readBytes int + progressFn func(readBytes int) + lastUpdate int + updateInterval int +} + +func NewReadCloser(r io.ReadCloser, progressFn func(readBytes int)) *ReadCloser { + return &ReadCloser{inner: r, progressFn: progressFn, updateInterval: defaultUpdateInterval} +} + +func (r *ReadCloser) WithUpdateInterval(bytes int) *ReadCloser { + r.updateInterval = bytes + return r +} + +func (r *ReadCloser) Read(p []byte) (n int, err error) { + n, err = r.inner.Read(p) + if err != nil { + return n, err + } + r.readBytes += n + if r.lastUpdate == 0 || r.readBytes-r.lastUpdate > r.updateInterval { + r.progressFn(r.readBytes) + r.lastUpdate = r.readBytes + } + return n, nil +} + +func (r *ReadCloser) Close() error { + return r.inner.Close() +} + +var _ io.ReadCloser = (*ReadCloser)(nil) diff --git a/progress/readcloser_test.go b/progress/readcloser_test.go new file mode 100644 index 0000000..38ac3fb --- /dev/null +++ b/progress/readcloser_test.go @@ -0,0 +1,26 @@ +package progress_test + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "go.mau.fi/util/progress" +) + +func TestReadCloser(t *testing.T) { + readCloser := io.NopCloser(bytes.NewReader(bytes.Repeat([]byte{42}, 1024*1024))) + + var progressUpdates []int + progressReader := progress.NewReadCloser(readCloser, func(readBytes int) { + progressUpdates = append(progressUpdates, readBytes) + }) + + data, err := io.ReadAll(progressReader) + assert.NoError(t, err) + assert.Equal(t, data, bytes.Repeat([]byte{42}, 1024*1024)) + + assert.Greater(t, len(progressUpdates), 1) + assert.IsIncreasing(t, progressUpdates) +}