Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Agents): Handle multiple tool calls for OpenAI function agent (rebase #858) #1054

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion agents/agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package agents
import (
"context"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)
Expand All @@ -11,7 +12,7 @@ import (
type Agent interface {
// Plan Given an input and previous steps decide what to do next. Returns
// either actions or a finish.
Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs map[string]string) ([]schema.AgentAction, *schema.AgentFinish, error) //nolint:lll
Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs map[string]any, intermediateMessages []llms.ChatMessage) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) //nolint:lll
GetInputKeys() []string
GetOutputKeys() []string
GetTools() []tools.Tool
Expand Down
15 changes: 8 additions & 7 deletions agents/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ func NewConversationalAgent(llm llms.Model, tools []tools.Tool, opts ...Option)
func (a *ConversationalAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
inputs map[string]any,
_ []llms.ChatMessage,
) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
fullInputs[key] = value
Expand All @@ -88,7 +89,7 @@ func (a *ConversationalAgent) Plan(
chains.WithStreamingFunc(stream),
)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

return a.parseOutput(output)
Expand Down Expand Up @@ -130,7 +131,7 @@ func constructScratchPad(steps []schema.AgentStep) string {
return scratchPad
}

func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) {
func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) {
if strings.Contains(output, _conversationalFinalAnswerAction) {
splits := strings.Split(output, _conversationalFinalAnswerAction)

Expand All @@ -141,18 +142,18 @@ func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction,
Log: output,
}

return nil, finishAction, nil
return nil, finishAction, nil, nil
}

r := regexp.MustCompile(`Action: (.*?)[\n]*Action Input: (.*)`)
matches := r.FindStringSubmatch(output)
if len(matches) == 0 {
return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
return nil, nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
}

return []schema.AgentAction{
{Tool: strings.TrimSpace(matches[1]), ToolInput: strings.TrimSpace(matches[2]), Log: output},
}, nil, nil
}, nil, nil, nil
}

//go:embed prompts/conversational_prefix.txt
Expand Down
126 changes: 76 additions & 50 deletions agents/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)
Expand Down Expand Up @@ -48,16 +49,18 @@ func NewExecutor(agent Agent, opts ...Option) *Executor {
}

