Skip to content

Commit

Permalink
Merge pull request #305 from ichiban/fix-eos-detection
Browse files Browse the repository at this point in the history
fix end_of_stream detection
  • Loading branch information
ichiban authored Aug 11, 2023
2 parents 982dbc1 + 6e740e8 commit c682cd3
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 67 deletions.
7 changes: 1 addition & 6 deletions engine/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,7 @@ func Open(vm *VM, sourceSink, mode, stream, options Term, k Cont, env *Env) *Pro
case err == nil:
if s.mode == ioModeRead {
s.source = f
s.initRead()
} else {
s.sink = f
}
Expand All @@ -1290,12 +1291,6 @@ func Open(vm *VM, sourceSink, mode, stream, options Term, k Cont, env *Env) *Pro
return Error(err)
}

if s.mode == ioModeRead {
if err := s.initRead(); err == nil {
s.checkEOS()
}
}

return Unify(vm, stream, &s, k, env)
}

Expand Down
9 changes: 5 additions & 4 deletions engine/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3295,6 +3295,7 @@ func TestOpen(t *testing.T) {
assert.True(t, ok)
assert.Equal(t, l, s)

assert.NoError(t, s.initRead())
b, err := io.ReadAll(s.buf)
assert.NoError(t, err)
assert.Equal(t, "test\n", string(b))
Expand Down Expand Up @@ -4768,7 +4769,7 @@ func TestGetByte(t *testing.T) {

t.Run("error", func(t *testing.T) {
var m mockReader
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Twice()
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once()
defer m.AssertExpectations(t)

s := &Stream{source: &m, mode: ioModeRead, streamType: streamTypeBinary}
Expand Down Expand Up @@ -4926,7 +4927,7 @@ func TestGetChar(t *testing.T) {

t.Run("error", func(t *testing.T) {
var m mockReader
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Times(2)
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once()
defer m.AssertExpectations(t)

v := NewVariable()
Expand Down Expand Up @@ -5096,7 +5097,7 @@ func TestPeekByte(t *testing.T) {

t.Run("error", func(t *testing.T) {
var m mockReader
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Twice()
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once()
defer m.AssertExpectations(t)

s := &Stream{source: &m, mode: ioModeRead}
Expand Down Expand Up @@ -5256,7 +5257,7 @@ func TestPeekChar(t *testing.T) {

t.Run("error", func(t *testing.T) {
var m mockReader
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Twice()
m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once()
defer m.AssertExpectations(t)

v := NewVariable()
Expand Down
116 changes: 76 additions & 40 deletions engine/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"os"
"unsafe"
)
Expand All @@ -14,7 +15,6 @@ var (
errWrongStreamType = errors.New("wrong stream type")
errPastEndOfStream = errors.New("past end of stream")
errReposition = errors.New("reposition")
errNotSupported = errors.New("not supported")
)

// Stream is a prolog stream.
Expand All @@ -23,7 +23,7 @@ type Stream struct {

source io.Reader
sink io.Writer
buf *bufio.Reader
buf bufReader
lastRuneSize int

mode ioMode
Expand Down Expand Up @@ -127,21 +127,11 @@ func (s *Stream) ReadByte() (byte, error) {
return 0, errWrongStreamType
}

// After reading a byte, we might be at the end of stream.
bs, err := s.buf.Peek(2)

b, err := s.buf.ReadByte()
if err == nil {
s.position += 1
}
switch len(bs) {
case 2:
s.endOfStream = endOfStreamNot
case 1:
s.endOfStream = endOfStreamAt
case 0:
s.endOfStream = endOfStreamPast
}
s.checkEOS(err)
return b, err
}

Expand Down Expand Up @@ -173,20 +163,10 @@ func (s *Stream) ReadRune() (r rune, size int, err error) {
return 0, 0, errWrongStreamType
}

// After reading a rune, we might be at the end of stream.
b, _ := s.buf.Peek(5) // A rune is 1~4 bytes.

r, n, err := s.buf.ReadRune()
s.position += int64(n)
s.lastRuneSize = n
switch {
case n == 0:
s.endOfStream = endOfStreamPast
case n < len(b):
s.endOfStream = endOfStreamNot
case n == len(b):
s.endOfStream = endOfStreamAt
}
s.checkEOS(err)
return r, n, err
}

Expand Down Expand Up @@ -228,11 +208,7 @@ func (s *Stream) Seek(offset int64, whence int) (int64, error) {
}

s.position = n

if r, ok := sk.(io.Reader); ok && s.buf != nil {
s.buf.Reset(r)
s.checkEOS()
}
s.reset()

return n, nil
}
Expand Down Expand Up @@ -306,39 +282,67 @@ func (s *Stream) Close() error {
}

func (s *Stream) initRead() error {
if s.buf == nil {
s.buf = bufio.NewReader(s.source)
}

if s.mode != ioModeRead {
return errWrongIOMode
}

if s.buf == (bufReader{}) {
s.buf = newBufReader(s.source)
}

if s.endOfStream == endOfStreamPast {
switch s.eofAction {
case eofActionError:
return errPastEndOfStream
case eofActionReset:
_, err := s.Seek(0, io.SeekStart)
return err
s.reset()
}
}

return nil
}

func (s *Stream) checkEOS() {
b, _ := s.buf.Peek(2)
switch len(b) {
case 0:
func (s *Stream) reset() {
if s.mode != ioModeRead {
return
}

s.buf = newBufReader(s.source)
s.endOfStream = endOfStreamNot
}

func (s *Stream) checkEOS(err error) {
// After reading, we might be at the end of stream.
switch b := s.buf.Buffered(); {
case errors.Is(err, io.EOF):
s.endOfStream = endOfStreamPast
case 1:
case b == 0 && errors.Is(s.buf.ReadErr(), io.EOF):
// io.Reader may return io.EOF at the very last read with a non-zero number of bytes.
// In that case, we can say we're at the end of stream after consuming all the buffered bytes.
s.endOfStream = endOfStreamAt
case b == 0 && s.position == fileSize(s.source):
// If the position equals to the file size after consuming all the buffered bytes,
// we can say we're at the end of stream.
s.endOfStream = endOfStreamAt
default:
// At least one byte is buffered or the underlying io.Reader hasn't reported io.EOF yet.
// io.Reader may surprise us with `0, io.EOF`. In that case, we fail to detect the end of stream.
s.endOfStream = endOfStreamNot
}
}

func fileSize(r io.Reader) int64 {
f, ok := r.(fs.File)
if !ok {
return -1
}
fi, err := f.Stat()
if err != nil {
return -1
}
return fi.Size()
}

func (s *Stream) properties() []Term {
ps := make([]Term, 0, 9)

Expand Down Expand Up @@ -535,3 +539,35 @@ func (ss *streams) lookup(a Atom) (*Stream, bool) {
s, ok := ss.aliases[a]
return s, ok
}

// bufReader is a wrapper around *bufio.Reader.
// *bufio.Reader doesn't tell us if the underlying io.Reader returned an error.
// We need to know this to determine end_of_stream.
type bufReader struct {
*bufio.Reader
er *errReader
}

func newBufReader(r io.Reader) bufReader {
er := errReader{r: r}
return bufReader{
Reader: bufio.NewReader(&er),
er: &er,
}
}

func (b bufReader) ReadErr() error {
return b.er.err
}

type errReader struct {
r io.Reader
err error
}

func (e *errReader) Read(p []byte) (n int, err error) {
defer func() {
e.err = err
}()
return e.r.Read(p)
}
Loading

0 comments on commit c682cd3

Please sign in to comment.