From 88d69f0289053d0499cadab8b6ae74e7598e24de Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Wed, 31 Jul 2024 11:52:47 +0100 Subject: [PATCH] Let the LLM models have meta information such as pricing and human-readable names Part of #296 --- model/llm/llm.go | 20 ++++++++++++++++++++ model/model.go | 3 +++ model/symflower/symflower.go | 5 +++++ model/testing/Model_mock_gen.go | 25 ++++++++++++++++++++++++- provider/openrouter/openrouter.go | 5 +++-- 5 files changed, 55 insertions(+), 3 deletions(-) diff --git a/model/llm/llm.go b/model/llm/llm.go index 820b3010..4bd52fee 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -31,6 +31,9 @@ type Model struct { // queryAttempts holds the number of query attempts to perform when a model request errors in the process of solving a task. queryAttempts uint + + // metaInformation holds a model meta information. + metaInformation *model.MetaInformation } // NewModel returns an LLM model corresponding to the given identifier which is queried via the given provider. @@ -43,6 +46,23 @@ func NewModel(provider provider.Query, modelIdentifier string) *Model { } } +// NewModelWithMetaInformation returns a LLM model with meta information corresponding to the given identifier which is queried via the given provider. +func NewModelWithMetaInformation(provider provider.Query, modelIdentifier string, metaInformation *model.MetaInformation) *Model { + return &Model{ + provider: provider, + model: modelIdentifier, + + queryAttempts: 1, + + metaInformation: metaInformation, + } +} + +// MetaInformation returns the meta information of a model. +func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { + return m.metaInformation +} + // llmSourceFilePromptContext is the context for template for generating an LLM test generation prompt. type llmSourceFilePromptContext struct { // Language holds the programming language name. diff --git a/model/model.go b/model/model.go index 8d43f3b4..fac8870d 100644 --- a/model/model.go +++ b/model/model.go @@ -9,6 +9,9 @@ import ( type Model interface { // ID returns the unique ID of this model. ID() (id string) + + // MetaInformation returns the meta information of a model. + MetaInformation() *MetaInformation } // MetaInformation holds a model. diff --git a/model/symflower/symflower.go b/model/symflower/symflower.go index 9c715ea2..ca1f2a74 100644 --- a/model/symflower/symflower.go +++ b/model/symflower/symflower.go @@ -46,6 +46,11 @@ func (m *Model) ID() (id string) { return "symflower" + provider.ProviderModelSeparator + "symbolic-execution" } +// MetaInformation returns the meta information of a model. +func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { + return nil +} + var _ model.CapabilityWriteTests = (*Model)(nil) // generateTestsForFile generates test files for the given implementation file in a repository. diff --git a/model/testing/Model_mock_gen.go b/model/testing/Model_mock_gen.go index 04057324..35f6f95c 100644 --- a/model/testing/Model_mock_gen.go +++ b/model/testing/Model_mock_gen.go @@ -2,7 +2,10 @@ package modeltesting -import mock "github.com/stretchr/testify/mock" +import ( + mock "github.com/stretchr/testify/mock" + model "github.com/symflower/eval-dev-quality/model" +) // MockModel is an autogenerated mock type for the Model type type MockModel struct { @@ -27,6 +30,26 @@ func (_m *MockModel) ID() string { return r0 } +// MetaInformation provides a mock function with given fields: +func (_m *MockModel) MetaInformation() *model.MetaInformation { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for MetaInformation") + } + + var r0 *model.MetaInformation + if rf, ok := ret.Get(0).(func() *model.MetaInformation); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.MetaInformation) + } + } + + return r0 +} + // NewMockModel creates a new instance of MockModel. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockModel(t interface { diff --git a/provider/openrouter/openrouter.go b/provider/openrouter/openrouter.go index a8488222..a420b537 100644 --- a/provider/openrouter/openrouter.go +++ b/provider/openrouter/openrouter.go @@ -57,7 +57,7 @@ func (p *Provider) ID() (id string) { // ModelsList holds a list of models. type ModelsList struct { - Models []model.MetaInformation `json:"data"` + Models []*model.MetaInformation `json:"data"` } // Models returns which models are available to be queried via this provider. @@ -69,7 +69,8 @@ func (p *Provider) Models() (models []model.Model, err error) { models = make([]model.Model, len(responseModels.Models)) for i, model := range responseModels.Models { - models[i] = llm.NewModel(p, p.ID()+provider.ProviderModelSeparator+model.ID) + model.ID = p.ID() + provider.ProviderModelSeparator + model.ID + models[i] = llm.NewModelWithMetaInformation(p, p.ID()+provider.ProviderModelSeparator+model.ID, model) } return models, nil