Skip to content

Commit

Permalink
fix end_of_stream detection: peek() based detection requires extra by…
Browse files Browse the repository at this point in the history
…tes from the io.Reader. Use buffered bytes and file sizes instead.
  • Loading branch information
ichiban committed Aug 11, 2023
1 parent b77736e commit d67ceac
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 63 deletions.
7 changes: 1 addition & 6 deletions engine/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,7 @@ func Open(vm *VM, sourceSink, mode, stream, options Term, k Cont, env *Env) *Pro
case err == nil:
if s.mode == ioModeRead {
s.source = f
s.initRead()
} else {
s.sink = f
}
Expand All @@ -1290,12 +1291,6 @@ func Open(vm *VM, sourceSink, mode, stream, options Term, k Cont, env *Env) *Pro
return Error(err)
}

if s.mode == ioModeRead {
if err := s.initRead(); err == nil {
s.checkEOS()
}
}

return Unify(vm, stream, &s, k, env)
}

Expand Down
9 changes: 5 additions & 4 deletions engine/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3295,6 +3295,7 @@ func TestOpen(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, l, s)

assert.NoError(t, s.initRead())
b, err := io.ReadAll(s.buf)
assert.NoError(t, err)
assert.Equal(t, "test\n", string(b))
Expand Down Expand Up @@ -4767,7 +4768,7 @@ func TestGetByte(t *testing.T) {

t.Run("error", func(t *testing.T) {
var m mockReader
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Twice()
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once()
defer m.AssertExpectations(t)

s := &Stream{source: &m, mode: ioModeRead, streamType: streamTypeBinary}
Expand Down Expand Up @@ -4925,7 +4926,7 @@ func TestGetChar(t *testing.T) {

t.Run("error", func(t *testing.T) {
var m mockReader
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Times(2)
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once()
defer m.AssertExpectations(t)

v := NewVariable()
Expand Down Expand Up @@ -5095,7 +5096,7 @@ func TestPeekByte(t *testing.T) {

t.Run("error", func(t *testing.T) {
var m mockReader
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Twice()
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once()
defer m.AssertExpectations(t)

s := &Stream{source: &m, mode: ioModeRead}
Expand Down Expand Up @@ -5255,7 +5256,7 @@ func TestPeekChar(t *testing.T) {

t.Run("error", func(t *testing.T) {
var m mockReader
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Twice()
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once()
defer m.AssertExpectations(t)

v := NewVariable()
Expand Down
116 changes: 76 additions & 40 deletions engine/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"unsafe"
)
Expand All @@ -14,7 +15,6 @@ var (
errWrongStreamType = errors.New("wrong stream type")
errPastEndOfStream = errors.New("past end of stream")
errReposition = errors.New("reposition")
errNotSupported = errors.New("not supported")
)

// Stream is a prolog stream.
Expand All @@ -23,7 +23,7 @@ type Stream struct {

source io.Reader
sink io.Writer
buf *bufio.Reader
buf bufReader
lastRuneSize int

mode ioMode
Expand Down Expand Up @@ -127,21 +127,11 @@ func (s *Stream) ReadByte() (byte, error) {
return 0, errWrongStreamType
}

// After reading a byte, we might be at the end of stream.
bs, err := s.buf.Peek(2)

b, err := s.buf.ReadByte()
if err == nil {
s.position += 1
}
switch len(bs) {
case 2:
s.endOfStream = endOfStreamNot
case 1:
s.endOfStream = endOfStreamAt
case 0:
s.endOfStream = endOfStreamPast
}
s.checkEOS(err)
return b, err
}

Expand Down Expand Up @@ -173,20 +163,10 @@ func (s *Stream) ReadRune() (r rune, size int, err error) {
return 0, 0, errWrongStreamType
}

// After reading a rune, we might be at the end of stream.
b, _ := s.buf.Peek(5) // A rune is 1~4 bytes.

r, n, err := s.buf.ReadRune()
s.position += int64(n)
s.lastRuneSize = n
switch {
case n == 0:
s.endOfStream = endOfStreamPast
case n < len(b):
s.endOfStream = endOfStreamNot
case n == len(b):
s.endOfStream = endOfStreamAt
}
s.checkEOS(err)
return r, n, err
}

Expand Down Expand Up @@ -228,11 +208,7 @@ func (s *Stream) Seek(offset int64, whence int) (int64, error) {
}

s.position = n

if r, ok := sk.(io.Reader); ok && s.buf != nil {
s.buf.Reset(r)
s.checkEOS()
}
s.reset()

return n, nil
}
Expand Down Expand Up @@ -306,39 +282,67 @@ func (s *Stream) Close() error {
}

func (s *Stream) initRead() error {
if s.buf == nil {
s.buf = bufio.NewReader(s.source)
}

if s.mode != ioModeRead {
return errWrongIOMode
}

if s.buf == (bufReader{}) {
s.buf = newBufReader(s.source)
}

if s.endOfStream == endOfStreamPast {
switch s.eofAction {
case eofActionError:
return errPastEndOfStream
case eofActionReset:
_, err := s.Seek(0, io.SeekStart)
return err
s.reset()
}
}

return nil
}

func (s *Stream) checkEOS() {
b, _ := s.buf.Peek(2)
switch len(b) {
case 0:
func (s *Stream) reset() {
if s.mode != ioModeRead {
return
}

Check warning on line 308 in engine/stream.go

View check run for this annotation

Codecov / codecov/patch

engine/stream.go#L307-L308

Added lines #L307 - L308 were not covered by tests

s.buf = newBufReader(s.source)
s.endOfStream = endOfStreamNot
}

func (s *Stream) checkEOS(err error) {
// After reading, we might be at the end of stream.
switch b := s.buf.Buffered(); {
case errors.Is(err, io.EOF):
s.endOfStream = endOfStreamPast
case 1:
case b == 0 && errors.Is(s.buf.ReadErr(), io.EOF):
// io.Reader may return io.EOF at the very last read with a non-zero number of bytes.
// In that case, we can say we're at the end of stream after consuming all the buffered bytes.
s.endOfStream = endOfStreamAt
case b == 0 && s.position == fileSize(s.source):
// If the position equals to the file size after consuming all the buffered bytes,
// we can say we're at the end of stream.
s.endOfStream = endOfStreamAt
default:
// At least one byte is buffered or the underlying io.Reader hasn't reported io.EOF yet.
// io.Reader may surprise us with `0, io.EOF`. In that case, we fail to detect the end of stream.
s.endOfStream = endOfStreamNot
}
}

func fileSize(r io.Reader) int64 {
f, ok := r.(fs.File)
if !ok {
return -1
}
fi, err := f.Stat()
if err != nil {
return -1
}

Check warning on line 342 in engine/stream.go

View check run for this annotation

Codecov / codecov/patch

engine/stream.go#L341-L342

Added lines #L341 - L342 were not covered by tests
return fi.Size()
}

func (s *Stream) properties() []Term {
ps := make([]Term, 0, 9)

Expand Down Expand Up @@ -535,3 +539,35 @@ func (ss *streams) lookup(a Atom) (*Stream, bool) {
s, ok := ss.aliases[a]
return s, ok
}

// bufReader is a wrapper around *bufio.Reader.
// *bufio.Reader doesn't tell us if the underlying io.Reader returned an error.
// We need to know this to determine end_of_stream.
type bufReader struct {
*bufio.Reader
er *errReader
}

func newBufReader(r io.Reader) bufReader {
er := errReader{r: r}
return bufReader{
Reader: bufio.NewReader(&er),
er: &er,
}
}

func (b bufReader) ReadErr() error {
return b.er.err
}

type errReader struct {
r io.Reader
err error
}

func (e *errReader) Read(p []byte) (n int, err error) {
defer func() {
e.err = err
}()
return e.r.Read(p)
}
49 changes: 40 additions & 9 deletions engine/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func TestStream_ReadByte(t *testing.T) {
s: &Stream{source: bytes.NewReader([]byte{3}), streamType: streamTypeBinary, position: 2},
b: 3,
pos: 3,
eos: endOfStreamAt,
eos: endOfStreamNot,
},
{
title: "input binary: empty",
Expand Down Expand Up @@ -309,11 +309,27 @@ func TestStream_ReadRune(t *testing.T) {
eos: endOfStreamNot,
},
{
title: "input text: 1 rune left",
title: "input text: 1 rune left, abrupt EOF",
s: &Stream{source: bytes.NewReader([]byte("c")), streamType: streamTypeText, position: 2},
r: 'c',
size: 1,
pos: 3,
eos: endOfStreamNot,
},
{
title: "input text: 1 rune left, non-abrupt EOF",
s: &Stream{source: newNonAbruptReader([]byte("c")), streamType: streamTypeText, position: 2},
r: 'c',
size: 1,
pos: 3,
eos: endOfStreamAt,
},
{
title: "input text: 1 rune left, file",
s: &Stream{source: mustOpen(testdata, "testdata/a.txt"), streamType: streamTypeText, position: 0},
r: 'a',
size: 1,
pos: 1,
eos: endOfStreamAt,
},
{
Expand Down Expand Up @@ -396,7 +412,6 @@ func TestStream_Seek(t *testing.T) {
whence int
pos int64
err error
eos endOfStream
}{
{
title: "ok",
Expand All @@ -411,10 +426,10 @@ func TestStream_Seek(t *testing.T) {
whence: 0,
err: errors.New("ng"),
},
{title: "reader", s: s, offset: 0, whence: 0, pos: 0, eos: endOfStreamNot},
{title: "reader", s: s, offset: 1, whence: 0, pos: 1, eos: endOfStreamNot},
{title: "reader", s: s, offset: 2, whence: 0, pos: 2, eos: endOfStreamAt},
{title: "reader", s: s, offset: 3, whence: 0, pos: 3, eos: endOfStreamPast},
{title: "reader", s: s, offset: 0, whence: 0, pos: 0},
{title: "reader", s: s, offset: 1, whence: 0, pos: 1},
{title: "reader", s: s, offset: 2, whence: 0, pos: 2},
{title: "reader", s: s, offset: 3, whence: 0, pos: 3},
{
title: "not seeker",
s: &Stream{source: &okSeeker.mockReader, reposition: true, position: 123},
Expand All @@ -429,8 +444,6 @@ func TestStream_Seek(t *testing.T) {
pos, err := tt.s.Seek(tt.offset, tt.whence)
assert.Equal(t, tt.pos, pos)
assert.Equal(t, tt.err, err)

assert.Equal(t, tt.eos, tt.s.endOfStream)
})
}
}
Expand Down Expand Up @@ -577,3 +590,21 @@ func TestStream_Flush(t *testing.T) {
assert.NoError(t, s.Flush())
})
}

type nonAbruptReader struct {
*bytes.Reader
}

func newNonAbruptReader(b []byte) nonAbruptReader {
return nonAbruptReader{
Reader: bytes.NewReader(b),
}
}

func (r nonAbruptReader) Read(b []byte) (int, error) {
n, err := r.Reader.Read(b)
if err == nil && r.Reader.Len() == 0 {
err = io.EOF
}
return n, err
}
9 changes: 9 additions & 0 deletions engine/text_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"embed"
"errors"
"io"
"io/fs"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -13,6 +14,14 @@ import (
//go:embed testdata
var testdata embed.FS

func mustOpen(fs fs.FS, name string) fs.File {
f, err := fs.Open(name)
if err != nil {
panic(err)
}
return f
}

func TestVM_Compile(t *testing.T) {
tests := []struct {
title string
Expand Down
Loading

0 comments on commit d67ceac

Please sign in to comment.