func (e *Executor) Call(ctx context.Context, inputValues map[string]any, _ ...chains.ChainCallOption) (map[string]any, error) { //nolint:lll
inputs, err := inputsToString(inputValues)
if err != nil {
return nil, err
}
// inputs, err := inputsToString(inputValues)
// if err != nil {
// return nil, err
//}
nameToTool := getNameToTool(e.Agent.GetTools())

steps := make([]schema.AgentStep, 0)
var intermediateMessages []llms.ChatMessage
var err error
for i := 0; i < e.MaxIterations; i++ {
var finish map[string]any
steps, finish, err = e.doIteration(ctx, steps, nameToTool, inputs)
steps, finish, intermediateMessages, err = e.doIteration(ctx, steps, nameToTool, inputValues, intermediateMessages)
if finish != nil || err != nil {
return finish, err
}
Expand All @@ -78,9 +81,13 @@ func (e *Executor) doIteration( // nolint
ctx context.Context,
steps []schema.AgentStep,
nameToTool map[string]tools.Tool,
inputs map[string]string,
) ([]schema.AgentStep, map[string]any, error) {
actions, finish, err := e.Agent.Plan(ctx, steps, inputs)
inputs map[string]any,
intermediateMessages []llms.ChatMessage,
) ([]schema.AgentStep, map[string]any, []llms.ChatMessage, error) {
actions, finish, newIntermediateMessages, err := e.Agent.Plan(ctx, steps, inputs, intermediateMessages)
if len(newIntermediateMessages) > 0 {
intermediateMessages = append(intermediateMessages, newIntermediateMessages...)
}
if errors.Is(err, ErrUnableToParseOutput) && e.ErrorHandler != nil {
formattedObservation := err.Error()
if e.ErrorHandler.Formatter != nil {
Expand All @@ -89,60 +96,79 @@ func (e *Executor) doIteration( // nolint
steps = append(steps, schema.AgentStep{
Observation: formattedObservation,
})
return steps, nil, nil
return steps, nil, intermediateMessages, nil
}
if err != nil {
return steps, nil, err
return steps, nil, intermediateMessages, err
}

if len(actions) == 0 && finish == nil {
return steps, nil, ErrAgentNoReturn
return steps, nil, intermediateMessages, ErrAgentNoReturn
}

if finish != nil {
if e.CallbacksHandler != nil {
e.CallbacksHandler.HandleAgentFinish(ctx, *finish)
}
return steps, e.getReturn(finish, steps), nil
return steps, e.getReturn(finish, steps), intermediateMessages, nil
}

for _, action := range actions {
steps, err = e.doAction(ctx, steps, nameToTool, action)
if err != nil {
return steps, nil, err
stepStreams := make([]<-chan schema.AgentStepWithError, len(actions))
for index, action := range actions {
stepStreams[index] = e.doAction(ctx, nameToTool, action)
}
for _, stepStream := range stepStreams {
agentStepWithError := <-stepStream
if agentStepWithError.Error != nil {
return steps, nil, intermediateMessages, agentStepWithError.Error
}
steps = append(steps, agentStepWithError.AgentStep)
}

return steps, nil, nil
return steps, nil, intermediateMessages, nil
}

func (e *Executor) doAction(
ctx context.Context,
steps []schema.AgentStep,
nameToTool map[string]tools.Tool,
action schema.AgentAction,
) ([]schema.AgentStep, error) {
if e.CallbacksHandler != nil {
e.CallbacksHandler.HandleAgentAction(ctx, action)
}
) <-chan schema.AgentStepWithError {
agentStepStream := make(chan schema.AgentStepWithError)
go func() {
defer close(agentStepStream)
if e.CallbacksHandler != nil {
e.CallbacksHandler.HandleAgentAction(ctx, action)
}

tool, ok := nameToTool[strings.ToUpper(action.Tool)]
if !ok {
return append(steps, schema.AgentStep{
Action: action,
Observation: fmt.Sprintf("%s is not a valid tool, try another one", action.Tool),
}), nil
}
tool, ok := nameToTool[strings.ToUpper(action.Tool)]
if !ok {
agentStepStream <- schema.AgentStepWithError{
AgentStep: schema.AgentStep{
Action: action,
Observation: fmt.Sprintf("%s is not a valid tool, try another one", action.Tool),
},
Error: nil,
}
return
}

observation, err := tool.Call(ctx, action.ToolInput)
if err != nil {
return nil, err
}
observation, err := tool.Call(ctx, action.ToolInput)
if err != nil {
agentStepStream <- schema.AgentStepWithError{
AgentStep: schema.AgentStep{}, Error: err,
}
return
}

return append(steps, schema.AgentStep{
Action: action,
Observation: observation,
}), nil
agentStepStream <- schema.AgentStepWithError{
AgentStep: schema.AgentStep{
Action: action,
Observation: observation,
},
Error: nil,
}
}()
return agentStepStream
}

func (e *Executor) getReturn(finish *schema.AgentFinish, steps []schema.AgentStep) map[string]any {
Expand Down Expand Up @@ -172,19 +198,19 @@ func (e *Executor) GetCallbackHandler() callbacks.Handler { //nolint:ireturn
return e.CallbacksHandler
}

func inputsToString(inputValues map[string]any) (map[string]string, error) {
inputs := make(map[string]string, len(inputValues))
for key, value := range inputValues {
valueStr, ok := value.(string)
if !ok {
return nil, fmt.Errorf("%w: %s", ErrExecutorInputNotString, key)
}

inputs[key] = valueStr
}

return inputs, nil
}
// func inputsToString(inputValues map[string]any) (map[string]string, error) {
// inputs := make(map[string]string, len(inputValues))
// for key, value := range inputValues {
// valueStr, ok := value.(string)
// if !ok {
// return nil, fmt.Errorf("%w: %s", ErrExecutorInputNotString, key)
// }
//
// inputs[key] = valueStr
// }
//
// return inputs, nil
//}

func getNameToTool(t []tools.Tool) map[string]tools.Tool {
if len(t) == 0 {
Expand Down
13 changes: 7 additions & 6 deletions agents/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/agents"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
Expand All @@ -24,27 +25,27 @@ type testAgent struct {
outputKeys []string

recordedIntermediateSteps []schema.AgentStep
recordedInputs map[string]string
recordedInputs map[string]any
numPlanCalls int
}

func (a *testAgent) Plan(
_ context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
inputs map[string]any, _ []llms.ChatMessage,
) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) {
a.recordedIntermediateSteps = intermediateSteps
a.recordedInputs = inputs
a.numPlanCalls++

return a.actions, a.finish, a.err
return a.actions, a.finish, nil, a.err
}

func (a testAgent) GetInputKeys() []string {
func (a *testAgent) GetInputKeys() []string {
return a.inputKeys
}

func (a testAgent) GetOutputKeys() []string {
func (a *testAgent) GetOutputKeys() []string {
return a.outputKeys
}

Expand Down
2 changes: 1 addition & 1 deletion agents/markl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestMRKLOutputParser(t *testing.T) {

a := OneShotZeroAgent{}
for _, tc := range testCases {
actions, finish, err := a.parseOutput(tc.input)
actions, finish, _, err := a.parseOutput(tc.input)
require.ErrorIs(t, tc.expectedErr, err)
require.Equal(t, tc.expectedActions, actions)
require.Equal(t, tc.expectedFinish, finish)
Expand Down
15 changes: 8 additions & 7 deletions agents/mrkl.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *OneSho
func (a *OneShotZeroAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
inputs map[string]any,
_ []llms.ChatMessage,
) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
fullInputs[key] = value
Expand All @@ -90,7 +91,7 @@ func (a *OneShotZeroAgent) Plan(
chains.WithStreamingFunc(stream),
)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

return a.parseOutput(output)
Expand Down Expand Up @@ -131,7 +132,7 @@ func constructMrklScratchPad(steps []schema.AgentStep) string {
return scratchPad
}

func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) {
func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) {
if strings.Contains(output, _finalAnswerAction) {
splits := strings.Split(output, _finalAnswerAction)

Expand All @@ -140,16 +141,16 @@ func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *sc
a.OutputKey: splits[len(splits)-1],
},
Log: output,
}, nil
}, nil, nil
}

r := regexp.MustCompile(`Action:\s*(.+)\s*Action Input:\s(?s)*(.+)`)
matches := r.FindStringSubmatch(output)
if len(matches) == 0 {
return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
return nil, nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
}

return []schema.AgentAction{
{Tool: strings.TrimSpace(matches[1]), ToolInput: strings.TrimSpace(matches[2]), Log: output},
}, nil, nil
}, nil, nil, nil
}
Loading