Skip to content

Commit

Permalink
Merge pull request #284 from ichiban/term-interface
Browse files Browse the repository at this point in the history
Term must implement WriteTerm() and Compare()
  • Loading branch information
ichiban authored Feb 8, 2023
2 parents bdec4d0 + cef77c4 commit 146378f
Show file tree
Hide file tree
Showing 22 changed files with 1,448 additions and 1,468 deletions.
62 changes: 58 additions & 4 deletions engine/atom.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package engine

import (
"fmt"
"io"
"regexp"
"strings"
"sync"
Expand Down Expand Up @@ -235,6 +236,63 @@ func NewAtom(name string) Atom {
return a
}

// WriteTerm outputs the Atom to an io.Writer.
func (a Atom) WriteTerm(w io.Writer, opts *WriteOptions, _ *Env) error {
ew := errWriter{w: w}
openClose := (opts.left != (operator{}) || opts.right != (operator{})) && opts.ops.defined(a)

if openClose {
if opts.left.name != 0 && opts.left.specifier.class() == operatorClassPrefix {
_, _ = ew.Write([]byte(" "))
}
_, _ = ew.Write([]byte("("))
opts = opts.withLeft(operator{}).withRight(operator{})
}

if opts.quoted && needQuoted(a) {
if opts.left != (operator{}) && needQuoted(opts.left.name) { // Avoid 'FOO''BAR'.
_, _ = ew.Write([]byte(" "))
}
_, _ = ew.Write([]byte(quote(a.String())))
if opts.right != (operator{}) && needQuoted(opts.right.name) { // Avoid 'FOO''BAR'.
_, _ = ew.Write([]byte(" "))
}
} else {
if (letterDigit(opts.left.name) && letterDigit(a)) || (graphic(opts.left.name) && graphic(a)) {
_, _ = ew.Write([]byte(" "))
}
_, _ = ew.Write([]byte(a.String()))
if (letterDigit(opts.right.name) && letterDigit(a)) || (graphic(opts.right.name) && graphic(a)) {
_, _ = ew.Write([]byte(" "))
}
}

if openClose {
_, _ = ew.Write([]byte(")"))
}

return ew.err
}

// Compare compares the Atom with a Term.
func (a Atom) Compare(t Term, env *Env) int {
switch t := env.Resolve(t).(type) {
case Variable, Float, Integer:
return 1
case Atom:
switch d := strings.Compare(a.String(), t.String()); {
case d > 0:
return 1
case d < 0:
return -1
default:
return 0
}
default: // Custom atomic terms, Compound.
return -1
}
}

func (a Atom) String() string {
if a <= utf8.MaxRune {
return string(rune(a))
Expand All @@ -244,10 +302,6 @@ func (a Atom) String() string {
return atomTable.names[a-(utf8.MaxRune+1)]
}

func (a Atom) GoString() string {
return fmt.Sprintf("%#v", a.String())
}

// Apply returns a Compound which Functor is the Atom and args are the arguments. If the arguments are empty,
// then returns itself.
func (a Atom) Apply(args ...Term) Term {
Expand Down
65 changes: 65 additions & 0 deletions engine/atom_test.go
Original file line number Diff line number Diff line change
@@ -1,2 +1,67 @@
package engine

import (
"bytes"
"github.com/stretchr/testify/assert"
"testing"
)

func TestAtom_WriteTerm(t *testing.T) {
tests := []struct {
name string
opts WriteOptions
output string
}{
{name: `a`, opts: WriteOptions{quoted: false}, output: `a`},
{name: `a`, opts: WriteOptions{quoted: true}, output: `a`},
{name: "\a\b\f\n\r\t\v\x00\\'\"`", opts: WriteOptions{quoted: false}, output: "\a\b\f\n\r\t\v\x00\\'\"`"},
{name: "\a\b\f\n\r\t\v\x00\\'\"`", opts: WriteOptions{quoted: true}, output: "'\\a\\b\\f\\n\\r\\t\\v\\x0\\\\\\\\'\"`'"},
{name: `,`, opts: WriteOptions{quoted: false}, output: `,`},
{name: `,`, opts: WriteOptions{quoted: true}, output: `','`},
{name: `[]`, opts: WriteOptions{quoted: false}, output: `[]`},
{name: `[]`, opts: WriteOptions{quoted: true}, output: `[]`},
{name: `{}`, opts: WriteOptions{quoted: false}, output: `{}`},
{name: `{}`, opts: WriteOptions{quoted: true}, output: `{}`},
{name: `-`, output: `-`},
{name: `-`, opts: WriteOptions{ops: operators{atomPlus: {}, atomMinus: {}}, left: operator{specifier: operatorSpecifierFY, name: atomPlus}}, output: ` (-)`},
{name: `-`, opts: WriteOptions{ops: operators{atomPlus: {}, atomMinus: {}}, right: operator{name: atomPlus}}, output: `(-)`},
{name: `X`, opts: WriteOptions{quoted: true, left: operator{name: NewAtom(`F`)}}, output: ` 'X'`}, // So that it won't be 'F''X'.
{name: `X`, opts: WriteOptions{quoted: true, right: operator{name: NewAtom(`F`)}}, output: `'X' `}, // So that it won't be 'X''F'.
{name: `foo`, opts: WriteOptions{left: operator{name: NewAtom(`bar`)}}, output: ` foo`}, // So that it won't be barfoo.
{name: `foo`, opts: WriteOptions{right: operator{name: NewAtom(`bar`)}}, output: `foo `}, // So that it won't be foobar.},
}

var buf bytes.Buffer
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf.Reset()
assert.NoError(t, NewAtom(tt.name).WriteTerm(&buf, &tt.opts, nil))
assert.Equal(t, tt.output, buf.String())
})
}
}

func TestAtom_Compare(t *testing.T) {
x := NewVariable()

tests := []struct {
title string
a Atom
t Term
o int
}{
{title: `a > X`, a: NewAtom("a"), t: x, o: 1},
{title: `a > 1.0`, a: NewAtom("a"), t: Float(1), o: 1},
{title: `a > 1`, a: NewAtom("a"), t: Integer(1), o: 1},
{title: `a > 'Z'`, a: NewAtom("a"), t: NewAtom("Z"), o: 1},
{title: `a = a`, a: NewAtom("a"), t: NewAtom("a"), o: 0},
{title: `a < b`, a: NewAtom("a"), t: NewAtom("b"), o: -1},
{title: `a < f(a)`, a: NewAtom("a"), t: NewAtom("f").Apply(NewAtom("a")), o: -1},
}

for _, tt := range tests {
t.Run(tt.title, func(t *testing.T) {
assert.Equal(t, tt.o, tt.a.Compare(tt.t, nil))
})
}
}
127 changes: 96 additions & 31 deletions engine/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func SubsumesTerm(_ *VM, general, specific Term, k Cont, env *Env) *Promise {
return Bool(false)
}

if d := env.compare(theta.simplify(general), specific); d != 0 {
if d := theta.simplify(general).Compare(specific, env); d != 0 {
return Bool(false)
}

Expand Down Expand Up @@ -793,6 +793,60 @@ func collectionOf(vm *VM, agg func([]Term, *Env) Term, template, goal, instances
}, env)
}

