From adfab2478c0454c86f3dbf84af301c2086d4ab40 Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Thu, 4 Jul 2024 16:12:07 +0100 Subject: [PATCH] Task for code transpilation, so models can transpile Go code to Java and back Closes #201 --- .mockery.yml | 1 + evaluate/task/task-transpile.go | 145 +++++++++ evaluate/task/task-transpile_test.go | 272 +++++++++++++++++ evaluate/task/task.go | 4 + model/capability.go | 6 + model/llm/llm.go | 100 +++++++ model/llm/llm_test.go | 276 ++++++++++++++++++ model/testing/CapabilityTranspile_mock_gen.go | 59 ++++ model/testing/helper.go | 22 ++ testdata/golang/transpile/repository.json | 5 + testdata/java/transpile/repository.json | 5 + 11 files changed, 895 insertions(+) create mode 100644 evaluate/task/task-transpile.go create mode 100644 evaluate/task/task-transpile_test.go create mode 100644 model/testing/CapabilityTranspile_mock_gen.go create mode 100644 testdata/golang/transpile/repository.json create mode 100644 testdata/java/transpile/repository.json diff --git a/.mockery.yml b/.mockery.yml index 3484c7c9..15c1cfe5 100644 --- a/.mockery.yml +++ b/.mockery.yml @@ -15,6 +15,7 @@ packages: Model: CapabilityWriteTests: CapabilityRepairCode: + CapabilityTranspile: github.com/symflower/eval-dev-quality/provider: interfaces: Loader: diff --git a/evaluate/task/task-transpile.go b/evaluate/task/task-transpile.go new file mode 100644 index 00000000..9a0ed4a2 --- /dev/null +++ b/evaluate/task/task-transpile.go @@ -0,0 +1,145 @@ +package task + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + pkgerrors "github.com/pkg/errors" + "github.com/symflower/eval-dev-quality/evaluate/metrics" + "github.com/symflower/eval-dev-quality/language" + "github.com/symflower/eval-dev-quality/language/golang" + "github.com/symflower/eval-dev-quality/language/java" + "github.com/symflower/eval-dev-quality/log" + "github.com/symflower/eval-dev-quality/model" + evaltask "github.com/symflower/eval-dev-quality/task" +) + +// TaskTranspile holds the transpilation task. +type TaskTranspile struct{} + +// TaskArgumentsTranspile holds extra arguments to be used in a query prompt. +type TaskArgumentsTranspile struct { + // SourceLanguage holds the source language for transpilation. + SourceLanguage language.Language + // StubFilePath holds the path for the file containing just the function signature of the language we are transpiling to. + StubFilePath string +} + +var _ evaltask.Task = (*TaskTranspile)(nil) + +// Identifier returns the transpilation task identifier. +func (t *TaskTranspile) Identifier() evaltask.Identifier { + return IdentifierTranspile +} + +// Run transpiles code between Go and Java and runs predefined tests to check if the transpilation was successful. +func (t *TaskTranspile) Run(ctx evaltask.Context) (repositoryAssessment map[evaltask.Identifier]metrics.Assessments, problems []error, err error) { + modelCapability, ok := ctx.Model.(model.CapabilityTranspile) + if !ok { + return nil, nil, pkgerrors.Wrap(evaltask.ErrTaskUnsupportedByModel, fmt.Sprintf("%q does not support %q", ctx.Model.ID(), string(t.Identifier()))) + } + + taskLogger, err := newTaskLogger(ctx, t) + if err != nil { + return nil, nil, err + } + defer func() { + taskLogger.finalize(problems) + }() + + var packagePaths []string + files, err := os.ReadDir(ctx.Repository.DataPath()) + if err != nil { + return nil, nil, pkgerrors.WithStack(err) + } + for _, file := range files { + if file.IsDir() && !strings.HasPrefix(file.Name(), ".") { // Ignore hidden directories. + packagePaths = append(packagePaths, filepath.Join(ctx.Repository.DataPath(), file.Name())) + } + } + + modelAssessments := metrics.NewAssessments() + for _, packagePath := range packagePaths { + if err := ctx.Repository.Reset(ctx.Logger); err != nil { + ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) + } + + var sourceLanguage language.Language + if _, ok := ctx.Language.(*golang.Language); ok { + sourceLanguage = &java.Language{} + } else { + sourceLanguage = &golang.Language{} + } + + sourceFilePath, stubFilePath, err := t.unpackTranspilerPackage(ctx, taskLogger.Logger, sourceLanguage, packagePath) + if err != nil { + return nil, nil, err + } + + modelContext := model.Context{ + Language: ctx.Language, + + RepositoryPath: packagePath, + FilePath: sourceFilePath, + + Arguments: &TaskArgumentsTranspile{ + SourceLanguage: sourceLanguage, + StubFilePath: stubFilePath, + }, + + Logger: taskLogger.Logger, + } + assessments, err := modelCapability.Transpile(modelContext) + if err != nil { + problems = append(problems, pkgerrors.WithMessage(err, sourceFilePath)) + + continue + } + if assessments[metrics.AssessmentKeyProcessingTime] == 0 { + return nil, nil, pkgerrors.Errorf("no model response time measurement present for %q at repository %q", ctx.Model.ID(), ctx.Repository.Name()) + } + modelAssessments.Add(assessments) + modelAssessments.Award(metrics.AssessmentKeyResponseNoError) + + coverage, ps, err := ctx.Language.Execute(taskLogger.Logger, packagePath) + problems = append(problems, ps...) + if err != nil { + problems = append(problems, pkgerrors.WithMessage(err, sourceFilePath)) + + continue + } + taskLogger.Printf("Executes tests with %d coverage objects", coverage) + modelAssessments.Award(metrics.AssessmentKeyFilesExecuted) + modelAssessments.AwardPoints(metrics.AssessmentKeyCoverage, coverage) + } + + repositoryAssessment = map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: modelAssessments, + } + + return repositoryAssessment, problems, nil +} + +// unpackTranspilerPackage checks if the testdata repository for the transpilation task is well-formed and returns the path to the implementation file and also the path to the file that holds the stub. +func (t *TaskTranspile) unpackTranspilerPackage(ctx evaltask.Context, fileLogger *log.Logger, sourceLanguage language.Language, packagePath string) (sourceFilePath string, stubFilePath string, err error) { + // Check if the package path has a directory called "implementation" with a source file in the language to transpile from. + files, err := sourceLanguage.Files(fileLogger, filepath.Join(packagePath, "implementation")) + if err != nil { + return "", "", pkgerrors.WithStack(err) + } else if len(files) != 1 { + return "", "", pkgerrors.Errorf("package %q in repository %q must have an \"implementation\" directory with just one %s source file to transpile", packagePath, ctx.Repository.Name(), sourceLanguage.Name()) + } else if strings.HasSuffix(files[0], sourceLanguage.DefaultTestFileSuffix()) { + return "", "", pkgerrors.Errorf("package %q in repository %q must have an \"implementation\" directory with only a %s source file, but found a test file %q", packagePath, ctx.Repository.Name(), sourceLanguage.Name(), sourceFilePath) + } + sourceFilePath = filepath.Join(packagePath, "implementation", sourceFilePath) + + stubFilePath, err = packageHasSourceAndTestFile(fileLogger, ctx.Repository.Name(), packagePath, ctx.Language) + if err != nil { + return "", "", err + } + stubFilePath = filepath.Join(packagePath, stubFilePath) + + return sourceFilePath, stubFilePath, nil +} diff --git a/evaluate/task/task-transpile_test.go b/evaluate/task/task-transpile_test.go new file mode 100644 index 00000000..6dfb3f5c --- /dev/null +++ b/evaluate/task/task-transpile_test.go @@ -0,0 +1,272 @@ +package task + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/symflower/eval-dev-quality/evaluate/metrics" + metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing" + tasktesting "github.com/symflower/eval-dev-quality/evaluate/task/testing" + "github.com/symflower/eval-dev-quality/language/golang" + "github.com/symflower/eval-dev-quality/language/java" + "github.com/symflower/eval-dev-quality/log" + modeltesting "github.com/symflower/eval-dev-quality/model/testing" + evaltask "github.com/symflower/eval-dev-quality/task" + "github.com/zimmski/osutil" + "github.com/zimmski/osutil/bytesutil" +) + +func TestTaskTranspileRun(t *testing.T) { + validate := func(t *testing.T, tc *tasktesting.TestCaseTask) { + t.Run(tc.Name, func(t *testing.T) { + task, err := TaskForIdentifier(IdentifierTranspile) + require.NoError(t, err) + tc.Task = task + + tc.Validate(t, + func(logger *log.Logger, testDataPath string, repositoryPathRelative string) (repository evaltask.Repository, cleanup func(), err error) { + return TemporaryRepository(logger, testDataPath, repositoryPathRelative) + }, + ) + }) + } + + t.Run("Transpile Java into Go", func(t *testing.T) { + { + temporaryDirectoryPath := t.TempDir() + + repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "transpile", "cascadingIfElse") + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "transpile", "cascadingIfElse"), repositoryPath)) + + modelMock := modeltesting.NewMockCapabilityTranspileNamed(t, "mocked-model") + + transpiledSourceFilePath := filepath.Join("cascadingIfElse.go") + transpiledSourceFileContent := bytesutil.StringTrimIndentations(` + package cascadingIfElse + + func cascadingIfElse(i int) int { + if i == 1 { + return 2 + } else if i == 3 { + return 4 + } else { + return 5 + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + validate(t, &tasktesting.TestCaseTask{ + Name: "Single test case", + + Model: modelMock, + Language: &golang.Language{}, + TestDataPath: temporaryDirectoryPath, + RepositoryPath: filepath.Join("golang", "transpile"), + + ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: metrics.Assessments{ + metrics.AssessmentKeyCoverage: 30, + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyResponseNoError: 1, + }, + }, + ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){ + filepath.Join(string(IdentifierTranspile), "mocked-model", "golang", "golang", "transpile.log"): func(t *testing.T, filePath, data string) { + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#00") + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#01") + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#02") + }, + }, + }) + } + { + temporaryDirectoryPath := t.TempDir() + + repositoryPath := filepath.Join(temporaryDirectoryPath, "golang", "transpile") + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "transpile", "cascadingIfElse"), filepath.Join(repositoryPath, "cascadingIfElse"))) + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "golang", "transpile", "sort"), filepath.Join(repositoryPath, "sort"))) + + modelMock := modeltesting.NewMockCapabilityTranspileNamed(t, "mocked-model") + + transpiledSourceFilePath := filepath.Join("cascadingIfElse.go") + transpiledSourceFileContent := bytesutil.StringTrimIndentations(` + package cascadingIfElse + + func cascadingIfElse(i int) int { + if i == 1 { + return 2 + } else if i == 3 { + return 4 + } else { + return 5 + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + transpiledSourceFilePath = filepath.Join("sort.go") + transpiledSourceFileContent = bytesutil.StringTrimIndentations(` + package isSorted + + func isSorted(a []int) bool { + i := 0 + for i < len(a)-1 && a[i] <= a[i+1] { + i++ + } + + return i == len(a)-1 + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + validate(t, &tasktesting.TestCaseTask{ + Name: "Multiple test cases", + + Model: modelMock, + Language: &golang.Language{}, + TestDataPath: temporaryDirectoryPath, + RepositoryPath: filepath.Join("golang", "transpile"), + + ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: metrics.Assessments{ + metrics.AssessmentKeyCoverage: 50, + metrics.AssessmentKeyFilesExecuted: 2, + metrics.AssessmentKeyResponseNoError: 2, + }, + }, + ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){ + filepath.Join(string(IdentifierTranspile), "mocked-model", "golang", "golang", "transpile.log"): func(t *testing.T, filePath, data string) { + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#00") + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#01") + assert.Contains(t, data, "PASS: TestSymflowerCascadingIfElse/#02") + + assert.Contains(t, data, "PASS: TestSymflowerIsSorted/#00") + assert.Contains(t, data, "PASS: TestSymflowerIsSorted/#01") + assert.Contains(t, data, "PASS: TestSymflowerIsSorted/#02") + assert.Contains(t, data, "PASS: TestSymflowerIsSorted/#03") + assert.Contains(t, data, "PASS: TestSymflowerIsSorted/#04") + }, + }, + }) + } + }) + t.Run("Transpile Go into Java", func(t *testing.T) { + { + temporaryDirectoryPath := t.TempDir() + + repositoryPath := filepath.Join(temporaryDirectoryPath, "java", "transpile", "cascadingIfElse") + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "java", "transpile", "cascadingIfElse"), repositoryPath)) + + modelMock := modeltesting.NewMockCapabilityTranspileNamed(t, "mocked-model") + + transpiledSourceFilePath := filepath.Join("src", "main", "java", "com", "eval", "CascadingIfElse.java") + transpiledSourceFileContent := bytesutil.StringTrimIndentations(` + package com.eval; + + class CascadingIfElse { + static int cascadingIfElse(int i) { + if (i == 1) { + return 2; + } else if (i == 3) { + return 4; + } else { + return 5; + } + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + validate(t, &tasktesting.TestCaseTask{ + Name: "Single test case", + + Model: modelMock, + Language: &java.Language{}, + TestDataPath: temporaryDirectoryPath, + RepositoryPath: filepath.Join("java", "transpile"), + + ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: metrics.Assessments{ + metrics.AssessmentKeyCoverage: 80, + metrics.AssessmentKeyFilesExecuted: 1, + metrics.AssessmentKeyResponseNoError: 1, + }, + }, + ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){ + filepath.Join(string(IdentifierTranspile), "mocked-model", "java", "java", "transpile.log"): func(t *testing.T, filePath, data string) { + assert.Contains(t, data, "BUILD SUCCESS") + }, + }, + }) + } + { + temporaryDirectoryPath := t.TempDir() + + repositoryPath := filepath.Join(temporaryDirectoryPath, "java", "transpile") + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "java", "transpile", "cascadingIfElse"), filepath.Join(repositoryPath, "cascadingIfElse"))) + require.NoError(t, osutil.CopyTree(filepath.Join("..", "..", "testdata", "java", "transpile", "sort"), filepath.Join(repositoryPath, "sort"))) + + modelMock := modeltesting.NewMockCapabilityTranspileNamed(t, "mocked-model") + + transpiledSourceFilePath := filepath.Join("src", "main", "java", "com", "eval", "CascadingIfElse.java") + transpiledSourceFileContent := bytesutil.StringTrimIndentations(` + package com.eval; + + class CascadingIfElse { + static int cascadingIfElse(int i) { + if (i == 1) { + return 2; + } else if (i == 3) { + return 4; + } else { + return 5; + } + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + transpiledSourceFilePath = filepath.Join("src", "main", "java", "com", "eval", "Sort.java") + transpiledSourceFileContent = bytesutil.StringTrimIndentations(` + package com.eval; + + class Sort { + static boolean isSorted(int[] a) { + int i = 0; + while (i < a.length - 1 && a[i] <= a[i + 1]) { + i++; + } + + return i == a.length - 1; + } + } + `) + modelMock.RegisterGenerateSuccess(t, transpiledSourceFilePath, transpiledSourceFileContent, metricstesting.AssessmentsWithProcessingTime).Once() + + validate(t, &tasktesting.TestCaseTask{ + Name: "Multiple test cases", + + Model: modelMock, + Language: &java.Language{}, + TestDataPath: temporaryDirectoryPath, + RepositoryPath: filepath.Join("java", "transpile"), + + ExpectedRepositoryAssessment: map[evaltask.Identifier]metrics.Assessments{ + IdentifierTranspile: metrics.Assessments{ + metrics.AssessmentKeyCoverage: 140, + metrics.AssessmentKeyFilesExecuted: 2, + metrics.AssessmentKeyResponseNoError: 2, + }, + }, + ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){ + filepath.Join(string(IdentifierTranspile), "mocked-model", "java", "java", "transpile.log"): func(t *testing.T, filePath, data string) { + assert.Contains(t, data, "BUILD SUCCESS") + }, + }, + }) + } + }) +} diff --git a/evaluate/task/task.go b/evaluate/task/task.go index a41d7af6..43a4c517 100644 --- a/evaluate/task/task.go +++ b/evaluate/task/task.go @@ -39,6 +39,8 @@ var ( IdentifierWriteTestsSymflowerFix = registerIdentifier("write-tests-symflower-fix") // IdentifierCodeRepair holds the identifier for the "code repair" task. IdentifierCodeRepair = registerIdentifier("code-repair") + // IdentifierTranspile holds the identifier for the "transpile" task. + IdentifierTranspile = registerIdentifier("transpile") ) // TaskForIdentifier returns a task based on the task identifier. @@ -48,6 +50,8 @@ func TaskForIdentifier(taskIdentifier evaltask.Identifier) (task evaltask.Task, return &TaskWriteTests{}, nil case IdentifierCodeRepair: return &TaskCodeRepair{}, nil + case IdentifierTranspile: + return &TaskTranspile{}, nil default: return nil, pkgerrors.Wrap(evaltask.ErrTaskUnknown, string(taskIdentifier)) } diff --git a/model/capability.go b/model/capability.go index 0a62cb27..c9a688fb 100644 --- a/model/capability.go +++ b/model/capability.go @@ -13,3 +13,9 @@ type CapabilityRepairCode interface { // RepairCode queries the model to repair a source code with compilation error. RepairCode(ctx Context) (assessments metrics.Assessments, err error) } + +// CapabilityTranspile defines the capability of a model to transpile code. +type CapabilityTranspile interface { + // Transpile queries the model to transpile source code to another language. + Transpile(ctx Context) (assessments metrics.Assessments, err error) +} diff --git a/model/llm/llm.go b/model/llm/llm.go index ae3f7961..cce92ab5 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -130,6 +130,46 @@ func llmCodeRepairSourceFilePrompt(data *llmCodeRepairSourceFilePromptContext) ( return b.String(), nil } +// llmTranspileSourceFilePromptContext is the template context for a transpilation LLM prompt. +type llmTranspileSourceFilePromptContext struct { + // llmSourceFilePromptContext holds the context for a source file prompt. + llmSourceFilePromptContext + + // SourceLanguage holds the source language for transpilation. + SourceLanguage language.Language + // StubCode holds the function signature of the language we are transpiling to. + StubCode string +} + +// llmTranspileSourceFilePromptTemplate is the template for generating an LLM transpilation prompt. +var llmTranspileSourceFilePromptTemplate = template.Must(template.New("model-llm-transpile-source-file-prompt").Parse(bytesutil.StringTrimIndentations(` + Given the following {{ .Language.Name }} code file "{{ .FilePath }}" with package "{{ .ImportPath }}", transpile it into a {{ .TargetLanguage.Name }} source file. + The response must contain only the transpiled {{ .TargetLanguage.Name }} source code and nothing else. + + ` + "```" + `{{ .Language.ID }} + {{ .Code }} + ` + "```" + ` + + The transpiled {{ .TargetLanguage.Name }} code file must have the following signature and package: + + ` + "```" + `{{ .TargetLanguage.ID }} + {{ .StubCode }} + ` + "```" + ` +`))) + +// llmTranspileSourceFilePrompt returns the prompt to transpile a source file. +func llmTranspileSourceFilePrompt(data *llmTranspileSourceFilePromptContext) (message string, err error) { + data.Code = strings.TrimSpace(data.Code) + data.StubCode = strings.TrimSpace(data.StubCode) + + var b strings.Builder + if err := llmTranspileSourceFilePromptTemplate.Execute(&b, data); err != nil { + return "", pkgerrors.WithStack(err) + } + + return b.String(), nil +} + var _ model.Model = (*Model)(nil) // ID returns the unique ID of this model. @@ -272,6 +312,66 @@ func (m *Model) RepairCode(ctx model.Context) (assessment metrics.Assessments, e return assessment, nil } +var _ model.CapabilityTranspile = (*Model)(nil) + +// Transpile queries the model to transpile source code to another language. +func (m *Model) Transpile(ctx model.Context) (assessment metrics.Assessments, err error) { + transpileArguments, ok := ctx.Arguments.(*evaluatetask.TaskArgumentsTranspile) + if !ok { + return nil, pkgerrors.Errorf("unexpected type %#v", ctx.Arguments) + } + + data, err := os.ReadFile(filepath.Join(ctx.RepositoryPath, transpileArguments.StubFilePath)) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + stubFileContent := strings.TrimSpace(string(data)) + + data, err = os.ReadFile(filepath.Join(ctx.RepositoryPath, ctx.FilePath)) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + sourceFileContent := strings.TrimSpace(string(data)) + + importPath := ctx.Language.ImportPath(ctx.RepositoryPath, transpileArguments.StubFilePath) + + request, err := llmTranspileSourceFilePrompt(&llmTranspileSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: ctx.Language, + + Code: sourceFileContent, + FilePath: ctx.FilePath, + ImportPath: importPath, + }, + + SourceLanguage: transpileArguments.SourceLanguage, + StubCode: stubFileContent, + }) + if err != nil { + return nil, err + } + + response, duration, err := m.query(ctx.Logger, request) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + + assessment, sourceFileContent, err = prompt.ParseResponse(response) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + assessment[metrics.AssessmentKeyProcessingTime] = uint64(duration.Milliseconds()) + assessment[metrics.AssessmentKeyResponseCharacterCount] = uint64(len(response)) + assessment[metrics.AssessmentKeyGenerateTestsForFileCharacterCount] = uint64(len(sourceFileContent)) + + err = os.WriteFile(filepath.Join(ctx.RepositoryPath, transpileArguments.StubFilePath), []byte(sourceFileContent), 0644) + if err != nil { + return nil, pkgerrors.WithStack(err) + } + + return assessment, nil +} + // Cost returns the cost of the model. func (m *Model) Cost() (cost float64) { return m.cost diff --git a/model/llm/llm_test.go b/model/llm/llm_test.go index 284f1c03..2c0c470e 100644 --- a/model/llm/llm_test.go +++ b/model/llm/llm_test.go @@ -340,3 +340,279 @@ func TestLLMCodeRepairSourceFilePrompt(t *testing.T) { `), }) } + +func TestLLMTranspileSourceFilePrompt(t *testing.T) { + type testCase struct { + Name string + + Data *llmTranspileSourceFilePromptContext + + ExpectedMessage string + } + + validate := func(t *testing.T, tc *testCase) { + t.Run(tc.Name, func(t *testing.T) { + actualMessage, actualErr := llmTranspileSourceFilePrompt(tc.Data) + require.NoError(t, actualErr) + + assert.Equal(t, tc.ExpectedMessage, actualMessage) + }) + } + + validate(t, &testCase{ + Name: "Plain", + + Data: &llmTranspileSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: &golang.Language{}, + + Code: bytesutil.StringTrimIndentations(` + package foobar + + func foobar(i int) int { + return i + 1 + } + `), + FilePath: "/path/to/foobar.go", + ImportPath: "foobar", + }, + SourceLanguage: &java.Language{}, + StubCode: bytesutil.StringTrimIndentations(` + package com.eval; + + class Foobar { + static int foobar(int i) {} + } + `), + }, + + ExpectedMessage: bytesutil.StringTrimIndentations(` + Given the following Go code file "/path/to/foobar.go" with package "foobar", transpile it into a Java source file. + The response must contain only the transpiled Java source code and nothing else. + + ` + "```" + `golang + package foobar + + func foobar(i int) int { + return i + 1 + } + ` + "```" + ` + + The transpiled Java code file must have the following signature and package: + + ` + "```" + `java + package com.eval; + + class Foobar { + static int foobar(int i) {} + } + ` + "```" + ` + `), + }) +} + +func TestModelTranspile(t *testing.T) { + type testCase struct { + Name string + + SetupMock func(t *testing.T, mockedProvider *providertesting.MockQuery) + + Language language.Language + TargetLanguage language.Language + + RepositoryPath string + SourceFilePath string + StubFilePath string + + ExpectedAssessment metrics.Assessments + ExpectedStubFileContent string + } + + validate := func(t *testing.T, tc *testCase) { + logOutput, logger := log.Buffer() + defer func() { + if t.Failed() { + t.Log(logOutput.String()) + } + }() + + temporaryPath := t.TempDir() + repositoryPath := filepath.Join(temporaryPath, filepath.Base(tc.RepositoryPath)) + require.NoError(t, osutil.CopyTree(tc.RepositoryPath, repositoryPath)) + + modelID := "some-model" + mock := providertesting.NewMockQuery(t) + tc.SetupMock(t, mock) + llm := NewModel(mock, modelID) + + ctx := model.Context{ + Language: tc.Language, + + RepositoryPath: repositoryPath, + FilePath: tc.SourceFilePath, + + Arguments: &evaluatetask.TaskArgumentsTranspile{ + SourceLanguage: tc.TargetLanguage, + StubFilePath: tc.StubFilePath, + }, + + Logger: logger, + } + + actualAssessment, actualError := llm.Transpile(ctx) + assert.NoError(t, actualError) + metricstesting.AssertAssessmentsEqual(t, tc.ExpectedAssessment, actualAssessment) + + actualStubFileContent, err := os.ReadFile(filepath.Join(repositoryPath, tc.StubFilePath)) + assert.NoError(t, err) + + assert.Equal(t, strings.TrimSpace(bytesutil.StringTrimIndentations(tc.ExpectedStubFileContent)), string(actualStubFileContent)) + } + + t.Run("Transpile Go into Java", func(t *testing.T) { + stubFileContent := ` + package com.eval; + + class BinarySearch { + static int binarySearch(int[] a, int x) { + int index = -1; + + int min = 0; + int max = a.length - 1; + + while (index == -1 && min <= max) { + int m = (min + max) / 2; + + if (x == a[m]) { + index = m; + } else if (x < a[m]) { + max = m - 1; + } else { + min = m + 1; + } + } + + return index; + } + } + ` + validate(t, &testCase{ + Name: "Binary search", + + SetupMock: func(t *testing.T, mockedProvider *providertesting.MockQuery) { + mockedProvider.On("Query", mock.Anything, "some-model", mock.Anything).Return(bytesutil.StringTrimIndentations(` + `+"```"+` + package com.eval; + + class BinarySearch { + static int binarySearch(int[] a, int x) { + int index = -1; + + int min = 0; + int max = a.length - 1; + + while (index == -1 && min <= max) { + int m = (min + max) / 2; + + if (x == a[m]) { + index = m; + } else if (x < a[m]) { + max = m - 1; + } else { + min = m + 1; + } + } + + return index; + } + } + `+"```"+` + `), nil) + }, + + Language: &golang.Language{}, + TargetLanguage: &java.Language{}, + + RepositoryPath: filepath.Join("..", "..", "testdata", "golang", "transpile", "binarySearch"), + SourceFilePath: filepath.Join("implementation", "binarySearch.go"), + StubFilePath: filepath.Join("src", "main", "java", "com", "eval", "BinarySearch.java"), + + ExpectedAssessment: metrics.Assessments{ + metrics.AssessmentKeyResponseNoExcess: 1, + metrics.AssessmentKeyResponseWithCode: 1, + }, + ExpectedStubFileContent: stubFileContent, + }) + }) + t.Run("Transpile Java into Go", func(t *testing.T) { + stubFileContent := ` + package binarySearch + + func binarySearch(a []int, x int) int { + index := -1 + + min := 0 + max := len(a) - 1 + + for index == -1 && min <= max { + m := (min + max) / 2 + + if x == a[m] { + index = m + } else if x < a[m] { + max = m - 1 + } else { + min = m + 1 + } + } + + return index + } + ` + validate(t, &testCase{ + Name: "Binary Search", + + SetupMock: func(t *testing.T, mockedProvider *providertesting.MockQuery) { + mockedProvider.On("Query", mock.Anything, "some-model", mock.Anything).Return(bytesutil.StringTrimIndentations(` + `+"```"+` + package binarySearch + + func binarySearch(a []int, x int) int { + index := -1 + + min := 0 + max := len(a) - 1 + + for index == -1 && min <= max { + m := (min + max) / 2 + + if x == a[m] { + index = m + } else if x < a[m] { + max = m - 1 + } else { + min = m + 1 + } + } + + return index + } + `+"```"+` + `), nil) + }, + + Language: &java.Language{}, + TargetLanguage: &golang.Language{}, + + RepositoryPath: filepath.Join("..", "..", "testdata", "java", "transpile", "binarySearch"), + SourceFilePath: filepath.Join("implementation", "BinarySearch.java"), + StubFilePath: filepath.Join("binarySearch.go"), + + ExpectedAssessment: metrics.Assessments{ + metrics.AssessmentKeyResponseNoExcess: 1, + metrics.AssessmentKeyResponseWithCode: 1, + }, + ExpectedStubFileContent: stubFileContent, + }) + }) +} diff --git a/model/testing/CapabilityTranspile_mock_gen.go b/model/testing/CapabilityTranspile_mock_gen.go new file mode 100644 index 00000000..9fdd36a4 --- /dev/null +++ b/model/testing/CapabilityTranspile_mock_gen.go @@ -0,0 +1,59 @@ +// Code generated by mockery v2.40.3. DO NOT EDIT. + +package modeltesting + +import ( + mock "github.com/stretchr/testify/mock" + metrics "github.com/symflower/eval-dev-quality/evaluate/metrics" + + model "github.com/symflower/eval-dev-quality/model" +) + +// MockCapabilityTranspile is an autogenerated mock type for the CapabilityTranspile type +type MockCapabilityTranspile struct { + mock.Mock +} + +// Transpile provides a mock function with given fields: ctx +func (_m *MockCapabilityTranspile) Transpile(ctx model.Context) (metrics.Assessments, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Transpile") + } + + var r0 metrics.Assessments + var r1 error + if rf, ok := ret.Get(0).(func(model.Context) (metrics.Assessments, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(model.Context) metrics.Assessments); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(metrics.Assessments) + } + } + + if rf, ok := ret.Get(1).(func(model.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewMockCapabilityTranspile creates a new instance of MockCapabilityTranspile. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockCapabilityTranspile(t interface { + mock.TestingT + Cleanup(func()) +}) *MockCapabilityTranspile { + mock := &MockCapabilityTranspile{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/model/testing/helper.go b/model/testing/helper.go index f66a808c..65fea703 100644 --- a/model/testing/helper.go +++ b/model/testing/helper.go @@ -55,6 +55,14 @@ func (m *MockCapabilityRepairCode) RegisterGenerateError(err error) *mock.Call { return m.On("RepairCode", mock.Anything).Return(nil, err) } +// RegisterGenerateSuccess registers a mock call for successful generation. +func (m *MockCapabilityTranspile) RegisterGenerateSuccess(t *testing.T, filePath string, fileContent string, assessment metrics.Assessments) *mock.Call { + return m.On("Transpile", mock.Anything).Return(assessment, nil).Run(func(args mock.Arguments) { + ctx := args.Get(0).(model.Context) + require.NoError(t, os.WriteFile(filepath.Join(ctx.RepositoryPath, filePath), []byte(fileContent), 0600)) + }) +} + // MockModelCapabilityWriteTests holds a mock implementing the "Model" and the "CapabilityWriteTests" interface. type MockModelCapabilityWriteTests struct { *MockModel @@ -82,3 +90,17 @@ func NewMockCapabilityRepairCodeNamed(t *testing.T, id string) *MockModelCapabil MockCapabilityRepairCode: NewMockCapabilityRepairCode(t), } } + +// MockModelCapabilityTranspile holds a mock implementing the "Model" and the "CapabilityTranspile" interface. +type MockModelCapabilityTranspile struct { + *MockModel + *MockCapabilityTranspile +} + +// NewMockCapabilityTranspileNamed returns a new named mocked model. +func NewMockCapabilityTranspileNamed(t *testing.T, id string) *MockModelCapabilityTranspile { + return &MockModelCapabilityTranspile{ + MockModel: NewMockModelNamed(t, id), + MockCapabilityTranspile: NewMockCapabilityTranspile(t), + } +} diff --git a/testdata/golang/transpile/repository.json b/testdata/golang/transpile/repository.json new file mode 100644 index 00000000..ab88c46a --- /dev/null +++ b/testdata/golang/transpile/repository.json @@ -0,0 +1,5 @@ +{ + "tasks": [ + "transpile" + ] +} diff --git a/testdata/java/transpile/repository.json b/testdata/java/transpile/repository.json new file mode 100644 index 00000000..ab88c46a --- /dev/null +++ b/testdata/java/transpile/repository.json @@ -0,0 +1,5 @@ +{ + "tasks": [ + "transpile" + ] +}