Skip to content

Commit

Permalink
Split up into manageable functions
Browse files Browse the repository at this point in the history
  • Loading branch information
edofic committed Apr 14, 2023
1 parent 16b8ed4 commit c361fcc
Showing 1 changed file with 120 additions and 66 deletions.
186 changes: 120 additions & 66 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,59 @@ import (
openai "github.com/sashabaranov/go-openai"
)

const defaultModel = openai.GPT3Dot5Turbo
const (
defaultModel = openai.GPT3Dot5Turbo
sessionFile = "/tmp/chatgpt-cli-last-session.json"
)

type params struct {
maxTokens int
systemMsg string
includeFile string
temperature float64
continueSession bool
msg string
}

func main() {
maxTokens := flag.Int("maxTokens", 500, "Maximum number of tokens to generate")
systemMsg := flag.String("systemMsg", "", "System message to include with the prompt")
includeFile := flag.String("includeFile", "", "File to include with the prompt")
temperature := flag.Float64("temperature", 0, "ChatGPT temperature")
continueSession := flag.Bool("c", false, "Continue last session (ignores other flags)")
p := parseArgs()

client := getClient()
model := os.Getenv("OPENAI_MODEL")
if model == "" {
model = defaultModel
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()

req := getCompletionRequest(p, model)
req = appendMessages(req, p)

fullResponse, err := streamCompletion(ctx, client, req, func(chunk string) error {
_, err := fmt.Print(chunk)
return err
})
fmt.Println()
if err != nil {
panic(err)
}

req.Messages = append(req.Messages, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleAssistant, Content: fullResponse})

err = saveCompletion(req)
if err != nil {
panic(err)
}
}

func parseArgs() params {
// var versions of flags from main, returning a params struct
var p params
flag.IntVar(&p.maxTokens, "maxTokens", 500, "Maximum number of tokens to generate")
flag.StringVar(&p.systemMsg, "systemMsg", "", "System message to include with the prompt")
flag.StringVar(&p.includeFile, "includeFile", "", "File to include with the prompt")
flag.Float64Var(&p.temperature, "temperature", 0, "ChatGPT temperature")
flag.BoolVar(&p.continueSession, "c", false, "Continue last session (ignores other flags)")
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [options] message\n", os.Args[0])
flag.PrintDefaults()
Expand All @@ -31,48 +76,75 @@ func main() {
msg := strings.TrimSpace(strings.Join(flag.Args(), " "))
if msg == "" {
flag.Usage()
return
os.Exit(1)
} else if msg == "-" {
scanner := bufio.NewScanner(os.Stdin)
for scanner.Scan() {
msg += scanner.Text() + "\n"
}
}
p.msg = msg
return p
}

client := getClient()
model := os.Getenv("OPENAI_MODEL")
if model == "" {
model = defaultModel
func getClient() *openai.Client {
apiKey := os.Getenv("OPENAI_API_KEY")
url := os.Getenv("OPENAI_AZURE_ENDPOINT")
if url != "" {
deployment := os.Getenv("OPENAI_AZURE_MODEL")
config := openai.DefaultAzureConfig(apiKey, url, deployment)
return openai.NewClientWithConfig(config)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
return openai.NewClient(apiKey)
}

var req openai.ChatCompletionRequest
if *continueSession {
session, err := os.ReadFile("/tmp/chatgpt-cli-last-session.json")
if err != nil {
panic(err)
}
err = json.Unmarshal(session, &req)
if err != nil {
panic(err)
}
func getCompletionRequest(p params, model string) openai.ChatCompletionRequest {
if p.continueSession {
return loadLastCompletion()
} else {
msgs := []openai.ChatCompletionMessage{}
if *systemMsg != "" {
msgs = append(msgs, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleSystem, Content: *systemMsg})
}
req = openai.ChatCompletionRequest{
Model: model,
MaxTokens: *maxTokens,
Temperature: float32(*temperature),
Stream: true,
Messages: msgs,
}
return newCompletionRequest(p, model)
}
}

func loadLastCompletion() openai.ChatCompletionRequest {
var req openai.ChatCompletionRequest
session, err := os.ReadFile(sessionFile)
if err != nil {
panic(err)
}
err = json.Unmarshal(session, &req)
if err != nil {
panic(err)
}
req.Messages = append(req.Messages, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleUser, Content: msg})
if *includeFile != "" {
contents, err := os.ReadFile(*includeFile)
return req
}

func saveCompletion(req openai.ChatCompletionRequest) error {
resJson, err := json.Marshal(req)
if err != nil {
return err
}
return os.WriteFile(sessionFile, resJson, 0644)
}

func newCompletionRequest(p params, model string) openai.ChatCompletionRequest {
msgs := []openai.ChatCompletionMessage{}
if p.systemMsg != "" {
msgs = append(msgs, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleSystem, Content: p.systemMsg})
}
return openai.ChatCompletionRequest{
Model: model,
MaxTokens: p.maxTokens,
Temperature: float32(p.temperature),
Stream: true,
Messages: msgs,
}
}

func appendMessages(req openai.ChatCompletionRequest, p params) openai.ChatCompletionRequest {
req.Messages = append(req.Messages, openai.ChatCompletionMessage{Role: openai.ChatMessageRoleUser, Content: p.msg})
if p.includeFile != "" {
contents, err := os.ReadFile(p.includeFile)
if err != nil {
panic(err)
}
Expand All @@ -81,51 +153,33 @@ func main() {
openai.ChatCompletionMessage{Role: openai.ChatMessageRoleUser, Content: string(contents)},
)
}
return req
}

func streamCompletion(ctx context.Context, client *openai.Client, req openai.ChatCompletionRequest, callback func(chunk string) error) (fullResponse string, err error) {
stream, err := client.CreateChatCompletionStream(ctx, req)
if err != nil {
fmt.Printf("ChatCompletionStream error: %v\n", err)
return
return "", fmt.Errorf("ChatCompletionStream error: %v\n", err)
}
defer stream.Close()

responseChunks := []string{}
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
fmt.Println()
break
}

if err != nil {
fmt.Printf("\nStream error: %v\n", err)
return
return "", fmt.Errorf("stream error: %v\n", err)
}

chunk := response.Choices[0].Delta.Content
fmt.Print(chunk)
err = callback(chunk)
if err != nil {
return "", fmt.Errorf("callback error: %v\n", err)
}
responseChunks = append(responseChunks, chunk)
}

fullResponse := strings.Join(responseChunks, "")
req.Messages = append(req.Messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: fullResponse,
})

resJson, err := json.Marshal(req)
if err != nil {
panic(err)
}
os.WriteFile("/tmp/chatgpt-cli-last-session.json", resJson, 0644)
}

func getClient() *openai.Client {
apiKey := os.Getenv("OPENAI_API_KEY")
url := os.Getenv("OPENAI_AZURE_ENDPOINT")
if url != "" {
deployment := os.Getenv("OPENAI_AZURE_MODEL")
config := openai.DefaultAzureConfig(apiKey, url, deployment)
return openai.NewClientWithConfig(config)
}
return openai.NewClient(apiKey)
return strings.Join(responseChunks, ""), nil
}

0 comments on commit c361fcc

Please sign in to comment.