Skip to content

Commit

Permalink
Perform "write test" with Symflower template
Browse files Browse the repository at this point in the history
Part of #350
  • Loading branch information
bauersimon committed Oct 4, 2024
1 parent 9d8f006 commit 7b6849e
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 40 deletions.
24 changes: 24 additions & 0 deletions evaluate/task/symflower.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package task

import (
"context"
"fmt"
"time"

pkgerrors "github.com/pkg/errors"
Expand Down Expand Up @@ -31,6 +32,29 @@ func symflowerFix(logger *log.Logger, repositoryPath string, language language.L
return uint64(time.Since(start).Milliseconds()), nil
}

// symflowerTemplate runs the "symflower utg" command and returns its execution time in milliseconds.
func symflowerTemplate(logger *log.Logger, repositoryPath string, language language.Language, filePath string) (duration uint64, err error) {
start := time.Now()

o, err := util.CommandWithResult(context.Background(), logger, &util.Command{
Command: []string{
tools.SymflowerPath, "uts",
"--language", language.ID(),
"--workspace", repositoryPath,
filePath,
},

Directory: repositoryPath,
})
if err != nil {
return 0, pkgerrors.WithStack(err)
}

fmt.Println(o)

return uint64(time.Since(start).Milliseconds()), nil
}

// ExecuteWithSymflowerFix runs the "symflower fix" command and calculates the new assessments.
func ExecuteWithSymflowerFix(ctx evaltask.Context, logger *log.Logger, packagePath string) (testResult *language.TestResult, processingTime uint64, problems []error, err error) {
// Run "symflower fix" if the model response fails to execute.
Expand Down
59 changes: 55 additions & 4 deletions evaluate/task/task-write-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package task

import (
"fmt"
"os"
"strings"

pkgerrors "github.com/pkg/errors"
Expand All @@ -18,6 +19,12 @@ type TaskWriteTests struct {

var _ evaltask.Task = (*TaskWriteTests)(nil)

// TaskArgumentsWriteTest holds extra arguments to be used in a query prompt.
type TaskArgumentsWriteTest struct {
// Template holds the template data to base the tests onto.
Template string
}

// Identifier returns the write test task identifier.
func (t *TaskWriteTests) Identifier() evaltask.Identifier {
return IdentifierWriteTests
Expand Down Expand Up @@ -46,34 +53,76 @@ func (t *TaskWriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[eva

modelAssessment := metrics.NewAssessments()
withSymflowerFixAssessment := metrics.NewAssessments()
withSymflowerTemplateAssessment := metrics.NewAssessments()
withSymflowerTemplateAndFixAssessment := metrics.NewAssessments()

maximumReachableFiles := uint64(len(filePaths))
modelAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles
withSymflowerFixAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles
withSymflowerTemplateAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles
withSymflowerTemplateAndFixAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles

for _, filePath := range filePaths {
// Handle this task case without a template.
if err := ctx.Repository.Reset(ctx.Logger); err != nil {
ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

modelAssessmentFile, withSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, modelCapability, dataPath, filePath)
modelAssessmentFile, withSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, modelCapability, dataPath, filePath, nil)
problems = append(problems, ps...)
if err != nil {
return nil, problems, err
}
modelAssessment.Add(modelAssessmentFile)
withSymflowerFixAssessment.Add(withSymflowerFixAssessmentFile)

if !(ctx.Language.ID() == "golang" || ctx.Language.ID() == "java") { // Symflower templates currently only exist for Go and Java.
withSymflowerTemplateAssessment.Add(modelAssessmentFile)
withSymflowerTemplateAndFixAssessment.Add(withSymflowerFixAssessmentFile)

continue
}

// Handle this task case with a template.
if err := ctx.Repository.Reset(ctx.Logger); err != nil {
ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

_, err = symflowerTemplate(taskLogger.Logger, dataPath, ctx.Language, filePath) // TODO incorporate duration
if err != nil {
problems = append(problems, pkgerrors.WithMessage(err, "generating Symflower template"))

continue
}

testTemplateFilePath := ctx.Language.TestFilePath(dataPath, filePath)
testTemplate, err := os.ReadFile(testTemplateFilePath)
if err != nil {
return nil, nil, pkgerrors.WithMessagef(err, "reading Symflower template from %q", testTemplateFilePath)
}

modelTemplateAssessmentFile, templateWithSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, modelCapability, dataPath, filePath, &TaskArgumentsWriteTest{
Template: string(testTemplate),
})
problems = append(problems, ps...)
if err != nil {
return nil, problems, err
}
withSymflowerTemplateAssessment.Add(modelTemplateAssessmentFile)
withSymflowerTemplateAndFixAssessment.Add(templateWithSymflowerFixAssessmentFile)
}

repositoryAssessment = map[evaltask.Identifier]metrics.Assessments{
IdentifierWriteTests: modelAssessment,
IdentifierWriteTestsSymflowerFix: withSymflowerFixAssessment,
IdentifierWriteTests: modelAssessment,
IdentifierWriteTestsSymflowerFix: withSymflowerFixAssessment,
IdentifierWriteTestsSymflowerTemplate: withSymflowerTemplateAssessment,
IdentifierWriteTestsSymflowerTemplateSymflowerFix: withSymflowerTemplateAndFixAssessment,
}

return repositoryAssessment, problems, nil
}

func runModelAndSymflowerFix(ctx evaltask.Context, taskLogger *taskLogger, modelCapability model.CapabilityWriteTests, dataPath string, filePath string) (modelAssessment metrics.Assessments, withSymflowerFixAssessment metrics.Assessments, problems []error, err error) {
func runModelAndSymflowerFix(ctx evaltask.Context, taskLogger *taskLogger, modelCapability model.CapabilityWriteTests, dataPath string, filePath string, arguments *TaskArgumentsWriteTest) (modelAssessment metrics.Assessments, withSymflowerFixAssessment metrics.Assessments, problems []error, err error) {
modelAssessment = metrics.NewAssessments()
withSymflowerFixAssessment = modelAssessment // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until we successfully applied "symflower fix".

Expand All @@ -84,6 +133,8 @@ func runModelAndSymflowerFix(ctx evaltask.Context, taskLogger *taskLogger, model
FilePath: filePath,

Logger: taskLogger.Logger,

Arguments: arguments,
}
assessments, err := modelCapability.WriteTests(modelContext)
if err != nil {
Expand Down
83 changes: 71 additions & 12 deletions evaluate/task/task-write-test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -38,21 +39,21 @@ func TestTaskWriteTestsRun(t *testing.T) {
})
}

t.Run("Clear repository on each task file", func(t *testing.T) {
t.Run("Clear repository on each task file", func(t *testing.T) { // This test simulates failing tests for the first of two task cases and ensures that they don't influence test execution for the second one.
temporaryDirectoryPath := t.TempDir()

repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "plain")
require.NoError(t, os.MkdirAll(repositoryPath, 0700))
require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "go.mod"), []byte("module plain\n\ngo 1.21.5"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "taskA.go"), []byte("package plain\n\nfunc TaskA(){}"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "taskB.go"), []byte("package plain\n\nfunc TaskB(){}"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "caseA.go"), []byte("package plain\n\nfunc caseA(){}"), 0600))
require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "caseB.go"), []byte("package plain\n\nfunc caseB(){}"), 0600))

modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")

// Generate invalid code for the first taskcontext.
modelMock.RegisterGenerateSuccess(t, "taskA_test.go", "does not compile", metricstesting.AssessmentsWithProcessingTime).Once()
// Generate valid code for the second taskcontext.
modelMock.RegisterGenerateSuccess(t, "taskB_test.go", "package plain\n\nimport \"testing\"\n\nfunc TestTaskB(t *testing.T){}", metricstesting.AssessmentsWithProcessingTime).Once()
// Generate invalid code for caseA (with and without template).
modelMock.RegisterGenerateSuccess(t, "caseA_test.go", "does not compile", metricstesting.AssessmentsWithProcessingTime).Twice()
// Generate valid code for caseB (with and without template).
modelMock.RegisterGenerateSuccess(t, "caseB_test.go", "package plain\n\nimport \"testing\"\n\nfunc TestCaseB(t *testing.T){}", metricstesting.AssessmentsWithProcessingTime).Twice()

validate(t, &tasktesting.TestCaseTask{
Name: "Plain",
Expand All @@ -73,14 +74,27 @@ func TestTaskWriteTestsRun(t *testing.T) {
metrics.AssessmentKeyFilesExecutedMaximumReachable: 2,
metrics.AssessmentKeyResponseNoError: 2,
},
// With the template there would be coverage but it is overwritten.
IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 2,
metrics.AssessmentKeyResponseNoError: 2,
},
IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 2,
metrics.AssessmentKeyResponseNoError: 2,
},
},
ExpectedProblemContains: []string{
"expected 'package', found does",
"exit status 1",
"expected 'package', found does", // Model error.
"exit status 1", // Symflower fix not applicable.
"expected 'package', found does", // Model error (overwrote template).
"exit status 1", // Symflower fix not applicable (overwrote template).
},
ValidateLog: func(t *testing.T, data string) {
assert.Contains(t, data, "Evaluating model \"mocked-model\"")
assert.Contains(t, data, "PASS: TestTaskB")
assert.Equal(t, 1, strings.Count(data, "Evaluating model \"mocked-model\""))
assert.Equal(t, 4, strings.Count(data, "PASS: TestCaseB")) // Bare model result, with fix, with template, with template and fix.
},
})
})
Expand All @@ -93,7 +107,7 @@ func TestTaskWriteTestsRun(t *testing.T) {
require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "plain"), repositoryPath))

modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")
modelMock.RegisterGenerateSuccess(t, "plain_test.go", testFileContent, metricstesting.AssessmentsWithProcessingTime).Once()
modelMock.RegisterGenerateSuccess(t, "plain_test.go", testFileContent, metricstesting.AssessmentsWithProcessingTime).Twice()

validate(t, &tasktesting.TestCaseTask{
Name: testName,
Expand Down Expand Up @@ -127,6 +141,18 @@ func TestTaskWriteTestsRun(t *testing.T) {
metrics.AssessmentKeyResponseNoError: 1,
metrics.AssessmentKeyCoverage: 10,
},
IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
metrics.AssessmentKeyCoverage: 10,
},
IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
metrics.AssessmentKeyCoverage: 10,
},
}
validateGo(t, "Model generated correct test", &golang.Language{}, bytesutil.StringTrimIndentations(`
package plain
Expand All @@ -150,9 +176,20 @@ func TestTaskWriteTestsRun(t *testing.T) {
metrics.AssessmentKeyResponseNoError: 1,
metrics.AssessmentKeyCoverage: 10,
},
IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
metrics.AssessmentKeyCoverage: 10,
},
}
expectedProblems := []string{
"imported and not used",
"imported and not used",
}
validateGo(t, "Model generated test with unused import", &golang.Language{}, bytesutil.StringTrimIndentations(`
package plain
Expand All @@ -177,10 +214,20 @@ func TestTaskWriteTestsRun(t *testing.T) {
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
}
expectedProblems := []string{
"expected declaration, found this",
"unable to format source code",
"expected declaration, found this",
"unable to format source code",
}
validateGo(t, "Model generated test that is unfixable", &golang.Language{}, bytesutil.StringTrimIndentations(`
package plain
Expand Down Expand Up @@ -233,6 +280,18 @@ func TestTaskWriteTestsRun(t *testing.T) {
metrics.AssessmentKeyCoverage: 10,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyCoverage: 10,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyCoverage: 10,
metrics.AssessmentKeyResponseNoError: 1,
},
},
ExpectedProblemContains: nil,
ValidateLog: func(t *testing.T, data string) {
Expand Down
4 changes: 4 additions & 0 deletions evaluate/task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ var (
IdentifierWriteTests = registerIdentifier("write-tests")
// IdentifierWriteTestsSymflowerFix holds the identifier for the "write test" task with the "symflower fix" applied.
IdentifierWriteTestsSymflowerFix = registerIdentifier("write-tests-symflower-fix")
// IdentifierWriteTestsSymflowerTemplate holds the identifier for the "write test" task based on a Symflower template.
IdentifierWriteTestsSymflowerTemplate = registerIdentifier("write-tests-symflower-template")
// IdentifierWriteTestsTemplateSymflowerFix holds the identifier for the "write test" task based on a Symflower template with the "symflower fix" applied.
IdentifierWriteTestsSymflowerTemplateSymflowerFix = registerIdentifier("write-tests-symflower-template-symflower-fix")
// IdentifierCodeRepair holds the identifier for the "code repair" task.
IdentifierCodeRepair = registerIdentifier("code-repair")
// IdentifierTranspile holds the identifier for the "transpile" task.
Expand Down
22 changes: 22 additions & 0 deletions model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ type llmSourceFilePromptContext struct {
type llmWriteTestSourceFilePromptContext struct {
// llmSourceFilePromptContext holds the context for a source file prompt.
llmSourceFilePromptContext

// Template holds the template data to base the tests onto.
Template string
}

// llmWriteTestForFilePromptTemplate is the template for generating an LLM test generation prompt.
Expand All @@ -91,6 +94,14 @@ var llmWriteTestForFilePromptTemplate = template.Must(template.New("model-llm-wr
` + "```" + `{{ .Language.ID }}
{{ .Code }}
` + "```" + `
{{- if .Template}}
The tests should be based on this template:
` + "```" + `{{ .Language.ID }}
{{ .Template -}}
` + "```" + `
{{- end}}
`)))

// Format returns the prompt for generating an LLM test generation.
Expand Down Expand Up @@ -196,6 +207,15 @@ var _ model.CapabilityWriteTests = (*Model)(nil)

// WriteTests generates test files for the given implementation file in a repository.
func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, err error) {
var templateContent string
if ctx.Arguments != nil {
if arguments, ok := ctx.Arguments.(*evaluatetask.TaskArgumentsWriteTest); !ok {
return nil, pkgerrors.Errorf("unexpected type %#v", ctx.Arguments)
} else {
templateContent = arguments.Template
}
}

data, err := os.ReadFile(filepath.Join(ctx.RepositoryPath, ctx.FilePath))
if err != nil {
return nil, pkgerrors.WithStack(err)
Expand All @@ -212,6 +232,8 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e
FilePath: ctx.FilePath,
ImportPath: importPath,
},

Template: templateContent,
}).Format()
if err != nil {
return nil, err
Expand Down
Loading

0 comments on commit 7b6849e

Please sign in to comment.