From e6767e350425ca83dab11014bd57a7fdc69288fd Mon Sep 17 00:00:00 2001 From: Rui Azevedo Date: Tue, 18 Jun 2024 15:41:06 +0100 Subject: [PATCH] refactor, Group all LLM task logic Part of #201 --- model/llm/llm.go | 56 ++++++++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/model/llm/llm.go b/model/llm/llm.go index 92e4e3e0..c1339c91 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -217,34 +217,6 @@ func (m *Model) generateTestsForFile(ctx task.Context) (assessment metrics.Asses return assessment, nil } -func (m *Model) query(log *log.Logger, request string) (response string, duration time.Duration, err error) { - if err := retry.Do( - func() error { - log.Printf("Querying model %q with:\n%s", m.ID(), string(bytesutil.PrefixLines([]byte(request), []byte("\t")))) - start := time.Now() - response, err = m.provider.Query(context.Background(), m.model, request) - if err != nil { - return err - } - duration = time.Since(start) - log.Printf("Model %q responded (%d ms) with:\n%s", m.ID(), duration.Milliseconds(), string(bytesutil.PrefixLines([]byte(response), []byte("\t")))) - - return nil - }, - retry.Attempts(m.queryAttempts), - retry.Delay(5*time.Second), - retry.DelayType(retry.BackOffDelay), - retry.LastErrorOnly(true), - retry.OnRetry(func(n uint, err error) { - log.Printf("Attempt %d/%d: %s", n+1, m.queryAttempts, err) - }), - ); err != nil { - return "", 0, err - } - - return response, duration, nil -} - // repairSourceCodeFile queries the model to repair a source code with compilation error. func (m *Model) repairSourceCodeFile(ctx task.Context, codeRepairArguments *evaluatetask.TaskArgumentsCodeRepair) (assessment metrics.Assessments, err error) { assessment = map[metrics.AssessmentKey]uint64{} @@ -303,6 +275,34 @@ func (m *Model) SetCost(cost float64) { m.cost = cost } +func (m *Model) query(log *log.Logger, request string) (response string, duration time.Duration, err error) { + if err := retry.Do( + func() error { + log.Printf("Querying model %q with:\n%s", m.ID(), string(bytesutil.PrefixLines([]byte(request), []byte("\t")))) + start := time.Now() + response, err = m.provider.Query(context.Background(), m.model, request) + if err != nil { + return err + } + duration = time.Since(start) + log.Printf("Model %q responded (%d ms) with:\n%s", m.ID(), duration.Milliseconds(), string(bytesutil.PrefixLines([]byte(response), []byte("\t")))) + + return nil + }, + retry.Attempts(m.queryAttempts), + retry.Delay(5*time.Second), + retry.DelayType(retry.BackOffDelay), + retry.LastErrorOnly(true), + retry.OnRetry(func(n uint, err error) { + log.Printf("Attempt %d/%d: %s", n+1, m.queryAttempts, err) + }), + ); err != nil { + return "", 0, err + } + + return response, duration, nil +} + var _ model.SetQueryAttempts = (*Model)(nil) // SetQueryAttempts sets the number of query attempts to perform when a model request errors in the process of solving a task.