From 667faf5268ecfc4618c1c87ede6d33e6f22d5b49 Mon Sep 17 00:00:00 2001 From: Matthew Jaffee Date: Sat, 18 Nov 2023 22:42:01 -0600 Subject: [PATCH] set up to support multiple different AI at runtime --- README.md | 10 ++++-- cmd/aicli/main.go | 23 ++++++++++--- pkg/aicli/cmd.go | 79 +++++++++++++++++++++++++++++++++++-------- pkg/aicli/cmd_test.go | 7 ++-- pkg/openai/client.go | 8 +++++ 5 files changed, 103 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index e525866..b60e42f 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ Whatever you type into the prompt will be sent to the AI, unless it's a meta com - `\config` Prints out aicli's configuration. - `\file ` Send the path and contents of a file on your local filesystem to the AI. It will be prefixed with a short message explaining that you'll refer to the file later. The AI should just respond with something like "ok". - `\system ` Prepends a system message to the list of messages (or replaces if one is already there). Does not send anything to the AI, but the new system message will be sent with the next message. +- `\set ` Set various config params. See `\config`. ## Configuration @@ -46,10 +47,14 @@ Flags are below. Any flag can also be set as an environment variable, just make ``` $ aicli -help Usage of aicli: + -ai string + Name of service (default "openai") + -context-limit int + Maximum number of bytes of context to keep. Earlier parts of the conversation are discarded. (default 10000) + -model string + Name of model to talk to. Most services have multiple options. (default "gpt-3.5-turbo") -openai-api-key string Your API key for OpenAI. - -openai-model string - Model name for OpenAI. (default "gpt-3.5-turbo") -temperature float Passed to model, higher numbers tend to generate less probable responses. (default 0.7) -verbose @@ -68,6 +73,7 @@ Usage of aicli: ## Future/TODO +- Need to use "model" param after startup - support other services like Anthropic, Cohere - Write conversation, or single response to file - automatically save conversations and allow listing/loading of convos diff --git a/cmd/aicli/main.go b/cmd/aicli/main.go index 6423951..c37e517 100644 --- a/cmd/aicli/main.go +++ b/cmd/aicli/main.go @@ -9,14 +9,29 @@ import ( ) func main() { - cmd := aicli.NewCmd(nil) - err := commandeer.LoadEnv(cmd, "", func(a interface{}) error { return nil }) + flags := NewFlags() + err := commandeer.LoadEnv(flags, "", func(a interface{}) error { return nil }) if err != nil { log.Fatal(err) } - client := openai.NewClient(cmd.OpenAIAPIKey, cmd.OpenAIModel) - cmd.SetAI(client) + + cmd := flags.Cmd + client := openai.NewClient(flags.OpenAI.APIKey, cmd.Model) + cmd.AddAI("openai", client) + cmd.AddAI("echo", &aicli.Echo{}) if err := cmd.Run(); err != nil { log.Fatal(err) } } + +func NewFlags() *Flags { + return &Flags{ + OpenAI: openai.NewConfig(), + Cmd: *aicli.NewCmd(), + } +} + +type Flags struct { + OpenAI openai.Config `flag:"!embed"` + Cmd aicli.Cmd `flag:"!embed"` +} diff --git a/pkg/aicli/cmd.go b/pkg/aicli/cmd.go index 9784222..43b00cc 100644 --- a/pkg/aicli/cmd.go +++ b/pkg/aicli/cmd.go @@ -18,8 +18,8 @@ const ( ) type Cmd struct { - OpenAIAPIKey string `flag:"openai-api-key" help:"Your API key for OpenAI."` - OpenAIModel string `flag:"openai-model" help:"Model name for OpenAI."` + AI string `help:"Name of service"` + Model string `help:"Name of model to talk to. Most services have multiple options."` Temperature float64 `help:"Passed to model, higher numbers tend to generate less probable responses."` Verbose bool `help:"Enables debug output."` ContextLimit int `help:"Maximum number of bytes of context to keep. Earlier parts of the conversation are discarded."` @@ -34,13 +34,13 @@ type Cmd struct { dotAICLIDir string historyPath string - client AI + ais map[string]AI } -func NewCmd(client AI) *Cmd { +func NewCmd() *Cmd { return &Cmd{ - OpenAIAPIKey: "", - OpenAIModel: "gpt-3.5-turbo", + AI: "openai", + Model: "gpt-3.5-turbo", Temperature: 0.7, ContextLimit: 10000, // 10,000 bytes ~2000 tokens @@ -50,12 +50,12 @@ func NewCmd(client AI) *Cmd { stdout: os.Stdout, stderr: os.Stderr, - client: client, + ais: make(map[string]AI), } } -func (cmd *Cmd) SetAI(ai AI) { - cmd.client = ai +func (cmd *Cmd) AddAI(name string, ai AI) { + cmd.ais[name] = ai } func (cmd *Cmd) Run() error { @@ -118,7 +118,7 @@ func (cmd *Cmd) messagesWithinLimit() []Message { } func (cmd *Cmd) sendMessages() error { - msg, err := cmd.client.StreamResp(cmd.messagesWithinLimit(), cmd.stdout) + msg, err := cmd.client().StreamResp(cmd.messagesWithinLimit(), cmd.stdout) if err != nil { return err } @@ -150,6 +150,18 @@ func (cmd *Cmd) handleMeta(line string) { cmd.printMessages() case `\config`: cmd.printConfig() + case `\set`: + if len(parts) != 2 { + err = errors.New("usage: \\set ") + break + } + pv := strings.SplitN(parts[1], " ", 2) + if len(pv) != 2 { + err = errors.New("usage: \\set ") + break + } + param, val := pv[0], pv[1] + err = cmd.Set(param, val) case `\file`: if len(parts) < 2 { err = errors.New("need a file name for \\file command") @@ -198,6 +210,39 @@ func (cmd *Cmd) handleMeta(line string) { } } +func (cmd *Cmd) Set(param, value string) error { + switch param { + case "ai": + cmd.AI = value + case "model": + cmd.Model = value + case "temperature": + temp, err := strconv.ParseFloat(value, 64) + if err != nil { + return errors.Wrap(err, "parsing temp value") + } + cmd.Temperature = temp + case "verbose": + switch strings.ToLower(value) { + case "1", "true", "yes": + cmd.Verbose = true + case "0", "false", "no": + cmd.Verbose = false + default: + return errors.Errorf("could not parse '%s' as bool", value) + } + case "context-limit": + lim, err := strconv.Atoi(value) + if err != nil { + return errors.Wrapf(err, "parsing '%s' to int", value) + } + cmd.ContextLimit = lim + default: + return errors.Errorf("unknown parameter '%s'", param) + } + return nil +} + func (cmd *Cmd) sendFile(file string) error { f, err := os.Open(file) if err != nil { @@ -216,7 +261,7 @@ func (cmd *Cmd) sendFile(file string) error { } b.WriteString("\n```\n") cmd.appendMessage(SimpleMsg{RoleField: "user", ContentField: b.String()}) - msg, err := cmd.client.StreamResp(cmd.messages, cmd.stdout) + msg, err := cmd.client().StreamResp(cmd.messages, cmd.stdout) if err != nil { return errors.Wrap(err, "sending file") } @@ -225,6 +270,10 @@ func (cmd *Cmd) sendFile(file string) error { return nil } +func (cmd *Cmd) client() AI { + return cmd.ais[cmd.AI] +} + func (cmd *Cmd) appendMessage(msg Message) { cmd.messages = append(cmd.messages, msg) cmd.totalLen += len(msg.Content()) @@ -257,15 +306,15 @@ func (cmd *Cmd) errOut(err error, format string, a ...any) { // checkConfig ensures the command configuration is valid before proceeding. func (cmd *Cmd) checkConfig() error { - if cmd.OpenAIAPIKey == "" { - return errors.New("Need an API key") + if cmd.client() == nil { + return errors.Errorf("have no AI named '%s' configured", cmd.AI) } return nil } func (cmd *Cmd) printConfig() { - fmt.Fprintf(cmd.stderr, "OpenAI_API_Key: length=%d\n", len(cmd.OpenAIAPIKey)) - fmt.Fprintf(cmd.stderr, "OpenAIModel: %s\n", cmd.OpenAIModel) + fmt.Fprintf(cmd.stderr, "AI: %s\n", cmd.AI) + fmt.Fprintf(cmd.stderr, "Model: %s\n", cmd.Model) fmt.Fprintf(cmd.stderr, "Temperature: %f\n", cmd.Temperature) fmt.Fprintf(cmd.stderr, "Verbose: %v\n", cmd.Verbose) fmt.Fprintf(cmd.stderr, "ContextLimit: %d\n", cmd.ContextLimit) diff --git a/pkg/aicli/cmd_test.go b/pkg/aicli/cmd_test.go index 042d2c8..977250b 100644 --- a/pkg/aicli/cmd_test.go +++ b/pkg/aicli/cmd_test.go @@ -9,7 +9,9 @@ import ( ) func TestCmd(t *testing.T) { - cmd := NewCmd(&Echo{}) + cmd := NewCmd() + cmd.AI = "echo" + cmd.AddAI("echo", &Echo{}) stdinr, stdinw := io.Pipe() stdout, stdoutw := io.Pipe() stderr, stderrw := io.Pipe() @@ -18,7 +20,6 @@ func TestCmd(t *testing.T) { cmd.stdout = stdoutw cmd.stderr = stderrw cmd.dotAICLIDir = t.TempDir() - cmd.OpenAIAPIKey = "blah" done := make(chan struct{}) var runErr error @@ -41,7 +42,7 @@ func TestCmd(t *testing.T) { _, _ = stdinw.Write([]byte("\\reset\n")) require.NoError(t, runErr) _, _ = stdinw.Write([]byte("\\config\n")) - expect(t, stderr, []byte("OpenAI_API_Key: length=4\nOpenAIModel: gpt-3.5-turbo\nTemperature: 0.700000\nVerbose: false\nContextLimit: 10000\n")) + expect(t, stderr, []byte("AI: echo\nModel: gpt-3.5-turbo\nTemperature: 0.700000\nVerbose: false\nContextLimit: 10000\n")) _, _ = stdinw.Write([]byte("\\reset\n")) _, _ = stdinw.Write([]byte("\\file ./testdata/myfile\n")) expect(t, stdout, []byte("Here is a file named './testdata/myfile' that I'll refer to later, you can just say 'ok': \n```\nhaha\n```\n\n")) diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 509373a..7922708 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -10,6 +10,14 @@ import ( openai "github.com/sashabaranov/go-openai" ) +type Config struct { + APIKey string `flag:"openai-api-key" help:"Your API key for OpenAI."` +} + +func NewConfig() Config { + return Config{} +} + var _ aicli.AI = &Client{} // assert that Client satisfies AI interface type Client struct {