From bf4b3edfc7098017bc2963cfbda9251c0c028483 Mon Sep 17 00:00:00 2001 From: Ittai Gilat Date: Wed, 3 Jan 2024 14:07:55 +0200 Subject: [PATCH] chat sast: add support for user-input --- internal/commands/chat-sast.go | 90 +++++++++++++++++------------ internal/commands/chat-sast_test.go | 40 ++++++++++--- 2 files changed, 83 insertions(+), 47 deletions(-) diff --git a/internal/commands/chat-sast.go b/internal/commands/chat-sast.go index 112648eb3..3d059fb18 100644 --- a/internal/commands/chat-sast.go +++ b/internal/commands/chat-sast.go @@ -19,6 +19,7 @@ import ( const ScanResultsFileErrorFormat = "Error reading and parsing scan results %s" const CreatePromptErrorFormat = "Error creating prompt for result ID %s" +const UserInputRequiredErrorFormat = "%s is required when %s is provided" func ChatSastSubCommand(chatWrapper wrappers.ChatWrapper) *cobra.Command { chatSastCmd := &cobra.Command{ @@ -59,12 +60,15 @@ func runChatSast(sastChatWrapper wrappers.ChatSastWrapper) func(cmd *cobra.Comma statefulWrapper := wrapper.NewStatefulWrapper(connector.NewFileSystemConnector(""), chatAPIKey, chatModel, dropLen, 0) + newConversation := false + var userInput string if chatConversationID == "" { + newConversation = true chatConversationID = statefulWrapper.GenerateId().String() } else { - userInput, _ := cmd.Flags().GetString(params.ChatUserInput) + userInput, _ = cmd.Flags().GetString(params.ChatUserInput) if userInput == "" { - msg := fmt.Sprintf("%s is required when %s is provided", params.ChatUserInput, params.ChatConversationID) + msg := fmt.Sprintf(UserInputRequiredErrorFormat, params.ChatUserInput, params.ChatConversationID) logger.PrintIfVerbose(msg) return outputError(cmd, uuid.Nil, errors.Errorf(msg)) } @@ -76,43 +80,23 @@ func runChatSast(sastChatWrapper wrappers.ChatSastWrapper) func(cmd *cobra.Comma return outputError(cmd, id, errors.Errorf(ConversationIDErrorFormat, chatConversationID)) } - scanResults, err := chatsast.ReadResultsSAST(scanResultsFile) - if err != nil { - logger.PrintIfVerbose(err.Error()) - return outputError(cmd, id, errors.Errorf(ScanResultsFileErrorFormat, scanResultsFile)) - } - - if sastResultId == "" { - msg := fmt.Sprintf("currently only %s is supported", params.ChatSastResultId) - logger.PrintIfVerbose(msg) - return outputError(cmd, uuid.Nil, errors.Errorf(msg)) - } - - //languages := GetLanguages(scanResults, sastLanguage) - //queriesByLanguage := GetQueries(scanResults, languages, sastQuery) - sastResult, err := chatsast.GetResultById(scanResults, sastResultId) - if err != nil { - logger.PrintIfVerbose(err.Error()) - return outputError(cmd, id, err) - } - - sources, err := chatsast.GetSourcesForResult(sastResult, sourceDir) - if err != nil { - logger.PrintIfVerbose(err.Error()) - return outputError(cmd, id, err) - } - - prompt, err := chatsast.CreatePrompt(sastResult, sources) - if err != nil { - logger.PrintIfVerbose(err.Error()) - return outputError(cmd, id, errors.Errorf(CreatePromptErrorFormat, sastResultId)) - } - var newMessages []message.Message - newMessages = append(newMessages, message.Message{ - Role: role.User, - Content: prompt, - }) + if newConversation { + prompt, err := buildPrompt(scanResultsFile, sastResultId, sourceDir) + if err != nil { + logger.PrintIfVerbose(err.Error()) + return outputError(cmd, id, err) + } + newMessages = append(newMessages, message.Message{ + Role: role.User, + Content: prompt, + }) + } else { + newMessages = append(newMessages, message.Message{ + Role: role.User, + Content: userInput, + }) + } response, err := sastChatWrapper.Call(statefulWrapper, id, newMessages) if err != nil { @@ -128,6 +112,36 @@ func runChatSast(sastChatWrapper wrappers.ChatSastWrapper) func(cmd *cobra.Comma } } +func buildPrompt(scanResultsFile, sastResultId, sourceDir string) (string, error) { + scanResults, err := chatsast.ReadResultsSAST(scanResultsFile) + if err != nil { + return "", fmt.Errorf("error in build-prompt: %s: %w", fmt.Sprintf(ScanResultsFileErrorFormat, scanResultsFile), err) + } + + if sastResultId == "" { + return "", errors.Errorf(fmt.Sprintf("error in build-prompt: currently only --%s is supported", params.ChatSastResultId)) + } + + //languages := GetLanguages(scanResults, sastLanguage) + //queriesByLanguage := GetQueries(scanResults, languages, sastQuery) + sastResult, err := chatsast.GetResultById(scanResults, sastResultId) + if err != nil { + return "", fmt.Errorf("error in build-prompt: %w", err) + } + + sources, err := chatsast.GetSourcesForResult(sastResult, sourceDir) + if err != nil { + return "", fmt.Errorf("error in build-prompt: %w", err) + } + + prompt, err := chatsast.CreatePrompt(sastResult, sources) + if err != nil { + return "", fmt.Errorf("error in build-prompt: %s: %w", fmt.Sprintf(CreatePromptErrorFormat, sastResultId), err) + } + + return prompt, nil +} + func getMessageContents(response []message.Message) []string { var responseContent []string for _, r := range response { diff --git a/internal/commands/chat-sast_test.go b/internal/commands/chat-sast_test.go index a97f50fbf..fff40ff1e 100644 --- a/internal/commands/chat-sast_test.go +++ b/internal/commands/chat-sast_test.go @@ -14,7 +14,7 @@ func TestChatSastHelp(t *testing.T) { execCmdNilAssertion(t, "help", "chat", "sast") } -func TestChatSastInvalidId(t *testing.T) { +func TestChatSastInvalidConversationId(t *testing.T) { buffer, err := executeRedirectedTestCommand("chat", "sast", "--conversation-id", "invalidId", "--chat-apikey", "apiKey", @@ -29,11 +29,23 @@ func TestChatSastInvalidId(t *testing.T) { assert.Assert(t, strings.Contains(s, fmt.Sprintf(ConversationIDErrorFormat, "invalidId")), s) } -func TestChatSastInvalidScanResultsFile(t *testing.T) { +func TestChatSastNoUserInput(t *testing.T) { buffer, err := executeRedirectedTestCommand("chat", "sast", "--conversation-id", uuid.New().String(), "--chat-apikey", "apiKey", - "--user-input", "userInput", + "--scan-results-file", "file", + "--source-dir", "dir", + "--sast-result-id", "resultId") + assert.NilError(t, err) + output, err := io.ReadAll(buffer) + assert.NilError(t, err) + s := string(output) + assert.Assert(t, strings.Contains(s, fmt.Sprintf(UserInputRequiredErrorFormat, "user-input", "conversation-id")), s) +} + +func TestChatSastInvalidScanResultsFile(t *testing.T) { + buffer, err := executeRedirectedTestCommand("chat", "sast", + "--chat-apikey", "apiKey", "--scan-results-file", "invalidFile", "--source-dir", "dir", "--sast-result-id", "resultId") @@ -46,9 +58,7 @@ func TestChatSastInvalidScanResultsFile(t *testing.T) { func TestChatSastInvalideResultId(t *testing.T) { buffer, err := executeRedirectedTestCommand("chat", "sast", - "--conversation-id", uuid.New().String(), "--chat-apikey", "apiKey", - "--user-input", "userInput", "--scan-results-file", "./data/cx_result.json", "--source-dir", "dir", "--sast-result-id", "invalidResultId") @@ -61,7 +71,6 @@ func TestChatSastInvalideResultId(t *testing.T) { func TestChatSastInvalidSourceDir(t *testing.T) { buffer, err := executeRedirectedTestCommand("chat", "sast", - "--conversation-id", uuid.New().String(), "--chat-apikey", "apiKey", "--user-input", "userInput", "--scan-results-file", "./data/cx_result.json", @@ -74,11 +83,9 @@ func TestChatSastInvalidSourceDir(t *testing.T) { assert.Assert(t, strings.Contains(s, "open invalidDir"), s) } -func TestChatSastCorrectResponse(t *testing.T) { +func TestChatSastFirstMessageCorrectResponse(t *testing.T) { buffer, err := executeRedirectedTestCommand("chat", "sast", - "--conversation-id", uuid.New().String(), "--chat-apikey", "apiKey", - "--user-input", "userInput", "--scan-results-file", "./data/cx_result.json", "--source-dir", "./data", "--sast-result-id", "13588362") @@ -88,3 +95,18 @@ func TestChatSastCorrectResponse(t *testing.T) { s := strings.ToLower(string(output)) assert.Assert(t, strings.Contains(s, "mock"), s) } + +func TestChatSastSecondMessageCorrectResponse(t *testing.T) { + buffer, err := executeRedirectedTestCommand("chat", "sast", + "--chat-apikey", "apiKey", + "--scan-results-file", "./data/cx_result.json", + "--source-dir", "./data", + "--sast-result-id", "13588362", + "--conversation-id", uuid.New().String(), + "--user-input", "userInput") + assert.NilError(t, err) + output, err := io.ReadAll(buffer) + assert.NilError(t, err) + s := strings.ToLower(string(output)) + assert.Assert(t, strings.Contains(s, "mock"), s) +}