Skip to content

Commit

Permalink
AST-35664 | Enhance AI Guided Remediation answers (#662)
Browse files Browse the repository at this point in the history
* AST-35664 | Enhance AI Guided Remediation answers
* AST-35664 | improved test readability
  • Loading branch information
AlvoBen authored Feb 20, 2024
1 parent 75b2b12 commit 2e486dc
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 43 deletions.
2 changes: 1 addition & 1 deletion internal/commands/chat-sast.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func runChatSast(chatWrapper wrappers.ChatWrapper) func(cmd *cobra.Command, args

responseContent := getMessageContents(response)

responseContent = AddNewlinesIfNecessary(responseContent)
responseContent = addDescriptionForIdentifier(responseContent)

return printer.Print(cmd.OutOrStdout(), &OutputModel{
ConversationID: id.String(),
Expand Down
68 changes: 34 additions & 34 deletions internal/commands/sast-prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package commands

import (
"fmt"
"regexp"
"strings"
)

Expand All @@ -13,12 +12,21 @@ If a question irrelevant to the mentioned source code or SAST result is asked, a
related to source code or SAST results or SAST Queries'.`

const (
confidence = "CONFIDENCE:"
explanation = "EXPLANATION:"
fix = "FIX"
confidence = "**CONFIDENCE:**"
explanation = "**EXPLANATION:**"
fix = "**PROPOSED REMEDIATION:**"
code = "```"
)

const (
confidenceDescription = " A score between 0 (low) and 100 (high) indicating OpenAI's confidence level for the effectiveness of the suggested remediation. <br>"
explanationDescription = " An OpenAI generated description of the vulnerability. <br>"
fixDescription = " A customized snippet, generated by OpenAI, that can be used to remediate the vulnerability in your code. <br>"
)

// This constant is used to format the identifiers (confidence, explanation, fix) and their descriptions with HTML tags
const identifierTitleForamt = "<span style=\"color: regular;\">%s</span><span style=\"color: grey; font-style: italic;\">%s</span>"

const userPromptTemplate = `Checkmarx Static Application Security Testing (SAST) detected the %s vulnerability within the provided %s code snippet.
The attack vector is presented by code snippets annotated by comments in the form ` + "`//SAST Node #X: element (element-type)`" + ` where X is
the node index in the result, ` + "`element`" + ` is the name of the element through which the data flows, and the ` + "`element-type`" + ` is it's type.
Expand All @@ -45,10 +53,13 @@ Please provide a brief explanation for your confidence score, don't mention all
Next, please provide code that remediates the vulnerability so that a developer can copy paste instead of the snippet above.
Your analysis should be presented in the following format:
` + confidence + `number
` + explanation + `short_text
` + fix + `: fixed_snippet`
Your analysis MUST be presented in the following format:
` + confidence +
`number
` + "\n" + explanation +
`short_text
` + "\n" + fix + ":" +
`fixed_snippet`

func GetSystemPrompt() string {
return systemPrompt
Expand Down Expand Up @@ -104,6 +115,7 @@ func createSourceForPrompt(result *Result, sources map[string][]string) (string,
methodLines[lineInMethod] += fmt.Sprintf("//SAST Node #%d%s: %s (%s)", i, edge, node.Name, nodeType)
methodsInPrompt[sourceFilename+":"+node.Method] = methodLines
}

for _, methodLines := range methodsInPrompt {
methodLines = append(methodLines, "// method continues ...")
sourcePrompt = append(sourcePrompt, methodLines...)
Expand Down Expand Up @@ -135,34 +147,22 @@ func GetMethodByMethodLine(filename string, lines []string, methodLineNumber, no
return methodLines, nil
}

func AddNewlinesIfNecessary(responseContent []string) []string {
if len(responseContent) == 0 {
return responseContent
func addDescriptionForIdentifier(responseContent []string) []string {
identifiersDescription := map[string]string{
confidence: confidenceDescription,
explanation: explanationDescription,
fix: fixDescription,
}
stringToFix := responseContent[len(responseContent)-1]

stringToFix = addNewlineIfNecessary(stringToFix, confidence, explanation)
stringToFix = addNewlineIfNecessary(stringToFix, explanation, fix)
return append(responseContent[:len(responseContent)-1], stringToFix)
}

func addNewlineIfNecessary(s, from, to string) string {
startsAt := strings.Index(s, from) + len(from)
upTo := strings.Index(s, to)
if startsAt == -1 || upTo == -1 {
return s
}
if !endsWithNewlineAndWhitespace(s[startsAt:upTo]) {
return s[:upTo] + "\n" + s[upTo:]
if len(responseContent) > 0 {
for i := 0; i < len(responseContent); i++ {
for identifier, description := range identifiersDescription {
responseContent[i] = replaceIdentifierTitleIfNeeded(responseContent[i], identifier, description)
}
}
}
return s
return responseContent
}

func endsWithNewlineAndWhitespace(s string) bool {
// Compile the regular expression that matches a newline followed by
// zero or more whitespace characters at the end of the string.
re := regexp.MustCompile(`\n\s*$`)
// Use the FindString method to find a match. If a match is found,
// it means the string ends with a newline and possibly other whitespace characters.
return re.FindString(s) != ""
func replaceIdentifierTitleIfNeeded(input, identifier, identifierDescription string) string {
return strings.Replace(input, identifier, fmt.Sprintf(identifierTitleForamt, identifier, identifierDescription), 1)
}
27 changes: 19 additions & 8 deletions internal/commands/sast-prompt_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
package commands

import (
"fmt"
"testing"
)

func TestAddNewlinesIfNecessaryNoNewlines(t *testing.T) {
input := confidence + " 35 " + explanation + " this is a short explanation." + fix + " a fixed snippet"
expected := confidence + " 35 \n" + explanation + " this is a short explanation.\n" + fix + " a fixed snippet"
const expectedOutputFormat = "<span style=\"color: regular;\">**CONFIDENCE:**</span><span style=\"color: grey; font-style: italic;\"> " +
"A score between 0 (low) and 100 (high) indicating OpenAI's confidence level for the effectiveness of the suggested remediation. " +
"<br></span>%s<span style=\"color: regular;\">**EXPLANATION:**</span><span style=\"color: grey; font-style: italic;\"> " +
"An OpenAI generated description of the vulnerability. <br></span>%s<span style=\"color: " +
"regular;\">**PROPOSED REMEDIATION:**</span><span style=\"color: grey; font-style: italic;\"> " +
"A customized snippet, generated by OpenAI, that can be used to remediate the vulnerability in your code. <br></span>%s"

func getExpectedOutput(confidenceNumber, explanationText, fixText string) string {
return fmt.Sprintf(expectedOutputFormat, confidenceNumber, explanationText, fixText)
}

func TestAddDescriptionForIdentifiers(t *testing.T) {
input := confidence + " 35 " + explanation + " this is a short explanation." + fix + " a fixed snippet"
expected := getExpectedOutput(" 35 ", " this is a short explanation.", " a fixed snippet")
output := getActual(input, t)

if output[len(output)-1] != expected {
Expand All @@ -16,8 +27,8 @@ func TestAddNewlinesIfNecessaryNoNewlines(t *testing.T) {
}

func TestAddNewlinesIfNecessarySomeNewlines(t *testing.T) {
input := confidence + " 35 " + explanation + " this is a short explanation.\n " + fix + " a fixed snippet"
expected := confidence + " 35 \n" + explanation + " this is a short explanation.\n " + fix + " a fixed snippet"
input := confidence + " 35 " + explanation + " this is a short explanation.\n" + fix + " a fixed snippet"
expected := getExpectedOutput(" 35 ", " this is a short explanation.\n", " a fixed snippet")

output := getActual(input, t)

Expand All @@ -27,8 +38,8 @@ func TestAddNewlinesIfNecessarySomeNewlines(t *testing.T) {
}

func TestAddNewlinesIfNecessaryAllNewlines(t *testing.T) {
input := confidence + " 35\n " + explanation + " this is a short explanation.\n " + fix + " a fixed snippet"
expected := input
input := confidence + " 35\n " + explanation + " this is a short explanation.\n" + fix + " a fixed snippet"
expected := getExpectedOutput(" 35\n ", " this is a short explanation.\n", " a fixed snippet")

output := getActual(input, t)

Expand All @@ -40,7 +51,7 @@ func TestAddNewlinesIfNecessaryAllNewlines(t *testing.T) {
func getActual(input string, t *testing.T) []string {
someText := "some text"
response := []string{someText, someText, input}
output := AddNewlinesIfNecessary(response)
output := addDescriptionForIdentifier(response)
for i := 0; i < len(output)-1; i++ {
if output[i] != response[i] {
t.Errorf("All strings except last expected to stay the same")
Expand Down

0 comments on commit 2e486dc

Please sign in to comment.