From 71d99f16fd3a24934d79ba0d43df816972b9dd0a Mon Sep 17 00:00:00 2001 From: Matthias Langbein Date: Thu, 4 Mar 2021 18:59:39 +1300 Subject: [PATCH] Configure maximum number of concurrent requests via same AppServer --- Dockerfile | 2 ++ config.go | 46 ++++++++++++++++++++++++++++++++++++++++------ main.go | 22 ++++++++++++++++++++-- main_test.go | 4 +++- 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index e219af6..9883a7f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,4 +4,6 @@ RUN apt-get update && apt-get install -y --no-install-recommends default-mysql-c ENV GO111MODULE=off +RUN git clone https://github.com/golang/sync $GOPATH/src/golang.org/x/sync + ENTRYPOINT [ "go", "test", "-v", "." ] diff --git a/config.go b/config.go index 2b05f02..b241d7f 100644 --- a/config.go +++ b/config.go @@ -9,6 +9,10 @@ import ( "strings" ) +type settings struct { + MaxAppServerConnections int64 +} + type database struct { AppServer string DbServer string @@ -18,9 +22,8 @@ type database struct { SQLType string } -func mustReadDatabasesConfigFile() map[string]database { +func getConfigPaths(fileName string) []string { var paths []string - databases := map[string]database{} usr, err := user.Current() if err != nil { @@ -33,7 +36,7 @@ func mustReadDatabasesConfigFile() map[string]database { if xdgHome == "" { xdgHome = fmt.Sprintf("%v/.config/", home) } - xdgHome += "sql/.databases.json" + xdgHome += fmt.Sprintf("sql/%v", fileName) paths = append(paths, xdgHome) @@ -41,21 +44,35 @@ func mustReadDatabasesConfigFile() map[string]database { xdgConfigDirs = append(xdgConfigDirs, "/etc/xdg") for _, d := range xdgConfigDirs { if d != "" { - paths = append(paths, fmt.Sprintf("%v/sql/.databases.json", d)) + paths = append(paths, fmt.Sprintf("%v/sql/%v", d, fileName)) } } - paths = append(paths, fmt.Sprintf("%v/.databases.json", home)) + paths = append(paths, fmt.Sprintf("%v/%v", home, fileName)) + return paths +} +func readFileContent(paths []string) ([]byte, error) { var byts []byte + var err error for _, p := range paths { if byts, err = ioutil.ReadFile(p); err != nil { continue } break } + return byts, err +} + +func mustReadDatabasesConfigFile() map[string]database { + databases := map[string]database{} + + fileName := ".databases.json" + paths := getConfigPaths(fileName) + byts, err := readFileContent(paths) if err != nil { - usage("Couldn't find .databases.json in the following paths [%v]. err=%v", paths, err) + usage("Couldn't find .%v in the following paths [%v]. err=%v", fileName, paths, err) + } err = json.Unmarshal(byts, &databases) @@ -69,3 +86,20 @@ func mustReadDatabasesConfigFile() map[string]database { return databases } + +func readSettingsFile() *settings { + s := new(settings) + fileName := ".settings.json" + paths := getConfigPaths(fileName) + byts, err := readFileContent(paths) + if err == nil { + err = json.Unmarshal(byts, s) + if err != nil { + usage("Found but couldn't JSON unmarshal %v. Looked like this:\n\n%v\n\nerr=%v", fileName, string(byts), err) + } + } + if s.MaxAppServerConnections == 0 { + s.MaxAppServerConnections = 5 + } + return s +} diff --git a/main.go b/main.go index 2eba9e8..0eb8f6a 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,8 @@ import ( "log" "os" "sync" + + "golang.org/x/sync/semaphore" ) func main() { @@ -28,6 +30,7 @@ func main() { } databases := mustReadDatabasesConfigFile() + settings := readSettingsFile() if len(os.Args[1:]) == 0 { usage("Target database unspecified; where should I run the query?") @@ -56,10 +59,10 @@ func main() { usage("No SQL to run. Exiting.") } - os.Exit(_main(databases, databasesArgs, query, newThreadSafePrintliner(os.Stdout).println)) + os.Exit(_main(settings, databases, databasesArgs, query, newThreadSafePrintliner(os.Stdout).println)) } -func _main(databases map[string]database, databasesArgs []string, query string, println func(string)) int { +func _main(settings *settings, databases map[string]database, databasesArgs []string, query string, println func(string)) int { targetDatabases := []string{} for _, k := range databasesArgs { if _, ok := databases[k]; k != "all" && !ok { @@ -81,12 +84,27 @@ func _main(databases map[string]database, databasesArgs []string, query string, var wg sync.WaitGroup wg.Add(len(targetDatabases)) + appServerSemaphors := make(map[string]*semaphore.Weighted) + for _, k := range targetDatabases { + var appServer = databases[k].AppServer + if appServer != "" && appServerSemaphors[appServer] == nil { + appServerSemaphors[appServer] = semaphore.NewWeighted(settings.MaxAppServerConnections) + } + } + sqlRunner := mustNewSQLRunner(quitContext, println, query, len(targetDatabases) > 1) returnCode := 0 for _, k := range targetDatabases { go func(db database, k string) { defer wg.Done() + if db.AppServer != "" { + fmt.Print("Aquiring lock from app server", db.AppServer, "\n") + var sem = appServerSemaphors[db.AppServer] + sem.Acquire(quitContext, 1) + fmt.Print("Aquired lock from app server", db.AppServer, "\n") + defer sem.Release(1) + } if r := sqlRunner.runSQL(db, k); !r { returnCode = 1 } diff --git a/main_test.go b/main_test.go index e4c092b..0738cf8 100644 --- a/main_test.go +++ b/main_test.go @@ -94,6 +94,8 @@ var mysqlTests = tests{{ }, } +var testSettings = settings{MaxAppServerConnections: 5} + func Test_MySQL(t *testing.T) { awaitDB(mySQL, t) @@ -196,7 +198,7 @@ func runTests(ts tests, testConfig testConfig, t *testing.T) { for _, tc := range ts { t.Run(tc.name, func(t *testing.T) { var buf = bytes.Buffer{} - _main(testConfig, tc.targetDBs, tc.query, newThreadSafePrintliner(&buf).println) + _main(&testSettings, testConfig, tc.targetDBs, tc.query, newThreadSafePrintliner(&buf).println) var actual = strings.Split(buf.String(), "\n") sort.Strings(actual) if !reflect.DeepEqual(tc.expected, actual) {