From 9adadfea4819bf3cb4e66e817883753e8c5cc229 Mon Sep 17 00:00:00 2001 From: Andreas Humenberger Date: Thu, 18 Jul 2024 13:22:27 +0200 Subject: [PATCH] Log model responses as artifact in separate file Part of #204 --- cmd/eval-dev-quality/cmd/evaluate_test.go | 2 + evaluate/evaluate.go | 4 + evaluate/evaluate_test.go | 2 + log/logger.go | 91 ++++++++++++++++++++--- log/logger_test.go | 39 ++++++++++ model/llm/llm.go | 9 ++- 6 files changed, 131 insertions(+), 16 deletions(-) diff --git a/cmd/eval-dev-quality/cmd/evaluate_test.go b/cmd/eval-dev-quality/cmd/evaluate_test.go index 11537344..6856f59b 100644 --- a/cmd/eval-dev-quality/cmd/evaluate_test.go +++ b/cmd/eval-dev-quality/cmd/evaluate_test.go @@ -527,6 +527,7 @@ func TestEvaluateExecute(t *testing.T) { }, filepath.Join("result-directory", "README.md"): nil, filepath.Join("result-directory", string(evaluatetask.IdentifierWriteTests), "ollama_"+log.CleanModelNameForFileSystem(providertesting.OllamaTestModel), "golang", "golang", "plain", "evaluation.log"): nil, + filepath.Join("result-directory", string(evaluatetask.IdentifierWriteTests), "ollama_"+log.CleanModelNameForFileSystem(providertesting.OllamaTestModel), "golang", "golang", "plain", "response-1.log"): nil, }, ExpectedOutputValidate: func(t *testing.T, output, resultPath string) { assert.Contains(t, output, `Starting services for provider "ollama"`) @@ -598,6 +599,7 @@ func TestEvaluateExecute(t *testing.T) { }, filepath.Join("result-directory", "README.md"): nil, filepath.Join("result-directory", string(evaluatetask.IdentifierWriteTests), "custom-ollama_"+log.CleanModelNameForFileSystem(providertesting.OllamaTestModel), "golang", "golang", "plain", "evaluation.log"): nil, + filepath.Join("result-directory", string(evaluatetask.IdentifierWriteTests), "custom-ollama_"+log.CleanModelNameForFileSystem(providertesting.OllamaTestModel), "golang", "golang", "plain", "response-1.log"): nil, }, }) } diff --git a/evaluate/evaluate.go b/evaluate/evaluate.go index 5d63a8c7..81e76eba 100644 --- a/evaluate/evaluate.go +++ b/evaluate/evaluate.go @@ -108,6 +108,8 @@ func Evaluate(ctx *Context) (assessments *report.AssessmentStore, totalScore uin logger.Printf("Run %d/%d", rl+1, ctx.Runs) } + logger := logger.With(log.AttributeKeyRun, rl+1) + for _, language := range ctx.Languages { logger := logger.With(log.AttributeKeyLanguage, language.ID()) @@ -214,6 +216,8 @@ func Evaluate(ctx *Context) (assessments *report.AssessmentStore, totalScore uin logger.Printf("Run %d/%d", rl+1, ctx.Runs) } + logger := logger.With(log.AttributeKeyRun, rl+1) + for _, language := range ctx.Languages { languageID := language.ID() logger := logger.With(log.AttributeKeyLanguage, languageID) diff --git a/evaluate/evaluate_test.go b/evaluate/evaluate_test.go index c11ddba4..deba8177 100644 --- a/evaluate/evaluate_test.go +++ b/evaluate/evaluate_test.go @@ -352,6 +352,7 @@ func TestEvaluate(t *testing.T) { filepath.Join(string(evaluatetask.IdentifierWriteTests), log.CleanModelNameForFileSystem(mockedModelID), "golang", "golang", "plain", "evaluation.log"): func(t *testing.T, filePath, data string) { assert.Contains(t, data, "Attempt 1/3: "+ErrEmptyResponseFromModel.Error()) }, + filepath.Join(string(evaluatetask.IdentifierWriteTests), log.CleanModelNameForFileSystem(mockedModelID), "golang", "golang", "plain", "response-1.log"): nil, "evaluation.csv": nil, }, }) @@ -421,6 +422,7 @@ func TestEvaluate(t *testing.T) { filepath.Join(string(evaluatetask.IdentifierWriteTests), log.CleanModelNameForFileSystem(mockedModelID), "golang", "golang", "plain", "evaluation.log"): func(t *testing.T, filePath, data string) { assert.Contains(t, data, "DONE 0 tests, 1 error") }, + filepath.Join(string(evaluatetask.IdentifierWriteTests), log.CleanModelNameForFileSystem(mockedModelID), "golang", "golang", "plain", "response-1.log"): nil, "evaluation.csv": nil, }, }) diff --git a/log/logger.go b/log/logger.go index 6a254dbe..cea1a1d4 100644 --- a/log/logger.go +++ b/log/logger.go @@ -21,13 +21,20 @@ import ( type AttributeKey string const ( - AttributeKeyLanguage AttributeKey = "Language" + AttributeKeyArtifact AttributeKey = "Artifact" + AttributeKeyLanguage = "Language" AttributeKeyModel = "Model" AttributeKeyRepository = "Repository" AttributeKeyResultPath = "ResultPath" + AttributeKeyRun = "Run" AttributeKeyTask = "Task" ) +// Attribute returns a logging attribute. +func Attribute(key AttributeKey, value any) (attribute slog.Attr) { + return slog.Any(string(key), value) +} + // Flags defines how log messages should be printed. type Flags int @@ -104,6 +111,11 @@ func (l *Logger) Printf(format string, args ...any) { l.Logger.Info(fmt.Sprintf(format, args...)) } +// PrintWith logs the given message at the "info" level. +func (l *Logger) PrintWith(message string, args ...any) { + l.Logger.Info(message, args...) +} + // Panicf is equivalent to "Printf" followed by a panic. func (l *Logger) Panicf(format string, args ...any) { message := fmt.Sprintf(format, args...) @@ -169,19 +181,29 @@ func STDOUT() (logger *Logger) { // newLogWriter returns a logger that writes to a file and to the parent logger at the same time. func newLogWriter(parent io.Writer, filePath string) (writer io.Writer, err error) { + file, err := openLogFile(filePath) + if err != nil { + return nil, err + } + addOpenLogFile(file) + + writer = io.MultiWriter(parent, file) + + return writer, nil +} + +// openLogFile opens the given file and creates it if necessary. +func openLogFile(filePath string) (file *os.File, err error) { if err := os.MkdirAll(filepath.Dir(filePath), 0755); err != nil { return nil, pkgerrors.WithStack(err) } - file, err := os.OpenFile(filePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) + file, err = os.OpenFile(filePath, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644) if err != nil { return nil, pkgerrors.WithStack(err) } - addOpenLogFile(file) - writer = io.MultiWriter(parent, file) - - return writer, nil + return file, nil } // spawningHandler is a structural logging handler which spawns a new log file if one of the given log file spawners triggers. @@ -231,17 +253,36 @@ func (h *spawningHandler) Enabled(ctx context.Context, level slog.Level) bool { } // Handle handles the Record. -func (h *spawningHandler) Handle(ctx context.Context, record slog.Record) error { +func (h *spawningHandler) Handle(ctx context.Context, record slog.Record) (err error) { + writer := h.writer + attributes := maps.Clone(h.attributes) + record.Attrs(func(attribute slog.Attr) bool { + attributes[AttributeKey(attribute.Key)] = attribute.Value.String() + + return true + }) + for _, spawner := range artifactLogFileSpawners { + if !spawner.NeedsSpawn(attributes) { + continue + } + + logFilePath := spawner.FilePath(attributes) + writer, err = newLogWriter(writer, logFilePath) + if err != nil { + return err + } + } + if h.flags&FlagDate != 0 { - fmt.Fprint(h.writer, record.Time.Format("2006/01/02")) - fmt.Fprint(h.writer, " ") + fmt.Fprint(writer, record.Time.Format("2006/01/02")) + fmt.Fprint(writer, " ") } if h.flags&FlagTime != 0 { - fmt.Fprint(h.writer, record.Time.Format("15:04:05")) - fmt.Fprint(h.writer, " ") + fmt.Fprint(writer, record.Time.Format("15:04:05")) + fmt.Fprint(writer, " ") } - fmt.Fprintln(h.writer, record.Message) + fmt.Fprintln(writer, record.Message) return nil } @@ -338,6 +379,32 @@ var defaultLogFileSpawners = []logFileSpawner{ }, } +var artifactLogFileSpawners = []logFileSpawner{ + logFileSpawner{ + NeededAttributes: []AttributeKey{ + AttributeKeyResultPath, + + AttributeKeyArtifact, + AttributeKeyLanguage, + AttributeKeyModel, + AttributeKeyRepository, + AttributeKeyRun, + AttributeKeyTask, + }, + FilePath: func(attributes map[AttributeKey]string) string { + resultPath := attributes[AttributeKeyResultPath] + modelID := attributes[AttributeKeyModel] + languageID := attributes[AttributeKeyLanguage] + repositoryName := attributes[AttributeKeyRepository] + taskIdentifier := attributes[AttributeKeyTask] + run := attributes[AttributeKeyRun] + artifact := attributes[AttributeKeyArtifact] + + return filepath.Join(resultPath, taskIdentifier, CleanModelNameForFileSystem(modelID), languageID, repositoryName, fmt.Sprintf("%s-%s.log", artifact, run)) + }, + }, +} + // logFileSpawner defines when a new log file is spawned. type logFileSpawner struct { // NeededAttributes holds the list of attributes that need to be set in order to spawn a new log file. diff --git a/log/logger_test.go b/log/logger_test.go index 6313457c..af69a6f1 100644 --- a/log/logger_test.go +++ b/log/logger_test.go @@ -181,6 +181,45 @@ func TestLoggerWith(t *testing.T) { filepath.Join("taskA", "modelA", "languageA", "repositoryB", "evaluation.log"): "", }, }) + + t.Run("Artifacts", func(t *testing.T) { + validate(t, &testCase{ + Name: "Response", + + Do: func(logger *Logger, temporaryPath string) { + logger = logger.With(AttributeKeyResultPath, temporaryPath) + logger = logger.With(AttributeKeyLanguage, "languageA") + logger = logger.With(AttributeKeyModel, "modelA") + logger = logger.With(AttributeKeyRepository, "repositoryA") + logger = logger.With(AttributeKeyTask, "taskA") + logger = logger.With(AttributeKeyRun, "1") + + logger.PrintWith("Artifact content", Attribute(AttributeKeyArtifact, "response")) + logger.Print("No artifact content") + }, + + ExpectedLogOutput: ` + Spawning new log file at $TEMPORARY_PATH/evaluation.log + Spawning new log file at $TEMPORARY_PATH/taskA/modelA/languageA/repositoryA/evaluation.log + Artifact content + No artifact content + `, + ExpectedFiles: map[string]string{ + "evaluation.log": ` + Spawning new log file at $TEMPORARY_PATH/taskA/modelA/languageA/repositoryA/evaluation.log + Artifact content + No artifact content + `, + filepath.Join("taskA", "modelA", "languageA", "repositoryA", "evaluation.log"): ` + Artifact content + No artifact content + `, + filepath.Join("taskA", "modelA", "languageA", "repositoryA", "response-1.log"): ` + Artifact content + `, + }, + }) + }) } func TestCleanModelNameForFileSystem(t *testing.T) { diff --git a/model/llm/llm.go b/model/llm/llm.go index 4b69766a..056604b3 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -2,6 +2,7 @@ package llm import ( "context" + "fmt" "os" "path/filepath" "strings" @@ -170,17 +171,17 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e return assessment, nil } -func (m *Model) query(log *log.Logger, request string) (response string, duration time.Duration, err error) { +func (m *Model) query(logger *log.Logger, request string) (response string, duration time.Duration, err error) { if err := retry.Do( func() error { - log.Printf("Querying model %q with:\n%s", m.ID(), string(bytesutil.PrefixLines([]byte(request), []byte("\t")))) + logger.Printf("Querying model %q with:\n%s", m.ID(), string(bytesutil.PrefixLines([]byte(request), []byte("\t")))) start := time.Now() response, err = m.provider.Query(context.Background(), m.model, request) if err != nil { return err } duration = time.Since(start) - log.Printf("Model %q responded (%d ms) with:\n%s", m.ID(), duration.Milliseconds(), string(bytesutil.PrefixLines([]byte(response), []byte("\t")))) + logger.PrintWith(fmt.Sprintf("Model %q responded (%d ms) with:\n%s", m.ID(), duration.Milliseconds(), string(bytesutil.PrefixLines([]byte(response), []byte("\t")))), log.Attribute(log.AttributeKeyArtifact, "response")) return nil }, @@ -189,7 +190,7 @@ func (m *Model) query(log *log.Logger, request string) (response string, duratio retry.DelayType(retry.BackOffDelay), retry.LastErrorOnly(true), retry.OnRetry(func(n uint, err error) { - log.Printf("Attempt %d/%d: %s", n+1, m.queryAttempts, err) + logger.Printf("Attempt %d/%d: %s", n+1, m.queryAttempts, err) }), ); err != nil { return "", 0, err