diff --git a/cmd/eval-dev-quality/cmd/report.go b/cmd/eval-dev-quality/cmd/report.go index 6e4f3616..4965b762 100644 --- a/cmd/eval-dev-quality/cmd/report.go +++ b/cmd/eval-dev-quality/cmd/report.go @@ -12,7 +12,8 @@ import ( "github.com/symflower/eval-dev-quality/evaluate" "github.com/symflower/eval-dev-quality/evaluate/report" "github.com/symflower/eval-dev-quality/log" - "github.com/symflower/eval-dev-quality/util" + "github.com/symflower/eval-dev-quality/model" + "github.com/symflower/eval-dev-quality/provider" ) // Report holds the "report" command. @@ -80,6 +81,43 @@ func (command *Report) Execute(args []string) (err error) { command.logger.Panicf("ERROR: %s", err) } + // Create a CSV file that holds the models meta information. + var modelsMetaInformationCSVFile *os.File + if modelsMetaInformationCSVFile, err = os.OpenFile(filepath.Join(command.ResultPath, "meta.csv"), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0755); err != nil { + command.logger.Panicf("ERROR: %s", err) + } + defer modelsMetaInformationCSVFile.Close() + + // Fetch all models meta information. + var modelsMetaInformation []*model.MetaInformation + for _, p := range provider.Providers { + // Start services of providers. + if service, ok := p.(provider.Service); ok { + command.logger.Printf("Starting services for provider %q", p.ID()) + shutdown, err := service.Start(command.logger) + if err != nil { + command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err) + } + defer shutdown() + } + + models, err := p.Models() + if err != nil { + command.logger.Panicf("ERROR: %s", err) + } + for _, model := range models { + if modelMetaInformation := model.MetaInformation(); modelMetaInformation != nil { + modelsMetaInformation = append(modelsMetaInformation, model.MetaInformation()) + } + } + } + metaInformationRecords := report.MetaInformationRecords(modelsMetaInformation) + report.SortEvaluationRecords(metaInformationRecords) + // Write models meta information to disk. + if err := report.WriteMetaInformationRecords(modelsMetaInformationCSVFile, metaInformationRecords); err != nil { + command.logger.Panicf("ERROR: %s", err) + } + // Write markdown reports. assessmentsPerModel, err := report.RecordsToAssessmentsPerModel(records) if err != nil { diff --git a/cmd/eval-dev-quality/cmd/report_test.go b/cmd/eval-dev-quality/cmd/report_test.go index 9f09119b..a16d8b77 100644 --- a/cmd/eval-dev-quality/cmd/report_test.go +++ b/cmd/eval-dev-quality/cmd/report_test.go @@ -170,6 +170,7 @@ func TestReportExecute(t *testing.T) { expectedContent := fmt.Sprintf("%s\n%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent) assert.Equal(t, expectedContent, data) }, + filepath.Join("result-directory", "meta.csv"): nil, }, }) validate(t, &testCase{ @@ -213,6 +214,7 @@ func TestReportExecute(t *testing.T) { expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent) assert.Equal(t, expectedContent, data) }, + filepath.Join("result-directory", "meta.csv"): nil, }, }) validate(t, &testCase{ @@ -253,6 +255,7 @@ func TestReportExecute(t *testing.T) { expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent) assert.Equal(t, expectedContent, data) }, + filepath.Join("result-directory", "meta.csv"): nil, }, }) } diff --git a/evaluate/report/csv.go b/evaluate/report/csv.go index 7b1e9594..df76bf46 100644 --- a/evaluate/report/csv.go +++ b/evaluate/report/csv.go @@ -140,6 +140,40 @@ func assessmentFromRecord(assessmentFields []string) (assessments metrics.Assess return assessments, nil } +// MetaInformationRecords converts the models meta information into CSV records. +func MetaInformationRecords(modelsMetaInformation []*model.MetaInformation) (records [][]string) { + records = [][]string{} + + for _, metaInformation := range modelsMetaInformation { + records = append(records, []string{ + metaInformation.ID, + metaInformation.Name, + strconv.FormatFloat(metaInformation.Pricing.Completion, 'f', -1, 64), + strconv.FormatFloat(metaInformation.Pricing.Image, 'f', -1, 64), + strconv.FormatFloat(metaInformation.Pricing.Prompt, 'f', -1, 64), + strconv.FormatFloat(metaInformation.Pricing.Request, 'f', -1, 64), + }) + } + + return records +} + +// WriteMetaInformationRecords writes the meta information records into a CSV file. +func WriteMetaInformationRecords(writer io.Writer, records [][]string) (err error) { + csv := csv.NewWriter(writer) + + header := []string{"model-id", "model-name", "completion", "image", "prompt", "request"} + if err := csv.Write(header); err != nil { + return pkgerrors.WithStack(err) + } + if err := csv.WriteAll(records); err != nil { + return pkgerrors.WithStack(err) + } + csv.Flush() + + return nil +} + // SortEvaluationRecords sorts the evaluation records. func SortEvaluationRecords(records [][]string) { sort.Slice(records, func(i, j int) bool { diff --git a/evaluate/report/csv_test.go b/evaluate/report/csv_test.go index 01b737e0..c7de0628 100644 --- a/evaluate/report/csv_test.go +++ b/evaluate/report/csv_test.go @@ -15,6 +15,7 @@ import ( "github.com/symflower/eval-dev-quality/evaluate/metrics" evaluatetask "github.com/symflower/eval-dev-quality/evaluate/task" languagetesting "github.com/symflower/eval-dev-quality/language/testing" + "github.com/symflower/eval-dev-quality/model" modeltesting "github.com/symflower/eval-dev-quality/model/testing" "github.com/symflower/eval-dev-quality/task" ) @@ -480,3 +481,88 @@ func TestRecordsToAssessmentsPerModel(t *testing.T) { }, }) } + +func TestWriteMetaInformationRecords(t *testing.T) { + var file strings.Builder + + err := WriteMetaInformationRecords(&file, [][]string{ + []string{"provider/modelA", "modelA", "0.1", "0.2", "0.3", "0.4"}, + []string{"provider/modelB", "modelB", "0.01", "0.02", "0.03", "0.04"}, + []string{"provider/modelC", "modelC", "0.001", "0.002", "0.003", "0.004"}, + []string{"provider/modelD", "modelD", "0.0001", "0.0002", "0.0003", "0.0004"}, + []string{"provider/modelE", "modelE", "0.00001", "0.00002", "0.00003", "0.00004"}, + }) + require.NoError(t, err) + + assert.Equal(t, bytesutil.StringTrimIndentations(` + model-id,model-name,completion,image,prompt,request + provider/modelA,modelA,0.1,0.2,0.3,0.4 + provider/modelB,modelB,0.01,0.02,0.03,0.04 + provider/modelC,modelC,0.001,0.002,0.003,0.004 + provider/modelD,modelD,0.0001,0.0002,0.0003,0.0004 + provider/modelE,modelE,0.00001,0.00002,0.00003,0.00004 + `), file.String()) +} + +func TestMetaInformationRecords(t *testing.T) { + actualRecords := MetaInformationRecords([]*model.MetaInformation{ + &model.MetaInformation{ + ID: "provider/modelA", + Name: "modelA", + Pricing: model.Pricing{ + Completion: 0.1, + Image: 0.2, + Prompt: 0.3, + Request: 0.4, + }, + }, + &model.MetaInformation{ + ID: "provider/modelB", + Name: "modelB", + Pricing: model.Pricing{ + Completion: 0.01, + Image: 0.02, + Prompt: 0.03, + Request: 0.04, + }, + }, + &model.MetaInformation{ + ID: "provider/modelC", + Name: "modelC", + Pricing: model.Pricing{ + Completion: 0.001, + Image: 0.002, + Prompt: 0.003, + Request: 0.004, + }, + }, + &model.MetaInformation{ + ID: "provider/modelD", + Name: "modelD", + Pricing: model.Pricing{ + Completion: 0.0001, + Image: 0.0002, + Prompt: 0.0003, + Request: 0.0004, + }, + }, + &model.MetaInformation{ + ID: "provider/modelE", + Name: "modelE", + Pricing: model.Pricing{ + Completion: 0.00001, + Image: 0.00002, + Prompt: 0.00003, + Request: 0.00004, + }, + }, + }) + + assert.ElementsMatch(t, [][]string{ + []string{"provider/modelA", "modelA", "0.1", "0.2", "0.3", "0.4"}, + []string{"provider/modelB", "modelB", "0.01", "0.02", "0.03", "0.04"}, + []string{"provider/modelC", "modelC", "0.001", "0.002", "0.003", "0.004"}, + []string{"provider/modelD", "modelD", "0.0001", "0.0002", "0.0003", "0.0004"}, + []string{"provider/modelE", "modelE", "0.00001", "0.00002", "0.00003", "0.00004"}, + }, actualRecords) +}