Skip to content

Commit

Permalink
bedrock llama2 support cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jaffee committed Dec 18, 2023
1 parent f370277 commit 382df4c
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 72 deletions.
1 change: 0 additions & 1 deletion .github/workflows/.golangci-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ name: golangci-lint
on:
push:
branches:
- master
- main
pull_request:

Expand Down
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.21.1
replace github.com/sashabaranov/go-openai => github.com/jaffee/go-openai v0.0.0-20231121153610-1c05908c31a0

require (
github.com/aws/aws-sdk-go v1.48.7
github.com/aws/aws-sdk-go-v2/config v1.25.5
github.com/aws/aws-sdk-go-v2/service/bedrock v1.3.3
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.3.3
Expand All @@ -31,6 +32,6 @@ require (
github.com/aws/smithy-go v1.17.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.1.0 // indirect
golang.org/x/text v0.3.7 // indirect
golang.org/x/sys v0.5.0 // indirect
golang.org/x/text v0.13.0 // indirect
)
11 changes: 8 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAE
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8=
github.com/aws/aws-sdk-go v1.48.7 h1:gDcOhmkohlNk20j0uWpko5cLBbwSkB+xpkshQO45F7Y=
github.com/aws/aws-sdk-go v1.48.7/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/aws/aws-sdk-go-v2 v1.23.1 h1:qXaFsOOMA+HsZtX8WoCa+gJnbyW7qyFFBlPqvTSzbaI=
github.com/aws/aws-sdk-go-v2 v1.23.1/go.mod h1:i1XDttT4rnf6vxc9AuskLc6s7XBee8rlLilKlc03uAA=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.5.1 h1:ZY3108YtBNq96jNZTICHxN1gSBSbnvIdYwwqnvCV4Mc=
Expand Down Expand Up @@ -76,6 +78,8 @@ github.com/jaffee/commandeer v0.6.0 h1:YI44XLWcJN21euhh32sZW8vM/tljPYxhsXIfEPkQK
github.com/jaffee/commandeer v0.6.0/go.mod h1:kCwfuSvZ2T0NVEr3LDSo6fDUgi0xSBnAVDdkOKTtpLQ=
github.com/jaffee/go-openai v0.0.0-20231121153610-1c05908c31a0 h1:8ZVcm7+jIxlhnqC5WAOlLfRIr81eMsSU0tpjP2GP4ag=
github.com/jaffee/go-openai v0.0.0-20231121153610-1c05908c31a0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvWXihfKN4Q=
Expand Down Expand Up @@ -148,11 +152,12 @@ golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
Expand Down
92 changes: 27 additions & 65 deletions pkg/aws/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ package aws

import (
"context"
"encoding/json"
"fmt"
"io"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrock"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
Expand All @@ -16,7 +15,10 @@ import (
)

const (
ModelLlama213BChatV1 = "meta.llama2-13b-chat-v1"
ModelLlama213BChatV1 = "meta.llama2-13b-chat-v1"
ModelLlama270BChatV1 = "meta.llama2-70b-chat-v1"
ModelTitanTextExpress = "amazon.titan-text-express-v1"
ModelTitanEmbedText = "amazon.titan-embed-text-v1"
)

// NewAI gets a new AI which uses the default AWS configuration (i.e. ~/.aws/config and standard AWS env vars).
Expand All @@ -27,22 +29,16 @@ func NewAI() (*AI, error) {
return nil, errors.Wrap(err, "loading default aws config")
}

// brc := bedrock.NewFromConfig(cfg)
// lfmOutput, err := brc.ListFoundationModels(context.Background(), &bedrock.ListFoundationModelsInput{})
// if err != nil {
// return nil, errors.Wrap(err, "listing models")
// }

// lcmOutput, err := brc.ListCustomModels(context.Background(), &bedrock.ListCustomModelsInput{})
// if err != nil {
// return nil, errors.Wrap(err, "listing custom models")
// }
brrc := bedrockruntime.NewFromConfig(cfg)
return &AI{
client: brrc,
}, nil
}

func NewAIFromConfig(cfg aws.Config) (*AI, error) {
brrc := bedrockruntime.NewFromConfig(cfg)
return &AI{
client: brrc,
//Output: lfmOutput,
//CustomOutput: lcmOutput,
}, nil
}

Expand All @@ -54,32 +50,25 @@ type AI struct {
}

func (ai *AI) GenerateStream(req *aicli.GenerateRequest, output io.Writer) (aicli.Message, error) {
fmt.Printf("req: %+v\n", req)
accept := "application/json"
model := ModelLlama213BChatV1
var body []byte
var sub AWSSubModel
switch req.Model {
case ModelLlama213BChatV1, "":
// TODO we'll eventually need different implementations for different
// models, but I only care about llama2 at the moment
sub = LlamaSubModel{}
default:
return nil, errors.Errorf("%s is not currently a supported model (try 'meta.llama2-13b-chat-v1')", req.Model)
}
bod := LlamaBody{
Prompt: promptifyMessages(req.Messages),
Temperature: req.Temperature,
TopP: 0.9,
MaxGenLen: 100,
}

bs, err := json.Marshal(bod)
body, err := sub.MakeBody(req)
if err != nil {
return nil, errors.Wrap(err, "marshalling")
return nil, errors.Wrap(err, "making body")
}
fmt.Printf("bod: %s\n", bs)

accept := "application/json"
streamOutput, err := ai.client.InvokeModelWithResponseStream(context.Background(), &bedrockruntime.InvokeModelWithResponseStreamInput{
Body: bs,
ModelId: &model,
Body: body,
ModelId: &req.Model,
Accept: &accept,
ContentType: &accept,
})
Expand All @@ -93,13 +82,13 @@ func (ai *AI) GenerateStream(req *aicli.GenerateRequest, output io.Writer) (aicl
for event := range echan {
switch eventT := event.(type) {
case *types.ResponseStreamMemberChunk:
chunk := &Event{}
if err := json.Unmarshal(eventT.Value.Bytes, chunk); err != nil {
return nil, errors.Wrap(err, "unmarshaling response")
chunk, err := sub.HandleResponseChunk(eventT.Value.Bytes)
if err != nil {
return nil, errors.Wrap(err, "handling chunk")
}

bldr.WriteString(chunk.Generation)
_, err := output.Write([]byte(chunk.Generation))
_, _ = bldr.Write(chunk)
_, err = output.Write(chunk)
if err != nil {
return nil, errors.Wrap(err, "writing output")
}
Expand All @@ -115,34 +104,7 @@ func (ai *AI) GetEmbedding(req *aicli.EmbeddingRequest) ([]aicli.Embedding, erro
return nil, errors.New("unimplemented")
}

type LlamaBody struct {
Prompt string `json:"prompt"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
MaxGenLen int `json:"max_gen_len"`
}

func promptifyMessages(msgs []aicli.Message) string {
bldr := &strings.Builder{}
bldr.WriteString("[INST] ")
msgsStart := 0
if msgs[0].Role() == aicli.RoleSystem {
fmt.Fprintf(bldr, "<<SYS>>\n%s\n<</SYS>>\n", msgs[0].Content())
msgsStart = 1
}
if len(msgs) == msgsStart {
return bldr.String()
}
for _, msg := range msgs[msgsStart:] {
fmt.Fprintf(bldr, "%s: %s\n", msg.Role(), msg.Content())
}
bldr.WriteString(" [/INST]\n")
return bldr.String()
}

type Event struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason *string `json:"stop_reason"`
type AWSSubModel interface {
MakeBody(req *aicli.GenerateRequest) ([]byte, error)
HandleResponseChunk(chunkBytes []byte) ([]byte, error)
}
5 changes: 4 additions & 1 deletion pkg/aws/ai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package aws_test

import (
"bytes"
"os"
"testing"

"github.com/jaffee/aicli/pkg/aicli"
Expand All @@ -13,6 +14,8 @@ func TestNewAI(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
err := os.Setenv("AWS_REGION", "us-east-1")
require.NoError(t, err)

ai, err := aws.NewAI()
require.NoError(t, err)
Expand All @@ -27,8 +30,8 @@ func TestNewAI(t *testing.T) {
ContentField: "hello, please respond with 'hello'",
},
}}, buf)

require.NoError(t, err)

require.Equal(t, "assistant", resp.Role())
require.True(t, 4 < len(resp.Content()))
require.True(t, 4 < len(buf.Bytes()))
Expand Down
68 changes: 68 additions & 0 deletions pkg/aws/llama.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package aws

import (
"encoding/json"
"fmt"
"strings"

"github.com/jaffee/aicli/pkg/aicli"
"github.com/pkg/errors"
)

type LlamaSubModel struct{}

func (m LlamaSubModel) MakeBody(req *aicli.GenerateRequest) ([]byte, error) {
bod := llamaBody{
Prompt: llamaPromptifyMessages(req.Messages),
Temperature: req.Temperature,
TopP: 0.9,
MaxGenLen: 100,
}

bs, err := json.Marshal(bod)
if err != nil {
return nil, errors.Wrap(err, "marshalling")
}

return bs, nil
}

func (m LlamaSubModel) HandleResponseChunk(chunkBytes []byte) ([]byte, error) {
chunk := llamaEvent{}
if err := json.Unmarshal(chunkBytes, &chunk); err != nil {
return nil, err
}
return []byte(chunk.Generation), nil
}

type llamaBody struct {
Prompt string `json:"prompt"`
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p"`
MaxGenLen int `json:"max_gen_len"`
}

func llamaPromptifyMessages(msgs []aicli.Message) string {
bldr := &strings.Builder{}
bldr.WriteString("[INST] ")
msgsStart := 0
if msgs[0].Role() == aicli.RoleSystem {
fmt.Fprintf(bldr, "<<SYS>>\n%s\n<</SYS>>\n", msgs[0].Content())
msgsStart = 1
}
if len(msgs) == msgsStart {
return bldr.String()
}
for _, msg := range msgs[msgsStart:] {
fmt.Fprintf(bldr, "%s: %s\n", msg.Role(), msg.Content())
}
bldr.WriteString(" [/INST] ")
return bldr.String()
}

type llamaEvent struct {
Generation string `json:"generation"`
PromptTokenCount int `json:"prompt_token_count"`
GenerationTokenCount int `json:"generation_token_count"`
StopReason *string `json:"stop_reason"`
}

0 comments on commit 382df4c

Please sign in to comment.