Skip to content

Commit

Permalink
refactor, Group all LLM task logic
Browse files Browse the repository at this point in the history
Part of #201
  • Loading branch information
ruiAzevedo19 committed Jun 27, 2024
1 parent 08cc00f commit e6767e3
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions model/llm/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,34 +217,6 @@ func (m *Model) generateTestsForFile(ctx task.Context) (assessment metrics.Asses
return assessment, nil
}

func (m *Model) query(log *log.Logger, request string) (response string, duration time.Duration, err error) {
if err := retry.Do(
func() error {
log.Printf("Querying model %q with:\n%s", m.ID(), string(bytesutil.PrefixLines([]byte(request), []byte("\t"))))
start := time.Now()
response, err = m.provider.Query(context.Background(), m.model, request)
if err != nil {
return err
}
duration = time.Since(start)
log.Printf("Model %q responded (%d ms) with:\n%s", m.ID(), duration.Milliseconds(), string(bytesutil.PrefixLines([]byte(response), []byte("\t"))))

return nil
},
retry.Attempts(m.queryAttempts),
retry.Delay(5*time.Second),
retry.DelayType(retry.BackOffDelay),
retry.LastErrorOnly(true),
retry.OnRetry(func(n uint, err error) {
log.Printf("Attempt %d/%d: %s", n+1, m.queryAttempts, err)
}),
); err != nil {
return "", 0, err
}

return response, duration, nil
}

// repairSourceCodeFile queries the model to repair a source code with compilation error.
func (m *Model) repairSourceCodeFile(ctx task.Context, codeRepairArguments *evaluatetask.TaskArgumentsCodeRepair) (assessment metrics.Assessments, err error) {
assessment = map[metrics.AssessmentKey]uint64{}
Expand Down Expand Up @@ -303,6 +275,34 @@ func (m *Model) SetCost(cost float64) {
m.cost = cost
}

func (m *Model) query(log *log.Logger, request string) (response string, duration time.Duration, err error) {
if err := retry.Do(
func() error {
log.Printf("Querying model %q with:\n%s", m.ID(), string(bytesutil.PrefixLines([]byte(request), []byte("\t"))))
start := time.Now()
response, err = m.provider.Query(context.Background(), m.model, request)
if err != nil {
return err
}
duration = time.Since(start)
log.Printf("Model %q responded (%d ms) with:\n%s", m.ID(), duration.Milliseconds(), string(bytesutil.PrefixLines([]byte(response), []byte("\t"))))

return nil
},
retry.Attempts(m.queryAttempts),
retry.Delay(5*time.Second),
retry.DelayType(retry.BackOffDelay),
retry.LastErrorOnly(true),
retry.OnRetry(func(n uint, err error) {
log.Printf("Attempt %d/%d: %s", n+1, m.queryAttempts, err)
}),
); err != nil {
return "", 0, err
}

return response, duration, nil
}

var _ model.SetQueryAttempts = (*Model)(nil)

// SetQueryAttempts sets the number of query attempts to perform when a model request errors in the process of solving a task.
Expand Down

0 comments on commit e6767e3

Please sign in to comment.