Skip to content

Commit

Permalink
Merge pull request #298 from symflower/296-model-costs
Browse files Browse the repository at this point in the history
Store models meta information in a CSV file, so it can be further used in data visualization
  • Loading branch information
ruiAzevedo19 authored Aug 6, 2024
2 parents 2987562 + 38c7c6b commit 49cc4e0
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 48 deletions.
43 changes: 30 additions & 13 deletions cmd/eval-dev-quality/cmd/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/symflower/eval-dev-quality/evaluate"
"github.com/symflower/eval-dev-quality/evaluate/report"
"github.com/symflower/eval-dev-quality/log"
"github.com/symflower/eval-dev-quality/model"
"github.com/symflower/eval-dev-quality/provider/openrouter"
)

// Report holds the "report" command.
Expand Down Expand Up @@ -41,19 +43,11 @@ func (command *Report) Execute(args []string) (err error) {
if err = osutil.MkdirAll(filepath.Dir(command.ResultPath)); err != nil {
command.logger.Panicf("ERROR: %s", err)
}
if _, err := os.Stat(filepath.Join(command.ResultPath, "evaluation.csv")); err != nil {
if os.IsNotExist(err) {
evaluationCSVFile, err = os.Create(filepath.Join(command.ResultPath, "evaluation.csv"))
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
defer evaluationCSVFile.Close()
} else {
command.logger.Panicf("ERROR: %s", err)
}
} else {
command.logger.Panicf("ERROR: an evaluation CSV file already exists in %s", command.ResultPath)

if evaluationCSVFile, err = os.OpenFile(filepath.Join(command.ResultPath, "evaluation.csv"), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0755); err != nil {
command.logger.Panicf("ERROR: %s", err)
}
defer evaluationCSVFile.Close()

// Collect all evaluation CSV file paths.
allEvaluationPaths := map[string]bool{}
Expand All @@ -76,7 +70,7 @@ func (command *Report) Execute(args []string) (err error) {

return nil
}
report.SortEvaluationRecords(records)
report.SortRecords(records)

// Write all records into a single evaluation CSV file.
evaluationFile, err := report.NewEvaluationFile(evaluationCSVFile)
Expand All @@ -87,6 +81,29 @@ func (command *Report) Execute(args []string) (err error) {
command.logger.Panicf("ERROR: %s", err)
}

// Create a CSV file that holds the models meta information.
modelsMetaInformationCSVFile, err := os.OpenFile(filepath.Join(command.ResultPath, "meta.csv"), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0755)
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
defer modelsMetaInformationCSVFile.Close()

// Fetch all openrouter models since it is the only provider that currently supports querying meta information.
provider := openrouter.NewProvider().(*openrouter.Provider)
models, err := provider.Models()
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
var modelsMetaInformation []*model.MetaInformation
for _, model := range models {
modelsMetaInformation = append(modelsMetaInformation, model.MetaInformation())
}
metaInformationRecords := report.MetaInformationRecords(modelsMetaInformation)
// Write models meta information to disk.
if err := report.WriteMetaInformationRecords(modelsMetaInformationCSVFile, metaInformationRecords); err != nil {
command.logger.Panicf("ERROR: %s", err)
}

// Write markdown reports.
assessmentsPerModel, err := report.RecordsToAssessmentsPerModel(records)
if err != nil {
Expand Down
13 changes: 13 additions & 0 deletions cmd/eval-dev-quality/cmd/report_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,17 @@ func TestReportExecute(t *testing.T) {
expectedContent := fmt.Sprintf("%s\n%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent)
assert.Equal(t, expectedContent, data)
},
filepath.Join("result-directory", "meta.csv"): func(t *testing.T, filePath, data string) {
records := strings.Split(data, "\n")
// Check if there are at least 3 records, excluding the CSV header.
require.Greater(t, len(records[1:]), 3)
// Check if the records are different.
uniqueRecords := map[string]bool{}
for _, record := range records[1:4] {
uniqueRecords[record] = true
}
assert.Equal(t, len(uniqueRecords), 3)
},
},
})
validate(t, &testCase{
Expand Down Expand Up @@ -213,6 +224,7 @@ func TestReportExecute(t *testing.T) {
expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent)
assert.Equal(t, expectedContent, data)
},
filepath.Join("result-directory", "meta.csv"): nil,
},
})
validate(t, &testCase{
Expand Down Expand Up @@ -253,6 +265,7 @@ func TestReportExecute(t *testing.T) {
expectedContent := fmt.Sprintf("%s\n%s%s%s", strings.Join(report.EvaluationHeader(), ","), claudeEvaluationCSVFileContent, gemmaEvaluationCSVFileContent, gpt4EvaluationCSVFileContent)
assert.Equal(t, expectedContent, data)
},
filepath.Join("result-directory", "meta.csv"): nil,
},
})
}
Expand Down
43 changes: 41 additions & 2 deletions evaluate/report/csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,47 @@ func assessmentFromRecord(assessmentFields []string) (assessments metrics.Assess
return assessments, nil
}

