diff --git a/agents/agents.go b/agents/agents.go index c3eff54a5..f99b41c7d 100644 --- a/agents/agents.go +++ b/agents/agents.go @@ -3,6 +3,7 @@ package agents import ( "context" + "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/tools" ) @@ -11,7 +12,7 @@ import ( type Agent interface { // Plan Given an input and previous steps decide what to do next. Returns // either actions or a finish. - Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs map[string]string) ([]schema.AgentAction, *schema.AgentFinish, error) //nolint:lll + Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs map[string]any, intermediateMessages []llms.ChatMessage) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) //nolint:lll GetInputKeys() []string GetOutputKeys() []string GetTools() []tools.Tool diff --git a/agents/conversational.go b/agents/conversational.go index 5380d9072..b53199b9a 100644 --- a/agents/conversational.go +++ b/agents/conversational.go @@ -62,8 +62,9 @@ func NewConversationalAgent(llm llms.Model, tools []tools.Tool, opts ...Option) func (a *ConversationalAgent) Plan( ctx context.Context, intermediateSteps []schema.AgentStep, - inputs map[string]string, -) ([]schema.AgentAction, *schema.AgentFinish, error) { + inputs map[string]any, + _ []llms.ChatMessage, +) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) { fullInputs := make(map[string]any, len(inputs)) for key, value := range inputs { fullInputs[key] = value @@ -88,7 +89,7 @@ func (a *ConversationalAgent) Plan( chains.WithStreamingFunc(stream), ) if err != nil { - return nil, nil, err + return nil, nil, nil, err } return a.parseOutput(output) @@ -130,7 +131,7 @@ func constructScratchPad(steps []schema.AgentStep) string { return scratchPad } -func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) { +func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) { if strings.Contains(output, _conversationalFinalAnswerAction) { splits := strings.Split(output, _conversationalFinalAnswerAction) @@ -141,18 +142,18 @@ func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction, Log: output, } - return nil, finishAction, nil + return nil, finishAction, nil, nil } r := regexp.MustCompile(`Action: (.*?)[\n]*Action Input: (.*)`) matches := r.FindStringSubmatch(output) if len(matches) == 0 { - return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output) + return nil, nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output) } return []schema.AgentAction{ {Tool: strings.TrimSpace(matches[1]), ToolInput: strings.TrimSpace(matches[2]), Log: output}, - }, nil, nil + }, nil, nil, nil } //go:embed prompts/conversational_prefix.txt diff --git a/agents/executor.go b/agents/executor.go index 5835ad008..0acedde72 100644 --- a/agents/executor.go +++ b/agents/executor.go @@ -8,6 +8,7 @@ import ( "github.com/tmc/langchaingo/callbacks" "github.com/tmc/langchaingo/chains" + "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" "github.com/tmc/langchaingo/tools" ) @@ -48,16 +49,18 @@ func NewExecutor(agent Agent, opts ...Option) *Executor { } func (e *Executor) Call(ctx context.Context, inputValues map[string]any, _ ...chains.ChainCallOption) (map[string]any, error) { //nolint:lll - inputs, err := inputsToString(inputValues) - if err != nil { - return nil, err - } + // inputs, err := inputsToString(inputValues) + // if err != nil { + // return nil, err + //} nameToTool := getNameToTool(e.Agent.GetTools()) steps := make([]schema.AgentStep, 0) + var intermediateMessages []llms.ChatMessage + var err error for i := 0; i < e.MaxIterations; i++ { var finish map[string]any - steps, finish, err = e.doIteration(ctx, steps, nameToTool, inputs) + steps, finish, intermediateMessages, err = e.doIteration(ctx, steps, nameToTool, inputValues, intermediateMessages) if finish != nil || err != nil { return finish, err } @@ -78,9 +81,13 @@ func (e *Executor) doIteration( // nolint ctx context.Context, steps []schema.AgentStep, nameToTool map[string]tools.Tool, - inputs map[string]string, -) ([]schema.AgentStep, map[string]any, error) { - actions, finish, err := e.Agent.Plan(ctx, steps, inputs) + inputs map[string]any, + intermediateMessages []llms.ChatMessage, +) ([]schema.AgentStep, map[string]any, []llms.ChatMessage, error) { + actions, finish, newIntermediateMessages, err := e.Agent.Plan(ctx, steps, inputs, intermediateMessages) + if len(newIntermediateMessages) > 0 { + intermediateMessages = append(intermediateMessages, newIntermediateMessages...) + } if errors.Is(err, ErrUnableToParseOutput) && e.ErrorHandler != nil { formattedObservation := err.Error() if e.ErrorHandler.Formatter != nil { @@ -89,60 +96,79 @@ func (e *Executor) doIteration( // nolint steps = append(steps, schema.AgentStep{ Observation: formattedObservation, }) - return steps, nil, nil + return steps, nil, intermediateMessages, nil } if err != nil { - return steps, nil, err + return steps, nil, intermediateMessages, err } if len(actions) == 0 && finish == nil { - return steps, nil, ErrAgentNoReturn + return steps, nil, intermediateMessages, ErrAgentNoReturn } if finish != nil { if e.CallbacksHandler != nil { e.CallbacksHandler.HandleAgentFinish(ctx, *finish) } - return steps, e.getReturn(finish, steps), nil + return steps, e.getReturn(finish, steps), intermediateMessages, nil } - for _, action := range actions { - steps, err = e.doAction(ctx, steps, nameToTool, action) - if err != nil { - return steps, nil, err + stepStreams := make([]<-chan schema.AgentStepWithError, len(actions)) + for index, action := range actions { + stepStreams[index] = e.doAction(ctx, nameToTool, action) + } + for _, stepStream := range stepStreams { + agentStepWithError := <-stepStream + if agentStepWithError.Error != nil { + return steps, nil, intermediateMessages, agentStepWithError.Error } + steps = append(steps, agentStepWithError.AgentStep) } - return steps, nil, nil + return steps, nil, intermediateMessages, nil } func (e *Executor) doAction( ctx context.Context, - steps []schema.AgentStep, nameToTool map[string]tools.Tool, action schema.AgentAction, -) ([]schema.AgentStep, error) { - if e.CallbacksHandler != nil { - e.CallbacksHandler.HandleAgentAction(ctx, action) - } +) <-chan schema.AgentStepWithError { + agentStepStream := make(chan schema.AgentStepWithError) + go func() { + defer close(agentStepStream) + if e.CallbacksHandler != nil { + e.CallbacksHandler.HandleAgentAction(ctx, action) + } - tool, ok := nameToTool[strings.ToUpper(action.Tool)] - if !ok { - return append(steps, schema.AgentStep{ - Action: action, - Observation: fmt.Sprintf("%s is not a valid tool, try another one", action.Tool), - }), nil - } + tool, ok := nameToTool[strings.ToUpper(action.Tool)] + if !ok { + agentStepStream <- schema.AgentStepWithError{ + AgentStep: schema.AgentStep{ + Action: action, + Observation: fmt.Sprintf("%s is not a valid tool, try another one", action.Tool), + }, + Error: nil, + } + return + } - observation, err := tool.Call(ctx, action.ToolInput) - if err != nil { - return nil, err - } + observation, err := tool.Call(ctx, action.ToolInput) + if err != nil { + agentStepStream <- schema.AgentStepWithError{ + AgentStep: schema.AgentStep{}, Error: err, + } + return + } - return append(steps, schema.AgentStep{ - Action: action, - Observation: observation, - }), nil + agentStepStream <- schema.AgentStepWithError{ + AgentStep: schema.AgentStep{ + Action: action, + Observation: observation, + }, + Error: nil, + } + }() + return agentStepStream } func (e *Executor) getReturn(finish *schema.AgentFinish, steps []schema.AgentStep) map[string]any { @@ -172,19 +198,19 @@ func (e *Executor) GetCallbackHandler() callbacks.Handler { //nolint:ireturn return e.CallbacksHandler } -func inputsToString(inputValues map[string]any) (map[string]string, error) { - inputs := make(map[string]string, len(inputValues)) - for key, value := range inputValues { - valueStr, ok := value.(string) - if !ok { - return nil, fmt.Errorf("%w: %s", ErrExecutorInputNotString, key) - } - - inputs[key] = valueStr - } - - return inputs, nil -} +// func inputsToString(inputValues map[string]any) (map[string]string, error) { +// inputs := make(map[string]string, len(inputValues)) +// for key, value := range inputValues { +// valueStr, ok := value.(string) +// if !ok { +// return nil, fmt.Errorf("%w: %s", ErrExecutorInputNotString, key) +// } +// +// inputs[key] = valueStr +// } +// +// return inputs, nil +//} func getNameToTool(t []tools.Tool) map[string]tools.Tool { if len(t) == 0 { diff --git a/agents/executor_test.go b/agents/executor_test.go index 2d7193fbe..750f1b738 100644 --- a/agents/executor_test.go +++ b/agents/executor_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/tmc/langchaingo/agents" "github.com/tmc/langchaingo/chains" + "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/llms/openai" "github.com/tmc/langchaingo/prompts" "github.com/tmc/langchaingo/schema" @@ -24,27 +25,27 @@ type testAgent struct { outputKeys []string recordedIntermediateSteps []schema.AgentStep - recordedInputs map[string]string + recordedInputs map[string]any numPlanCalls int } func (a *testAgent) Plan( _ context.Context, intermediateSteps []schema.AgentStep, - inputs map[string]string, -) ([]schema.AgentAction, *schema.AgentFinish, error) { + inputs map[string]any, _ []llms.ChatMessage, +) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) { a.recordedIntermediateSteps = intermediateSteps a.recordedInputs = inputs a.numPlanCalls++ - return a.actions, a.finish, a.err + return a.actions, a.finish, nil, a.err } -func (a testAgent) GetInputKeys() []string { +func (a *testAgent) GetInputKeys() []string { return a.inputKeys } -func (a testAgent) GetOutputKeys() []string { +func (a *testAgent) GetOutputKeys() []string { return a.outputKeys } diff --git a/agents/markl_test.go b/agents/markl_test.go index 3112f77ea..65924c170 100644 --- a/agents/markl_test.go +++ b/agents/markl_test.go @@ -40,7 +40,7 @@ func TestMRKLOutputParser(t *testing.T) { a := OneShotZeroAgent{} for _, tc := range testCases { - actions, finish, err := a.parseOutput(tc.input) + actions, finish, _, err := a.parseOutput(tc.input) require.ErrorIs(t, tc.expectedErr, err) require.Equal(t, tc.expectedActions, actions) require.Equal(t, tc.expectedFinish, finish) diff --git a/agents/mrkl.go b/agents/mrkl.go index 4bc577408..1259e50aa 100644 --- a/agents/mrkl.go +++ b/agents/mrkl.go @@ -63,8 +63,9 @@ func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *OneSho func (a *OneShotZeroAgent) Plan( ctx context.Context, intermediateSteps []schema.AgentStep, - inputs map[string]string, -) ([]schema.AgentAction, *schema.AgentFinish, error) { + inputs map[string]any, + _ []llms.ChatMessage, +) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) { fullInputs := make(map[string]any, len(inputs)) for key, value := range inputs { fullInputs[key] = value @@ -90,7 +91,7 @@ func (a *OneShotZeroAgent) Plan( chains.WithStreamingFunc(stream), ) if err != nil { - return nil, nil, err + return nil, nil, nil, err } return a.parseOutput(output) @@ -131,7 +132,7 @@ func constructMrklScratchPad(steps []schema.AgentStep) string { return scratchPad } -func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) { +func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) { if strings.Contains(output, _finalAnswerAction) { splits := strings.Split(output, _finalAnswerAction) @@ -140,16 +141,16 @@ func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *sc a.OutputKey: splits[len(splits)-1], }, Log: output, - }, nil + }, nil, nil } r := regexp.MustCompile(`Action:\s*(.+)\s*Action Input:\s(?s)*(.+)`) matches := r.FindStringSubmatch(output) if len(matches) == 0 { - return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output) + return nil, nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output) } return []schema.AgentAction{ {Tool: strings.TrimSpace(matches[1]), ToolInput: strings.TrimSpace(matches[2]), Log: output}, - }, nil, nil + }, nil, nil, nil } diff --git a/agents/openai_functions_agent.go b/agents/openai_functions_agent.go index 1ff621a29..5882ac2bf 100644 --- a/agents/openai_functions_agent.go +++ b/agents/openai_functions_agent.go @@ -48,18 +48,21 @@ func NewOpenAIFunctionsAgent(llm llms.Model, tools []tools.Tool, opts ...Option) } } -func (o *OpenAIFunctionsAgent) functions() []llms.FunctionDefinition { - res := make([]llms.FunctionDefinition, 0) +func (o *OpenAIFunctionsAgent) tools() []llms.Tool { + res := make([]llms.Tool, 0) for _, tool := range o.Tools { - res = append(res, llms.FunctionDefinition{ - Name: tool.Name(), - Description: tool.Description(), - Parameters: map[string]any{ - "properties": map[string]any{ - "__arg1": map[string]string{"title": "__arg1", "type": "string"}, + res = append(res, llms.Tool{ + Type: "function", + Function: &llms.FunctionDefinition{ + Name: tool.Name(), + Description: tool.Description(), + Parameters: map[string]any{ + "properties": map[string]any{ + "__arg1": map[string]string{"title": "__arg1", "type": "string"}, + }, + "required": []string{"__arg1"}, + "type": "object", }, - "required": []string{"__arg1"}, - "type": "object", }, }) } @@ -70,13 +73,14 @@ func (o *OpenAIFunctionsAgent) functions() []llms.FunctionDefinition { func (o *OpenAIFunctionsAgent) Plan( ctx context.Context, intermediateSteps []schema.AgentStep, - inputs map[string]string, -) ([]schema.AgentAction, *schema.AgentFinish, error) { + inputs map[string]any, + intermediateMessages []llms.ChatMessage, +) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) { fullInputs := make(map[string]any, len(inputs)) for key, value := range inputs { fullInputs[key] = value } - fullInputs[agentScratchpad] = o.constructScratchPad(intermediateSteps) + fullInputs[agentScratchpad] = o.constructScratchPad(intermediateMessages, intermediateSteps) var stream func(ctx context.Context, chunk []byte) error @@ -89,7 +93,7 @@ func (o *OpenAIFunctionsAgent) Plan( prompt, err := o.Prompt.FormatPrompt(fullInputs) if err != nil { - return nil, nil, err + return nil, nil, nil, err } mcList := make([]llms.MessageContent, len(prompt.Messages())) @@ -112,14 +116,20 @@ func (o *OpenAIFunctionsAgent) Plan( case llms.AIChatMessage: mc = llms.MessageContent{ Role: role, - Parts: []llms.ContentPart{ - llms.ToolCall{ - ID: p.ToolCalls[0].ID, - Type: p.ToolCalls[0].Type, - FunctionCall: p.ToolCalls[0].FunctionCall, - }, - }, } + var contentParts []llms.ContentPart + for _, toolCall := range p.ToolCalls { + contentParts = append(contentParts, llms.ToolCall{ + ID: toolCall.ID, + Type: toolCall.Type, + FunctionCall: toolCall.FunctionCall, + }) + } + if len(text) > 0 { + contentParts = append(contentParts, llms.TextContent{Text: text}) + } + mc.Parts = contentParts + default: mc = llms.MessageContent{ Role: role, @@ -130,9 +140,9 @@ func (o *OpenAIFunctionsAgent) Plan( } result, err := o.LLM.GenerateContent(ctx, mcList, - llms.WithFunctions(o.functions()), llms.WithStreamingFunc(stream)) + llms.WithTools(o.tools()), llms.WithStreamingFunc(stream)) if err != nil { - return nil, nil, err + return nil, nil, nil, err } return o.ParseOutput(result) @@ -173,66 +183,86 @@ func createOpenAIFunctionPrompt(opts Options) prompts.ChatPromptTemplate { return tmpl } -func (o *OpenAIFunctionsAgent) constructScratchPad(steps []schema.AgentStep) []llms.ChatMessage { +func (o *OpenAIFunctionsAgent) constructScratchPad(intermediateMessages []llms.ChatMessage, steps []schema.AgentStep) []llms.ChatMessage { if len(steps) == 0 { return nil } messages := make([]llms.ChatMessage, 0) - for _, step := range steps { - messages = append(messages, llms.FunctionChatMessage{ - Name: step.Action.Tool, - Content: step.Observation, - }) + + for _, message := range intermediateMessages { + var messageToolCalls []llms.ToolCall + aiChatMessage, ok := message.(llms.AIChatMessage) + if ok { + messageToolCalls = append(messageToolCalls, aiChatMessage.ToolCalls...) + messages = append(messages, message) + for _, toolCall := range messageToolCalls { + for _, step := range steps { + toolCallID := step.Action.ToolID + functionName := step.Action.Tool + arguments := step.Action.ToolInputOriginalArguments + if toolCallID != toolCall.ID || functionName != toolCall.FunctionCall.Name || + arguments != toolCall.FunctionCall.Arguments { + // add tool call messages only for previous assistant message + continue + } + messages = append(messages, llms.ToolChatMessage{ + ID: toolCallID, + Content: step.Observation, + }) + } + } + } } return messages } func (o *OpenAIFunctionsAgent) ParseOutput(contentResp *llms.ContentResponse) ( - []schema.AgentAction, *schema.AgentFinish, error, + []schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error, ) { - choice := contentResp.Choices[0] - - // finish - if choice.FuncCall == nil { - return nil, &schema.AgentFinish{ - ReturnValues: map[string]any{ - "output": choice.Content, - }, - Log: choice.Content, - }, nil - } - - // action - functionCall := choice.FuncCall - functionName := functionCall.Name - toolInputStr := functionCall.Arguments - toolInputMap := make(map[string]any, 0) - err := json.Unmarshal([]byte(toolInputStr), &toolInputMap) - if err != nil { - return nil, nil, err - } - - toolInput := toolInputStr - if arg1, ok := toolInputMap["__arg1"]; ok { - toolInputCheck, ok := arg1.(string) - if ok { - toolInput = toolInputCheck + var agentActions []schema.AgentAction + intermediateMessages := make([]llms.ChatMessage, 0) + for _, choice := range contentResp.Choices { + // finish + if len(choice.ToolCalls) == 0 { + return nil, &schema.AgentFinish{ + ReturnValues: map[string]any{ + "output": choice.Content, + }, + Log: choice.Content, + }, nil, nil } - } + for _, toolCall := range choice.ToolCalls { + functionName := toolCall.FunctionCall.Name + toolInputMap := make(map[string]any) + toolInputStr := toolCall.FunctionCall.Arguments + err := json.Unmarshal([]byte(toolInputStr), &toolInputMap) + if err != nil { + return nil, nil, nil, err + } - contentMsg := "\n" - if choice.Content != "" { - contentMsg = fmt.Sprintf("responded: %s\n", choice.Content) + toolInput := toolInputStr + if arg1, ok := toolInputMap["__arg1"]; ok { + toolInputCheck, ok := arg1.(string) + if ok { + toolInput = toolInputCheck + } + } + contentMsg := "\n" + if choice.Content != "" { + contentMsg = fmt.Sprintf("responded: %s\n", choice.Content) + } + agentActions = append(agentActions, schema.AgentAction{ + Tool: functionName, + ToolInput: toolInput, + Log: fmt.Sprintf("Invoking: %s with %s \n %s \n", functionName, toolInputStr, contentMsg), + ToolID: toolCall.ID, + ToolInputOriginalArguments: toolInputStr, + }) + } + intermediateMessages = append(intermediateMessages, choice.ChatMessage) } - return []schema.AgentAction{ - { - Tool: functionName, - ToolInput: toolInput, - Log: fmt.Sprintf("Invoking: %s with %s \n %s \n", functionName, toolInputStr, contentMsg), - ToolID: choice.ToolCalls[0].ID, - }, - }, nil, nil + return agentActions, nil, intermediateMessages, nil } diff --git a/llms/generatecontent.go b/llms/generatecontent.go index 8702143b0..daafbcf48 100644 --- a/llms/generatecontent.go +++ b/llms/generatecontent.go @@ -143,6 +143,9 @@ type ContentChoice struct { // ToolCalls is a list of tool calls the model asks to invoke. ToolCalls []ToolCall + + // ChatMessage is a tool_calls message from llm that needs to be sent in next agent executor iteration for llm context + ChatMessage ChatMessage } // TextParts is a helper function to create a MessageContent with a role and a diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 2bb572a0a..a01fe8731 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -457,7 +457,7 @@ func parseStreamingChatResponse(ctx context.Context, r *http.Response, payload * func combineStreamingChatResponse( ctx context.Context, payload *ChatRequest, - responseChan chan StreamedChatResponsePayload, + responseChan <-chan StreamedChatResponsePayload, ) (*ChatCompletionResponse, error) { response := ChatCompletionResponse{ Choices: []*ChatCompletionChoice{ diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 78f8334d2..cce7ddfd7 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -43,17 +43,7 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio return llms.GenerateFromSinglePrompt(ctx, o, prompt, options...) } -// GenerateContent implements the Model interface. -func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, goerr113, funlen - if o.CallbacksHandler != nil { - o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) - } - - opts := llms.CallOptions{} - for _, opt := range options { - opt(&opts) - } - +func buildMessagesForRequestFromContent(messages []llms.MessageContent) ([]*ChatMessage, error) { chatMsgs := make([]*ChatMessage, 0, len(messages)) for _, mc := range messages { msg := &ChatMessage{MultiContent: mc.Parts} @@ -95,6 +85,24 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten chatMsgs = append(chatMsgs, msg) } + return chatMsgs, nil +} + +// GenerateContent implements the Model interface. +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, goerr113, funlen + if o.CallbacksHandler != nil { + o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) + } + + opts := llms.CallOptions{} + for _, opt := range options { + opt(&opts) + } + + chatMsgs, err := buildMessagesForRequestFromContent(messages) + if err != nil { + return nil, err + } req := &openaiclient.ChatRequest{ Model: opts.Model, StopWords: opts.StopWords, @@ -152,6 +160,10 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten choices := make([]*llms.ContentChoice, len(result.Choices)) for i, c := range result.Choices { + llmMessage, err := messageFromMessage(c.Message) + if err != nil { + return nil, err + } choices[i] = &llms.ContentChoice{ Content: c.Message.Content, StopReason: fmt.Sprint(c.FinishReason), @@ -161,6 +173,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten "TotalTokens": result.Usage.TotalTokens, "ReasoningTokens": result.Usage.CompletionTokensDetails.ReasoningTokens, }, + ChatMessage: llmMessage, } // Legacy function call handling @@ -170,19 +183,21 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten Arguments: c.Message.FunctionCall.Arguments, } } - for _, tool := range c.Message.ToolCalls { - choices[i].ToolCalls = append(choices[i].ToolCalls, llms.ToolCall{ - ID: tool.ID, - Type: string(tool.Type), - FunctionCall: &llms.FunctionCall{ - Name: tool.Function.Name, - Arguments: tool.Function.Arguments, - }, - }) - } - // populate legacy single-function call field for backwards compatibility - if len(choices[i].ToolCalls) > 0 { - choices[i].FuncCall = choices[i].ToolCalls[0].FunctionCall + if c.FinishReason == "tool_calls" { + for _, tool := range c.Message.ToolCalls { + choices[i].ToolCalls = append(choices[i].ToolCalls, llms.ToolCall{ + ID: tool.ID, + Type: string(tool.Type), + FunctionCall: &llms.FunctionCall{ + Name: tool.Function.Name, + Arguments: tool.Function.Arguments, + }, + }) + } + // populate legacy single-function call field for backwards compatibility + if len(choices[i].ToolCalls) > 0 { + choices[i].FuncCall = choices[i].ToolCalls[0].FunctionCall + } } } response := &llms.ContentResponse{Choices: choices} @@ -248,6 +263,31 @@ func toolFromTool(t llms.Tool) (openaiclient.Tool, error) { return tool, nil } +// messageFromMessage converts a openAI ChatMessage to llms.ChatMessage to pass in next iteration for agentic flow. +func messageFromMessage(c ChatMessage) (llms.ChatMessage, error) { + // TODO need to support only returned assistant tool_calls message for now + switch c.Role { + case "assistant": + var llmToolCalls []llms.ToolCall + for _, toolCall := range c.ToolCalls { + llmToolCalls = append(llmToolCalls, llms.ToolCall{ + ID: toolCall.ID, + Type: string(toolCall.Type), + FunctionCall: &llms.FunctionCall{ + Name: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + }, + }) + } + return llms.AIChatMessage{ + Content: c.Content, + ToolCalls: llmToolCalls, + }, nil + default: + return nil, fmt.Errorf("message role %v not supported", c.Role) + } +} + // toolCallsFromToolCalls converts a slice of llms.ToolCall to a slice of ToolCall. func toolCallsFromToolCalls(tcs []llms.ToolCall) []openaiclient.ToolCall { toolCalls := make([]openaiclient.ToolCall, len(tcs)) diff --git a/schema/schema.go b/schema/schema.go index 120078f10..c26976978 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -2,10 +2,11 @@ package schema // AgentAction is the agent's action to take. type AgentAction struct { - Tool string - ToolInput string - Log string - ToolID string + Tool string + ToolInput string + Log string + ToolID string + ToolInputOriginalArguments string } // AgentStep is a step of the agent. @@ -19,3 +20,9 @@ type AgentFinish struct { ReturnValues map[string]any Log string } + +// AgentStepWithError combines AgentStep with Error for concurrent execution. +type AgentStepWithError struct { + AgentStep + Error error +}