From cf3c0f15b73180c847c4da096794d9e42174c11e Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Wed, 31 Jul 2024 11:52:47 +0100 Subject: [PATCH] Let the LLM models have meta information such as pricing and human-readable names Part of #296 --- cmd/eval-dev-quality/cmd/evaluate.go | 4 ++-- evaluate/evaluate_test.go | 6 +++--- model/llm/llm.go | 12 +++++++++++- model/llm/llm_test.go | 6 +++--- model/model.go | 3 +++ model/symflower/symflower.go | 5 +++++ model/testing/Model_mock_gen.go | 25 ++++++++++++++++++++++++- provider/ollama/ollama.go | 2 +- provider/openrouter/openrouter.go | 2 +- 9 files changed, 53 insertions(+), 12 deletions(-) diff --git a/cmd/eval-dev-quality/cmd/evaluate.go b/cmd/eval-dev-quality/cmd/evaluate.go index ce783f79..0322b1bc 100644 --- a/cmd/eval-dev-quality/cmd/evaluate.go +++ b/cmd/eval-dev-quality/cmd/evaluate.go @@ -230,7 +230,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. if command.Runtime != "local" { // Copy the models over. for _, modelID := range command.Models { - evaluationContext.Models = append(evaluationContext.Models, llm.NewModel(nil, modelID)) + evaluationContext.Models = append(evaluationContext.Models, llm.NewModel(nil, modelID, nil)) } return evaluationContext, evaluationConfiguration, func() {} @@ -262,7 +262,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate. command.logger.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model) } - modelProvider.AddModel(llm.NewModel(modelProvider, model)) + modelProvider.AddModel(llm.NewModel(modelProvider, model, nil)) } } diff --git a/evaluate/evaluate_test.go b/evaluate/evaluate_test.go index deba8177..bff0926e 100644 --- a/evaluate/evaluate_test.go +++ b/evaluate/evaluate_test.go @@ -231,7 +231,7 @@ func TestEvaluate(t *testing.T) { languageGolang := &golang.Language{} mockedModelID := "testing-provider/empty-response-model" mockedQuery := providertesting.NewMockQuery(t) - mockedModel := llm.NewModel(mockedQuery, mockedModelID) + mockedModel := llm.NewModel(mockedQuery, mockedModelID, nil) repositoryPath := filepath.Join("golang", "plain") validate(t, &testCase{ @@ -290,7 +290,7 @@ func TestEvaluate(t *testing.T) { languageGolang := &golang.Language{} mockedModelID := "testing-provider/empty-response-model" mockedQuery := providertesting.NewMockQuery(t) - mockedModel := llm.NewModel(mockedQuery, mockedModelID) + mockedModel := llm.NewModel(mockedQuery, mockedModelID, nil) repositoryPath := filepath.Join("golang", "plain") validate(t, &testCase{ @@ -361,7 +361,7 @@ func TestEvaluate(t *testing.T) { languageGolang := &golang.Language{} mockedModelID := "testing-provider/empty-response-model" mockedQuery := providertesting.NewMockQuery(t) - mockedModel := llm.NewModel(mockedQuery, mockedModelID) + mockedModel := llm.NewModel(mockedQuery, mockedModelID, nil) repositoryPath := filepath.Join("golang", "plain") validate(t, &testCase{ diff --git a/model/llm/llm.go b/model/llm/llm.go index 820b3010..a7d76a79 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -31,18 +31,28 @@ type Model struct { // queryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task. queryAttempts uint + + // metaInformation holds a model meta information. + metaInformation *model.MetaInformation } // NewModel returns an LLM model corresponding to the given identifier which is queried via the given provider. -func NewModel(provider provider.Query, modelIdentifier string) *Model { +func NewModel(provider provider.Query, modelIdentifier string, metaInformation *model.MetaInformation) *Model { return &Model{ provider: provider, model: modelIdentifier, queryAttempts: 1, + + metaInformation: metaInformation, } } +// MetaInformation returns the meta information of a model. +func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { + return m.metaInformation +} + // llmSourceFilePromptContext is the context for template for generating an LLM test generation prompt. type llmSourceFilePromptContext struct { // Language holds the programming language name. diff --git a/model/llm/llm_test.go b/model/llm/llm_test.go index 0a3de5c4..70b00d7a 100644 --- a/model/llm/llm_test.go +++ b/model/llm/llm_test.go @@ -56,7 +56,7 @@ func TestModelGenerateTestsForFile(t *testing.T) { mock := providertesting.NewMockQuery(t) tc.SetupMock(mock) - llm := NewModel(mock, tc.ModelID) + llm := NewModel(mock, tc.ModelID, nil) ctx := model.Context{ Language: tc.Language, @@ -156,7 +156,7 @@ func TestModelRepairSourceCodeFile(t *testing.T) { modelID := "some-model" mock := providertesting.NewMockQuery(t) tc.SetupMock(t, mock) - llm := NewModel(mock, modelID) + llm := NewModel(mock, modelID, nil) ctx := model.Context{ Language: tc.Language, @@ -496,7 +496,7 @@ func TestModelTranspile(t *testing.T) { modelID := "some-model" mock := providertesting.NewMockQuery(t) tc.SetupMock(t, mock) - llm := NewModel(mock, modelID) + llm := NewModel(mock, modelID, nil) ctx := model.Context{ Language: tc.Language, diff --git a/model/model.go b/model/model.go index 8d43f3b4..fac8870d 100644 --- a/model/model.go +++ b/model/model.go @@ -9,6 +9,9 @@ import ( type Model interface { // ID returns the unique ID of this model. ID() (id string) + + // MetaInformation returns the meta information of a model. + MetaInformation() *MetaInformation } // MetaInformation holds a model. diff --git a/model/symflower/symflower.go b/model/symflower/symflower.go index 9c715ea2..ca1f2a74 100644 --- a/model/symflower/symflower.go +++ b/model/symflower/symflower.go @@ -46,6 +46,11 @@ func (m *Model) ID() (id string) { return "symflower" + provider.ProviderModelSeparator + "symbolic-execution" } +// MetaInformation returns the meta information of a model. +func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { + return nil +} + var _ model.CapabilityWriteTests = (*Model)(nil) // generateTestsForFile generates test files for the given implementation file in a repository. diff --git a/model/testing/Model_mock_gen.go b/model/testing/Model_mock_gen.go index 04057324..35f6f95c 100644 --- a/model/testing/Model_mock_gen.go +++ b/model/testing/Model_mock_gen.go @@ -2,7 +2,10 @@ package modeltesting -import mock "github.com/stretchr/testify/mock" +import ( + mock "github.com/stretchr/testify/mock" + model "github.com/symflower/eval-dev-quality/model" +) // MockModel is an autogenerated mock type for the Model type type MockModel struct { @@ -27,6 +30,26 @@ func (_m *MockModel) ID() string { return r0 } +// MetaInformation provides a mock function with given fields: +func (_m *MockModel) MetaInformation() *model.MetaInformation { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for MetaInformation") + } + + var r0 *model.MetaInformation + if rf, ok := ret.Get(0).(func() *model.MetaInformation); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.MetaInformation) + } + } + + return r0 +} + // NewMockModel creates a new instance of MockModel. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockModel(t interface { diff --git a/provider/ollama/ollama.go b/provider/ollama/ollama.go index 84f0aff7..2ebee361 100644 --- a/provider/ollama/ollama.go +++ b/provider/ollama/ollama.go @@ -72,7 +72,7 @@ func (p *Provider) Models() (models []model.Model, err error) { models = make([]model.Model, len(ms)) for i, modelName := range ms { - models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+modelName) + models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+modelName, nil) } return models, nil diff --git a/provider/openrouter/openrouter.go b/provider/openrouter/openrouter.go index a8488222..fca5deed 100644 --- a/provider/openrouter/openrouter.go +++ b/provider/openrouter/openrouter.go @@ -69,7 +69,7 @@ func (p *Provider) Models() (models []model.Model, err error) { models = make([]model.Model, len(responseModels.Models)) for i, model := range responseModels.Models { - models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID) + models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID, &model) } return models, nil