Skip to content

Commit

Permalink
fix: misc improvements related to creds and prompting
Browse files Browse the repository at this point in the history
Signed-off-by: Grant Linville <[email protected]>
  • Loading branch information
g-linville committed Jun 24, 2024
1 parent 60da900 commit 47f097e
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
7 changes: 7 additions & 0 deletions pkg/prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string
func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) {
defer context2.GetPauseFuncFromCtx(ctx)()()

if req.Message != "" && len(req.Fields) == 1 && strings.TrimSpace(req.Fields[0]) == "" {
_, _ = fmt.Fprintln(os.Stderr, req.Message)
_, _ = fmt.Fprintln(os.Stderr, "Press enter to continue...")
_, _ = fmt.Fscanln(os.Stdin)
return "", nil
}

if req.Message != "" && len(req.Fields) != 1 {
_, _ = fmt.Fprintln(os.Stderr, req.Message)
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
Time: time.Now(),
CallContext: callCtx.GetCallContext(),
Type: EventTypeCallFinish,
Content: getFinishEventContent(*state, callCtx),
Content: getEventContent(*state.Continuation.Result, callCtx),
})
if callCtx.Tool.Chat {
return &State{
Expand Down Expand Up @@ -681,7 +681,7 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp
CallContext: callCtx.GetCallContext(),
Type: EventTypeCallProgress,
ChatCompletionID: status.CompletionID,
Content: message.String(),
Content: getEventContent(message.String(), *callCtx),
})
} else {
monitor.Event(Event{
Expand Down Expand Up @@ -821,13 +821,13 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
return state, callResults, nil
}

func getFinishEventContent(state State, callCtx engine.Context) string {
// If it is a credential tool, the finish event contains its output, which is sensitive, so we don't return it.
func getEventContent(content string, callCtx engine.Context) string {
// If it is a credential tool, the progress and finish events may contain its output, which is sensitive, so we don't return it.
if callCtx.ToolCategory == engine.CredentialToolCategory {
return ""
}

return *state.Continuation.Result
return content
}

func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string) ([]string, error) {
Expand Down
8 changes: 4 additions & 4 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,10 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str

inputMap := make(map[string]any)
if input != "" {
err := json.Unmarshal([]byte(input), &inputMap)
if err != nil {
return "", "", nil, fmt.Errorf("failed to unmarshal input: %w", err)
}
// Sometimes this function can be called with input that is not a JSON string.
// This typically happens during chat mode.
// That's why we ignore the error if this fails to unmarshal.
_ = json.Unmarshal([]byte(input), &inputMap)
}

fields, err := shlex.Split(toolName)
Expand Down

0 comments on commit 47f097e

Please sign in to comment.