Skip to content

Commit

Permalink
chat sast: add support for user-input
Browse files Browse the repository at this point in the history
  • Loading branch information
ittaigilat-cx committed Jan 3, 2024
1 parent 3e74db2 commit bf4b3ed
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 47 deletions.
90 changes: 52 additions & 38 deletions internal/commands/chat-sast.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (

const ScanResultsFileErrorFormat = "Error reading and parsing scan results %s"
const CreatePromptErrorFormat = "Error creating prompt for result ID %s"
const UserInputRequiredErrorFormat = "%s is required when %s is provided"

func ChatSastSubCommand(chatWrapper wrappers.ChatWrapper) *cobra.Command {
chatSastCmd := &cobra.Command{
Expand Down Expand Up @@ -59,12 +60,15 @@ func runChatSast(sastChatWrapper wrappers.ChatSastWrapper) func(cmd *cobra.Comma

statefulWrapper := wrapper.NewStatefulWrapper(connector.NewFileSystemConnector(""), chatAPIKey, chatModel, dropLen, 0)

newConversation := false
var userInput string
if chatConversationID == "" {
newConversation = true
chatConversationID = statefulWrapper.GenerateId().String()
} else {
userInput, _ := cmd.Flags().GetString(params.ChatUserInput)
userInput, _ = cmd.Flags().GetString(params.ChatUserInput)
if userInput == "" {
msg := fmt.Sprintf("%s is required when %s is provided", params.ChatUserInput, params.ChatConversationID)
msg := fmt.Sprintf(UserInputRequiredErrorFormat, params.ChatUserInput, params.ChatConversationID)
logger.PrintIfVerbose(msg)
return outputError(cmd, uuid.Nil, errors.Errorf(msg))
}
Expand All @@ -76,43 +80,23 @@ func runChatSast(sastChatWrapper wrappers.ChatSastWrapper) func(cmd *cobra.Comma
return outputError(cmd, id, errors.Errorf(ConversationIDErrorFormat, chatConversationID))
}

scanResults, err := chatsast.ReadResultsSAST(scanResultsFile)
if err != nil {
logger.PrintIfVerbose(err.Error())
return outputError(cmd, id, errors.Errorf(ScanResultsFileErrorFormat, scanResultsFile))
}

if sastResultId == "" {
msg := fmt.Sprintf("currently only %s is supported", params.ChatSastResultId)
logger.PrintIfVerbose(msg)
return outputError(cmd, uuid.Nil, errors.Errorf(msg))
}

//languages := GetLanguages(scanResults, sastLanguage)
//queriesByLanguage := GetQueries(scanResults, languages, sastQuery)
sastResult, err := chatsast.GetResultById(scanResults, sastResultId)
if err != nil {
logger.PrintIfVerbose(err.Error())
return outputError(cmd, id, err)
}

sources, err := chatsast.GetSourcesForResult(sastResult, sourceDir)
if err != nil {
logger.PrintIfVerbose(err.Error())
return outputError(cmd, id, err)
}

prompt, err := chatsast.CreatePrompt(sastResult, sources)
if err != nil {
logger.PrintIfVerbose(err.Error())
return outputError(cmd, id, errors.Errorf(CreatePromptErrorFormat, sastResultId))
}

var newMessages []message.Message
newMessages = append(newMessages, message.Message{
Role: role.User,
Content: prompt,
})
if newConversation {
prompt, err := buildPrompt(scanResultsFile, sastResultId, sourceDir)
if err != nil {
logger.PrintIfVerbose(err.Error())
return outputError(cmd, id, err)
}
newMessages = append(newMessages, message.Message{
Role: role.User,
Content: prompt,
})
} else {
newMessages = append(newMessages, message.Message{
Role: role.User,
Content: userInput,
})
}

response, err := sastChatWrapper.Call(statefulWrapper, id, newMessages)
if err != nil {
Expand All @@ -128,6 +112,36 @@ func runChatSast(sastChatWrapper wrappers.ChatSastWrapper) func(cmd *cobra.Comma
}
}

func buildPrompt(scanResultsFile, sastResultId, sourceDir string) (string, error) {
scanResults, err := chatsast.ReadResultsSAST(scanResultsFile)
if err != nil {
return "", fmt.Errorf("error in build-prompt: %s: %w", fmt.Sprintf(ScanResultsFileErrorFormat, scanResultsFile), err)
}

if sastResultId == "" {
return "", errors.Errorf(fmt.Sprintf("error in build-prompt: currently only --%s is supported", params.ChatSastResultId))
}

//languages := GetLanguages(scanResults, sastLanguage)
//queriesByLanguage := GetQueries(scanResults, languages, sastQuery)
sastResult, err := chatsast.GetResultById(scanResults, sastResultId)
if err != nil {
return "", fmt.Errorf("error in build-prompt: %w", err)
}

sources, err := chatsast.GetSourcesForResult(sastResult, sourceDir)
if err != nil {
return "", fmt.Errorf("error in build-prompt: %w", err)
}

prompt, err := chatsast.CreatePrompt(sastResult, sources)
if err != nil {
return "", fmt.Errorf("error in build-prompt: %s: %w", fmt.Sprintf(CreatePromptErrorFormat, sastResultId), err)
}

return prompt, nil
}

func getMessageContents(response []message.Message) []string {
var responseContent []string
for _, r := range response {
Expand Down
40 changes: 31 additions & 9 deletions internal/commands/chat-sast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func TestChatSastHelp(t *testing.T) {
execCmdNilAssertion(t, "help", "chat", "sast")
}

func TestChatSastInvalidId(t *testing.T) {
func TestChatSastInvalidConversationId(t *testing.T) {
buffer, err := executeRedirectedTestCommand("chat", "sast",
"--conversation-id", "invalidId",
"--chat-apikey", "apiKey",
Expand All @@ -29,11 +29,23 @@ func TestChatSastInvalidId(t *testing.T) {
assert.Assert(t, strings.Contains(s, fmt.Sprintf(ConversationIDErrorFormat, "invalidId")), s)
}

func TestChatSastInvalidScanResultsFile(t *testing.T) {
func TestChatSastNoUserInput(t *testing.T) {
buffer, err := executeRedirectedTestCommand("chat", "sast",
"--conversation-id", uuid.New().String(),
"--chat-apikey", "apiKey",
"--user-input", "userInput",
"--scan-results-file", "file",
"--source-dir", "dir",
"--sast-result-id", "resultId")
assert.NilError(t, err)
output, err := io.ReadAll(buffer)
assert.NilError(t, err)
s := string(output)
assert.Assert(t, strings.Contains(s, fmt.Sprintf(UserInputRequiredErrorFormat, "user-input", "conversation-id")), s)
}

func TestChatSastInvalidScanResultsFile(t *testing.T) {
buffer, err := executeRedirectedTestCommand("chat", "sast",
"--chat-apikey", "apiKey",
"--scan-results-file", "invalidFile",
"--source-dir", "dir",
"--sast-result-id", "resultId")
Expand All @@ -46,9 +58,7 @@ func TestChatSastInvalidScanResultsFile(t *testing.T) {

func TestChatSastInvalideResultId(t *testing.T) {
buffer, err := executeRedirectedTestCommand("chat", "sast",
"--conversation-id", uuid.New().String(),
"--chat-apikey", "apiKey",
"--user-input", "userInput",
"--scan-results-file", "./data/cx_result.json",
"--source-dir", "dir",
"--sast-result-id", "invalidResultId")
Expand All @@ -61,7 +71,6 @@ func TestChatSastInvalideResultId(t *testing.T) {

func TestChatSastInvalidSourceDir(t *testing.T) {
buffer, err := executeRedirectedTestCommand("chat", "sast",
"--conversation-id", uuid.New().String(),
"--chat-apikey", "apiKey",
"--user-input", "userInput",
"--scan-results-file", "./data/cx_result.json",
Expand All @@ -74,11 +83,9 @@ func TestChatSastInvalidSourceDir(t *testing.T) {
assert.Assert(t, strings.Contains(s, "open invalidDir"), s)
}

func TestChatSastCorrectResponse(t *testing.T) {
func TestChatSastFirstMessageCorrectResponse(t *testing.T) {
buffer, err := executeRedirectedTestCommand("chat", "sast",
"--conversation-id", uuid.New().String(),
"--chat-apikey", "apiKey",
"--user-input", "userInput",
"--scan-results-file", "./data/cx_result.json",
"--source-dir", "./data",
"--sast-result-id", "13588362")
Expand All @@ -88,3 +95,18 @@ func TestChatSastCorrectResponse(t *testing.T) {
s := strings.ToLower(string(output))
assert.Assert(t, strings.Contains(s, "mock"), s)
}

func TestChatSastSecondMessageCorrectResponse(t *testing.T) {
buffer, err := executeRedirectedTestCommand("chat", "sast",
"--chat-apikey", "apiKey",
"--scan-results-file", "./data/cx_result.json",
"--source-dir", "./data",
"--sast-result-id", "13588362",
"--conversation-id", uuid.New().String(),
"--user-input", "userInput")
assert.NilError(t, err)
output, err := io.ReadAll(buffer)
assert.NilError(t, err)
s := strings.ToLower(string(output))
assert.Assert(t, strings.Contains(s, "mock"), s)
}

0 comments on commit bf4b3ed

Please sign in to comment.