Skip to content

Commit

Permalink
Optionally put databases in .settings.json
Browse files Browse the repository at this point in the history
  • Loading branch information
langbeinmovio committed Mar 4, 2021
1 parent 71d99f1 commit 9e1d932
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 17 deletions.
10 changes: 8 additions & 2 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@ import (
"strings"
)

const DefaultMaxAppServerConnections = 5

type settings struct {
MaxAppServerConnections int64
Databases map[string]database
}

type database struct {
Expand Down Expand Up @@ -87,7 +90,7 @@ func mustReadDatabasesConfigFile() map[string]database {
return databases
}

func readSettingsFile() *settings {
func mustReadSettings() *settings {
s := new(settings)
fileName := ".settings.json"
paths := getConfigPaths(fileName)
Expand All @@ -99,7 +102,10 @@ func readSettingsFile() *settings {
}
}
if s.MaxAppServerConnections == 0 {
s.MaxAppServerConnections = 5
s.MaxAppServerConnections = DefaultMaxAppServerConnections
}
if len(s.Databases) == 0 {
s.Databases = mustReadDatabasesConfigFile()
}
return s
}
22 changes: 10 additions & 12 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ func main() {
if *flagHelp {
usage("")
}

settings := mustReadSettings()

if *flagListDBs { // for auto-completion
for dbName := range mustReadDatabasesConfigFile() {
for dbName := range settings.Databases {
fmt.Print(dbName, " ")
}
fmt.Println()
os.Exit(0)
}

databases := mustReadDatabasesConfigFile()
settings := readSettingsFile()

if len(os.Args[1:]) == 0 {
usage("Target database unspecified; where should I run the query?")
}
Expand Down Expand Up @@ -59,18 +59,18 @@ func main() {
usage("No SQL to run. Exiting.")
}

os.Exit(_main(settings, databases, databasesArgs, query, newThreadSafePrintliner(os.Stdout).println))
os.Exit(_main(settings, databasesArgs, query, newThreadSafePrintliner(os.Stdout).println))
}

func _main(settings *settings, databases map[string]database, databasesArgs []string, query string, println func(string)) int {
func _main(settings *settings, databasesArgs []string, query string, println func(string)) int {
targetDatabases := []string{}
for _, k := range databasesArgs {
if _, ok := databases[k]; k != "all" && !ok {
if _, ok := settings.Databases[k]; k != "all" && !ok {
usage("Target database unknown: [%v]", k)
}
if k == "all" {
targetDatabases = nil
for k := range databases {
for k := range settings.Databases {
targetDatabases = append(targetDatabases, k)
}
break
Expand All @@ -86,7 +86,7 @@ func _main(settings *settings, databases map[string]database, databasesArgs []st

appServerSemaphors := make(map[string]*semaphore.Weighted)
for _, k := range targetDatabases {
var appServer = databases[k].AppServer
var appServer = settings.Databases[k].AppServer
if appServer != "" && appServerSemaphors[appServer] == nil {
appServerSemaphors[appServer] = semaphore.NewWeighted(settings.MaxAppServerConnections)
}
Expand All @@ -99,16 +99,14 @@ func _main(settings *settings, databases map[string]database, databasesArgs []st
go func(db database, k string) {
defer wg.Done()
if db.AppServer != "" {
fmt.Print("Aquiring lock from app server", db.AppServer, "\n")
var sem = appServerSemaphors[db.AppServer]
sem.Acquire(quitContext, 1)
fmt.Print("Aquired lock from app server", db.AppServer, "\n")
defer sem.Release(1)
}
if r := sqlRunner.runSQL(db, k); !r {
returnCode = 1
}
}(databases[k], k)
}(settings.Databases[k], k)
}

wg.Wait()
Expand Down
5 changes: 2 additions & 3 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ var mysqlTests = tests{{
},
}

var testSettings = settings{MaxAppServerConnections: 5}

func Test_MySQL(t *testing.T) {
awaitDB(mySQL, t)

Expand Down Expand Up @@ -198,7 +196,8 @@ func runTests(ts tests, testConfig testConfig, t *testing.T) {
for _, tc := range ts {
t.Run(tc.name, func(t *testing.T) {
var buf = bytes.Buffer{}
_main(&testSettings, testConfig, tc.targetDBs, tc.query, newThreadSafePrintliner(&buf).println)
var testSettings = settings{MaxAppServerConnections: 5, Databases: testConfig}
_main(&testSettings, tc.targetDBs, tc.query, newThreadSafePrintliner(&buf).println)
var actual = strings.Split(buf.String(), "\n")
sort.Strings(actual)
if !reflect.DeepEqual(tc.expected, actual) {
Expand Down

0 comments on commit 9e1d932

Please sign in to comment.