From 7b6849e378eb67d802acd1245397953e142726cb Mon Sep 17 00:00:00 2001 From: Simon Bauer Date: Fri, 4 Oct 2024 13:30:30 +0200 Subject: [PATCH] Perform "write test" with Symflower template Part of #350 --- evaluate/task/symflower.go | 24 +++++ evaluate/task/task-write-test.go | 59 ++++++++++- evaluate/task/task-write-test_test.go | 83 +++++++++++++--- evaluate/task/task.go | 4 + model/llm/llm.go | 22 ++++ model/llm/llm_test.go | 138 +++++++++++++++++++++----- 6 files changed, 290 insertions(+), 40 deletions(-) diff --git a/evaluate/task/symflower.go b/evaluate/task/symflower.go index 1c8adfa4..2fe02482 100644 --- a/evaluate/task/symflower.go +++ b/evaluate/task/symflower.go @@ -2,6 +2,7 @@ package task import ( "context" + "fmt" "time" pkgerrors "github.com/pkg/errors" @@ -31,6 +32,29 @@ func symflowerFix(logger *log.Logger, repositoryPath string, language language.L return uint64(time.Since(start).Milliseconds()), nil } +// symflowerTemplate runs the "symflower utg" command and returns its execution time in milliseconds. +func symflowerTemplate(logger *log.Logger, repositoryPath string, language language.Language, filePath string) (duration uint64, err error) { + start := time.Now() + + o, err := util.CommandWithResult(context.Background(), logger, &util.Command{ + Command: []string{ + tools.SymflowerPath, "uts", + "--language", language.ID(), + "--workspace", repositoryPath, + filePath, + }, + + Directory: repositoryPath, + }) + if err != nil { + return 0, pkgerrors.WithStack(err) + } + + fmt.Println(o) + + return uint64(time.Since(start).Milliseconds()), nil +} + // ExecuteWithSymflowerFix runs the "symflower fix" command and calculates the new assessments. func ExecuteWithSymflowerFix(ctx evaltask.Context, logger *log.Logger, packagePath string) (testResult *language.TestResult, processingTime uint64, problems []error, err error) { // Run "symflower fix" if the model response fails to execute. diff --git a/evaluate/task/task-write-test.go b/evaluate/task/task-write-test.go index 7d6c9526..5dfdbd42 100644 --- a/evaluate/task/task-write-test.go +++ b/evaluate/task/task-write-test.go @@ -2,6 +2,7 @@ package task import ( "fmt" + "os" "strings" pkgerrors "github.com/pkg/errors" @@ -18,6 +19,12 @@ type TaskWriteTests struct { var _ evaltask.Task = (*TaskWriteTests)(nil) +// TaskArgumentsWriteTest holds extra arguments to be used in a query prompt. +type TaskArgumentsWriteTest struct { + // Template holds the template data to base the tests onto. + Template string +} + // Identifier returns the write test task identifier. func (t *TaskWriteTests) Identifier() evaltask.Identifier { return IdentifierWriteTests @@ -46,34 +53,76 @@ func (t *TaskWriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[eva modelAssessment := metrics.NewAssessments() withSymflowerFixAssessment := metrics.NewAssessments() + withSymflowerTemplateAssessment := metrics.NewAssessments() + withSymflowerTemplateAndFixAssessment := metrics.NewAssessments() maximumReachableFiles := uint64(len(filePaths)) modelAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles withSymflowerFixAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles + withSymflowerTemplateAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles + withSymflowerTemplateAndFixAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles for _, filePath := range filePaths { + // Handle this task case without a template. if err := ctx.Repository.Reset(ctx.Logger); err != nil { ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) } - modelAssessmentFile, withSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, modelCapability, dataPath, filePath) + modelAssessmentFile, withSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, modelCapability, dataPath, filePath, nil) problems = append(problems, ps...) if err != nil { return nil, problems, err } modelAssessment.Add(modelAssessmentFile) withSymflowerFixAssessment.Add(withSymflowerFixAssessmentFile) + + if !(ctx.Language.ID() == "golang" || ctx.Language.ID() == "java") { // Symflower templates currently only exist for Go and Java. + withSymflowerTemplateAssessment.Add(modelAssessmentFile) + withSymflowerTemplateAndFixAssessment.Add(withSymflowerFixAssessmentFile) + + continue + } + + // Handle this task case with a template. + if err := ctx.Repository.Reset(ctx.Logger); err != nil { + ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) + } + + _, err = symflowerTemplate(taskLogger.Logger, dataPath, ctx.Language, filePath) // TODO incorporate duration + if err != nil { + problems = append(problems, pkgerrors.WithMessage(err, "generating Symflower template")) + + continue + } + + testTemplateFilePath := ctx.Language.TestFilePath(dataPath, filePath) + testTemplate, err := os.ReadFile(testTemplateFilePath) + if err != nil { + return nil, nil, pkgerrors.WithMessagef(err, "reading Symflower template from %q", testTemplateFilePath) + } + + modelTemplateAssessmentFile, templateWithSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, modelCapability, dataPath, filePath, &TaskArgumentsWriteTest{ + Template: string(testTemplate), + }) + problems = append(problems, ps...) + if err != nil { + return nil, problems, err + } + withSymflowerTemplateAssessment.Add(modelTemplateAssessmentFile) + withSymflowerTemplateAndFixAssessment.Add(templateWithSymflowerFixAssessmentFile) } repositoryAssessment = map[evaltask.Identifier]metrics.Assessments{ - IdentifierWriteTests: modelAssessment, - IdentifierWriteTestsSymflowerFix: withSymflowerFixAssessment, + IdentifierWriteTests: modelAssessment, + IdentifierWriteTestsSymflowerFix: withSymflowerFixAssessment, + IdentifierWriteTestsSymflowerTemplate: withSymflowerTemplateAssessment, + IdentifierWriteTestsSymflowerTemplateSymflowerFix: withSymflowerTemplateAndFixAssessment, } return repositoryAssessment, problems, nil } -func runModelAndSymflowerFix(ctx evaltask.Context, taskLogger *taskLogger, modelCapability model.CapabilityWriteTests, dataPath string, filePath string) (modelAssessment metrics.Assessments, withSymflowerFixAssessment metrics.Assessments, problems []error, err error) { +func runModelAndSymflowerFix(ctx evaltask.Context, taskLogger *taskLogger, modelCapability model.CapabilityWriteTests, dataPath string, filePath string, arguments *TaskArgumentsWriteTest) (modelAssessment metrics.Assessments, withSymflowerFixAssessment metrics.Assessments, problems []error, err error) { modelAssessment = metrics.NewAssessments() withSymflowerFixAssessment = modelAssessment // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until we successfully applied "symflower fix". @@ -84,6 +133,8 @@ func runModelAndSymflowerFix(ctx evaltask.Context, taskLogger *taskLogger, model FilePath: filePath, Logger: taskLogger.Logger, + + Arguments: arguments, } assessments, err := modelCapability.WriteTests(modelContext) if err != nil { diff --git a/evaluate/task/task-write-test_test.go b/evaluate/task/task-write-test_test.go index ebd7182d..7dbc3487 100644 --- a/evaluate/task/task-write-test_test.go +++ b/evaluate/task/task-write-test_test.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -38,21 +39,21 @@ func TestTaskWriteTestsRun(t *testing.T) { }) } - t.Run("Clear repository on each task file", func(t *testing.T) { + t.Run("Clear repository on each task file", func(t *testing.T) { // This test simulates failing tests for the first of two task cases and ensures that they don't influence test execution for the second one. temporaryDirectoryPath := t.TempDir() repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "plain") require.NoError(t, os.MkdirAll(repositoryPath, 0700)) require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "go.mod"), []byte("module plain\n\ngo 1.21.5"), 0600)) - require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "taskA.go"), []byte("package plain\n\nfunc TaskA(){}"), 0600)) - require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "taskB.go"), []byte("package plain\n\nfunc TaskB(){}"), 0600)) + require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "caseA.go"), []byte("package plain\n\nfunc caseA(){}"), 0600)) + require.NoError(t, os.WriteFile(filepath.Join(repositoryPath, "caseB.go"), []byte("package plain\n\nfunc caseB(){}"), 0600)) modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model") - // Generate invalid code for the first taskcontext. - modelMock.RegisterGenerateSuccess(t, "taskA_test.go", "does not compile", metricstesting.AssessmentsWithProcessingTime).Once() - // Generate valid code for the second taskcontext. - modelMock.RegisterGenerateSuccess(t, "taskB_test.go", "package plain\n\nimport \"testing\"\n\nfunc TestTaskB(t *testing.T){}", metricstesting.AssessmentsWithProcessingTime).Once() + // Generate invalid code for caseA (with and without template). + modelMock.RegisterGenerateSuccess(t, "caseA_test.go", "does not compile", metricstesting.AssessmentsWithProcessingTime).Twice() + // Generate valid code for caseB (with and without template). + modelMock.RegisterGenerateSuccess(t, "caseB_test.go", "package plain\n\nimport \"testing\"\n\nfunc TestCaseB(t *testing.T){}", metricstesting.AssessmentsWithProcessingTime).Twice() validate(t, &tasktesting.TestCaseTask{ Name: "Plain", @@ -73,14 +74,27 @@ func TestTaskWriteTestsRun(t *testing.T) { metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, metrics.AssessmentKeyResponseNoError: 2, }, + // With the template there would be coverage but it is overwritten. + IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{ + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, + metrics.AssessmentKeyResponseNoError: 2, + }, + IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, + metrics.AssessmentKeyResponseNoError: 2, + }, }, ExpectedProblemContains: []string{ - "expected 'package', found does", - "exit status 1", + "expected 'package', found does", // Model error. + "exit status 1", // Symflower fix not applicable. + "expected 'package', found does", // Model error (overwrote template). + "exit status 1", // Symflower fix not applicable (overwrote template). }, ValidateLog: func(t *testing.T, data string) { - assert.Contains(t, data, "Evaluating model \"mocked-model\"") - assert.Contains(t, data, "PASS: TestTaskB") + assert.Equal(t, 1, strings.Count(data, "Evaluating model \"mocked-model\"")) + assert.Equal(t, 4, strings.Count(data, "PASS: TestCaseB")) // Bare model result, with fix, with template, with template and fix. }, }) }) @@ -93,7 +107,7 @@ func TestTaskWriteTestsRun(t *testing.T) { require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "plain"), repositoryPath)) modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model") - modelMock.RegisterGenerateSuccess(t, "plain_test.go", testFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + modelMock.RegisterGenerateSuccess(t, "plain_test.go", testFileContent, metricstesting.AssessmentsWithProcessingTime).Twice() validate(t, &tasktesting.TestCaseTask{ Name: testName, @@ -127,6 +141,18 @@ func TestTaskWriteTestsRun(t *testing.T) { metrics.AssessmentKeyResponseNoError: 1, metrics.AssessmentKeyCoverage: 10, }, + IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{ + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + metrics.AssessmentKeyCoverage: 10, + }, + IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + metrics.AssessmentKeyCoverage: 10, + }, } validateGo(t, "Model generated correct test", &golang.Language{}, bytesutil.StringTrimIndentations(` package plain @@ -150,9 +176,20 @@ func TestTaskWriteTestsRun(t *testing.T) { metrics.AssessmentKeyResponseNoError: 1, metrics.AssessmentKeyCoverage: 10, }, + IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{ + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + }, + IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + metrics.AssessmentKeyCoverage: 10, + }, } expectedProblems := []string{ "imported and not used", + "imported and not used", } validateGo(t, "Model generated test with unused import", &golang.Language{}, bytesutil.StringTrimIndentations(` package plain @@ -177,10 +214,20 @@ func TestTaskWriteTestsRun(t *testing.T) { metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, metrics.AssessmentKeyResponseNoError: 1, }, + IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{ + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + }, + IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + }, } expectedProblems := []string{ "expected declaration, found this", "unable to format source code", + "expected declaration, found this", + "unable to format source code", } validateGo(t, "Model generated test that is unfixable", &golang.Language{}, bytesutil.StringTrimIndentations(` package plain @@ -233,6 +280,18 @@ func TestTaskWriteTestsRun(t *testing.T) { metrics.AssessmentKeyCoverage: 10, metrics.AssessmentKeyResponseNoError: 1, }, + IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{ + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyCoverage: 10, + metrics.AssessmentKeyResponseNoError: 1, + }, + IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyCoverage: 10, + metrics.AssessmentKeyResponseNoError: 1, + }, }, ExpectedProblemContains: nil, ValidateLog: func(t *testing.T, data string) { diff --git a/evaluate/task/task.go b/evaluate/task/task.go index 56a5997f..ef149feb 100644 --- a/evaluate/task/task.go +++ b/evaluate/task/task.go @@ -37,6 +37,10 @@ var ( IdentifierWriteTests = registerIdentifier("write-tests") // IdentifierWriteTestsSymflowerFix holds the identifier for the "write test" task with the "symflower fix" applied. IdentifierWriteTestsSymflowerFix = registerIdentifier("write-tests-symflower-fix") + // IdentifierWriteTestsSymflowerTemplate holds the identifier for the "write test" task based on a Symflower template. + IdentifierWriteTestsSymflowerTemplate = registerIdentifier("write-tests-symflower-template") + // IdentifierWriteTestsTemplateSymflowerFix holds the identifier for the "write test" task based on a Symflower template with the "symflower fix" applied. + IdentifierWriteTestsSymflowerTemplateSymflowerFix = registerIdentifier("write-tests-symflower-template-symflower-fix") // IdentifierCodeRepair holds the identifier for the "code repair" task. IdentifierCodeRepair = registerIdentifier("code-repair") // IdentifierTranspile holds the identifier for the "transpile" task. diff --git a/model/llm/llm.go b/model/llm/llm.go index ac8a2afa..9d384807 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -80,6 +80,9 @@ type llmSourceFilePromptContext struct { type llmWriteTestSourceFilePromptContext struct { // llmSourceFilePromptContext holds the context for a source file prompt. llmSourceFilePromptContext + + // Template holds the template data to base the tests onto. + Template string } // llmWriteTestForFilePromptTemplate is the template for generating an LLM test generation prompt. @@ -91,6 +94,14 @@ var llmWriteTestForFilePromptTemplate = template.Must(template.New("model-llm-wr ` + "```" + `{{ .Language.ID }} {{ .Code }} ` + "```" + ` + {{- if .Template}} + + The tests should be based on this template: + + ` + "```" + `{{ .Language.ID }} + {{ .Template -}} + ` + "```" + ` + {{- end}} `))) // Format returns the prompt for generating an LLM test generation. @@ -196,6 +207,15 @@ var _ model.CapabilityWriteTests = (*Model)(nil) // WriteTests generates test files for the given implementation file in a repository. func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, err error) { + var templateContent string + if ctx.Arguments != nil { + if arguments, ok := ctx.Arguments.(*evaluatetask.TaskArgumentsWriteTest); !ok { + return nil, pkgerrors.Errorf("unexpected type %#v", ctx.Arguments) + } else { + templateContent = arguments.Template + } + } + data, err := os.ReadFile(filepath.Join(ctx.RepositoryPath, ctx.FilePath)) if err != nil { return nil, pkgerrors.WithStack(err) @@ -212,6 +232,8 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e FilePath: ctx.FilePath, ImportPath: importPath, }, + + Template: templateContent, }).Format() if err != nil { return nil, err diff --git a/model/llm/llm_test.go b/model/llm/llm_test.go index 2f519e5d..70ea1f65 100644 --- a/model/llm/llm_test.go +++ b/model/llm/llm_test.go @@ -311,38 +311,128 @@ func TestFormatPromptContext(t *testing.T) { }) } - validate(t, &testCase{ - Name: "Write Test", + t.Run("Write Test", func(t *testing.T) { + validate(t, &testCase{ + Name: "No Template", - Context: &llmWriteTestSourceFilePromptContext{ - llmSourceFilePromptContext: llmSourceFilePromptContext{ - Language: &golang.Language{}, + Context: &llmWriteTestSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: &golang.Language{}, - Code: bytesutil.StringTrimIndentations(` - package increment + Code: bytesutil.StringTrimIndentations(` + package increment + + func increment(i int) int + return i + 1 + } + `), + FilePath: filepath.Join("path", "to", "increment.go"), + ImportPath: "increment", + }, + }, + + ExpectedMessage: bytesutil.StringTrimIndentations(` + Given the following Go code file "path/to/increment.go" with package "increment", provide a test file for this code. + The tests should produce 100 percent code coverage and must compile. + The response must contain only the test code in a fenced code block and nothing else. + + ` + "```" + `golang + package increment + + func increment(i int) int + return i + 1 + } + ` + "```" + ` + `), + }) + + validate(t, &testCase{ + Name: "With Template", + + Context: &llmWriteTestSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: &golang.Language{}, + + Code: bytesutil.StringTrimIndentations(` + package increment + + func increment(i int) int + return i + 1 + } + `), + FilePath: filepath.Join("path", "to", "increment.go"), + ImportPath: "increment", + }, - func increment(i int) int - return i + 1 + Template: bytesutil.StringTrimIndentations(` + package increment + + import ( + "testing" + + "github.com/stretchr/testify/assert" + ) + + func TestIncrement(t *testing.T) { + type testCase struct { + Name string + + I int + + Expected int } - `), - FilePath: filepath.Join("path", "to", "increment.go"), - ImportPath: "increment", + + validate := func(t *testing.T, tc *testCase) { + t.Run(tc.Name, func(t *testing.T){ + assert.Equal(t, tc.Expected, increment(tc.I)) + }) + } + } + `), }, - }, - ExpectedMessage: bytesutil.StringTrimIndentations(` - Given the following Go code file "path/to/increment.go" with package "increment", provide a test file for this code. - The tests should produce 100 percent code coverage and must compile. - The response must contain only the test code in a fenced code block and nothing else. + ExpectedMessage: bytesutil.StringTrimIndentations(` + Given the following Go code file "path/to/increment.go" with package "increment", provide a test file for this code. + The tests should produce 100 percent code coverage and must compile. + The response must contain only the test code in a fenced code block and nothing else. - ` + "```" + `golang - package increment + ` + "```" + `golang + package increment - func increment(i int) int - return i + 1 - } - ` + "```" + ` - `), + func increment(i int) int + return i + 1 + } + ` + "```" + ` + + The tests should be based on this template: + + ` + "```" + `golang + package increment + + import ( + "testing" + + "github.com/stretchr/testify/assert" + ) + + func TestIncrement(t *testing.T) { + type testCase struct { + Name string + + I int + + Expected int + } + + validate := func(t *testing.T, tc *testCase) { + t.Run(tc.Name, func(t *testing.T){ + assert.Equal(t, tc.Expected, increment(tc.I)) + }) + } + } + ` + "```" + ` + `), + }) }) validate(t, &testCase{