From d67ceac38408e0c08971447e690b6f865be4f8e9 Mon Sep 17 00:00:00 2001 From: Yutaka Ichibangase Date: Fri, 11 Aug 2023 11:09:29 +0900 Subject: [PATCH] fix end_of_stream detection: peek() based detection requires extra bytes from the io.Reader. Use buffered bytes and file sizes instead. --- engine/builtin.go | 7 +-- engine/builtin_test.go | 9 ++-- engine/stream.go | 116 +++++++++++++++++++++++++++-------------- engine/stream_test.go | 49 +++++++++++++---- engine/text_test.go | 9 ++++ interpreter_test.go | 10 ++-- 6 files changed, 137 insertions(+), 63 deletions(-) diff --git a/engine/builtin.go b/engine/builtin.go index a21a91c..e707963 100644 --- a/engine/builtin.go +++ b/engine/builtin.go @@ -1266,6 +1266,7 @@ func Open(vm *VM, sourceSink, mode, stream, options Term, k Cont, env *Env) *Pro case err == nil: if s.mode == ioModeRead { s.source = f + s.initRead() } else { s.sink = f } @@ -1290,12 +1291,6 @@ func Open(vm *VM, sourceSink, mode, stream, options Term, k Cont, env *Env) *Pro return Error(err) } - if s.mode == ioModeRead { - if err := s.initRead(); err == nil { - s.checkEOS() - } - } - return Unify(vm, stream, &s, k, env) } diff --git a/engine/builtin_test.go b/engine/builtin_test.go index f9475ab..6195c35 100644 --- a/engine/builtin_test.go +++ b/engine/builtin_test.go @@ -3295,6 +3295,7 @@ func TestOpen(t *testing.T) { assert.True(t, ok) assert.Equal(t, l, s) + assert.NoError(t, s.initRead()) b, err := io.ReadAll(s.buf) assert.NoError(t, err) assert.Equal(t, "test\n", string(b)) @@ -4767,7 +4768,7 @@ func TestGetByte(t *testing.T) { t.Run("error", func(t *testing.T) { var m mockReader - m.On("Read", mock.Anything).Return(0, errors.New("failed")).Twice() + m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once() defer m.AssertExpectations(t) s := &Stream{source: &m, mode: ioModeRead, streamType: streamTypeBinary} @@ -4925,7 +4926,7 @@ func TestGetChar(t *testing.T) { t.Run("error", func(t *testing.T) { var m mockReader - m.On("Read", mock.Anything).Return(0, errors.New("failed")).Times(2) + m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once() defer m.AssertExpectations(t) v := NewVariable() @@ -5095,7 +5096,7 @@ func TestPeekByte(t *testing.T) { t.Run("error", func(t *testing.T) { var m mockReader - m.On("Read", mock.Anything).Return(0, errors.New("failed")).Twice() + m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once() defer m.AssertExpectations(t) s := &Stream{source: &m, mode: ioModeRead} @@ -5255,7 +5256,7 @@ func TestPeekChar(t *testing.T) { t.Run("error", func(t *testing.T) { var m mockReader - m.On("Read", mock.Anything).Return(0, errors.New("failed")).Twice() + m.On("Read", mock.Anything).Return(0, errors.New("failed")).Once() defer m.AssertExpectations(t) v := NewVariable() diff --git a/engine/stream.go b/engine/stream.go index 353b015..fe7f7da 100644 --- a/engine/stream.go +++ b/engine/stream.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "io/fs" "os" "unsafe" ) @@ -14,7 +15,6 @@ var ( errWrongStreamType = errors.New("wrong stream type") errPastEndOfStream = errors.New("past end of stream") errReposition = errors.New("reposition") - errNotSupported = errors.New("not supported") ) // Stream is a prolog stream. @@ -23,7 +23,7 @@ type Stream struct { source io.Reader sink io.Writer - buf *bufio.Reader + buf bufReader lastRuneSize int mode ioMode @@ -127,21 +127,11 @@ func (s *Stream) ReadByte() (byte, error) { return 0, errWrongStreamType } - // After reading a byte, we might be at the end of stream. - bs, err := s.buf.Peek(2) - b, err := s.buf.ReadByte() if err == nil { s.position += 1 } - switch len(bs) { - case 2: - s.endOfStream = endOfStreamNot - case 1: - s.endOfStream = endOfStreamAt - case 0: - s.endOfStream = endOfStreamPast - } + s.checkEOS(err) return b, err } @@ -173,20 +163,10 @@ func (s *Stream) ReadRune() (r rune, size int, err error) { return 0, 0, errWrongStreamType } - // After reading a rune, we might be at the end of stream. - b, _ := s.buf.Peek(5) // A rune is 1~4 bytes. - r, n, err := s.buf.ReadRune() s.position += int64(n) s.lastRuneSize = n - switch { - case n == 0: - s.endOfStream = endOfStreamPast - case n < len(b): - s.endOfStream = endOfStreamNot - case n == len(b): - s.endOfStream = endOfStreamAt - } + s.checkEOS(err) return r, n, err } @@ -228,11 +208,7 @@ func (s *Stream) Seek(offset int64, whence int) (int64, error) { } s.position = n - - if r, ok := sk.(io.Reader); ok && s.buf != nil { - s.buf.Reset(r) - s.checkEOS() - } + s.reset() return n, nil } @@ -306,39 +282,67 @@ func (s *Stream) Close() error { } func (s *Stream) initRead() error { - if s.buf == nil { - s.buf = bufio.NewReader(s.source) - } - if s.mode != ioModeRead { return errWrongIOMode } + if s.buf == (bufReader{}) { + s.buf = newBufReader(s.source) + } + if s.endOfStream == endOfStreamPast { switch s.eofAction { case eofActionError: return errPastEndOfStream case eofActionReset: - _, err := s.Seek(0, io.SeekStart) - return err + s.reset() } } return nil } -func (s *Stream) checkEOS() { - b, _ := s.buf.Peek(2) - switch len(b) { - case 0: +func (s *Stream) reset() { + if s.mode != ioModeRead { + return + } + + s.buf = newBufReader(s.source) + s.endOfStream = endOfStreamNot +} + +func (s *Stream) checkEOS(err error) { + // After reading, we might be at the end of stream. + switch b := s.buf.Buffered(); { + case errors.Is(err, io.EOF): s.endOfStream = endOfStreamPast - case 1: + case b == 0 && errors.Is(s.buf.ReadErr(), io.EOF): + // io.Reader may return io.EOF at the very last read with a non-zero number of bytes. + // In that case, we can say we're at the end of stream after consuming all the buffered bytes. + s.endOfStream = endOfStreamAt + case b == 0 && s.position == fileSize(s.source): + // If the position equals to the file size after consuming all the buffered bytes, + // we can say we're at the end of stream. s.endOfStream = endOfStreamAt default: + // At least one byte is buffered or the underlying io.Reader hasn't reported io.EOF yet. + // io.Reader may surprise us with `0, io.EOF`. In that case, we fail to detect the end of stream. s.endOfStream = endOfStreamNot } } +func fileSize(r io.Reader) int64 { + f, ok := r.(fs.File) + if !ok { + return -1 + } + fi, err := f.Stat() + if err != nil { + return -1 + } + return fi.Size() +} + func (s *Stream) properties() []Term { ps := make([]Term, 0, 9) @@ -535,3 +539,35 @@ func (ss *streams) lookup(a Atom) (*Stream, bool) { s, ok := ss.aliases[a] return s, ok } + +// bufReader is a wrapper around *bufio.Reader. +// *bufio.Reader doesn't tell us if the underlying io.Reader returned an error. +// We need to know this to determine end_of_stream. +type bufReader struct { + *bufio.Reader + er *errReader +} + +func newBufReader(r io.Reader) bufReader { + er := errReader{r: r} + return bufReader{ + Reader: bufio.NewReader(&er), + er: &er, + } +} + +func (b bufReader) ReadErr() error { + return b.er.err +} + +type errReader struct { + r io.Reader + err error +} + +func (e *errReader) Read(p []byte) (n int, err error) { + defer func() { + e.err = err + }() + return e.r.Read(p) +} diff --git a/engine/stream_test.go b/engine/stream_test.go index 0038e7a..711f9fc 100644 --- a/engine/stream_test.go +++ b/engine/stream_test.go @@ -235,7 +235,7 @@ func TestStream_ReadByte(t *testing.T) { s: &Stream{source: bytes.NewReader([]byte{3}), streamType: streamTypeBinary, position: 2}, b: 3, pos: 3, - eos: endOfStreamAt, + eos: endOfStreamNot, }, { title: "input binary: empty", @@ -309,11 +309,27 @@ func TestStream_ReadRune(t *testing.T) { eos: endOfStreamNot, }, { - title: "input text: 1 rune left", + title: "input text: 1 rune left, abrupt EOF", s: &Stream{source: bytes.NewReader([]byte("c")), streamType: streamTypeText, position: 2}, r: 'c', size: 1, pos: 3, + eos: endOfStreamNot, + }, + { + title: "input text: 1 rune left, non-abrupt EOF", + s: &Stream{source: newNonAbruptReader([]byte("c")), streamType: streamTypeText, position: 2}, + r: 'c', + size: 1, + pos: 3, + eos: endOfStreamAt, + }, + { + title: "input text: 1 rune left, file", + s: &Stream{source: mustOpen(testdata, "testdata/a.txt"), streamType: streamTypeText, position: 0}, + r: 'a', + size: 1, + pos: 1, eos: endOfStreamAt, }, { @@ -396,7 +412,6 @@ func TestStream_Seek(t *testing.T) { whence int pos int64 err error - eos endOfStream }{ { title: "ok", @@ -411,10 +426,10 @@ func TestStream_Seek(t *testing.T) { whence: 0, err: errors.New("ng"), }, - {title: "reader", s: s, offset: 0, whence: 0, pos: 0, eos: endOfStreamNot}, - {title: "reader", s: s, offset: 1, whence: 0, pos: 1, eos: endOfStreamNot}, - {title: "reader", s: s, offset: 2, whence: 0, pos: 2, eos: endOfStreamAt}, - {title: "reader", s: s, offset: 3, whence: 0, pos: 3, eos: endOfStreamPast}, + {title: "reader", s: s, offset: 0, whence: 0, pos: 0}, + {title: "reader", s: s, offset: 1, whence: 0, pos: 1}, + {title: "reader", s: s, offset: 2, whence: 0, pos: 2}, + {title: "reader", s: s, offset: 3, whence: 0, pos: 3}, { title: "not seeker", s: &Stream{source: &okSeeker.mockReader, reposition: true, position: 123}, @@ -429,8 +444,6 @@ func TestStream_Seek(t *testing.T) { pos, err := tt.s.Seek(tt.offset, tt.whence) assert.Equal(t, tt.pos, pos) assert.Equal(t, tt.err, err) - - assert.Equal(t, tt.eos, tt.s.endOfStream) }) } } @@ -577,3 +590,21 @@ func TestStream_Flush(t *testing.T) { assert.NoError(t, s.Flush()) }) } + +type nonAbruptReader struct { + *bytes.Reader +} + +func newNonAbruptReader(b []byte) nonAbruptReader { + return nonAbruptReader{ + Reader: bytes.NewReader(b), + } +} + +func (r nonAbruptReader) Read(b []byte) (int, error) { + n, err := r.Reader.Read(b) + if err == nil && r.Reader.Len() == 0 { + err = io.EOF + } + return n, err +} diff --git a/engine/text_test.go b/engine/text_test.go index c72dc56..e18fc2e 100644 --- a/engine/text_test.go +++ b/engine/text_test.go @@ -5,6 +5,7 @@ import ( "embed" "errors" "io" + "io/fs" "testing" "github.com/stretchr/testify/assert" @@ -13,6 +14,14 @@ import ( //go:embed testdata var testdata embed.FS +func mustOpen(fs fs.FS, name string) fs.File { + f, err := fs.Open(name) + if err != nil { + panic(err) + } + return f +} + func TestVM_Compile(t *testing.T) { tests := []struct { title string diff --git a/interpreter_test.go b/interpreter_test.go index 3ad87ce..0784ba7 100644 --- a/interpreter_test.go +++ b/interpreter_test.go @@ -229,6 +229,7 @@ func TestNew_variableNames(t *testing.T) { output string outputFn func(t *testing.T, output string) err error + waits bool }{ {name: "1", query: `catch(write_term(T,[quoted(true), variable_names([N=T])]), error(instantiation_error, _), true).`}, {name: "2", query: `N = 'X', write_term(T,[quoted(true), variable_names([N=T])]).`, output: `X`}, @@ -269,7 +270,7 @@ func TestNew_variableNames(t *testing.T) { }}, {name: "26", query: `T=(X,Y,Z), write_term(T,[quoted(true), variable_names(['X'=Z,'X'=Y,'X'=X])]).`, output: `X,X,X`}, {name: "27", query: `T=(1,2,3), T=(X,Y,Z), write_term(T,[quoted(true), variable_names(['X'=Z,'X'=Y,'X'=X])]).`, output: `1,2,3`}, - {name: "32", query: `read_term(T,[variable_names(VN_list)]), VN_list=[_=1,_=2,_=3], writeq(VN_list).`, err: context.DeadlineExceeded}, + {name: "32", query: `read_term(T,[variable_names(VN_list)]), VN_list=[_=1,_=2,_=3], writeq(VN_list).`, err: context.DeadlineExceeded, waits: true}, {name: "29", input: `B+C+A+B+C+A.`, query: `read_term(T,[variable_names(VN_list)]), VN_list=[_=1,_=2,_=3], writeq(VN_list).`, output: `['B'=1,'C'=2,'A'=3]`}, {name: "30", query: `catch(write_term(T, [variable_names(VN_list)]), error(instantiation_error, _), true).`}, {name: "31", query: `catch((VN_list = 1, write_term(T, [variable_names(VN_list)])), error(domain_error(write_option, variable_names(_)), _), true).`}, @@ -297,9 +298,9 @@ func TestNew_variableNames(t *testing.T) { {name: "60", query: `catch((O = alias(_), open(f,write,_,[O])), error(instantiation_error, _), true).`}, {name: "42", query: `catch((O = type(nontype), open(f,write,_,[O])), error(domain_error(stream_option, type(nontype)), _), true).`}, {name: "61", query: `catch((O = alias(1), open(f,write,_,[O])), error(domain_error(stream_option, alias(1)), _), true).`}, - {name: "45", query: `read_term(T,[variable_names(VN_list)]).`, err: context.DeadlineExceeded}, + {name: "45", query: `read_term(T,[variable_names(VN_list)]).`, waits: true}, {name: "46", input: `a.`, query: `read_term(T,[variable_names(VN_list)]), T = a, VN_list = [].`}, - {name: "47", query: `VN_list = 42, read_term(T,[variable_names(VN_list)]).`, err: context.DeadlineExceeded}, + {name: "47", query: `VN_list = 42, read_term(T,[variable_names(VN_list)]).`, waits: true}, {name: "48", input: `a.`, query: `VN_list = 42, \+read_term(T,[variable_names(VN_list)]).`}, {name: "49", input: `a b.`, query: `VN_list = 42, catch(read_term(T,[variable_names(VN_list)]), error(syntax_error(_), _), true).`}, {name: "53", query: `catch(write_term(S,[quoted(true), variable_names([N=T])]), error(instantiation_error, _), true).`}, @@ -307,7 +308,7 @@ func TestNew_variableNames(t *testing.T) { {name: "55", query: `S=1+T,N=' /*r*/V',write_term(S,[quoted(true), variable_names([N=T])]).`, output: `1+ /*r*/V`}, {name: "58", query: `S=1+T,N=(+),write_term(S,[quoted(true), variable_names([N=T])]).`, output: `1++`}, {name: "59", query: `S=T+1,N=(+),write_term(S,[quoted(true), variable_names([N=T])]).`, output: `++1`}, - {name: "69", query: `read_term(T, [singletons(1)]).`, err: context.DeadlineExceeded}, + {name: "69", query: `read_term(T, [singletons(1)]).`, waits: true}, {name: "70", input: `a.`, query: `\+read_term(T, [singletons(1)]).`}, } @@ -331,6 +332,7 @@ func TestNew_variableNames(t *testing.T) { } else { tt.outputFn(t, out.String()) } + assert.Equal(t, tt.waits, errors.Is(ctx.Err(), context.DeadlineExceeded)) }) } }