diff --git a/cmd/eval-dev-quality/cmd/report.go b/cmd/eval-dev-quality/cmd/report.go index 7aba7c2c..b7bcedeb 100644 --- a/cmd/eval-dev-quality/cmd/report.go +++ b/cmd/eval-dev-quality/cmd/report.go @@ -12,7 +12,8 @@ import ( "github.com/symflower/eval-dev-quality/evaluate" "github.com/symflower/eval-dev-quality/evaluate/report" "github.com/symflower/eval-dev-quality/log" - "github.com/symflower/eval-dev-quality/util" + "github.com/symflower/eval-dev-quality/model" + "github.com/symflower/eval-dev-quality/provider/openrouter" ) // Report holds the "report" command. @@ -80,6 +81,29 @@ func (command *Report) Execute(args []string) (err error) { command.logger.Panicf("ERROR: %s", err) } + // Create a CSV file that holds the models meta information. + modelsMetaInformationCSVFile, err := os.OpenFile(filepath.Join(command.ResultPath, "meta.csv"), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0755) + if err != nil { + command.logger.Panicf("ERROR: %s", err) + } + defer modelsMetaInformationCSVFile.Close() + + // Fetch all openrouter models since it is the only provider that currently supports querying meta information. + provider := openrouter.NewProvider().(*openrouter.Provider) + models, err := provider.Models() + if err != nil { + command.logger.Panicf("ERROR: %s", err) + } + var modelsMetaInformation []*model.MetaInformation + for _, model := range models { + modelsMetaInformation = append(modelsMetaInformation, model.MetaInformation()) + } + metaInformationRecords := report.MetaInformationRecords(modelsMetaInformation) + // Write models meta information to disk. + if err := report.WriteMetaInformationRecords(modelsMetaInformationCSVFile, metaInformationRecords); err != nil { + command.logger.Panicf("ERROR: %s", err) + } + // Write markdown reports. assessmentsPerModel, err := report.RecordsToAssessmentsPerModel(records) if err != nil { diff --git a/cmd/eval-dev-quality/cmd/report_test.go b/cmd/eval-dev-quality/cmd/report_test.go index 9f09119b..ee90b59e 100644 --- a/cmd/eval-dev-quality/cmd/report_test.go +++ b/cmd/eval-dev-quality/cmd/report_test.go @@ -170,6 +170,17 @@ func TestReportExecute(t *testing.T) { expectedContent := fmt.Sprintf("%s\n%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent) assert.Equal(t, expectedContent, data) }, + filepath.Join("result-directory", "meta.csv"): func(t *testing.T, filePath, data string) { + records := strings.Split(data, "\n") + // Check if there are at least 3 records, excluding the CSV header. + require.Greater(t, len(records[1:]), 3) + // Check if the records are different. + uniqueRecords := map[string]bool{} + for _, record := range records[1:4] { + uniqueRecords[record] = true + } + assert.Equal(t, len(uniqueRecords), 3) + }, }, }) validate(t, &testCase{ @@ -213,6 +224,7 @@ func TestReportExecute(t *testing.T) { expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent) assert.Equal(t, expectedContent, data) }, + filepath.Join("result-directory", "meta.csv"): nil, }, }) validate(t, &testCase{ @@ -253,6 +265,7 @@ func TestReportExecute(t *testing.T) { expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent) assert.Equal(t, expectedContent, data) }, + filepath.Join("result-directory", "meta.csv"): nil, }, }) } diff --git a/evaluate/report/csv.go b/evaluate/report/csv.go index a0deb79e..90590bdf 100644 --- a/evaluate/report/csv.go +++ b/evaluate/report/csv.go @@ -140,6 +140,45 @@ func assessmentFromRecord(assessmentFields []string) (assessments metrics.Assess return assessments, nil } +// MetaInformationRecords converts the models meta information into sorted CSV records. +func MetaInformationRecords(modelsMetaInformation []*model.MetaInformation) (records [][]string) { + records = [][]string{} + + for _, metaInformation := range modelsMetaInformation { + records = append(records, []string{ + metaInformation.ID, + metaInformation.Name, + strconv.FormatFloat(metaInformation.Pricing.Completion, 'f', -1, 64), + strconv.FormatFloat(metaInformation.Pricing.Image, 'f', -1, 64), + strconv.FormatFloat(metaInformation.Pricing.Prompt, 'f', -1, 64), + strconv.FormatFloat(metaInformation.Pricing.Request, 'f', -1, 64), + }) + } + SortRecords(records) + + return records +} + +// WriteMetaInformationRecords writes the meta information records into a CSV file. +func WriteMetaInformationRecords(writer io.Writer, records [][]string) (err error) { + return WriteCSV(writer, []string{"model-id", "model-name", "completion", "image", "prompt", "request"}, records) +} + +// WriteCSV writes a header and records to a CSV file. +func WriteCSV(writer io.Writer, header []string, records [][]string) (err error) { + csv := csv.NewWriter(writer) + + if err := csv.Write(header); err != nil { + return pkgerrors.WithStack(err) + } + if err := csv.WriteAll(records); err != nil { + return pkgerrors.WithStack(err) + } + csv.Flush() + + return nil +} + // SortRecords sorts CSV records. func SortRecords(records [][]string) { sort.Slice(records, func(i, j int) bool { diff --git a/evaluate/report/csv_test.go b/evaluate/report/csv_test.go index 82beecab..1f143e75 100644 --- a/evaluate/report/csv_test.go +++ b/evaluate/report/csv_test.go @@ -15,6 +15,7 @@ import ( "github.com/symflower/eval-dev-quality/evaluate/metrics" evaluatetask "github.com/symflower/eval-dev-quality/evaluate/task" languagetesting "github.com/symflower/eval-dev-quality/language/testing" + "github.com/symflower/eval-dev-quality/model" modeltesting "github.com/symflower/eval-dev-quality/model/testing" "github.com/symflower/eval-dev-quality/task" ) @@ -480,3 +481,146 @@ func TestRecordsToAssessmentsPerModel(t *testing.T) { }, }) } + +func TestWriteMetaInformationRecords(t *testing.T) { + var file strings.Builder + + err := WriteMetaInformationRecords(&file, [][]string{ + []string{"provider/modelA", "modelA", "0.1", "0.2", "0.3", "0.4"}, + []string{"provider/modelB", "modelB", "0.01", "0.02", "0.03", "0.04"}, + []string{"provider/modelC", "modelC", "0.001", "0.002", "0.003", "0.004"}, + []string{"provider/modelD", "modelD", "0.0001", "0.0002", "0.0003", "0.0004"}, + []string{"provider/modelE", "modelE", "0.00001", "0.00002", "0.00003", "0.00004"}, + }) + require.NoError(t, err) + + assert.Equal(t, bytesutil.StringTrimIndentations(` + model-id,model-name,completion,image,prompt,request + provider/modelA,modelA,0.1,0.2,0.3,0.4 + provider/modelB,modelB,0.01,0.02,0.03,0.04 + provider/modelC,modelC,0.001,0.002,0.003,0.004 + provider/modelD,modelD,0.0001,0.0002,0.0003,0.0004 + provider/modelE,modelE,0.00001,0.00002,0.00003,0.00004 + `), file.String()) +} + +func TestMetaInformationRecords(t *testing.T) { + actualRecords := MetaInformationRecords([]*model.MetaInformation{ + &model.MetaInformation{ + ID: "provider/modelA", + Name: "modelA", + Pricing: model.Pricing{ + Completion: 0.1, + Image: 0.2, + Prompt: 0.3, + Request: 0.4, + }, + }, + &model.MetaInformation{ + ID: "provider/modelB", + Name: "modelB", + Pricing: model.Pricing{ + Completion: 0.01, + Image: 0.02, + Prompt: 0.03, + Request: 0.04, + }, + }, + &model.MetaInformation{ + ID: "provider/modelC", + Name: "modelC", + Pricing: model.Pricing{ + Completion: 0.001, + Image: 0.002, + Prompt: 0.003, + Request: 0.004, + }, + }, + &model.MetaInformation{ + ID: "provider/modelD", + Name: "modelD", + Pricing: model.Pricing{ + Completion: 0.0001, + Image: 0.0002, + Prompt: 0.0003, + Request: 0.0004, + }, + }, + &model.MetaInformation{ + ID: "provider/modelE", + Name: "modelE", + Pricing: model.Pricing{ + Completion: 0.00001, + Image: 0.00002, + Prompt: 0.00003, + Request: 0.00004, + }, + }, + }) + + assert.ElementsMatch(t, [][]string{ + []string{"provider/modelA", "modelA", "0.1", "0.2", "0.3", "0.4"}, + []string{"provider/modelB", "modelB", "0.01", "0.02", "0.03", "0.04"}, + []string{"provider/modelC", "modelC", "0.001", "0.002", "0.003", "0.004"}, + []string{"provider/modelD", "modelD", "0.0001", "0.0002", "0.0003", "0.0004"}, + []string{"provider/modelE", "modelE", "0.00001", "0.00002", "0.00003", "0.00004"}, + }, actualRecords) +} + +func TestWriteCSV(t *testing.T) { + type testCase struct { + Name string + + Header []string + Records [][]string + + ExpectedContent string + } + + validate := func(t *testing.T, tc *testCase) { + t.Run(tc.Name, func(t *testing.T) { + var file strings.Builder + actualErr := WriteCSV(&file, tc.Header, tc.Records) + require.NoError(t, actualErr) + + assert.Equal(t, bytesutil.StringTrimIndentations(tc.ExpectedContent), file.String()) + }) + } + + validate(t, &testCase{ + Name: "Single record", + + Header: []string{ + "model-id", "price", "score", + }, + + Records: [][]string{ + []string{"modelA", "0.01", "1000"}, + }, + + ExpectedContent: ` + model-id,price,score + modelA,0.01,1000 + `, + }) + validate(t, &testCase{ + Name: "Multiple records", + + Header: []string{ + "model-id", "price", "score", + }, + + Records: [][]string{ + []string{"modelA", "0.01", "1000"}, + []string{"modelB", "0.02", "2000"}, + []string{"modelC", "0.03", "3000"}, + }, + + ExpectedContent: ` + model-id,price,score + modelA,0.01,1000 + modelB,0.02,2000 + modelC,0.03,3000 + `, + }) +}