func variant(t1, t2 Term, env *Env) bool {
s := map[Variable]Variable{}
rest := [][2]Term{
{t1, t2},
}
var xy [2]Term
for len(rest) > 0 {
rest, xy = rest[:len(rest)-1], rest[len(rest)-1]
x, y := env.Resolve(xy[0]), env.Resolve(xy[1])
switch x := x.(type) {
case Variable:
switch y := y.(type) {
case Variable:
if z, ok := s[x]; ok {
if z != y {
return false
}
} else {
s[x] = y
}
default:
return false
}
case Compound:
switch y := y.(type) {
case Compound:
if x.Functor() != y.Functor() || x.Arity() != y.Arity() {
return false
}
for i := 0; i < x.Arity(); i++ {
rest = append(rest, [2]Term{x.Arg(i), y.Arg(i)})
}
default:
return false
}
default:
if x != y {
return false
}
}
}
return true
}

func iteratedGoalTerm(t Term, env *Env) Term {
for {
c, ok := env.Resolve(t).(Compound)
if !ok || c.Functor() != atomCaret || c.Arity() != 2 {
return t
}
t = c.Arg(1)
}
}

// FindAll collects all the solutions of goal as instances, which unify with template. instances may contain duplications.
func FindAll(vm *VM, template, goal, instances Term, k Cont, env *Env) *Promise {
iter := ListIterator{List: instances, Env: env, AllowPartial: true}
Expand Down Expand Up @@ -833,7 +887,7 @@ func Compare(vm *VM, order, term1, term2 Term, k Cont, env *Env) *Promise {
return Error(typeError(validTypeAtom, order, env))
}

switch o := env.compare(term1, term2); o {
switch o := term1.Compare(term2, env); o {
case 1:
return Unify(vm, atomGreaterThan, order, k, env)
case -1:
Expand Down Expand Up @@ -957,7 +1011,7 @@ func KeySort(vm *VM, pairs, sorted Term, k Cont, env *Env) *Promise {
}

sort.SliceStable(elems, func(i, j int) bool {
return env.compare(elems[i].(Compound).Arg(0), elems[j].(Compound).Arg(0)) == -1
return elems[i].(Compound).Arg(0).Compare(elems[j].(Compound).Arg(0), env) == -1
})

return Unify(vm, sorted, List(elems...), k, env)
Expand All @@ -976,7 +1030,7 @@ func Throw(_ *VM, ball Term, _ Cont, env *Env) *Promise {
// Catch calls goal. If an exception is thrown and unifies with catcher, it calls recover.
func Catch(vm *VM, goal, catcher, recover Term, k Cont, env *Env) *Promise {
return catch(func(err error) *Promise {
e, ok := env.Resolve(err).(Exception)
e, ok := err.(Exception)
if !ok {
e = Exception{term: atomError.Apply(NewAtom("system_error"), NewAtom(err.Error()))}
}
Expand Down Expand Up @@ -1209,7 +1263,11 @@ func Open(vm *VM, sourceSink, mode, stream, options Term, k Cont, env *Env) *Pro
s := Stream{vm: vm, mode: streamMode}
switch f, err := openFile(name, int(s.mode), 0644); {
case err == nil:
s.sourceSink = f
if s.mode == ioModeRead {
s.source = f
} else {
s.sink = f
}
if fi, err := f.Stat(); err == nil {
s.reposition = fi.Mode()&fs.ModeType == 0
}
Expand All @@ -1231,8 +1289,10 @@ func Open(vm *VM, sourceSink, mode, stream, options Term, k Cont, env *Env) *Pro
return Error(err)
}

if err := s.initRead(); err == nil {
s.checkEOS()
if s.mode == ioModeRead {
if err := s.initRead(); err == nil {
s.checkEOS()
}
}

return Unify(vm, stream, &s, k, env)
Expand Down Expand Up @@ -1404,7 +1464,7 @@ func WriteTerm(vm *VM, streamOrAlias, t, options Term, k Cont, env *Env) *Promis
return Error(err)
}

opts := writeOptions{
opts := WriteOptions{
ops: vm.operators,
priority: 1200,
}
Expand All @@ -1418,19 +1478,24 @@ func WriteTerm(vm *VM, streamOrAlias, t, options Term, k Cont, env *Env) *Promis
return Error(err)
}

switch err := writeTerm(s, env.Resolve(t), &opts, env); err {
case nil:
return k(env)
case errWrongIOMode:
w, err := s.textWriter()
switch {
case errors.Is(err, errWrongIOMode):
return Error(permissionError(operationOutput, permissionTypeStream, streamOrAlias, env))
case errWrongStreamType:
case errors.Is(err, errWrongStreamType):
return Error(permissionError(operationOutput, permissionTypeBinaryStream, streamOrAlias, env))
default:
case err != nil:
return Error(err)
}

if err := env.Resolve(t).WriteTerm(w, &opts, env); err != nil {
return Error(err)
}

return k(env)
}

func writeTermOption(opts *writeOptions, option Term, env *Env) error {
func writeTermOption(opts *WriteOptions, option Term, env *Env) error {
switch o := env.Resolve(option).(type) {
case Variable:
return InstantiationError(env)
Expand Down Expand Up @@ -1589,16 +1654,16 @@ func PutByte(vm *VM, streamOrAlias, byt Term, k Cont, env *Env) *Promise {
return Error(typeError(validTypeByte, byt, env))
}

switch err := s.WriteByte(byte(b)); err {
case nil:
return k(env)
case errWrongIOMode:
switch err := s.WriteByte(byte(b)); {
case errors.Is(err, errWrongIOMode):
return Error(permissionError(operationOutput, permissionTypeStream, streamOrAlias, env))
case errWrongStreamType:
case errors.Is(err, errWrongStreamType):
return Error(permissionError(operationOutput, permissionTypeTextStream, streamOrAlias, env))
default:
case err != nil:
return Error(err)
}

return k(env)
default:
return Error(typeError(validTypeByte, byt, env))
}
Expand All @@ -1621,16 +1686,16 @@ func PutChar(vm *VM, streamOrAlias, char Term, k Cont, env *Env) *Promise {

r := rune(c)

switch _, err := s.WriteRune(r); err {
case nil:
return k(env)
case errWrongIOMode:
switch _, err := s.WriteRune(r); {
case errors.Is(err, errWrongIOMode):
return Error(permissionError(operationOutput, permissionTypeStream, streamOrAlias, env))
case errWrongStreamType:
case errors.Is(err, errWrongStreamType):
return Error(permissionError(operationOutput, permissionTypeBinaryStream, streamOrAlias, env))
default:
case err != nil:
return Error(err)
}

return k(env)
default:
return Error(typeError(validTypeCharacter, char, env))
}
Expand Down Expand Up @@ -2014,11 +2079,11 @@ func AtomConcat(vm *VM, atom1, atom2, atom3 Term, k Cont, env *Env) *Promise {
for i := range s {
a1, a2 := s[:i], s[i:]
ks = append(ks, func(context.Context) *Promise {
return Unify(vm, &pattern, tuple(NewAtom(a1), NewAtom(a2)), k, env)
return Unify(vm, pattern, tuple(NewAtom(a1), NewAtom(a2)), k, env)
})
}
ks = append(ks, func(context.Context) *Promise {
return Unify(vm, &pattern, tuple(a3, atomEmpty), k, env)
return Unify(vm, pattern, tuple(a3, atomEmpty), k, env)
})
return Delay(ks...)
default:
Expand Down Expand Up @@ -2260,7 +2325,7 @@ func numberCharsWrite(vm *VM, num, chars Term, k Cont, env *Env) *Promise {
}

var buf bytes.Buffer
_ = writeTerm(&buf, n, &defaultWriteOptions, nil)
_ = n.WriteTerm(&buf, &defaultWriteOptions, nil)
rs := []rune(buf.String())

cs := make([]Term, len(rs))
Expand Down Expand Up @@ -2342,7 +2407,7 @@ func numberCodesWrite(vm *VM, num, codes Term, k Cont, env *Env) *Promise {
}

var buf bytes.Buffer
_ = writeTerm(&buf, n, &defaultWriteOptions, nil)
_ = n.WriteTerm(&buf, &defaultWriteOptions, nil)
rs := []rune(buf.String())

cs := make([]Term, len(rs))
Expand Down
Loading

0 comments on commit 146378f

Please sign in to comment.