From 7bdbb8f8e3174f76846a9f5100e7ea68bbcde19c Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Wed, 31 Jul 2024 10:26:46 +0100 Subject: [PATCH 1/7] refactor, Convert models costs to numeric values when unmarshaling the JSON response, to avoid these values to be converted latter on Part of #296 --- provider/openrouter/openrouter.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/provider/openrouter/openrouter.go b/provider/openrouter/openrouter.go index 7fcef65a..c9c99a0b 100644 --- a/provider/openrouter/openrouter.go +++ b/provider/openrouter/openrouter.go @@ -70,13 +70,13 @@ type Model struct { // Pricing holds the pricing information of a model. type Pricing struct { // Prompt holds the price for a prompt in dollars per token. - Prompt string `json:"prompt"` + Prompt float64 `json:"prompt,string"` // Completion holds the price for a completion in dollars per token. - Completion string `json:"completion"` + Completion float64 `json:"completion,string"` // Request holds the price for a request in dollars per request. - Request string `json:"request"` + Request float64 `json:"request,string"` // Image holds the price for an image in dollars per token. - Image string `json:"image"` + Image float64 `json:"image,string"` } // Models returns which models are available to be queried via this provider. From 74d4990e551d643e0afda0a6b39f3ee14ddf49f7 Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Wed, 31 Jul 2024 10:46:49 +0100 Subject: [PATCH 2/7] refactor, Move the model's meta information structures from the provider to the model package, since it is model related Part of #296 --- model/model.go | 23 +++++++++++++++++++++++ provider/openrouter/openrouter.go | 25 +------------------------ 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/model/model.go b/model/model.go index 5d002ffa..8d43f3b4 100644 --- a/model/model.go +++ b/model/model.go @@ -11,6 +11,29 @@ type Model interface { ID() (id string) } +// MetaInformation holds a model. +type MetaInformation struct { + // ID holds the model id. + ID string `json:"id"` + // Name holds the model name. + Name string `json:"name"` + + // Pricing holds the pricing information of a model. + Pricing Pricing `json:"pricing"` +} + +// Pricing holds the pricing information of a model. +type Pricing struct { + // Prompt holds the price for a prompt in dollars per token. + Prompt float64 `json:"prompt,string"` + // Completion holds the price for a completion in dollars per token. + Completion float64 `json:"completion,string"` + // Request holds the price for a request in dollars per request. + Request float64 `json:"request,string"` + // Image holds the price for an image in dollars per token. + Image float64 `json:"image,string"` +} + // Context holds the data needed by a model for running a task. type Context struct { // Language holds the language for which the task should be evaluated. diff --git a/provider/openrouter/openrouter.go b/provider/openrouter/openrouter.go index c9c99a0b..a6ed0e9e 100644 --- a/provider/openrouter/openrouter.go +++ b/provider/openrouter/openrouter.go @@ -53,30 +53,7 @@ func (p *Provider) ID() (id string) { // ModelsList holds a list of models. type ModelsList struct { - Models []Model `json:"data"` -} - -// Model holds a model. -type Model struct { - // ID holds the model id. - ID string `json:"id"` - // Name holds the model name. - Name string `json:"name"` - - // Pricing holds the pricing information of a model. - Pricing Pricing `json:"pricing"` -} - -// Pricing holds the pricing information of a model. -type Pricing struct { - // Prompt holds the price for a prompt in dollars per token. - Prompt float64 `json:"prompt,string"` - // Completion holds the price for a completion in dollars per token. - Completion float64 `json:"completion,string"` - // Request holds the price for a request in dollars per request. - Request float64 `json:"request,string"` - // Image holds the price for an image in dollars per token. - Image float64 `json:"image,string"` + Models []model.MetaInformation `json:"data"` } // Models returns which models are available to be queried via this provider. From 77493124178238be71cc2dcc066c404ee3cb4038 Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Tue, 30 Jul 2024 13:23:16 +0100 Subject: [PATCH 3/7] refactor, Extract the logic to fetch models from the openrouter API into a helper, so to it can be reused Part of #296 --- provider/openrouter/openrouter.go | 52 ++++++++++++++++++++++++++----- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/provider/openrouter/openrouter.go b/provider/openrouter/openrouter.go index a6ed0e9e..a8488222 100644 --- a/provider/openrouter/openrouter.go +++ b/provider/openrouter/openrouter.go @@ -2,7 +2,11 @@ package openrouter import ( "context" + "encoding/json" "errors" + "io" + "net/http" + "net/url" "strings" "time" @@ -58,16 +62,49 @@ type ModelsList struct { // Models returns which models are available to be queried via this provider. func (p *Provider) Models() (models []model.Model, err error) { - client := p.client() + responseModels, err := p.fetchModels() + if err != nil { + return nil, err + } + + models = make([]model.Model, len(responseModels.Models)) + for i, model := range responseModels.Models { + models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID) + } + + return models, nil +} + +// fetchModels returns the list of models of the provider. +func (p *Provider) fetchModels() (models ModelsList, err error) { + modelsURLPath, err := url.JoinPath(p.baseURL, "models") + if err != nil { + return ModelsList{}, pkgerrors.WithStack(err) + } + request, err := http.NewRequest("GET", modelsURLPath, nil) + if err != nil { + return ModelsList{}, pkgerrors.WithStack(err) + } + request.Header.Set("Accept", "application/json") - var responseModels openai.ModelsList + client := &http.Client{} + var responseBody []byte if err := retry.Do( // Query available models with a retry logic cause "openrouter.ai" has failed us in the past. func() error { - ms, err := client.ListModels(context.Background()) + response, err := client.Do(request) + if err != nil { + return pkgerrors.WithStack(err) + } + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return pkgerrors.Errorf("received status code %d when querying provider models", response.StatusCode) + } + + responseBody, err = io.ReadAll(response.Body) if err != nil { return pkgerrors.WithStack(err) } - responseModels = ms return nil }, @@ -76,12 +113,11 @@ func (p *Provider) Models() (models []model.Model, err error) { retry.DelayType(retry.BackOffDelay), retry.LastErrorOnly(true), ); err != nil { - return nil, err + return ModelsList{}, err } - models = make([]model.Model, len(responseModels.Models)) - for i, model := range responseModels.Models { - models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID) + if err = json.Unmarshal(responseBody, &models); err != nil { + return ModelsList{}, pkgerrors.WithStack(err) } return models, nil From 3f46219ac6faf034f8f4be6a4e915ae6594183e8 Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Tue, 30 Jul 2024 16:06:48 +0100 Subject: [PATCH 4/7] refactor, Use the built-in Golang function to open report files, since it can error if the file already exists Part of #296 --- cmd/eval-dev-quality/cmd/report.go | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/cmd/eval-dev-quality/cmd/report.go b/cmd/eval-dev-quality/cmd/report.go index f71fd01c..6e4f3616 100644 --- a/cmd/eval-dev-quality/cmd/report.go +++ b/cmd/eval-dev-quality/cmd/report.go @@ -12,6 +12,7 @@ import ( "github.com/symflower/eval-dev-quality/evaluate" "github.com/symflower/eval-dev-quality/evaluate/report" "github.com/symflower/eval-dev-quality/log" + "github.com/symflower/eval-dev-quality/util" ) // Report holds the "report" command. @@ -41,19 +42,11 @@ func (command *Report) Execute(args []string) (err error) { if err = osutil.MkdirAll(filepath.Dir(command.ResultPath)); err != nil { command.logger.Panicf("ERROR: %s", err) } - if _, err := os.Stat(filepath.Join(command.ResultPath, "evaluation.csv")); err != nil { - if os.IsNotExist(err) { - evaluationCSVFile, err = os.Create(filepath.Join(command.ResultPath, "evaluation.csv")) - if err != nil { - command.logger.Panicf("ERROR: %s", err) - } - defer evaluationCSVFile.Close() - } else { - command.logger.Panicf("ERROR: %s", err) - } - } else { - command.logger.Panicf("ERROR: an evaluation CSV file already exists in %s", command.ResultPath) + + if evaluationCSVFile, err = os.OpenFile(filepath.Join(command.ResultPath, "evaluation.csv"), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0755); err != nil { + command.logger.Panicf("ERROR: %s", err) } + defer evaluationCSVFile.Close() // Collect all evaluation CSV file paths. allEvaluationPaths := map[string]bool{} From a4dfe3ebaad11ffc5f856dafbafd9befadbbca4e Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Fri, 2 Aug 2024 09:42:27 +0100 Subject: [PATCH 5/7] refactor, Rename the function that sorts evaluation records to a more generic name, since it can sort all kind of CSV records Part of #296 --- cmd/eval-dev-quality/cmd/report.go | 2 +- evaluate/report/csv.go | 4 ++-- evaluate/report/csv_test.go | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cmd/eval-dev-quality/cmd/report.go b/cmd/eval-dev-quality/cmd/report.go index 6e4f3616..7aba7c2c 100644 --- a/cmd/eval-dev-quality/cmd/report.go +++ b/cmd/eval-dev-quality/cmd/report.go @@ -69,7 +69,7 @@ func (command *Report) Execute(args []string) (err error) { return nil } - report.SortEvaluationRecords(records) + report.SortRecords(records) // Write all records into a single evaluation CSV file. evaluationFile, err := report.NewEvaluationFile(evaluationCSVFile) diff --git a/evaluate/report/csv.go b/evaluate/report/csv.go index 7b1e9594..a0deb79e 100644 --- a/evaluate/report/csv.go +++ b/evaluate/report/csv.go @@ -140,8 +140,8 @@ func assessmentFromRecord(assessmentFields []string) (assessments metrics.Assess return assessments, nil } -// SortEvaluationRecords sorts the evaluation records. -func SortEvaluationRecords(records [][]string) { +// SortRecords sorts CSV records. +func SortRecords(records [][]string) { sort.Slice(records, func(i, j int) bool { for x := range records[i] { if records[i][x] == records[j][x] { diff --git a/evaluate/report/csv_test.go b/evaluate/report/csv_test.go index 01b737e0..82beecab 100644 --- a/evaluate/report/csv_test.go +++ b/evaluate/report/csv_test.go @@ -260,7 +260,7 @@ func TestEvaluationFileWriteLines(t *testing.T) { }) } -func TestSortEvaluationRecords(t *testing.T) { +func TestSortRecords(t *testing.T) { type testCase struct { Name string @@ -271,7 +271,7 @@ func TestSortEvaluationRecords(t *testing.T) { validate := func(t *testing.T, tc *testCase) { t.Run(tc.Name, func(t *testing.T) { - SortEvaluationRecords(tc.Records) + SortRecords(tc.Records) assert.Equal(t, tc.ExpectedRecords, tc.Records) }) From 88d69f0289053d0499cadab8b6ae74e7598e24de Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Wed, 31 Jul 2024 11:52:47 +0100 Subject: [PATCH 6/7] Let the LLM models have meta information such as pricing and human-readable names Part of #296 --- model/llm/llm.go | 20 ++++++++++++++++++++ model/model.go | 3 +++ model/symflower/symflower.go | 5 +++++ model/testing/Model_mock_gen.go | 25 ++++++++++++++++++++++++- provider/openrouter/openrouter.go | 5 +++-- 5 files changed, 55 insertions(+), 3 deletions(-) diff --git a/model/llm/llm.go b/model/llm/llm.go index 820b3010..4bd52fee 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -31,6 +31,9 @@ type Model struct { // queryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task. queryAttempts uint + + // metaInformation holds a model meta information. + metaInformation *model.MetaInformation } // NewModel returns an LLM model corresponding to the given identifier which is queried via the given provider. @@ -43,6 +46,23 @@ func NewModel(provider provider.Query, modelIdentifier string) *Model { } } +// NewModelWithMetaInformation returns a LLM model with meta information corresponding to the given identifier which is queried via the given provider. +func NewModelWithMetaInformation(provider provider.Query, modelIdentifier string, metaInformation *model.MetaInformation) *Model { + return &Model{ + provider: provider, + model: modelIdentifier, + + queryAttempts: 1, + + metaInformation: metaInformation, + } +} + +// MetaInformation returns the meta information of a model. +func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { + return m.metaInformation +} + // llmSourceFilePromptContext is the context for template for generating an LLM test generation prompt. type llmSourceFilePromptContext struct { // Language holds the programming language name. diff --git a/model/model.go b/model/model.go index 8d43f3b4..fac8870d 100644 --- a/model/model.go +++ b/model/model.go @@ -9,6 +9,9 @@ import ( type Model interface { // ID returns the unique ID of this model. ID() (id string) + + // MetaInformation returns the meta information of a model. + MetaInformation() *MetaInformation } // MetaInformation holds a model. diff --git a/model/symflower/symflower.go b/model/symflower/symflower.go index 9c715ea2..ca1f2a74 100644 --- a/model/symflower/symflower.go +++ b/model/symflower/symflower.go @@ -46,6 +46,11 @@ func (m *Model) ID() (id string) { return "symflower" + provider.ProviderModelSeparator + "symbolic-execution" } +// MetaInformation returns the meta information of a model. +func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { + return nil +} + var _ model.CapabilityWriteTests = (*Model)(nil) // generateTestsForFile generates test files for the given implementation file in a repository. diff --git a/model/testing/Model_mock_gen.go b/model/testing/Model_mock_gen.go index 04057324..35f6f95c 100644 --- a/model/testing/Model_mock_gen.go +++ b/model/testing/Model_mock_gen.go @@ -2,7 +2,10 @@ package modeltesting -import mock "github.com/stretchr/testify/mock" +import ( + mock "github.com/stretchr/testify/mock" + model "github.com/symflower/eval-dev-quality/model" +) // MockModel is an autogenerated mock type for the Model type type MockModel struct { @@ -27,6 +30,26 @@ func (_m *MockModel) ID() string { return r0 } +// MetaInformation provides a mock function with given fields: +func (_m *MockModel) MetaInformation() *model.MetaInformation { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for MetaInformation") + } + + var r0 *model.MetaInformation + if rf, ok := ret.Get(0).(func() *model.MetaInformation); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.MetaInformation) + } + } + + return r0 +} + // NewMockModel creates a new instance of MockModel. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockModel(t interface { diff --git a/provider/openrouter/openrouter.go b/provider/openrouter/openrouter.go index a8488222..a420b537 100644 --- a/provider/openrouter/openrouter.go +++ b/provider/openrouter/openrouter.go @@ -57,7 +57,7 @@ func (p *Provider) ID() (id string) { // ModelsList holds a list of models. type ModelsList struct { - Models []model.MetaInformation `json:"data"` + Models []*model.MetaInformation `json:"data"` } // Models returns which models are available to be queried via this provider. @@ -69,7 +69,8 @@ func (p *Provider) Models() (models []model.Model, err error) { models = make([]model.Model, len(responseModels.Models)) for i, model := range responseModels.Models { - models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID) + model.ID = p.ID() + provider.ProviderModelSeparator + model.ID + models[i] = llm.NewModelWithMetaInformation(p, p.ID()+provider.ProviderModelSeparator+model.ID, model) } return models, nil From 38c7c6bbc6a5db539796e175af66e9a156a98b70 Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Wed, 31 Jul 2024 12:28:53 +0100 Subject: [PATCH 7/7] Store models meta information in a CSV file, so it can be further used in data visualization Part of #296 --- cmd/eval-dev-quality/cmd/report.go | 26 ++++- cmd/eval-dev-quality/cmd/report_test.go | 13 +++ evaluate/report/csv.go | 39 +++++++ evaluate/report/csv_test.go | 144 ++++++++++++++++++++++++ 4 files changed, 221 insertions(+), 1 deletion(-) diff --git a/cmd/eval-dev-quality/cmd/report.go b/cmd/eval-dev-quality/cmd/report.go index 7aba7c2c..b7bcedeb 100644 --- a/cmd/eval-dev-quality/cmd/report.go +++ b/cmd/eval-dev-quality/cmd/report.go @@ -12,7 +12,8 @@ import ( "github.com/symflower/eval-dev-quality/evaluate" "github.com/symflower/eval-dev-quality/evaluate/report" "github.com/symflower/eval-dev-quality/log" - "github.com/symflower/eval-dev-quality/util" + "github.com/symflower/eval-dev-quality/model" + "github.com/symflower/eval-dev-quality/provider/openrouter" ) // Report holds the "report" command. @@ -80,6 +81,29 @@ func (command *Report) Execute(args []string) (err error) { command.logger.Panicf("ERROR: %s", err) } + // Create a CSV file that holds the models meta information. + modelsMetaInformationCSVFile, err := os.OpenFile(filepath.Join(command.ResultPath, "meta.csv"), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0755) + if err != nil { + command.logger.Panicf("ERROR: %s", err) + } + defer modelsMetaInformationCSVFile.Close() + + // Fetch all openrouter models since it is the only provider that currently supports querying meta information. + provider := openrouter.NewProvider().(*openrouter.Provider) + models, err := provider.Models() + if err != nil { + command.logger.Panicf("ERROR: %s", err) + } + var modelsMetaInformation []*model.MetaInformation + for _, model := range models { + modelsMetaInformation = append(modelsMetaInformation, model.MetaInformation()) + } + metaInformationRecords := report.MetaInformationRecords(modelsMetaInformation) + // Write models meta information to disk. + if err := report.WriteMetaInformationRecords(modelsMetaInformationCSVFile, metaInformationRecords); err != nil { + command.logger.Panicf("ERROR: %s", err) + } + // Write markdown reports. assessmentsPerModel, err := report.RecordsToAssessmentsPerModel(records) if err != nil { diff --git a/cmd/eval-dev-quality/cmd/report_test.go b/cmd/eval-dev-quality/cmd/report_test.go index 9f09119b..ee90b59e 100644 --- a/cmd/eval-dev-quality/cmd/report_test.go +++ b/cmd/eval-dev-quality/cmd/report_test.go @@ -170,6 +170,17 @@ func TestReportExecute(t *testing.T) { expectedContent := fmt.Sprintf("%s\n%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent) assert.Equal(t, expectedContent, data) }, + filepath.Join("result-directory", "meta.csv"): func(t *testing.T, filePath, data string) { + records := strings.Split(data, "\n") + // Check if there are at least 3 records, excluding the CSV header. + require.Greater(t, len(records[1:]), 3) + // Check if the records are different. + uniqueRecords := map[string]bool{} + for _, record := range records[1:4] { + uniqueRecords[record] = true + } + assert.Equal(t, len(uniqueRecords), 3) + }, }, }) validate(t, &testCase{ @@ -213,6 +224,7 @@ func TestReportExecute(t *testing.T) { expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent) assert.Equal(t, expectedContent, data) }, + filepath.Join("result-directory", "meta.csv"): nil, }, }) validate(t, &testCase{ @@ -253,6 +265,7 @@ func TestReportExecute(t *testing.T) { expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent) assert.Equal(t, expectedContent, data) }, + filepath.Join("result-directory", "meta.csv"): nil, }, }) } diff --git a/evaluate/report/csv.go b/evaluate/report/csv.go index a0deb79e..90590bdf 100644 --- a/evaluate/report/csv.go +++ b/evaluate/report/csv.go @@ -140,6 +140,45 @@ func assessmentFromRecord(assessmentFields []string) (assessments metrics.Assess return assessments, nil } +// MetaInformationRecords converts the models meta information into sorted CSV records. +func MetaInformationRecords(modelsMetaInformation []*model.MetaInformation) (records [][]string) { + records = [][]string{} + + for _, metaInformation := range modelsMetaInformation { + records = append(records, []string{ + metaInformation.ID, + metaInformation.Name, + strconv.FormatFloat(metaInformation.Pricing.Completion, 'f', -1, 64), + strconv.FormatFloat(metaInformation.Pricing.Image, 'f', -1, 64), + strconv.FormatFloat(metaInformation.Pricing.Prompt, 'f', -1, 64), + strconv.FormatFloat(metaInformation.Pricing.Request, 'f', -1, 64), + }) + } + SortRecords(records) + + return records +} + +// WriteMetaInformationRecords writes the meta information records into a CSV file. +func WriteMetaInformationRecords(writer io.Writer, records [][]string) (err error) { + return WriteCSV(writer, []string{"model-id", "model-name", "completion", "image", "prompt", "request"}, records) +} + +// WriteCSV writes a header and records to a CSV file. +func WriteCSV(writer io.Writer, header []string, records [][]string) (err error) { + csv := csv.NewWriter(writer) + + if err := csv.Write(header); err != nil { + return pkgerrors.WithStack(err) + } + if err := csv.WriteAll(records); err != nil { + return pkgerrors.WithStack(err) + } + csv.Flush() + + return nil +} + // SortRecords sorts CSV records. func SortRecords(records [][]string) { sort.Slice(records, func(i, j int) bool { diff --git a/evaluate/report/csv_test.go b/evaluate/report/csv_test.go index 82beecab..1f143e75 100644 --- a/evaluate/report/csv_test.go +++ b/evaluate/report/csv_test.go @@ -15,6 +15,7 @@ import ( "github.com/symflower/eval-dev-quality/evaluate/metrics" evaluatetask "github.com/symflower/eval-dev-quality/evaluate/task" languagetesting "github.com/symflower/eval-dev-quality/language/testing" + "github.com/symflower/eval-dev-quality/model" modeltesting "github.com/symflower/eval-dev-quality/model/testing" "github.com/symflower/eval-dev-quality/task" ) @@ -480,3 +481,146 @@ func TestRecordsToAssessmentsPerModel(t *testing.T) { }, }) } + +func TestWriteMetaInformationRecords(t *testing.T) { + var file strings.Builder + + err := WriteMetaInformationRecords(&file, [][]string{ + []string{"provider/modelA", "modelA", "0.1", "0.2", "0.3", "0.4"}, + []string{"provider/modelB", "modelB", "0.01", "0.02", "0.03", "0.04"}, + []string{"provider/modelC", "modelC", "0.001", "0.002", "0.003", "0.004"}, + []string{"provider/modelD", "modelD", "0.0001", "0.0002", "0.0003", "0.0004"}, + []string{"provider/modelE", "modelE", "0.00001", "0.00002", "0.00003", "0.00004"}, + }) + require.NoError(t, err) + + assert.Equal(t, bytesutil.StringTrimIndentations(` + model-id,model-name,completion,image,prompt,request + provider/modelA,modelA,0.1,0.2,0.3,0.4 + provider/modelB,modelB,0.01,0.02,0.03,0.04 + provider/modelC,modelC,0.001,0.002,0.003,0.004 + provider/modelD,modelD,0.0001,0.0002,0.0003,0.0004 + provider/modelE,modelE,0.00001,0.00002,0.00003,0.00004 + `), file.String()) +} + +func TestMetaInformationRecords(t *testing.T) { + actualRecords := MetaInformationRecords([]*model.MetaInformation{ + &model.MetaInformation{ + ID: "provider/modelA", + Name: "modelA", + Pricing: model.Pricing{ + Completion: 0.1, + Image: 0.2, + Prompt: 0.3, + Request: 0.4, + }, + }, + &model.MetaInformation{ + ID: "provider/modelB", + Name: "modelB", + Pricing: model.Pricing{ + Completion: 0.01, + Image: 0.02, + Prompt: 0.03, + Request: 0.04, + }, + }, + &model.MetaInformation{ + ID: "provider/modelC", + Name: "modelC", + Pricing: model.Pricing{ + Completion: 0.001, + Image: 0.002, + Prompt: 0.003, + Request: 0.004, + }, + }, + &model.MetaInformation{ + ID: "provider/modelD", + Name: "modelD", + Pricing: model.Pricing{ + Completion: 0.0001, + Image: 0.0002, + Prompt: 0.0003, + Request: 0.0004, + }, + }, + &model.MetaInformation{ + ID: "provider/modelE", + Name: "modelE", + Pricing: model.Pricing{ + Completion: 0.00001, + Image: 0.00002, + Prompt: 0.00003, + Request: 0.00004, + }, + }, + }) + + assert.ElementsMatch(t, [][]string{ + []string{"provider/modelA", "modelA", "0.1", "0.2", "0.3", "0.4"}, + []string{"provider/modelB", "modelB", "0.01", "0.02", "0.03", "0.04"}, + []string{"provider/modelC", "modelC", "0.001", "0.002", "0.003", "0.004"}, + []string{"provider/modelD", "modelD", "0.0001", "0.0002", "0.0003", "0.0004"}, + []string{"provider/modelE", "modelE", "0.00001", "0.00002", "0.00003", "0.00004"}, + }, actualRecords) +} + +func TestWriteCSV(t *testing.T) { + type testCase struct { + Name string + + Header []string + Records [][]string + + ExpectedContent string + } + + validate := func(t *testing.T, tc *testCase) { + t.Run(tc.Name, func(t *testing.T) { + var file strings.Builder + actualErr := WriteCSV(&file, tc.Header, tc.Records) + require.NoError(t, actualErr) + + assert.Equal(t, bytesutil.StringTrimIndentations(tc.ExpectedContent), file.String()) + }) + } + + validate(t, &testCase{ + Name: "Single record", + + Header: []string{ + "model-id", "price", "score", + }, + + Records: [][]string{ + []string{"modelA", "0.01", "1000"}, + }, + + ExpectedContent: ` + model-id,price,score + modelA,0.01,1000 + `, + }) + validate(t, &testCase{ + Name: "Multiple records", + + Header: []string{ + "model-id", "price", "score", + }, + + Records: [][]string{ + []string{"modelA", "0.01", "1000"}, + []string{"modelB", "0.02", "2000"}, + []string{"modelC", "0.03", "3000"}, + }, + + ExpectedContent: ` + model-id,price,score + modelA,0.01,1000 + modelB,0.02,2000 + modelC,0.03,3000 + `, + }) +}