From 2be1f8759cd2d8fda6571c6f27a1b8d697fd991c Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Thu, 4 Jul 2024 16:12:07 +0100 Subject: [PATCH] Task for code transpilation, so models can transpile Go code to Java and back Closes #201 --- .mockery.yml | 1 + evaluate/metrics/assessment.go | 4 + evaluate/metrics/assessment_test.go | 7 +- evaluate/report/csv_test.go | 38 +- evaluate/task/task-transpile.go | 193 ++++++++++ evaluate/task/task-transpile_test.go | 351 ++++++++++++++++++ evaluate/task/task.go | 6 + model/capability.go | 6 + model/llm/llm.go | 100 +++++ model/llm/llm_test.go | 224 +++++++++++ model/testing/CapabilityTranspile_mock_gen.go | 59 +++ model/testing/helper.go | 22 ++ testdata/golang/transpile/repository.json | 5 + testdata/java/transpile/repository.json | 5 + 14 files changed, 1000 insertions(+), 21 deletions(-) create mode 100644 evaluate/task/task-transpile.go create mode 100644 evaluate/task/task-transpile_test.go create mode 100644 model/testing/CapabilityTranspile_mock_gen.go create mode 100644 testdata/golang/transpile/repository.json create mode 100644 testdata/java/transpile/repository.json diff --git a/.mockery.yml b/.mockery.yml index 3484c7c9..15c1cfe5 100644 --- a/.mockery.yml +++ b/.mockery.yml @@ -15,6 +15,7 @@ packages: Model: CapabilityWriteTests: CapabilityRepairCode: + CapabilityTranspile: github.com/symflower/eval-dev-quality/provider: interfaces: Loader: diff --git a/evaluate/metrics/assessment.go b/evaluate/metrics/assessment.go index 5113c115..0e911e0a 100644 --- a/evaluate/metrics/assessment.go +++ b/evaluate/metrics/assessment.go @@ -44,6 +44,9 @@ var ( // AssessmentKeyCoverage counts execution coverage objects. AssessmentKeyCoverage = RegisterAssessmentKey("coverage", 10) + // AssessmentKeyTestsPassing holds the percentage of passing tests. + AssessmentKeyTestsPassing = RegisterAssessmentKey("tests-passing", 10) + // AssessmentKeyResponseCharacterCount counts the number of characters of a response. AssessmentKeyResponseCharacterCount = RegisterAssessmentKey("response-character-count", 0) // AssessmentKeyGenerateTestsForFileCharacterCount counts the number of characters of a generated test file. @@ -167,6 +170,7 @@ func CombineWithSymflowerFixAssessments(model Assessments, fixed Assessments) (c combined[AssessmentKeyResponseNoError] = model[AssessmentKeyResponseNoError] combined[AssessmentKeyResponseNoExcess] = model[AssessmentKeyResponseNoExcess] combined[AssessmentKeyResponseWithCode] = model[AssessmentKeyResponseWithCode] + combined[AssessmentKeyTestsPassing] = fixed[AssessmentKeyTestsPassing] return combined } diff --git a/evaluate/metrics/assessment_test.go b/evaluate/metrics/assessment_test.go index bd88ab64..4b37135b 100644 --- a/evaluate/metrics/assessment_test.go +++ b/evaluate/metrics/assessment_test.go @@ -137,7 +137,7 @@ func TestAssessmentString(t *testing.T) { Assessment: NewAssessments(), - ExpectedString: "score=0, coverage=0, files-executed=0, files-executed-maximum-reachable=0, generate-tests-for-file-character-count=0, processing-time=0, response-character-count=0, response-no-error=0, response-no-excess=0, response-with-code=0", + ExpectedString: "score=0, coverage=0, files-executed=0, files-executed-maximum-reachable=0, generate-tests-for-file-character-count=0, processing-time=0, response-character-count=0, response-no-error=0, response-no-excess=0, response-with-code=0, tests-passing=0", }) validate(t, &testCase{ @@ -153,9 +153,10 @@ func TestAssessmentString(t *testing.T) { AssessmentKeyResponseNoExcess: 4, AssessmentKeyResponseWithCode: 5, AssessmentKeyProcessingTime: 200, + AssessmentKeyTestsPassing: 70, }, - ExpectedString: "score=15, coverage=1, files-executed=2, files-executed-maximum-reachable=2, generate-tests-for-file-character-count=50, processing-time=200, response-character-count=100, response-no-error=3, response-no-excess=4, response-with-code=5", + ExpectedString: "score=85, coverage=1, files-executed=2, files-executed-maximum-reachable=2, generate-tests-for-file-character-count=50, processing-time=200, response-character-count=100, response-no-error=3, response-no-excess=4, response-with-code=5, tests-passing=70", }) } @@ -310,6 +311,7 @@ func TestCombineModelAndSymflowerFixAssessments(t *testing.T) { AssessmentKeyProcessingTime: uint64(100), AssessmentKeyCoverage: 10, AssessmentKeyResponseNoError: 1, + AssessmentKeyTestsPassing: 100, }, ExpectedAssessments: Assessments{ @@ -321,6 +323,7 @@ func TestCombineModelAndSymflowerFixAssessments(t *testing.T) { AssessmentKeyResponseNoError: 0, AssessmentKeyResponseWithCode: 1, AssessmentKeyResponseNoExcess: 1, + AssessmentKeyTestsPassing: 100, }, }) } diff --git a/evaluate/report/csv_test.go b/evaluate/report/csv_test.go index 89fcdfd7..2d2c245c 100644 --- a/evaluate/report/csv_test.go +++ b/evaluate/report/csv_test.go @@ -27,7 +27,7 @@ func TestNewEvaluationFile(t *testing.T) { require.NoError(t, err) expectedEvaluationFileContent := bytesutil.StringTrimIndentations(` - model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code + model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing `) assert.Equal(t, expectedEvaluationFileContent, string(actualEvaluationFileContent)) @@ -66,8 +66,8 @@ func TestWriteEvaluationRecord(t *testing.T) { }, ExpectedCSV: ` - model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code - mocked-model,golang,golang/plain,write-tests,0,0,0,0,0,0,0,0,0,0 + model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing + mocked-model,golang,golang/plain,write-tests,0,0,0,0,0,0,0,0,0,0,0 `, }) validate(t, &testCase{ @@ -89,9 +89,9 @@ func TestWriteEvaluationRecord(t *testing.T) { }, ExpectedCSV: ` - model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code - mocked-model,golang,golang/plain,write-tests,2,0,1,1,0,0,0,1,0,0 - mocked-model,golang,golang/plain,write-tests-symflower-fix,12,10,1,1,0,0,0,1,0,0 + model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing + mocked-model,golang,golang/plain,write-tests,2,0,1,1,0,0,0,1,0,0,0 + mocked-model,golang,golang/plain,write-tests-symflower-fix,12,10,1,1,0,0,0,1,0,0,0 `, }) } @@ -224,37 +224,37 @@ func TestEvaluationFileWriteLines(t *testing.T) { Name: "No records", ExpectedEvaluationFile: ` - model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code + model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing `, }) validate(t, &testCase{ Name: "Single record", RawRecords: [][]string{ - []string{"modelA", "golang", "golang/light", "write-tests", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1"}, + []string{"modelA", "golang", "golang/light", "write-tests", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1"}, }, ExpectedEvaluationFile: ` - model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code - modelA,golang,golang/light,write-tests,1,1,1,1,1,1,1,1,1,1 + model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing + modelA,golang,golang/light,write-tests,1,1,1,1,1,1,1,1,1,1,1 `, }) validate(t, &testCase{ Name: "Multiple records", RawRecords: [][]string{ - []string{"modelA", "golang", "golang/light", "write-tests", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1"}, - []string{"modelA", "golang", "golang/plain", "write-tests", "2", "2", "2", "2", "2", "2", "2", "2", "2", "2"}, - []string{"modelA", "java", "java/light", "write-tests", "3", "3", "3", "3", "3", "3", "3", "3", "3", "3"}, - []string{"modelA", "java", "java/plain", "write-tests", "4", "4", "4", "4", "4", "4", "4", "4", "4", "4"}, + []string{"modelA", "golang", "golang/light", "write-tests", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1", "1"}, + []string{"modelA", "golang", "golang/plain", "write-tests", "2", "2", "2", "2", "2", "2", "2", "2", "2", "2", "2"}, + []string{"modelA", "java", "java/light", "write-tests", "3", "3", "3", "3", "3", "3", "3", "3", "3", "3", "3"}, + []string{"modelA", "java", "java/plain", "write-tests", "4", "4", "4", "4", "4", "4", "4", "4", "4", "4", "4"}, }, ExpectedEvaluationFile: ` - model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code - modelA,golang,golang/light,write-tests,1,1,1,1,1,1,1,1,1,1 - modelA,golang,golang/plain,write-tests,2,2,2,2,2,2,2,2,2,2 - modelA,java,java/light,write-tests,3,3,3,3,3,3,3,3,3,3 - modelA,java,java/plain,write-tests,4,4,4,4,4,4,4,4,4,4 + model-id,language,repository,task,score,coverage,files-executed,files-executed-maximum-reachable,generate-tests-for-file-character-count,processing-time,response-character-count,response-no-error,response-no-excess,response-with-code,tests-passing + modelA,golang,golang/light,write-tests,1,1,1,1,1,1,1,1,1,1,1 + modelA,golang,golang/plain,write-tests,2,2,2,2,2,2,2,2,2,2,2 + modelA,java,java/light,write-tests,3,3,3,3,3,3,3,3,3,3,3 + modelA,java,java/plain,write-tests,4,4,4,4,4,4,4,4,4,4,4 `, }) } diff --git a/evaluate/task/task-transpile.go b/evaluate/task/task-transpile.go new file mode 100644 index 00000000..926ffcfd --- /dev/null +++ b/evaluate/task/task-transpile.go @@ -0,0 +1,193 @@ +package task + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + pkgerrors "github.com/pkg/errors" + "github.com/symflower/eval-dev-quality/evaluate/metrics" + "github.com/symflower/eval-dev-quality/language" + "github.com/symflower/eval-dev-quality/language/golang" + "github.com/symflower/eval-dev-quality/language/java" + "github.com/symflower/eval-dev-quality/log" + "github.com/symflower/eval-dev-quality/model" + evaltask "github.com/symflower/eval-dev-quality/task" +) + +// TaskTranspile holds the transpilation task. +type TaskTranspile struct{} + +// TaskArgumentsTranspile holds extra arguments to be used in a query prompt. +type TaskArgumentsTranspile struct { + // OriginLanguage holds the language we are transpiling from. + OriginLanguage language.Language + // OriginFilePath holds the path for the file containing the source code we want to transpile. + OriginFilePath string +} + +var _ evaltask.Task = (*TaskTranspile)(nil) + +// Identifier returns the transpilation task identifier. +func (t *TaskTranspile) Identifier() evaltask.Identifier { + return IdentifierTranspile +} + +// Run transpiles code between languages and runs predefined tests to check if the transpilation was successful. +func (t *TaskTranspile) Run(ctx evaltask.Context) (repositoryAssessment map[evaltask.Identifier]metrics.Assessments, problems []error, err error) { + modelCapability, ok := ctx.Model.(model.CapabilityTranspile) + if !ok { + return nil, nil, pkgerrors.Wrap(evaltask.ErrTaskUnsupportedByModel, fmt.Sprintf("%q does not support %q", ctx.Model.ID(), string(t.Identifier()))) + } + + taskLogger, err := newTaskLogger(ctx, t) + if err != nil { + return nil, nil, err + } + defer func() { + taskLogger.finalize(problems) + }() + + var packagePaths []string + files, err := os.ReadDir(ctx.Repository.DataPath()) + if err != nil { + return nil, nil, pkgerrors.WithStack(err) + } + for _, file := range files { + if file.IsDir() && !strings.HasPrefix(file.Name(), ".") { // Ignore hidden directories. + packagePaths = append(packagePaths, file.Name()) + } + } + + modelAssessments := metrics.NewAssessments() + withSymflowerAssessments := metrics.NewAssessments() + + maximumReachableFiles := uint64(len(packagePaths)) + modelAssessments[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles + withSymflowerAssessments[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles + + for _, packagePath := range packagePaths { + modelAssessmentsForFile := metrics.NewAssessments() + withSymflowerAssessmentsForFile := modelAssessmentsForFile // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until a failure actually happens. + + if err := ctx.Repository.Reset(ctx.Logger); err != nil { + ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) + } + + var originLanguage language.Language + if _, ok := ctx.Language.(*golang.Language); ok { + originLanguage = &java.Language{} + } else { + originLanguage = &golang.Language{} + } + + originFilePath, stubFilePath, err := t.unpackTranspilerPackage(ctx, taskLogger.Logger, originLanguage, packagePath) + if err != nil { + return nil, nil, err + } + + modelContext := model.Context{ + Language: ctx.Language, + + RepositoryPath: filepath.Join(ctx.Repository.DataPath(), packagePath), + FilePath: stubFilePath, + + Arguments: &TaskArgumentsTranspile{ + OriginLanguage: originLanguage, + OriginFilePath: originFilePath, + }, + + Logger: taskLogger.Logger, + } + assessments, err := modelCapability.Transpile(modelContext) + if err != nil { + problems = append(problems, pkgerrors.WithMessage(err, originFilePath)) + + continue + } + if assessments[metrics.AssessmentKeyProcessingTime] == 0 { + return nil, nil, pkgerrors.Errorf("no model response time measurement present for %q at repository %q", ctx.Model.ID(), ctx.Repository.Name()) + } + modelAssessmentsForFile.Add(assessments) + modelAssessmentsForFile.Award(metrics.AssessmentKeyResponseNoError) + + testResult, ps, err := ctx.Language.ExecuteTests(taskLogger.Logger, filepath.Join(ctx.Repository.DataPath(), packagePath)) + problems = append(problems, ps...) + if err != nil { + problems = append(problems, pkgerrors.WithMessage(err, originFilePath)) + + // If there is an execution timeout do not run "symflower fix" because the code itself is correct. + if errors.Is(err, context.DeadlineExceeded) { + modelAssessments.Add(modelAssessmentsForFile) + withSymflowerAssessments.Add(withSymflowerAssessmentsForFile) + + continue + } + + // Run "symflower fix" if the model response fails to execute. + if ctx.Language.ID() == "golang" { // Currently we only support Go for "symflower fix". + withSymflowerFixTestResult, processingTime, ps, err := ExecuteWithSymflowerFix(ctx, taskLogger.Logger, filepath.Join(ctx.Repository.DataPath(), packagePath)) + problems = append(problems, ps...) + if err != nil { + problems = append(problems, err) + + modelAssessments.Add(modelAssessmentsForFile) + withSymflowerAssessments.Add(withSymflowerAssessmentsForFile) + + continue + } else { + testsPassing := withSymflowerFixTestResult.TestsPass / withSymflowerFixTestResult.TestsTotal * 100 + taskLogger.Printf("Executes tests with %d percent tests passing after \"symflower fix\"", testsPassing) + + // Symflower was able to fix a failure so now update the assessment with the improved results. + withSymflowerFixAssessments := metrics.NewAssessments() + withSymflowerFixAssessments[metrics.AssessmentKeyProcessingTime] = processingTime + withSymflowerFixAssessments.Award(metrics.AssessmentKeyFilesExecuted) + withSymflowerFixAssessments.AwardPoints(metrics.AssessmentKeyTestsPassing, uint64(testsPassing)) + + withSymflowerAssessmentsForFile = metrics.CombineWithSymflowerFixAssessments(modelAssessmentsForFile, withSymflowerFixAssessments) + } + } + } else { + testsPassing := testResult.TestsPass / testResult.TestsTotal * 100 + taskLogger.Printf("Executes tests with %d percent tests passing", testsPassing) + modelAssessmentsForFile.Award(metrics.AssessmentKeyFilesExecuted) + modelAssessmentsForFile.AwardPoints(metrics.AssessmentKeyTestsPassing, uint64(testsPassing)) + } + + modelAssessments.Add(modelAssessmentsForFile) + withSymflowerAssessments.Add(withSymflowerAssessmentsForFile) + } + + repositoryAssessment = map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: modelAssessments, + IdentifierTranspileSymflowerFix: withSymflowerAssessments, + } + + return repositoryAssessment, problems, nil +} + +// unpackTranspilerPackage checks if the testdata repository for the transpilation task is well-formed and returns the path to the implementation file and also the path to the file that holds the stub. +func (t *TaskTranspile) unpackTranspilerPackage(ctx evaltask.Context, fileLogger *log.Logger, originLanguage language.Language, packagePath string) (originFilePath string, stubFilePath string, err error) { + packagePathAbsolute := filepath.Join(ctx.Repository.DataPath(), packagePath) + // Check if the package path has a directory called "implementation" with a source file in the language to transpile from. + files, err := originLanguage.Files(fileLogger, filepath.Join(packagePathAbsolute, "implementation")) + if err != nil { + return "", "", pkgerrors.WithStack(err) + } else if len(files) != 1 { + return "", "", pkgerrors.Errorf("package %q in repository %q must have an \"implementation\" directory with just one %s source file to transpile", packagePath, ctx.Repository.Name(), originLanguage.Name()) + } else if strings.HasSuffix(files[0], originLanguage.DefaultTestFileSuffix()) { + return "", "", pkgerrors.Errorf("package %q in repository %q must have an \"implementation\" directory with only a %s source file, but found a test file %q", packagePath, ctx.Repository.Name(), originLanguage.Name(), originFilePath) + } + originFilePath = filepath.Join("implementation", files[0]) + + stubFilePath, err = packageSourceFile(fileLogger, packagePathAbsolute, ctx.Language) + if err != nil { + return "", "", err + } + + return originFilePath, stubFilePath, nil +} diff --git a/evaluate/task/task-transpile_test.go b/evaluate/task/task-transpile_test.go new file mode 100644 index 00000000..64c0031b --- /dev/null +++ b/evaluate/task/task-transpile_test.go @@ -0,0 +1,351 @@ +package task + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/symflower/eval-dev-quality/evaluate/metrics" + metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing" + tasktesting "github.com/symflower/eval-dev-quality/evaluate/task/testing" + "github.com/symflower/eval-dev-quality/language/golang" + "github.com/symflower/eval-dev-quality/language/java" + "github.com/symflower/eval-dev-quality/log" + modeltesting "github.com/symflower/eval-dev-quality/model/testing" + evaltask "github.com/symflower/eval-dev-quality/task" + "github.com/zimmski/osutil" + "github.com/zimmski/osutil/bytesutil" +) + +func TestTaskTranspileRun(t *testing.T) { + validate := func(t *testing.T, tc *tasktesting.TestCaseTask) { + t.Run(tc.Name, func(t *testing.T) { + task, err := TaskForIdentifier(IdentifierTranspile) + require.NoError(t, err) + tc.Task = task + + tc.Validate(t, + func(logger *log.Logger, testDataPath string, repositoryPathRelative string) (repository evaltask.Repository, cleanup func(), err error) { + return TemporaryRepository(logger, testDataPath, repositoryPathRelative) + }, + ) + }) + } + + t.Run("Transpile Java into Go", func(t *testing.T) { + { + temporaryDirectoryPath := t.TempDir() + + repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "transpile", "cascadingIfElse") + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "transpile", "cascadingIfElse"), repositoryPath)) + + modelMock := modeltesting.NewMockCapabilityTranspileNamed(t, "mocked-model") + + transpiledSourceFilePath := filepath.Join("cascadingIfElse.go") + transpiledSourceFileContent := bytesutil.StringTrimIndentations(` + package cascadingIfElse + + func cascadingIfElse(i int) int { + if i == 1 { + return 2 + } else if i == 3 { + return 4 + } else { + return 5 + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + validate(t, &tasktesting.TestCaseTask{ + Name: "Single test case", + + Model: modelMock, + Language: &golang.Language{}, + TestDataPath: temporaryDirectoryPath, + RepositoryPath: filepath.Join("golang", "transpile"), + + ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: metrics.Assessments{ + metrics.AssessmentKeyTestsPassing: 1000, + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + }, + IdentifierTranspileSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyTestsPassing: 1000, + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + }, + }, + ValidateLog: func(t *testing.T, data string) { + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#00") + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#01") + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#02") + }, + }) + } + { + temporaryDirectoryPath := t.TempDir() + + repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "transpile") + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "transpile", "cascadingIfElse"), filepath.Join(repositoryPath, "cascadingIfElse"))) + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "transpile", "isSorted"), filepath.Join(repositoryPath, "isSorted"))) + + modelMock := modeltesting.NewMockCapabilityTranspileNamed(t, "mocked-model") + + transpiledSourceFilePath := filepath.Join("cascadingIfElse.go") + transpiledSourceFileContent := bytesutil.StringTrimIndentations(` + package cascadingIfElse + + func cascadingIfElse(i int) int { + if i == 1 { + return 2 + } else if i == 3 { + return 4 + } else { + return 5 + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + transpiledSourceFilePath = filepath.Join("isSorted.go") + transpiledSourceFileContent = bytesutil.StringTrimIndentations(` + package isSorted + + func isSorted(a []int) bool { + i := 0 + for i < len(a)-1 && a[i] <= a[i+1] { + i++ + } + + return i == len(a)-1 + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + validate(t, &tasktesting.TestCaseTask{ + Name: "Multiple test cases", + + Model: modelMock, + Language: &golang.Language{}, + TestDataPath: temporaryDirectoryPath, + RepositoryPath: filepath.Join("golang", "transpile"), + + ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: metrics.Assessments{ + metrics.AssessmentKeyTestsPassing: 2000, + metrics.AssessmentKeyFilesExecuted: 2, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, + metrics.AssessmentKeyResponseNoError: 2, + }, + IdentifierTranspileSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyTestsPassing: 2000, + metrics.AssessmentKeyFilesExecuted: 2, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, + metrics.AssessmentKeyResponseNoError: 2, + }, + }, + ValidateLog: func(t *testing.T, data string) { + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#00") + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#01") + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#02") + + assert.Contains(t, data, "PASS: TestSymflowerIsSorted/#00") + assert.Contains(t, data, "PASS: TestSymflowerIsSorted/#01") + assert.Contains(t, data, "PASS: TestSymflowerIsSorted/#02") + assert.Contains(t, data, "PASS: TestSymflowerIsSorted/#03") + assert.Contains(t, data, "PASS: TestSymflowerIsSorted/#04") + }, + }) + } + }) + t.Run("Transpile Go into Java", func(t *testing.T) { + { + temporaryDirectoryPath := t.TempDir() + + repositoryPath := filepath.Join(temporaryDirectoryPath, "java", "transpile", "cascadingIfElse") + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "java", "transpile", "cascadingIfElse"), repositoryPath)) + + modelMock := modeltesting.NewMockCapabilityTranspileNamed(t, "mocked-model") + + transpiledSourceFilePath := filepath.Join("src", "main", "java", "com", "eval", "CascadingIfElse.java") + transpiledSourceFileContent := bytesutil.StringTrimIndentations(` + package com.eval; + + class CascadingIfElse { + static int cascadingIfElse(int i) { + if (i == 1) { + return 2; + } else if (i == 3) { + return 4; + } else { + return 5; + } + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + validate(t, &tasktesting.TestCaseTask{ + Name: "Single test case", + + Model: modelMock, + Language: &java.Language{}, + TestDataPath: temporaryDirectoryPath, + RepositoryPath: filepath.Join("java", "transpile"), + + ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: metrics.Assessments{ + metrics.AssessmentKeyTestsPassing: 1000, + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + }, + IdentifierTranspileSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyTestsPassing: 1000, + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + }, + }, + ValidateLog: func(t *testing.T, data string) { + assert.Contains(t, data, "BUILD SUCCESS") + }, + }) + } + { + temporaryDirectoryPath := t.TempDir() + + repositoryPath := filepath.Join(temporaryDirectoryPath, "java", "transpile") + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "java", "transpile", "cascadingIfElse"), filepath.Join(repositoryPath, "cascadingIfElse"))) + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "java", "transpile", "isSorted"), filepath.Join(repositoryPath, "isSorted"))) + + modelMock := modeltesting.NewMockCapabilityTranspileNamed(t, "mocked-model") + + transpiledSourceFilePath := filepath.Join("src", "main", "java", "com", "eval", "CascadingIfElse.java") + transpiledSourceFileContent := bytesutil.StringTrimIndentations(` + package com.eval; + + class CascadingIfElse { + static int cascadingIfElse(int i) { + if (i == 1) { + return 2; + } else if (i == 3) { + return 4; + } else { + return 5; + } + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + transpiledSourceFilePath = filepath.Join("src", "main", "java", "com", "eval", "IsSorted.java") + transpiledSourceFileContent = bytesutil.StringTrimIndentations(` + package com.eval; + + class IsSorted { + static boolean isSorted(int[] a) { + int i = 0; + while (i < a.length - 1 && a[i] <= a[i + 1]) { + i++; + } + + return i == a.length - 1; + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + validate(t, &tasktesting.TestCaseTask{ + Name: "Multiple test cases", + + Model: modelMock, + Language: &java.Language{}, + TestDataPath: temporaryDirectoryPath, + RepositoryPath: filepath.Join("java", "transpile"), + + ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: metrics.Assessments{ + metrics.AssessmentKeyTestsPassing: 2000, + metrics.AssessmentKeyFilesExecuted: 2, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, + metrics.AssessmentKeyResponseNoError: 2, + }, + IdentifierTranspileSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyTestsPassing: 2000, + metrics.AssessmentKeyFilesExecuted: 2, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 2, + metrics.AssessmentKeyResponseNoError: 2, + }, + }, + ValidateLog: func(t *testing.T, data string) { + assert.Contains(t, data, "BUILD SUCCESS") + }, + }) + } + }) + t.Run("Symflower fix", func(t *testing.T) { + { + temporaryDirectoryPath := t.TempDir() + + repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "transpile", "cascadingIfElse") + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "transpile", "cascadingIfElse"), repositoryPath)) + + modelMock := modeltesting.NewMockCapabilityTranspileNamed(t, "mocked-model") + + transpiledSourceFilePath := filepath.Join("cascadingIfElse.go") + transpiledSourceFileContent := bytesutil.StringTrimIndentations(` + package cascadingIfElse + + import "strings" + + func cascadingIfElse(i int) int { + if i == 1 { + return 2 + } else if i == 3 { + return 4 + } else { + return 5 + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + validate(t, &tasktesting.TestCaseTask{ + Name: "Model generated test with unused import", + + Model: modelMock, + Language: &golang.Language{}, + TestDataPath: temporaryDirectoryPath, + RepositoryPath: filepath.Join("golang", "transpile"), + + ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: metrics.Assessments{ + metrics.AssessmentKeyTestsPassing: 0, + metrics.AssessmentKeyResponseNoError: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + }, + IdentifierTranspileSymflowerFix: metrics.Assessments{ + metrics.AssessmentKeyTestsPassing: 1000, + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, + metrics.AssessmentKeyResponseNoError: 1, + }, + }, + ExpectedProblemContains: []string{ + "imported and not used", + }, + ValidateLog: func(t *testing.T, data string) { + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#00") + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#01") + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#02") + }, + }) + } + }) +} diff --git a/evaluate/task/task.go b/evaluate/task/task.go index d411461d..eedb025a 100644 --- a/evaluate/task/task.go +++ b/evaluate/task/task.go @@ -38,6 +38,10 @@ var ( IdentifierWriteTestsSymflowerFix = registerIdentifier("write-tests-symflower-fix") // IdentifierCodeRepair holds the identifier for the "code repair" task. IdentifierCodeRepair = registerIdentifier("code-repair") + // IdentifierTranspile holds the identifier for the "transpile" task. + IdentifierTranspile = registerIdentifier("transpile") + // IdentifierTranspileSymflowerFix holds the identifier for the "transpile" task with the "symflower fix" applied. + IdentifierTranspileSymflowerFix = registerIdentifier("transpile-symflower-fix") ) // TaskForIdentifier returns a task based on the task identifier. @@ -47,6 +51,8 @@ func TaskForIdentifier(taskIdentifier evaltask.Identifier) (task evaltask.Task, return &TaskWriteTests{}, nil case IdentifierCodeRepair: return &TaskCodeRepair{}, nil + case IdentifierTranspile: + return &TaskTranspile{}, nil default: return nil, pkgerrors.Wrap(evaltask.ErrTaskUnknown, string(taskIdentifier)) } diff --git a/model/capability.go b/model/capability.go index 0a62cb27..c9a688fb 100644 --- a/model/capability.go +++ b/model/capability.go @@ -13,3 +13,9 @@ type CapabilityRepairCode interface { // RepairCode queries the model to repair a source code with compilation error. RepairCode(ctx Context) (assessments metrics.Assessments, err error) } + +// CapabilityTranspile defines the capability of a model to transpile code. +type CapabilityTranspile interface { + // Transpile queries the model to transpile source code to another language. + Transpile(ctx Context) (assessments metrics.Assessments, err error) +} diff --git a/model/llm/llm.go b/model/llm/llm.go index 4b69766a..62e36687 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -116,6 +116,46 @@ func llmCodeRepairSourceFilePrompt(data *llmCodeRepairSourceFilePromptContext) ( return b.String(), nil } +// llmTranspileSourceFilePromptContext is the template context for a transpilation LLM prompt. +type llmTranspileSourceFilePromptContext struct { + // llmSourceFilePromptContext holds the context for a source file prompt. + llmSourceFilePromptContext + + // OriginLanguage holds the language we are transpiling from. + OriginLanguage language.Language + // OriginFileContent holds the code we want to transpile. + OriginFileContent string +} + +// llmTranspileSourceFilePromptTemplate is the template for generating an LLM transpilation prompt. +var llmTranspileSourceFilePromptTemplate = template.Must(template.New("model-llm-transpile-source-file-prompt").Parse(bytesutil.StringTrimIndentations(` + Given the following {{ .OriginLanguage.Name }} code file, transpile it into a {{ .Language.Name }} code file. + The response must contain only the transpiled {{ .Language.Name }} source code in a fenced code block and nothing else. + + ` + "```" + `{{ .OriginLanguage.ID }} + {{ .OriginFileContent }} + ` + "```" + ` + + The transpiled {{ .Language.Name }} code file must have the package "{{ .ImportPath }}" and the following signature: + + ` + "```" + `{{ .Language.ID }} + {{ .Code }} + ` + "```" + ` +`))) + +// llmTranspileSourceFilePrompt returns the prompt to transpile a source file. +func llmTranspileSourceFilePrompt(data *llmTranspileSourceFilePromptContext) (message string, err error) { + data.Code = strings.TrimSpace(data.Code) + data.OriginFileContent = strings.TrimSpace(data.OriginFileContent) + + var b strings.Builder + if err := llmTranspileSourceFilePromptTemplate.Execute(&b, data); err != nil { + return "", pkgerrors.WithStack(err) + } + + return b.String(), nil +} + var _ model.Model = (*Model)(nil) // ID returns the unique ID of this model. @@ -253,6 +293,66 @@ func (m *Model) RepairCode(ctx model.Context) (assessment metrics.Assessments, e return assessment, nil } +var _ model.CapabilityTranspile = (*Model)(nil) + +// Transpile queries the model to transpile source code to another language. +func (m *Model) Transpile(ctx model.Context) (assessment metrics.Assessments, err error) { + transpileArguments, ok := ctx.Arguments.(*evaluatetask.TaskArgumentsTranspile) + if !ok { + return nil, pkgerrors.Errorf("unexpected type %#v", ctx.Arguments) + } + + data, err := os.ReadFile(filepath.Join(ctx.RepositoryPath, ctx.FilePath)) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + stubFileContent := strings.TrimSpace(string(data)) + + data, err = os.ReadFile(filepath.Join(ctx.RepositoryPath, transpileArguments.OriginFilePath)) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + originFileContent := strings.TrimSpace(string(data)) + + importPath := ctx.Language.ImportPath(ctx.RepositoryPath, ctx.FilePath) + + request, err := llmTranspileSourceFilePrompt(&llmTranspileSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: ctx.Language, + + Code: stubFileContent, + FilePath: ctx.FilePath, + ImportPath: importPath, + }, + + OriginLanguage: transpileArguments.OriginLanguage, + OriginFileContent: originFileContent, + }) + if err != nil { + return nil, err + } + + response, duration, err := m.query(ctx.Logger, request) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + + assessment, originFileContent, err = prompt.ParseResponse(response) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + assessment[metrics.AssessmentKeyProcessingTime] = uint64(duration.Milliseconds()) + assessment[metrics.AssessmentKeyResponseCharacterCount] = uint64(len(response)) + assessment[metrics.AssessmentKeyGenerateTestsForFileCharacterCount] = uint64(len(originFileContent)) + + err = os.WriteFile(filepath.Join(ctx.RepositoryPath, ctx.FilePath), []byte(originFileContent), 0644) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + + return assessment, nil +} + var _ model.SetQueryAttempts = (*Model)(nil) // SetQueryAttempts sets the number of query attempts to perform when a model request errors in the process of solving a task. diff --git a/model/llm/llm_test.go b/model/llm/llm_test.go index f7583d99..0a3de5c4 100644 --- a/model/llm/llm_test.go +++ b/model/llm/llm_test.go @@ -393,3 +393,227 @@ func TestLLMCodeRepairSourceFilePrompt(t *testing.T) { `), }) } + +func TestLLMTranspileSourceFilePrompt(t *testing.T) { + type testCase struct { + Name string + + Data *llmTranspileSourceFilePromptContext + + ExpectedMessage string + } + + validate := func(t *testing.T, tc *testCase) { + t.Run(tc.Name, func(t *testing.T) { + actualMessage, actualErr := llmTranspileSourceFilePrompt(tc.Data) + require.NoError(t, actualErr) + + assert.Equal(t, tc.ExpectedMessage, actualMessage) + }) + } + + validate(t, &testCase{ + Name: "Transpile Go into Java", + + Data: &llmTranspileSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: &java.Language{}, + + Code: bytesutil.StringTrimIndentations(` + package com.eval; + + class Foobar { + static int foobar(int i) {} + } + `), + FilePath: "Foobar.java", + ImportPath: "com.eval", + }, + OriginLanguage: &golang.Language{}, + OriginFileContent: bytesutil.StringTrimIndentations(` + package foobar + + func foobar(i int) int { + return i + 1 + } + `), + }, + + ExpectedMessage: bytesutil.StringTrimIndentations(` + Given the following Go code file, transpile it into a Java code file. + The response must contain only the transpiled Java source code in a fenced code block and nothing else. + + ` + "```" + `golang + package foobar + + func foobar(i int) int { + return i + 1 + } + ` + "```" + ` + + The transpiled Java code file must have the package "com.eval" and the following signature: + + ` + "```" + `java + package com.eval; + + class Foobar { + static int foobar(int i) {} + } + ` + "```" + ` + `), + }) +} + +func TestModelTranspile(t *testing.T) { + type testCase struct { + Name string + + SetupMock func(t *testing.T, mockedProvider *providertesting.MockQuery) + + Language language.Language + OriginLanguage language.Language + + RepositoryPath string + OriginFilePath string + StubFilePath string + + ExpectedAssessment metrics.Assessments + ExpectedStubFileContent string + } + + validate := func(t *testing.T, tc *testCase) { + logOutput, logger := log.Buffer() + defer func() { + if t.Failed() { + t.Log(logOutput.String()) + } + }() + + temporaryPath := t.TempDir() + repositoryPath := filepath.Join(temporaryPath, filepath.Base(tc.RepositoryPath)) + require.NoError(t, osutil.CopyTree(tc.RepositoryPath, repositoryPath)) + + modelID := "some-model" + mock := providertesting.NewMockQuery(t) + tc.SetupMock(t, mock) + llm := NewModel(mock, modelID) + + ctx := model.Context{ + Language: tc.Language, + + RepositoryPath: repositoryPath, + FilePath: tc.StubFilePath, + + Arguments: &evaluatetask.TaskArgumentsTranspile{ + OriginLanguage: tc.OriginLanguage, + OriginFilePath: tc.OriginFilePath, + }, + + Logger: logger, + } + + actualAssessment, actualError := llm.Transpile(ctx) + assert.NoError(t, actualError) + metricstesting.AssertAssessmentsEqual(t, tc.ExpectedAssessment, actualAssessment) + + actualStubFileContent, err := os.ReadFile(filepath.Join(repositoryPath, tc.StubFilePath)) + assert.NoError(t, err) + + assert.Equal(t, strings.TrimSpace(tc.ExpectedStubFileContent), string(actualStubFileContent)) + } + + t.Run("Transpile Java into Go", func(t *testing.T) { + transpiledFileContent := bytesutil.StringTrimIndentations(` + package binarySearch + + func binarySearch(a []int, x int) int { + index := -1 + + min := 0 + max := len(a) - 1 + + for index == -1 && min <= max { + m := (min + max) / 2 + + if x == a[m] { + index = m + } else if x < a[m] { + max = m - 1 + } else { + min = m + 1 + } + } + + return index + } + `) + validate(t, &testCase{ + Name: "Binary search", + + SetupMock: func(t *testing.T, mockedProvider *providertesting.MockQuery) { + mockedProvider.On("Query", mock.Anything, "some-model", mock.Anything).Return("```\n"+transpiledFileContent+"```\n", nil) + }, + + Language: &golang.Language{}, + OriginLanguage: &java.Language{}, + + RepositoryPath: filepath.Join("..", "..", "testdata", "golang", "transpile", "binarySearch"), + OriginFilePath: filepath.Join("implementation", "BinarySearch.java"), + StubFilePath: filepath.Join("binarySearch.go"), + + ExpectedAssessment: metrics.Assessments{ + metrics.AssessmentKeyResponseNoExcess: 1, + metrics.AssessmentKeyResponseWithCode: 1, + }, + ExpectedStubFileContent: transpiledFileContent, + }) + }) + t.Run("Transpile Go into Java", func(t *testing.T) { + transpiledFileContent := bytesutil.StringTrimIndentations(` + package com.eval; + + class BinarySearch { + static int binarySearch(int[] a, int x) { + int index = -1; + + int min = 0; + int max = a.length - 1; + + while (index == -1 && min <= max) { + int m = (min + max) / 2; + + if (x == a[m]) { + index = m; + } else if (x < a[m]) { + max = m - 1; + } else { + min = m + 1; + } + } + + return index; + } + } + `) + validate(t, &testCase{ + Name: "Binary Search", + + SetupMock: func(t *testing.T, mockedProvider *providertesting.MockQuery) { + mockedProvider.On("Query", mock.Anything, "some-model", mock.Anything).Return("```\n"+transpiledFileContent+"```\n", nil) + }, + + Language: &java.Language{}, + OriginLanguage: &golang.Language{}, + + RepositoryPath: filepath.Join("..", "..", "testdata", "java", "transpile", "binarySearch"), + OriginFilePath: filepath.Join("implementation", "binarySearch.go"), + StubFilePath: filepath.Join("src", "main", "java", "com", "eval", "BinarySearch.java"), + + ExpectedAssessment: metrics.Assessments{ + metrics.AssessmentKeyResponseNoExcess: 1, + metrics.AssessmentKeyResponseWithCode: 1, + }, + ExpectedStubFileContent: transpiledFileContent, + }) + }) +} diff --git a/model/testing/CapabilityTranspile_mock_gen.go b/model/testing/CapabilityTranspile_mock_gen.go new file mode 100644 index 00000000..9fdd36a4 --- /dev/null +++ b/model/testing/CapabilityTranspile_mock_gen.go @@ -0,0 +1,59 @@ +// Code generated by mockery v2.40.3. DO NOT EDIT. + +package modeltesting + +import ( + mock "github.com/stretchr/testify/mock" + metrics "github.com/symflower/eval-dev-quality/evaluate/metrics" + + model "github.com/symflower/eval-dev-quality/model" +) + +// MockCapabilityTranspile is an autogenerated mock type for the CapabilityTranspile type +type MockCapabilityTranspile struct { + mock.Mock +} + +// Transpile provides a mock function with given fields: ctx +func (_m *MockCapabilityTranspile) Transpile(ctx model.Context) (metrics.Assessments, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Transpile") + } + + var r0 metrics.Assessments + var r1 error + if rf, ok := ret.Get(0).(func(model.Context) (metrics.Assessments, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(model.Context) metrics.Assessments); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(metrics.Assessments) + } + } + + if rf, ok := ret.Get(1).(func(model.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewMockCapabilityTranspile creates a new instance of MockCapabilityTranspile. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCapabilityTranspile(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCapabilityTranspile { + mock := &MockCapabilityTranspile{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/model/testing/helper.go b/model/testing/helper.go index e984c45d..93f7fade 100644 --- a/model/testing/helper.go +++ b/model/testing/helper.go @@ -45,6 +45,14 @@ func (m *MockCapabilityRepairCode) RegisterGenerateError(err error) *mock.Call { return m.On("RepairCode", mock.Anything).Return(nil, err) } +// RegisterGenerateSuccess registers a mock call for successful generation. +func (m *MockCapabilityTranspile) RegisterGenerateSuccess(t *testing.T, filePath string, fileContent string, assessment metrics.Assessments) *mock.Call { + return m.On("Transpile", mock.Anything).Return(assessment, nil).Run(func(args mock.Arguments) { + ctx := args.Get(0).(model.Context) + require.NoError(t, os.WriteFile(filepath.Join(ctx.RepositoryPath, filePath), []byte(fileContent), 0600)) + }) +} + // MockModelCapabilityWriteTests holds a mock implementing the "Model" and the "CapabilityWriteTests" interface. type MockModelCapabilityWriteTests struct { *MockModel @@ -72,3 +80,17 @@ func NewMockCapabilityRepairCodeNamed(t *testing.T, id string) *MockModelCapabil MockCapabilityRepairCode: NewMockCapabilityRepairCode(t), } } + +// MockModelCapabilityTranspile holds a mock implementing the "Model" and the "CapabilityTranspile" interface. +type MockModelCapabilityTranspile struct { + *MockModel + *MockCapabilityTranspile +} + +// NewMockCapabilityTranspileNamed returns a new named mocked model. +func NewMockCapabilityTranspileNamed(t *testing.T, id string) *MockModelCapabilityTranspile { + return &MockModelCapabilityTranspile{ + MockModel: NewMockModelNamed(t, id), + MockCapabilityTranspile: NewMockCapabilityTranspile(t), + } +} diff --git a/testdata/golang/transpile/repository.json b/testdata/golang/transpile/repository.json new file mode 100644 index 00000000..ab88c46a --- /dev/null +++ b/testdata/golang/transpile/repository.json @@ -0,0 +1,5 @@ +{ + "tasks": [ + "transpile" + ] +} diff --git a/testdata/java/transpile/repository.json b/testdata/java/transpile/repository.json new file mode 100644 index 00000000..ab88c46a --- /dev/null +++ b/testdata/java/transpile/repository.json @@ -0,0 +1,5 @@ +{ + "tasks": [ + "transpile" + ] +}