Skip to content

Commit

Permalink
Merge pull request #211 from symflower/docker-runtime
Browse files Browse the repository at this point in the history
Docker runtime
  • Loading branch information
bauersimon authored Jun 27, 2024
2 parents 1d388d1 + c49a891 commit 6fd5180
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 49 deletions.
192 changes: 143 additions & 49 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package cmd

import (
"context"
"os"
"os/exec"
"path/filepath"
"slices"
"sort"
Expand Down Expand Up @@ -68,6 +70,11 @@ type Evaluate struct {
// NoDisqualification indicates that models are not to be disqualified if they fail to solve basic language tasks.
NoDisqualification bool `long:"no-disqualification" description:"By default, models that cannot solve basic language tasks are disqualified for more complex tasks. Overwriting this behavior runs all tasks regardless."`

// Runtime indicates if the evaluation is run locally or inside a container.
Runtime string `long:"runtime" description:"The runtime which will be used for the evaluation." default:"local" choice:"local" choice:"docker"`
// RuntimeImage determines the container image used for any container runtime.
RuntimeImage string `long:"runtime-image" description:"The container image to use for the evaluation." default:""`

// logger holds the logger of the command.
logger *log.Logger
// timestamp holds the timestamp of the command execution.
Expand All @@ -94,29 +101,6 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
}()
evaluationContext = &evaluate.Context{}

// Setup evaluation result directory.
{
command.ResultPath = strings.ReplaceAll(command.ResultPath, "%datetime%", command.timestamp.Format("2006-01-02-15:04:05")) // REMARK Use a datetime format with a dash, so directories can be easily marked because they are only one group.
uniqueResultPath, err := util.UniqueDirectory(command.ResultPath)
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
command.ResultPath = uniqueResultPath
evaluationContext.ResultPath = uniqueResultPath
command.logger.Printf("Writing results to %s", command.ResultPath)
}

// Initialize logging within result directory.
{
log, logClose, err := log.WithFile(command.logger, filepath.Join(command.ResultPath, "evaluation.log"))
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
cleanup = logClose
command.logger = log
evaluationContext.Log = log
}

// Check and validate common options.
{
if command.InstallToolsPath == "" {
Expand Down Expand Up @@ -154,6 +138,16 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
}
evaluationContext.Runs = command.Runs

if command.Runtime == "docker" {
if _, err := exec.LookPath("docker"); err != nil {
command.logger.Panic("docker runtime could not be found")
}
}

if command.RuntimeImage == "" {
command.RuntimeImage = "ghcr.io/symflower/eval-dev-quality:v" + evaluate.Version
}

evaluationContext.NoDisqualification = command.NoDisqualification
}

Expand All @@ -170,6 +164,29 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
evaluationContext.TestdataPath = testdataPath
}

// Setup evaluation result directory.
{
command.ResultPath = strings.ReplaceAll(command.ResultPath, "%datetime%", command.timestamp.Format("2006-01-02-15:04:05")) // REMARK Use a datetime format with a dash, so directories can be easily marked because they are only one group.
uniqueResultPath, err := util.UniqueDirectory(command.ResultPath)
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
command.ResultPath = uniqueResultPath
evaluationContext.ResultPath = uniqueResultPath
command.logger.Printf("Writing results to %s", command.ResultPath)
}

// Initialize logging within result directory.
{
log, logClose, err := log.WithFile(command.logger, filepath.Join(command.ResultPath, "evaluation.log"))
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
cleanup = logClose
command.logger = log
evaluationContext.Log = log
}

// Register custom OpenAI API providers and models.
{
customProviders := map[string]*openaiapi.Provider{}
Expand Down Expand Up @@ -356,6 +373,56 @@ func (command *Evaluate) Execute(args []string) (err error) {
command.logger.Panic("ERROR: empty evaluation context")
}

switch command.Runtime {
case "local":
return command.evaluateLocal(evaluationContext)
case "docker":
return command.evaluateDocker(evaluationContext)
default:
command.logger.Panicf("ERROR: unknown runtime")
}

return nil
}

