Skip to content

Commit

Permalink
add some tests
Browse files Browse the repository at this point in the history
tests of the CLI itself have uncovered some interesting oddities with testing
readline applications which I'll have to think about how to work with
  • Loading branch information
jaffee committed Nov 12, 2023
1 parent e2f3487 commit 3f19073
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 15 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ Usage of aicli:

## Future/TODO

- set/view system message
- abstract openai-specific stuff for testing and supporting other models
- support other services like Anthropic, Cohere
- Write conversation, or single response to file
- automatically save conversations and allow listing/loading of convos
- Load old conversation from file
Expand Down
15 changes: 10 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ module github.com/jaffee/aicli

go 1.21.1

require github.com/wader/readline v0.0.0-20230307172220-bcb7158e7448
require (
github.com/jaffee/commandeer v0.6.0
github.com/pkg/errors v0.9.1
github.com/sashabaranov/go-openai v1.17.0
github.com/stretchr/testify v1.2.2
github.com/wader/readline v0.0.0-20230307172220-bcb7158e7448
)

require (
github.com/jaffee/commandeer v0.6.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/sashabaranov/go-openai v1.17.0 // indirect
golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.1.0 // indirect
golang.org/x/text v0.3.7 // indirect
)
7 changes: 6 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no=
Expand Down Expand Up @@ -57,6 +58,7 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
Expand All @@ -77,9 +79,11 @@ github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B
github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE=
github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo=
github.com/spf13/pflag v1.0.3 h1:zPAT6CGy6wXeQ7NtTnaTerfKOsV6V6F8agHXFiazDkg=
github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/yZzE=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
Expand Down Expand Up @@ -109,8 +113,9 @@ golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b h1:2n253B2r0pYSmEV+UNCQoPfU/FiaizQEK5Gu4Bq4JE8=
golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
Expand Down
14 changes: 10 additions & 4 deletions pkg/aicli/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type Cmd struct {
stdout io.Writer
stderr io.Writer

historyPath string

client AI
}

Expand Down Expand Up @@ -54,9 +56,10 @@ func (cmd *Cmd) Run() error {
}

rl, err := readline.NewEx(&readline.Config{
Prompt: "> ",
HistoryFile: getHistoryFilePath(),
HistoryLimit: 1000000,
Prompt: "> ",
HistoryFile: cmd.getHistoryFilePath(),
HistoryLimit: 1000000,
ForceUseInteractive: true, // seems to be needed for testing

Stdin: cmd.stdin,
Stdout: cmd.stdout,
Expand Down Expand Up @@ -216,7 +219,10 @@ func (cmd *Cmd) printConfig() {
fmt.Fprintf(cmd.stderr, "Verbose: %v\n", cmd.Verbose)
}

func getHistoryFilePath() string {
func (cmd *Cmd) getHistoryFilePath() string {
if cmd.historyPath != "" {
return cmd.historyPath
}
home, err := os.UserHomeDir()
if err != nil {
// if we can't get a home dir, we'll use the local directory
Expand Down
54 changes: 54 additions & 0 deletions pkg/aicli/cmd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package aicli

import (
"bytes"
"io"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestCmd(t *testing.T) {
cmd := NewCmd(&Echo{})
stdinr, stdinw := io.Pipe()
stdout := &bytes.Buffer{}
stderr := &bytes.Buffer{}

cmd.stdin = stdinr
cmd.stdout = stdout
cmd.stderr = stderr
cmd.historyPath = t.TempDir() + "/.aicli_history"
cmd.OpenAI_API_Key = "blah"

done := make(chan struct{})
var runErr error
go func() {
runErr = cmd.Run()
close(done)
}()
time.Sleep(time.Millisecond * 10)
require.NoError(t, runErr)
expect(t, stdout, []byte{0x20, 0x08, 0x1b, 0x5b, 0x36, 0x6e, 0x3e, 0x20})
stdinw.Close()

select {
case <-done:
case <-time.After(time.Second * 2):
t.Fatal("command should have ended after stdin was closed")
}

require.NoError(t, runErr)
}

func expect(t *testing.T, r io.Reader, exp []byte) {
t.Helper()
buffer := make([]byte, len(exp))

n, err := r.Read(buffer)
if err != nil && err.Error() != "EOF" {
require.NoError(t, err)
}

require.Equal(t, exp, buffer[:n])
}
29 changes: 29 additions & 0 deletions pkg/aicli/echoai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package aicli

import (
"fmt"
"io"

"github.com/pkg/errors"
)

var _ AI = &Echo{} // assert that Client satisfies AI interface

// Echo is an AI implementation for testing which repeats the user's last
// message back with some extra information.
type Echo struct{}

func (c *Echo) StreamResp(msgs []Message, output io.Writer) (Message, error) {
var resp string
if len(msgs) == 0 {
resp = "0 msgs"
} else {
resp = fmt.Sprintf("msgs: %d, role: %s, content: %s", len(msgs), RoleAssistant, msgs[len(msgs)-1].Content())
}

_, err := output.Write([]byte(resp))
if err != nil {
return nil, errors.Wrap(err, "writing msg")
}
return SimpleMsg{RoleField: RoleAssistant, ContentField: resp}, nil
}
44 changes: 44 additions & 0 deletions pkg/aicli/echoai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package aicli_test

import (
"bytes"
"fmt"
"testing"

"github.com/jaffee/aicli/pkg/aicli"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestEcho(t *testing.T) {
c := &aicli.Echo{}

cases := []struct {
msgs []aicli.Message
exp string
}{
{msgs: nil, exp: "0 msgs"},
{
msgs: []aicli.Message{
aicli.SimpleMsg{
ContentField: "hello",
RoleField: aicli.RoleUser,
},
},
exp: "msgs: 1, role: assistant, content: hello",
},
}

for i, tst := range cases {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
buf := &bytes.Buffer{}
msg, err := c.StreamResp(tst.msgs, buf)
require.NoError(t, err)

assert.Equal(t, tst.exp, string(buf.Bytes()))
assert.Equal(t, tst.exp, msg.Content())
assert.Equal(t, aicli.RoleAssistant, msg.Role())
})
}

}
6 changes: 6 additions & 0 deletions pkg/aicli/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ type AI interface {
StreamResp(msgs []Message, output io.Writer) (Message, error)
}

const (
RoleAssistant = "assistant"
RoleUser = "user"
RoleSystem = "system"
)

type Message interface {
Role() string
Content() string
Expand Down
5 changes: 2 additions & 3 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
openai "github.com/sashabaranov/go-openai"
)

var _ aicli.AI = &Client{} // assert that Client satisfies AI interface

type Client struct {
model string

Expand All @@ -24,9 +26,6 @@ func NewClient(apiKey, model string) *Client {
}
}

func (c *Client) SetModel(model string) {
}

func toOpenAIMessages(msgs []aicli.Message) []openai.ChatCompletionMessage {
ret := make([]openai.ChatCompletionMessage, len(msgs))
for i, msg := range msgs {
Expand Down

0 comments on commit 3f19073

Please sign in to comment.