Skip to content

Commit

Permalink
Allow custom validation for "write-test" task so one can ensure that …
Browse files Browse the repository at this point in the history
…tests for Spring Boot actually spin up Spring Boot

Part of #365
  • Loading branch information
bauersimon committed Oct 25, 2024
1 parent 72395e7 commit 4326f25
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 45 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,21 @@ It is possible to configure some model prompt parameters through `repository.jso

This `prompt.test-framework` setting is currently only respected for the test generation task `write-tests`.

When task results are validated, some repositories might require custom logic. For example: generating tests for a Spring Boot project requires ensuring that the tests used an actual Spring context (i.e. Spring Boot was initialized when the tests were executed). Therefore, the `repository.json` supports adding rudimentary custom validation:

```json
{
"tasks": ["write-tests"],
"validation": {
"execution": {
"stdout": "Initializing Spring" // Ensure the string "Initializing Spring" is contained in the execution output.
}
}
}
```

This `validation.execution.stdout` setting is currently only respected for the test generation task `write-tests`.

## Tasks

### Task: Test Generation
Expand Down
4 changes: 2 additions & 2 deletions evaluate/task/write-test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func runModelAndSymflowerFix(ctx evaltask.Context, taskLogger *taskLogger, model
problems = append(problems, ps...)
if err != nil {
problems = append(problems, pkgerrors.WithMessage(err, filePath))
} else {
} else if ctx.Repository.Configuration().Validation.Execution.Validate(testResult.StdOut) {
taskLogger.Printf("Executes tests with %d coverage objects", testResult.Coverage)
modelAssessment.Award(metrics.AssessmentKeyFilesExecuted)
modelAssessment.AwardPoints(metrics.AssessmentKeyCoverage, testResult.Coverage)
Expand All @@ -187,7 +187,7 @@ func runModelAndSymflowerFix(ctx evaltask.Context, taskLogger *taskLogger, model
problems = append(problems, ps...)
if err != nil {
problems = append(problems, err)
} else {
} else if ctx.Repository.Configuration().Validation.Execution.Validate(withSymflowerFixTestResult.StdOut) {
ctx.Logger.Printf("with symflower repair: Executes tests with %d coverage objects", withSymflowerFixTestResult.Coverage)

// Symflower was able to fix a failure so now update the assessment with the improved results.
Expand Down
147 changes: 104 additions & 43 deletions evaluate/task/write-test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,16 +395,17 @@ func TestWriteTestsRun(t *testing.T) {
})
}

{
temporaryDirectoryPath := t.TempDir()
repositoryPath := filepath.Join(temporaryDirectoryPath, "java", "spring-plain")
require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "java", "spring-plain"), repositoryPath))
modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")
modelMock.RegisterGenerateSuccessValidateContext(t, func(t *testing.T, c model.Context) {
args, ok := c.Arguments.(*ArgumentsWriteTest)
require.Truef(t, ok, "unexpected type %T", c.Arguments)
assert.Equal(t, "JUnit 5 for Spring", args.TestFramework)
}, filepath.Join("src", "test", "java", "com", "example", "controller", "SomeControllerTest.java"), bytesutil.StringTrimIndentations(`
t.Run("Spring Boot", func(t *testing.T) {
{
temporaryDirectoryPath := t.TempDir()
repositoryPath := filepath.Join(temporaryDirectoryPath, "java", "spring-plain")
require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "java", "spring-plain"), repositoryPath))
modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")
modelMock.RegisterGenerateSuccessValidateContext(t, func(t *testing.T, c model.Context) {
args, ok := c.Arguments.(*ArgumentsWriteTest)
require.Truef(t, ok, "unexpected type %T", c.Arguments)
assert.Equal(t, "JUnit 5 for Spring", args.TestFramework)
}, filepath.Join("src", "test", "java", "com", "example", "controller", "SomeControllerTest.java"), bytesutil.StringTrimIndentations(`
package com.example.controller;
import org.junit.jupiter.api.*;
Expand All @@ -431,45 +432,105 @@ func TestWriteTestsRun(t *testing.T) {
}
`), metricstesting.AssessmentsWithProcessingTime)

validate(t, &tasktesting.TestCaseTask{
Name: "Spring Boot",
validate(t, &tasktesting.TestCaseTask{
Name: "Spring Boot Test",

Model: modelMock,
Language: &java.Language{},
TestDataPath: temporaryDirectoryPath,
RepositoryPath: filepath.Join("java", "spring-plain"),
Model: modelMock,
Language: &java.Language{},
TestDataPath: temporaryDirectoryPath,
RepositoryPath: filepath.Join("java", "spring-plain"),

ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{
IdentifierWriteTests: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyCoverage: 20,
metrics.AssessmentKeyResponseNoError: 1,
ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{
IdentifierWriteTests: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyCoverage: 20,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyCoverage: 20,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyCoverage: 20,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyCoverage: 20,
metrics.AssessmentKeyResponseNoError: 1,
},
},
IdentifierWriteTestsSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyCoverage: 20,
metrics.AssessmentKeyResponseNoError: 1,
ValidateLog: func(t *testing.T, data string) {
assert.Equal(t, 2, strings.Count(data, "Starting SomeControllerTest using Java"), "Expected two successful Spring startup announcements (one bare and one for template)")
},
IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyCoverage: 20,
metrics.AssessmentKeyResponseNoError: 1,
})
}
{
temporaryDirectoryPath := t.TempDir()
repositoryPath := filepath.Join(temporaryDirectoryPath, "java", "spring-plain")
require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "java", "spring-plain"), repositoryPath))
modelMock := modeltesting.NewMockCapabilityWriteTestsNamed(t, "mocked-model")
modelMock.RegisterGenerateSuccessValidateContext(t, func(t *testing.T, c model.Context) {
args, ok := c.Arguments.(*ArgumentsWriteTest)
require.Truef(t, ok, "unexpected type %T", c.Arguments)
assert.Equal(t, "JUnit 5 for Spring", args.TestFramework)
}, filepath.Join("src", "test", "java", "com", "example", "controller", "SomeControllerTest.java"), bytesutil.StringTrimIndentations(`
package com.example.controller;
import com.example.controller.SomeController;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
class SomeControllerTests {
@Test // Normal JUnit tests instead of Spring Boot.
void helloGet() {
SomeController controller = new SomeController();
String result = controller.helloGet();
assertEquals("get!", result);
}
}
`), metricstesting.AssessmentsWithProcessingTime)