// WriteCSVs writes the various CSV reports to disk.
func writeCSVs(resultPath string, assessments *report.AssessmentStore) (err error) {
// Write the "evaluation.csv" containing all data.
csv, err := report.GenerateCSV(assessments)
if err != nil {
return pkgerrors.Wrap(err, "could not create evaluation.csv summary")
}
if err := os.WriteFile(filepath.Join(resultPath, "evaluation.csv"), []byte(csv), 0644); err != nil {
return pkgerrors.Wrap(err, "could not write evaluation.csv summary")
}

// Write the "models-summed.csv" containing the summary per model.
byModel := assessments.CollapseByModel()
csvByModel, err := report.GenerateCSV(byModel)
if err != nil {
return pkgerrors.Wrap(err, "could not create models-summed.csv summary")
}
if err := os.WriteFile(filepath.Join(resultPath, "models-summed.csv"), []byte(csvByModel), 0644); err != nil {
return pkgerrors.Wrap(err, "could not write models-summed.csv summary")
}

// Write the individual "language-summed.csv" containing the summary per model per language.
byLanguage := assessments.CollapseByLanguage()
for language, modelsByLanguage := range byLanguage {
csvByLanguage, err := report.GenerateCSV(modelsByLanguage)
if err != nil {
return pkgerrors.Wrap(err, "could not create "+language.ID()+"-summed.csv summary")
}
if err := os.WriteFile(filepath.Join(resultPath, language.ID()+"-summed.csv"), []byte(csvByLanguage), 0644); err != nil {
return pkgerrors.Wrap(err, "could not write "+language.ID()+"-summed.csv summary")
}
}

return nil
}

