Skip to content

Commit

Permalink
Fix bug in config set when flag --config is provided (#274)
Browse files Browse the repository at this point in the history
* config set agent.model wasn't working when you use "--config" to
specify a non default config file
* The break occured because of
https://github.com/jlewi/foyle/blob/e03b8fe40b65f777de3b7c9ed02b612403c9c71f/app/pkg/config/config.go#L366
* That change meant InitViper was no longer operating on the global
instance.
* However this line
https://github.com/jlewi/foyle/blob/e03b8fe40b65f777de3b7c9ed02b612403c9c71f/app/cmd/config.go#L60
which was modifying the configuration was using the global instance
* This PR fixes this by being explicit about the instance of viper used.
* The unittest verifies we correctly modify the configuration
  • Loading branch information
jlewi authored Oct 4, 2024
1 parent e03b8fe commit e2feb99
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 34 deletions.
37 changes: 9 additions & 28 deletions app/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cmd
import (
"fmt"
"os"
"strings"

"github.com/jlewi/foyle/app/pkg/config"
"github.com/pkg/errors"
Expand Down Expand Up @@ -31,39 +30,21 @@ func NewSetConfigCmd() *cobra.Command {
Run: func(cmd *cobra.Command, args []string) {

err := func() error {
if err := config.InitViper(cmd); err != nil {
v := viper.GetViper()

if err := config.InitViperInstance(v, cmd); err != nil {
return err
}

pieces := strings.Split(args[0], "=")
cfgName := pieces[0]

var fConfig *config.Config
switch cfgName {
case "azureOpenAI.deployments":
if len(pieces) != 3 {
return errors.New("Invalid argument; argument is not in the form azureOpenAI.deployments=<model>=<deployment>")
}

d := config.AzureDeployment{
Model: pieces[1],
Deployment: pieces[2],
}

fConfig = config.GetConfig()
config.SetAzureDeployment(fConfig, d)
default:
if len(pieces) < 2 {
return errors.New("Invalid usage; set expects an argument in the form <NAME>=<VALUE>")
}
cfgValue := pieces[1]
viper.Set(cfgName, cfgValue)
fConfig = config.GetConfig()
fConfig, err := config.UpdateViperConfig(v, args[0])

if err != nil {
return errors.Wrap(err, "Failed to update configuration")
}

file := viper.ConfigFileUsed()
file := fConfig.GetConfigFile()
if file == "" {
file = config.DefaultConfigFile()
return errors.New("Failed to get configuration file")
}
// Persist the configuration
return fConfig.Write(file)
Expand Down
2 changes: 2 additions & 0 deletions app/pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ func (c *Config) DeepCopy() Config {
func InitViper(cmd *cobra.Command) error {
// N.B. we need to set globalV because the subsequent call GetConfig will use that viper instance.
// Would it make sense to combine InitViper and Get into one command that returns a config object?
// TODO(jeremy): Could we just use viper.GetViper() to get the global instance?
globalV = viper.New()
return InitViperInstance(globalV, cmd)
}
Expand Down Expand Up @@ -407,6 +408,7 @@ func InitViperInstance(v *viper.Viper, cmd *cobra.Command) error {

// Ensure the path for the config file path is set
// Required since we use viper to persist the location of the config file so can save to it.
// This allows us to overwrite the config file location with the --config flag.
cfgFile := v.GetString(ConfigFlagName)
if cfgFile != "" {
v.SetConfigFile(cfgFile)
Expand Down
6 changes: 0 additions & 6 deletions app/pkg/config/test_data/partial.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,7 @@ openai:
telemetry:
honeycomb:
apiKeyFile: /Users/fred/secrets/honeycomb.api.key
eval:
gcpServiceAccount: [email protected]
learner:
logDirs: []
exampleDirs:
- /Users/fred/.foyle/training
replicate:
apiKeyFile: /Users/fred/replicate/secrets/apikey
anthropic:
apiKeyFile: /Users/fred/secrets/anthropic.key
41 changes: 41 additions & 0 deletions app/pkg/config/update.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package config

import (
"strings"

"github.com/pkg/errors"
"github.com/spf13/viper"
)

// UpdateViperConfig update the viper configuration with the given expression.
// expression should be a value such as "agent.model=gpt-4o-mini"
// The input is a viper configuration because we leverage viper to handle setting most keys.
// However, in some special cases we use custom functions. This is why we return a Config object.
func UpdateViperConfig(v *viper.Viper, expression string) (*Config, error) {
pieces := strings.Split(expression, "=")
cfgName := pieces[0]

var fConfig *Config

switch cfgName {
case "azureOpenAI.deployments":
if len(pieces) != 3 {
return fConfig, errors.New("Invalid argument; argument is not in the form azureOpenAI.deployments=<model>=<deployment>")
}

d := AzureDeployment{
Model: pieces[1],
Deployment: pieces[2],
}

SetAzureDeployment(fConfig, d)
default:
if len(pieces) < 2 {
return fConfig, errors.New("Invalid usage; set expects an argument in the form <NAME>=<VALUE>")
}
cfgValue := pieces[1]
v.Set(cfgName, cfgValue)
}

return getConfigFromViper(v)
}
88 changes: 88 additions & 0 deletions app/pkg/config/update_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package config

import (
"os"
"path/filepath"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/jlewi/foyle/app/api"
"github.com/spf13/viper"
)

func Test_UpdateViperConfig(t *testing.T) {
type testCase struct {
name string
configFile string
expression string
expected *Config
}

cases := []testCase{
{
name: "model",
configFile: "partial.yaml",
expression: "agent.model=some-other-model",
expected: &Config{
Logging: Logging{
Level: "info",
Sinks: []LogSink{{JSON: true, Path: "gcplogs:///projects/fred-dev/logs/foyle"}, {Path: "stderr"}},
},
Agent: &api.AgentConfig{
Model: "some-other-model",
ModelProvider: "anthropic",
RAG: &api.RAGConfig{
Enabled: true,
MaxResults: 3,
},
},
Server: ServerConfig{
BindAddress: "0.0.0.0",
GRPCPort: 9080,
HttpPort: 8877,
HttpMaxReadTimeout: time.Minute,
HttpMaxWriteTimeout: time.Minute,
},
OpenAI: &OpenAIConfig{
APIKeyFile: "/Users/red/secrets/openapi.api.key",
},
Telemetry: &TelemetryConfig{
Honeycomb: &HoneycombConfig{
APIKeyFile: "/Users/fred/secrets/honeycomb.api.key",
},
},
Learner: &LearnerConfig{LogDirs: []string{}, ExampleDirs: []string{"/Users/fred/.foyle/training"}},
},
},
}

cwd, err := os.Getwd()
if err != nil {
t.Fatalf("Failed to get working directory")
}
tDir := filepath.Join(cwd, "test_data")

for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
// Create an empty configuration file and run various assertions on it
v := viper.New()
v.SetConfigFile(filepath.Join(tDir, c.configFile))

if err := InitViperInstance(v, nil); err != nil {
t.Fatalf("Failed to initialize the configuration.")
}

cfg, err := UpdateViperConfig(v, c.expression)
if err != nil {
t.Fatalf("Failed to update config; %+v", err)
}

opts := cmpopts.IgnoreUnexported(Config{})
if d := cmp.Diff(c.expected, cfg, opts); d != "" {
t.Fatalf("Unexpected diff:\n%+v", d)
}
})
}
}

0 comments on commit e2feb99

Please sign in to comment.