validate(t, &tasktesting.TestCaseTask{
Name: "Plain JUnit Test",

Model: modelMock,
Language: &java.Language{},
TestDataPath: temporaryDirectoryPath,
RepositoryPath: filepath.Join("java", "spring-plain"),

ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{
IdentifierWriteTests: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerTemplate: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyResponseNoError: 1,
},
},
IdentifierWriteTestsSymflowerTemplateSymflowerFix: metrics.Assessments{
metrics.AssessmentKeyFilesExecutedMaximumReachable: 1,
metrics.AssessmentKeyFilesExecuted: 1,
metrics.AssessmentKeyCoverage: 20,
metrics.AssessmentKeyResponseNoError: 1,
ValidateLog: func(t *testing.T, data string) {
assert.Contains(t, data, "Tests run: 1") // Tests are running but they are not Spring Boot.
},
},
ValidateLog: func(t *testing.T, data string) {
assert.Equal(t, 2, strings.Count(data, "Starting SomeControllerTest using Java"), "Expected two successful Spring startup announcements (one bare and one for template)")
},
})
}
})
}
})
}

func TestValidateWriteTestsRepository(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions language/golang/language.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ func (l *Language) ExecuteTests(logger *log.Logger, repositoryPath string) (test
testResult = &language.TestResult{
TestsTotal: uint(testsTotal),
TestsPass: uint(testsPass),

StdOut: commandOutput,
}
testResult.Coverage, err = language.CoverageObjectCountOfFile(logger, coverageFilePath)
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions language/java/language.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ func (l *Language) ExecuteTests(logger *log.Logger, repositoryPath string) (test
testResult = &language.TestResult{
TestsTotal: uint(testsTotal),
TestsPass: uint(testsPass),

StdOut: commandOutput,
}

testResult.Coverage, err = language.CoverageObjectCountOfFile(logger, coverageFilePath)
Expand Down
2 changes: 2 additions & 0 deletions language/language.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ type TestResult struct {
TestsPass uint

Coverage uint64

StdOut string
}

// PassingTestsPercentage returns the percentage of passing tests.
Expand Down
2 changes: 2 additions & 0 deletions language/ruby/language.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ func (l *Language) ExecuteTests(logger *log.Logger, repositoryPath string) (test
testResult = &language.TestResult{
TestsTotal: uint(testsTotal),
TestsPass: uint(testsPass),

StdOut: commandOutput,
}

testResult.Coverage, err = language.CoverageObjectCountOfFile(logger, coverageFilePath)
Expand Down
1 change: 1 addition & 0 deletions language/testing/language.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func (tc *TestCaseExecuteTests) Validate(t *testing.T) {
assert.ErrorContains(t, actualError, tc.ExpectedErrorText)
} else {
assert.NoError(t, actualError)
actualTestResult.StdOut = ""
assert.Equal(t, tc.ExpectedTestResult, actualTestResult)
}
})
Expand Down
25 changes: 25 additions & 0 deletions task/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"os"
"path/filepath"
"regexp"
"strings"

pkgerrors "github.com/pkg/errors"
Expand All @@ -24,6 +25,17 @@ type RepositoryConfiguration struct {
// TestFramework overwrites the language-specific test framework to use.
TestFramework string `json:"test-framework,omitempty"`
} `json:",omitempty"`

// Validation holds quality gates for evaluation.
Validation struct {
Execution RepositoryConfigurationExecution `json:",omitempty"`
}
}

// RepositoryConfigurationExecution execution-related quality gates for evaluation.
type RepositoryConfigurationExecution struct {
// StdOutRE holds a regular expression that must be part of execution standard output.
StdOutRE string `json:"stdout,omitempty"`
}

// RepositoryConfigurationFileName holds the file name for a repository configuration.
Expand Down Expand Up @@ -70,6 +82,10 @@ func (rc *RepositoryConfiguration) validate(validTasks []Identifier) (err error)
}
}

if _, err := regexp.Compile(rc.Validation.Execution.StdOutRE); err != nil {
return pkgerrors.WithMessagef(err, "invalid regular expression %q", rc.Validation.Execution.StdOutRE)
}

return nil
}

Expand All @@ -85,3 +101,12 @@ func (rc *RepositoryConfiguration) IsFilePathIgnored(filePath string) bool {

return false
}

// Validate validates execution outcomes against the configured quality gates.
func (e *RepositoryConfigurationExecution) Validate(stdout string) bool {
if e.StdOutRE != "" {
return regexp.MustCompile(e.StdOutRE).MatchString(stdout)
}

return true
}
5 changes: 5 additions & 0 deletions testdata/java/spring-plain/repository.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,10 @@
"ignore": ["src/main/java/com/example/Application.java"],
"prompt": {
"test-framework": "JUnit 5 for Spring"
},
"validation": {
"execution": {
"stdout": "Initializing Spring"
}
}
}

0 comments on commit 4326f25

Please sign in to comment.