// evaluateLocal executes the evaluation on the current system.
func (command *Evaluate) evaluateLocal(evaluationContext *evaluate.Context) (err error) {
// Install required tools for the basic evaluation.
if err := tools.InstallEvaluation(command.logger, command.InstallToolsPath); err != nil {
command.logger.Panicf("ERROR: %s", err)
Expand Down Expand Up @@ -392,37 +459,64 @@ func (command *Evaluate) Execute(args []string) (err error) {
return nil
}

// WriteCSVs writes the various CSV reports to disk.
func writeCSVs(resultPath string, assessments *report.AssessmentStore) (err error) {
// Write the "evaluation.csv" containing all data.
csv, err := report.GenerateCSV(assessments)
if err != nil {
return pkgerrors.Wrap(err, "could not create evaluation.csv summary")
}
if err := os.WriteFile(filepath.Join(resultPath, "evaluation.csv"), []byte(csv), 0644); err != nil {
return pkgerrors.Wrap(err, "could not write evaluation.csv summary")
}
// evaluateDocker executes the evaluation for each model inside a docker container.
func (command *Evaluate) evaluateDocker(ctx *evaluate.Context) (err error) {
// Filter all the args to pass them onto the container.
args := util.FilterArgs(os.Args[2:], []string{
"--runtime",
"--model",
"--result-path",
})

// Write the "models-summed.csv" containing the summary per model.
byModel := assessments.CollapseByModel()
csvByModel, err := report.GenerateCSV(byModel)
if err != nil {
return pkgerrors.Wrap(err, "could not create models-summed.csv summary")
}
if err := os.WriteFile(filepath.Join(resultPath, "models-summed.csv"), []byte(csvByModel), 0644); err != nil {
return pkgerrors.Wrap(err, "could not write models-summed.csv summary")
}
// Iterate over each model and start the container.
for _, model := range ctx.Models {
// We are skipping ollama models until we fully support pulling. https://github.com/symflower/eval-dev-quality/issues/100.
if ctx.ProviderForModel[model].ID() == "ollama" {
command.logger.Print("Skipping unsupported ollama model with docker runtime")

// Write the individual "language-summed.csv" containing the summary per model per language.
byLanguage := assessments.CollapseByLanguage()
for language, modelsByLanguage := range byLanguage {
csvByLanguage, err := report.GenerateCSV(modelsByLanguage)
continue
}

// Create for each model a dedicated subfolder inside the results path.
resultPath, err := filepath.Abs(command.ResultPath)
if err != nil {
return pkgerrors.Wrap(err, "could not create "+language.ID()+"-summed.csv summary")
return err
}
if err := os.WriteFile(filepath.Join(resultPath, language.ID()+"-summed.csv"), []byte(csvByLanguage), 0644); err != nil {
return pkgerrors.Wrap(err, "could not write "+language.ID()+"-summed.csv summary")
// Set permission 777 so the non-root docker image is able to store its results inside the result path.
if err := os.Chmod(resultPath, 0777); err != nil {
return err
}

// Commands regarding the docker runtime.
dockerCommand := []string{
"docker",
"run",
"-v", // bind volume
resultPath + ":/home/ubuntu/evaluation",
"--rm", // automatically remove container after it finished
command.RuntimeImage,
}

// Commands for the evaluation to run inside the container.
evaluationCommand := []string{
"eval-dev-quality",
"evaluate",
"--model",
model.ID(),
"--result-path",
"/home/ubuntu/evaluation/" + model.ID(),
}

cmd := append(dockerCommand, evaluationCommand...)
cmd = append(cmd, args...)

commandOutput, err := util.CommandWithResult(context.Background(), command.logger, &util.Command{
Command: cmd,
})
if err != nil {
return pkgerrors.WithMessage(pkgerrors.WithStack(err), commandOutput)
}

}

return nil
Expand Down
36 changes: 36 additions & 0 deletions util/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,39 @@ func CommandWithResult(ctx context.Context, logger *log.Logger, command *Command

return writer.String(), nil
}

// FilterArgs parses args and removes the ignored ones.
func FilterArgs(args []string, ignored []string) (filtered []string) {
filterMap := map[string]bool{}
for _, v := range ignored {
filterMap[v] = true
}

// Resolve args with equals sign.
var resolvedArgs []string
for _, v := range args {
if strings.HasPrefix(v, "--") && strings.Contains(v, "=") {
resolvedArgs = append(resolvedArgs, strings.SplitN(v, "=", 2)...)
} else {
resolvedArgs = append(resolvedArgs, v)
}
}

skip := false
for _, v := range resolvedArgs {
if skip && strings.HasPrefix(v, "--") {
skip = false
}
if filterMap[v] {
skip = true
}

if skip {
continue
}

filtered = append(filtered, v)
}

return filtered
}
59 changes: 59 additions & 0 deletions util/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,62 @@ func TestCommandWithResultTimeout(t *testing.T) {
assert.Error(t, err)
assert.Less(t, duration.Seconds(), 5.0)
}

func TestFilterArgs(t *testing.T) {
type testCase struct {
Name string

Args []string
Ignored []string

ExpectedFiltered []string
}

validate := func(t *testing.T, tc *testCase) {
t.Run(tc.Name, func(t *testing.T) {
actualFiltered := FilterArgs(tc.Args, tc.Ignored)

assert.Equal(t, tc.ExpectedFiltered, actualFiltered)
})
}

validate(t, &testCase{
Name: "Filter arguments",

Args: []string{
"--runtime",
"abc",
"--runs",
"5",
},
Ignored: []string{
"--runtime",
},

ExpectedFiltered: []string{
"--runs",
"5",
},
})

validate(t, &testCase{
Name: "Filter arguments with equals sign",

Args: []string{
"--runtime=abc",
"--runs=5",
"--foo",
"bar",
},
Ignored: []string{
"--runtime",
},

ExpectedFiltered: []string{
"--runs",
"5",
"--foo",
"bar",
},
})
}

0 comments on commit 6fd5180

Please sign in to comment.