// SortEvaluationRecords sorts the evaluation records.
func SortEvaluationRecords(records [][]string) {
// MetaInformationRecords converts the models meta information into sorted CSV records.
func MetaInformationRecords(modelsMetaInformation []*model.MetaInformation) (records [][]string) {
records = [][]string{}

for _, metaInformation := range modelsMetaInformation {
records = append(records, []string{
metaInformation.ID,
metaInformation.Name,
strconv.FormatFloat(metaInformation.Pricing.Completion, 'f', -1, 64),
strconv.FormatFloat(metaInformation.Pricing.Image, 'f', -1, 64),
strconv.FormatFloat(metaInformation.Pricing.Prompt, 'f', -1, 64),
strconv.FormatFloat(metaInformation.Pricing.Request, 'f', -1, 64),
})
}
SortRecords(records)

return records
}

// WriteMetaInformationRecords writes the meta information records into a CSV file.
func WriteMetaInformationRecords(writer io.Writer, records [][]string) (err error) {
return WriteCSV(writer, []string{"model-id", "model-name", "completion", "image", "prompt", "request"}, records)
}

// WriteCSV writes a header and records to a CSV file.
func WriteCSV(writer io.Writer, header []string, records [][]string) (err error) {
csv := csv.NewWriter(writer)

if err := csv.Write(header); err != nil {
return pkgerrors.WithStack(err)
}
if err := csv.WriteAll(records); err != nil {
return pkgerrors.WithStack(err)
}
csv.Flush()

return nil
}

// SortRecords sorts CSV records.
func SortRecords(records [][]string) {
sort.Slice(records, func(i, j int) bool {
for x := range records[i] {
if records[i][x] == records[j][x] {
Expand Down
148 changes: 146 additions & 2 deletions evaluate/report/csv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/symflower/eval-dev-quality/evaluate/metrics"
evaluatetask "github.com/symflower/eval-dev-quality/evaluate/task"
languagetesting "github.com/symflower/eval-dev-quality/language/testing"
"github.com/symflower/eval-dev-quality/model"
modeltesting "github.com/symflower/eval-dev-quality/model/testing"
"github.com/symflower/eval-dev-quality/task"
)
Expand Down Expand Up @@ -260,7 +261,7 @@ func TestEvaluationFileWriteLines(t *testing.T) {
})
}

func TestSortEvaluationRecords(t *testing.T) {
func TestSortRecords(t *testing.T) {
type testCase struct {
Name string

Expand All @@ -271,7 +272,7 @@ func TestSortEvaluationRecords(t *testing.T) {

validate := func(t *testing.T, tc *testCase) {
t.Run(tc.Name, func(t *testing.T) {
SortEvaluationRecords(tc.Records)
SortRecords(tc.Records)

assert.Equal(t, tc.ExpectedRecords, tc.Records)
})
Expand Down Expand Up @@ -480,3 +481,146 @@ func TestRecordsToAssessmentsPerModel(t *testing.T) {
},
})
}

func TestWriteMetaInformationRecords(t *testing.T) {
var file strings.Builder

err := WriteMetaInformationRecords(&file, [][]string{
[]string{"provider/modelA", "modelA", "0.1", "0.2", "0.3", "0.4"},
[]string{"provider/modelB", "modelB", "0.01", "0.02", "0.03", "0.04"},
[]string{"provider/modelC", "modelC", "0.001", "0.002", "0.003", "0.004"},
[]string{"provider/modelD", "modelD", "0.0001", "0.0002", "0.0003", "0.0004"},
[]string{"provider/modelE", "modelE", "0.00001", "0.00002", "0.00003", "0.00004"},
})
require.NoError(t, err)

assert.Equal(t, bytesutil.StringTrimIndentations(`
model-id,model-name,completion,image,prompt,request
provider/modelA,modelA,0.1,0.2,0.3,0.4
provider/modelB,modelB,0.01,0.02,0.03,0.04
provider/modelC,modelC,0.001,0.002,0.003,0.004
provider/modelD,modelD,0.0001,0.0002,0.0003,0.0004
provider/modelE,modelE,0.00001,0.00002,0.00003,0.00004
`), file.String())
}

func TestMetaInformationRecords(t *testing.T) {
actualRecords := MetaInformationRecords([]*model.MetaInformation{
&model.MetaInformation{
ID: "provider/modelA",
Name: "modelA",
Pricing: model.Pricing{
Completion: 0.1,
Image: 0.2,
Prompt: 0.3,
Request: 0.4,
},
},
&model.MetaInformation{
ID: "provider/modelB",
Name: "modelB",
Pricing: model.Pricing{
Completion: 0.01,
Image: 0.02,
Prompt: 0.03,
Request: 0.04,
},
},
&model.MetaInformation{
ID: "provider/modelC",
Name: "modelC",
Pricing: model.Pricing{
Completion: 0.001,
Image: 0.002,
Prompt: 0.003,
Request: 0.004,
},
},
&model.MetaInformation{
ID: "provider/modelD",
Name: "modelD",
Pricing: model.Pricing{
Completion: 0.0001,
Image: 0.0002,
Prompt: 0.0003,
Request: 0.0004,
},
},
&model.MetaInformation{
ID: "provider/modelE",
Name: "modelE",
Pricing: model.Pricing{
Completion: 0.00001,
Image: 0.00002,
Prompt: 0.00003,
Request: 0.00004,
},
},
})

assert.ElementsMatch(t, [][]string{
[]string{"provider/modelA", "modelA", "0.1", "0.2", "0.3", "0.4"},
[]string{"provider/modelB", "modelB", "0.01", "0.02", "0.03", "0.04"},
[]string{"provider/modelC", "modelC", "0.001", "0.002", "0.003", "0.004"},
[]string{"provider/modelD", "modelD", "0.0001", "0.0002", "0.0003", "0.0004"},
[]string{"provider/modelE", "modelE", "0.00001", "0.00002", "0.00003", "0.00004"},
}, actualRecords)
}

func TestWriteCSV(t *testing.T) {
type testCase struct {
Name string

Header []string
Records [][]string

ExpectedContent string
}

validate := func(t *testing.T, tc *testCase) {
t.Run(tc.Name, func(t *testing.T) {
var file strings.Builder
actualErr := WriteCSV(&file, tc.Header, tc.Records)
require.NoError(t, actualErr)

assert.Equal(t, bytesutil.StringTrimIndentations(tc.ExpectedContent), file.String())
})
}

validate(t, &testCase{
Name: "Single record",

Header: []string{
"model-id", "price", "score",
},

Records: [][]string{
[]string{"modelA", "0.01", "1000"},
},

ExpectedContent: `
model-id,price,score
modelA,0.01,1000
`,
})
validate(t, &testCase{
Name: "Multiple records",

Header: []string{
"model-id", "price", "score",
},

Records: [][]string{
[]string{"modelA", "0.01", "1000"},
[]string{"modelB", "0.02", "2000"},
[]string{"modelC", "0.03", "3000"},
},

ExpectedContent: `
model-id,price,score
modelA,0.01,1000
modelB,0.02,2000
modelC,0.03,3000
`,
})
}
20 changes: 20 additions & 0 deletions model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type Model struct {

// queryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task.
queryAttempts uint

// metaInformation holds a model meta information.
metaInformation *model.MetaInformation
}

// NewModel returns an LLM model corresponding to the given identifier which is queried via the given provider.
Expand All @@ -43,6 +46,23 @@ func NewModel(provider provider.Query, modelIdentifier string) *Model {
}
}

// NewModelWithMetaInformation returns a LLM model with meta information corresponding to the given identifier which is queried via the given provider.
func NewModelWithMetaInformation(provider provider.Query, modelIdentifier string, metaInformation *model.MetaInformation) *Model {
return &Model{
provider: provider,
model: modelIdentifier,

queryAttempts: 1,

metaInformation: metaInformation,
}
}

// MetaInformation returns the meta information of a model.
func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) {
return m.metaInformation
}

// llmSourceFilePromptContext is the context for template for generating an LLM test generation prompt.
type llmSourceFilePromptContext struct {
// Language holds the programming language name.
Expand Down
26 changes: 26 additions & 0 deletions model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,32 @@ import (
type Model interface {
// ID returns the unique ID of this model.
ID() (id string)

// MetaInformation returns the meta information of a model.
MetaInformation() *MetaInformation
}

// MetaInformation holds a model.
type MetaInformation struct {
// ID holds the model id.
ID string `json:"id"`
// Name holds the model name.
Name string `json:"name"`

// Pricing holds the pricing information of a model.
Pricing Pricing `json:"pricing"`
}

// Pricing holds the pricing information of a model.
type Pricing struct {
// Prompt holds the price for a prompt in dollars per token.
Prompt float64 `json:"prompt,string"`
// Completion holds the price for a completion in dollars per token.
Completion float64 `json:"completion,string"`
// Request holds the price for a request in dollars per request.
Request float64 `json:"request,string"`
// Image holds the price for an image in dollars per token.
Image float64 `json:"image,string"`
}

// Context holds the data needed by a model for running a task.
Expand Down
Loading

0 comments on commit 49cc4e0

Please sign in to comment.