diff --git a/.golangci.yml b/.golangci.yml index f159c65..d3416a8 100755 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,3 +1,7 @@ +run: + # The default runtime timeout is 1m, which doesn't work well on Github Actions. + timeout: 4m + # NOTE: This file is populated by the lint-install tool. Local adjustments may be overwritten. linters-settings: cyclop: @@ -18,9 +22,19 @@ linters-settings: min-occurrences: 5 ignore-tests: true + gosec: + excludes: + - G107 # Potential HTTP request made with variable url + - G204 # Subprocess launched with function call as argument or cmd arguments + - G404 # Use of weak random number generator (math/rand instead of crypto/rand + errorlint: # these are still common in Go: for instance, exit errors. asserts: false + # Forcing %w in error wrapping forces authors to make errors part of their package APIs. The decision to make + # an error part of a package API should be a concious decision by the author. + # Also see Hyrums Law. + errorf: false exhaustive: default-signifies-exhaustive: true @@ -29,9 +43,9 @@ linters-settings: min-complexity: 8 nolintlint: - require-explanation: true - allow-unused: false - require-specific: true + require-explanation: true + allow-unused: false + require-specific: true revive: ignore-generated-header: true @@ -78,10 +92,10 @@ linters-settings: - name: waitgroup-by-value staticcheck: - go: "1.20" + go: "1.18" unused: - go: "1.20" + go: "1.18" output: sort-results: true @@ -96,8 +110,7 @@ linters: - dupl - durationcheck - errcheck - # errname is only available in golangci-lint v1.42.0+ - wait until v1.43 is available to settle - #- errname + - errname - errorlint - exhaustive - exportloopref @@ -107,6 +120,8 @@ linters: - gocritic - godot - gofmt + - gofumpt + - gosec - goheader - goimports - goprintffuncname @@ -123,11 +138,12 @@ linters: - nolintlint - predeclared # disabling for the initial iteration of the linting tool - #- promlinter + # - promlinter - revive - - rowserrcheck - - sqlclosecheck + # - rowserrcheck - disabled because of generics, https://github.com/golangci/golangci-lint/issues/2649 + # - sqlclosecheck - disabled because of generics, https://github.com/golangci/golangci-lint/issues/2649 - staticcheck + # - structcheck - disabled because of generics, https://github.com/golangci/golangci-lint/issues/2649 - stylecheck - thelper - tparallel @@ -135,7 +151,7 @@ linters: - unconvert - unparam - unused - - wastedassign + # - wastedassign - disabled because of generics, https://github.com/golangci/golangci-lint/issues/2649 - whitespace # Disabled linters, due to being misaligned with Go practices @@ -150,28 +166,25 @@ linters: # - nlreturn # - testpackage # - wsl - # Disabled linters, due to not being relevant to our code base: # - maligned # - prealloc "For most programs usage of prealloc will be a premature optimization." - # Disabled linters due to bad error messages or bugs - # - gofumpt - # - gosec # - tagliatelle - issues: # Excluding configuration per-path, per-linter, per-text and per-source exclude-rules: - path: _test\.go linters: - - gocyclo - - errcheck - dupl + - errcheck + - forcetypeassert + - gocyclo - gosec + - noctx - - path: cmd.* + - path: .*cmd.* linters: - noctx @@ -179,7 +192,7 @@ issues: linters: - noctx - - path: cmd.* + - path: .*cmd.* text: "deep-exit" - path: main\.go diff --git a/Makefile b/Makefile index c1530f3..d342952 100644 --- a/Makefile +++ b/Makefile @@ -101,14 +101,16 @@ pbs-docker-image: ## generate container image for building protocol buffers run-image: ## run PBnJ container image scripts/run-image.sh -# BEGIN: lint-install . +# BEGIN: lint-install github.com/tinkerbell/pbnj # http://github.com/tinkerbell/lint-install -GOLINT_VERSION ?= v1.52.2 -HADOLINT_VERSION ?= v2.7.0 -SHELLCHECK_VERSION ?= v0.7.2 -LINT_OS := $(shell uname) +.PHONY: lint +lint: _lint + LINT_ARCH := $(shell uname -m) +LINT_OS := $(shell uname) +LINT_OS_LOWER := $(shell echo $(LINT_OS) | tr '[:upper:]' '[:lower:]') +LINT_ROOT := $(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) # shellcheck and hadolint lack arm64 native binaries: rely on x86-64 emulation ifeq ($(LINT_OS),Darwin) @@ -117,30 +119,34 @@ ifeq ($(LINT_OS),Darwin) endif endif -LINT_LOWER_OS = $(shell echo $(LINT_OS) | tr '[:upper:]' '[:lower:]') -GOLINT_CONFIG:=$(shell dirname $(realpath $(firstword $(MAKEFILE_LIST))))/.golangci.yml +LINTERS := +FIXERS := + +GOLANGCI_LINT_CONFIG := $(LINT_ROOT)/.golangci.yml +GOLANGCI_LINT_VERSION ?= v1.53.3 +GOLANGCI_LINT_BIN := $(LINT_ROOT)/out/linters/golangci-lint-$(GOLANGCI_LINT_VERSION)-$(LINT_ARCH) +$(GOLANGCI_LINT_BIN): + mkdir -p $(LINT_ROOT)/out/linters + rm -rf $(LINT_ROOT)/out/linters/golangci-lint-* + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(LINT_ROOT)/out/linters $(GOLANGCI_LINT_VERSION) + mv $(LINT_ROOT)/out/linters/golangci-lint $@ + +LINTERS += golangci-lint-lint +golangci-lint-lint: $(GOLANGCI_LINT_BIN) + find . -name go.mod -execdir "$(GOLANGCI_LINT_BIN)" run -c "$(GOLANGCI_LINT_CONFIG)" \; -lint: out/linters/shellcheck-$(SHELLCHECK_VERSION)-$(LINT_ARCH)/shellcheck out/linters/hadolint-$(HADOLINT_VERSION)-$(LINT_ARCH) out/linters/golangci-lint-$(GOLINT_VERSION)-$(LINT_ARCH) - out/linters/golangci-lint-$(GOLINT_VERSION)-$(LINT_ARCH) run - out/linters/hadolint-$(HADOLINT_VERSION)-$(LINT_ARCH) -t info $(shell find . -name "*Dockerfile") - out/linters/shellcheck-$(SHELLCHECK_VERSION)-$(LINT_ARCH)/shellcheck $(shell find . -name "*.sh") +FIXERS += golangci-lint-fix +golangci-lint-fix: $(GOLANGCI_LINT_BIN) + find . -name go.mod -execdir "$(GOLANGCI_LINT_BIN)" run -c "$(GOLANGCI_LINT_CONFIG)" --fix \; -out/linters/shellcheck-$(SHELLCHECK_VERSION)-$(LINT_ARCH)/shellcheck: - mkdir -p out/linters - curl -sSfL https://github.com/koalaman/shellcheck/releases/download/$(SHELLCHECK_VERSION)/shellcheck-$(SHELLCHECK_VERSION).$(LINT_LOWER_OS).$(LINT_ARCH).tar.xz | tar -C out/linters -xJf - - mv out/linters/shellcheck-$(SHELLCHECK_VERSION) out/linters/shellcheck-$(SHELLCHECK_VERSION)-$(LINT_ARCH) +.PHONY: _lint $(LINTERS) +_lint: $(LINTERS) -out/linters/hadolint-$(HADOLINT_VERSION)-$(LINT_ARCH): - mkdir -p out/linters - curl -sfL https://github.com/hadolint/hadolint/releases/download/v2.6.1/hadolint-$(LINT_OS)-$(LINT_ARCH) > out/linters/hadolint-$(HADOLINT_VERSION)-$(LINT_ARCH) - chmod u+x out/linters/hadolint-$(HADOLINT_VERSION)-$(LINT_ARCH) +.PHONY: fix $(FIXERS) +fix: $(FIXERS) -out/linters/golangci-lint-$(GOLINT_VERSION)-$(LINT_ARCH): - mkdir -p out/linters - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b out/linters $(GOLINT_VERSION) - mv out/linters/golangci-lint out/linters/golangci-lint-$(GOLINT_VERSION)-$(LINT_ARCH) +# END: lint-install github.com/tinkerbell/pbnj -# END: lint-install . ##@ Clients diff --git a/client/client.go b/client/client.go index 3f37c96..af2a457 100644 --- a/client/client.go +++ b/client/client.go @@ -118,7 +118,7 @@ func Screenshot(ctx context.Context, client v1.DiagnosticClient, request *v1.Scr filename := fmt.Sprintf("%s.%s", time.Now().String(), screenshotResponse.Filetype) - if err := os.WriteFile(filename, screenshotResponse.Image, 0755); err != nil { + if err := os.WriteFile(filename, screenshotResponse.Image, 0o755); err != nil { //nolint:gosec // Can we make this 0600? return "", err } diff --git a/cmd/server.go b/cmd/server.go index 3a3d006..fb0f2f4 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -48,6 +48,13 @@ var ( // // for more information see https://github.com/bmc-toolbox/bmclib#bmc-connections skipRedfishVersions string + // maxWorkers is the maximum number of concurrent workers that will be allowed to handle bmc tasks. + maxWorkers int + // workerIdleTimeout is the idle timeout for workers. If no tasks are received within the timeout, the worker will exit. + workerIdleTimeout time.Duration + // maxIngestionWorkers is the maximum number of concurrent workers that will be allowed. + // These are the workers that handle ingesting tasks from RPC endpoints and writing them to the map of per Host ID queues. + maxIngestionWorkers int // serverCmd represents the server command. serverCmd = &cobra.Command{ Use: "server", @@ -91,7 +98,11 @@ var ( httpServer := http.NewServer(metricsAddr) httpServer.WithLogger(logger) - opts := []grpcsvr.ServerOption{grpcsvr.WithBmcTimeout(bmcTimeout)} + opts := []grpcsvr.ServerOption{ + grpcsvr.WithBmcTimeout(bmcTimeout), + grpcsvr.WithMaxWorkers(maxWorkers), + grpcsvr.WithWorkerIdleTimeout(workerIdleTimeout), + } if skipRedfishVersions != "" { versions := strings.Split(skipRedfishVersions, ",") @@ -114,6 +125,9 @@ func init() { serverCmd.PersistentFlags().StringVar(&rsPubKey, "rsPubKey", "", "RS public key") serverCmd.PersistentFlags().DurationVar(&bmcTimeout, "bmcTimeout", oob.DefaultBMCTimeout, "Timeout for BMC calls") serverCmd.PersistentFlags().StringVar(&skipRedfishVersions, "skipRedfishVersions", "", "Ignore the redfish endpoint on BMCs running the given version(s)") + serverCmd.PersistentFlags().IntVar(&maxWorkers, "maxWorkers", 1000, "Maximum number of concurrent workers that will be allowed to handle bmc tasks") + serverCmd.PersistentFlags().DurationVar(&workerIdleTimeout, "workerIdleTimeout", 30*time.Second, "Idle timeout for workers. If no tasks are received within the timeout, the worker will exit. New tasks will spawn a new worker if there isn't a worker running") + serverCmd.PersistentFlags().IntVar(&maxIngestionWorkers, "maxIngestionWorkers", 1000, "Maximum number of concurrent workers that will be allowed. These are the workers that handle ingesting tasks from RPC endpoints and writing them to the map of per Host ID queues") rootCmd.AddCommand(serverCmd) } diff --git a/go.mod b/go.mod index e6b5afd..c947741 100644 --- a/go.mod +++ b/go.mod @@ -90,7 +90,7 @@ require ( golang.org/x/crypto v0.6.0 // indirect golang.org/x/exp v0.0.0-20230212135524-a684f29349b6 // indirect golang.org/x/net v0.8.0 // indirect - golang.org/x/sys v0.6.0 // indirect + golang.org/x/sys v0.7.0 // indirect golang.org/x/text v0.8.0 // indirect google.golang.org/genproto v0.0.0-20230209215440-0dfe4f8abfcc // indirect gopkg.in/go-playground/validator.v9 v9.31.0 // indirect diff --git a/go.sum b/go.sum index 858bb36..a39eb31 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,6 @@ github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bmc-toolbox/bmclib v0.5.7 h1:v3CqOJCMUuH+kA+xi7CdY5EuzUhMH9gsBkYTQMYlbog= github.com/bmc-toolbox/bmclib v0.5.7/go.mod h1:jSCb2/o2bZhTTg3IgShckCfFxkX4yqQC065tuYh2pKk= -github.com/bmc-toolbox/bmclib/v2 v2.0.1-0.20230324092939-d39fb75b6aa9 h1:UNtiASZUNvhF7Dr2qdqQy63VjddnpxS4bH3f+SQc/yQ= -github.com/bmc-toolbox/bmclib/v2 v2.0.1-0.20230324092939-d39fb75b6aa9/go.mod h1:iRhgD8P0gvy95wYXA3FDCKbo/aRiKBaodBBgoUG/+Qg= github.com/bmc-toolbox/bmclib/v2 v2.0.1-0.20230515164712-2714c7479477 h1:2GKBUqU+hrthvhEJyvJMj473uUQ7ByufchSftLNLS8E= github.com/bmc-toolbox/bmclib/v2 v2.0.1-0.20230515164712-2714c7479477/go.mod h1:a3Ra0ce/LV3wAj7AHuphlHNTx5Sg67iQqtLGr1zoqio= github.com/bmc-toolbox/common v0.0.0-20230220061748-93ff001f4a1d h1:cQ30Wa8mhLzK1TSOG+g3FlneIsXtFgun61mmPwVPmD0= @@ -236,12 +234,8 @@ github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1: github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/jacobweinstock/iamt v0.0.0-20230304043040-a6b4a1001123 h1:Mh2eOcadGcu/6E0bq5FfaGlFYFe2oyNOeRjpgC1vOq0= -github.com/jacobweinstock/iamt v0.0.0-20230304043040-a6b4a1001123/go.mod h1:FgmiLTU6cJewV4Xgrq6m5o8CUlTQOJtqzaFLGA0mG+E= github.com/jacobweinstock/iamt v0.0.0-20230502042727-d7cdbe67d9ef h1:G4k02HGmBUfJFSNu3gfKJ+ki+B3qutKsYzYndkqqKc4= github.com/jacobweinstock/iamt v0.0.0-20230502042727-d7cdbe67d9ef/go.mod h1:FgmiLTU6cJewV4Xgrq6m5o8CUlTQOJtqzaFLGA0mG+E= -github.com/jacobweinstock/registrar v0.4.6 h1:0O3g2jT2Lx+Bf+yl4QsMUN48fVZxUpM3kS+NtIJ+ucw= -github.com/jacobweinstock/registrar v0.4.6/go.mod h1:IDx65tQ7DLJ2UqiVjE1zo74jMZZfel9YZW8VrC26m6o= github.com/jacobweinstock/registrar v0.4.7 h1:s4dOExccgD+Pc7rJC+f3Mc3D+NXHcXUaOibtcEsPxOc= github.com/jacobweinstock/registrar v0.4.7/go.mod h1:PWmkdGFG5/ZdCqgMo7pvB3pXABOLHc5l8oQ0sgmBNDU= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= @@ -344,8 +338,6 @@ github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU= github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA= -github.com/stmcginnis/gofish v0.13.1-0.20230130123602-c77017d5737a h1:FJWP5Gv6qUSa+hfkOMpWtVdpWEl/jl+QZfwRIA1mK9E= -github.com/stmcginnis/gofish v0.13.1-0.20230130123602-c77017d5737a/go.mod h1:BLDSFTp8pDlf/xDbLZa+F7f7eW0E/CHCboggsu8CznI= github.com/stmcginnis/gofish v0.14.0 h1:geECNAiG33JDB2x2xDkerpOOuXFqxp5YP3EFE3vd5iM= github.com/stmcginnis/gofish v0.14.0/go.mod h1:BLDSFTp8pDlf/xDbLZa+F7f7eW0E/CHCboggsu8CznI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -396,7 +388,7 @@ go.opentelemetry.io/proto/otlp v0.19.0/go.mod h1:H7XAot3MsfNsj7EXtrA2q5xSNQ10UqI go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= -go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= +go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= @@ -554,8 +546,8 @@ golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= +golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/grpc/rpc/bmc.go b/grpc/rpc/bmc.go index 3d5fe62..1a74627 100644 --- a/grpc/rpc/bmc.go +++ b/grpc/rpc/bmc.go @@ -45,7 +45,7 @@ func (b *BmcService) Reset(ctx context.Context, in *v1.ResetRequest) (*v1.ResetR l.Info( "start Reset request", - "username", in.Authn.GetDirectAuthn().GetUsername(), + "username", in.GetAuthn().GetDirectAuthn().GetUsername(), "vendor", in.Vendor.GetName(), "resetKind", in.GetResetKind().String(), ) @@ -66,20 +66,20 @@ func (b *BmcService) Reset(ctx context.Context, in *v1.ResetRequest) (*v1.ResetR defer cancel() return "", t.BMCReset(taskCtx, in.ResetKind.String()) } - b.TaskRunner.Execute(ctx, l, "bmc reset", taskID, execFunc) + b.TaskRunner.Execute(ctx, l, "bmc reset", taskID, in.GetAuthn().GetDirectAuthn().GetHost().GetHost(), execFunc) return &v1.ResetResponse{TaskId: taskID}, nil } // CreateUser sets the next boot device of a machine. -func (b *BmcService) CreateUser(ctx context.Context, in *v1.CreateUserRequest) (*v1.CreateUserResponse, error) { +func (b *BmcService) CreateUser(ctx context.Context, in *v1.CreateUserRequest) (*v1.CreateUserResponse, error) { //nolint:dupl // there is enough difference to not be a duplicate. l := logging.ExtractLogr(ctx) taskID := xid.New().String() l = l.WithValues("taskID", taskID) l.Info( "start CreateUser request", - "username", in.Authn.GetDirectAuthn().GetUsername(), + "username", in.GetAuthn().GetDirectAuthn().GetUsername(), "vendor", in.Vendor.GetName(), "userCreds.Username", in.UserCreds.Username, "userCreds.UserRole", in.UserCreds.UserRole, @@ -101,20 +101,20 @@ func (b *BmcService) CreateUser(ctx context.Context, in *v1.CreateUserRequest) ( defer cancel() return "", t.CreateUser(taskCtx) } - b.TaskRunner.Execute(ctx, l, "creating user", taskID, execFunc) + b.TaskRunner.Execute(ctx, l, "creating user", taskID, in.GetAuthn().GetDirectAuthn().GetHost().GetHost(), execFunc) return &v1.CreateUserResponse{TaskId: taskID}, nil } // UpdateUser updates a users credentials on a BMC. -func (b *BmcService) UpdateUser(ctx context.Context, in *v1.UpdateUserRequest) (*v1.UpdateUserResponse, error) { +func (b *BmcService) UpdateUser(ctx context.Context, in *v1.UpdateUserRequest) (*v1.UpdateUserResponse, error) { //nolint:dupl // there is enough difference to not be a duplicate. l := logging.ExtractLogr(ctx) taskID := xid.New().String() l = l.WithValues("taskID", taskID) l.Info( "start UpdateUser request", - "username", in.Authn.GetDirectAuthn().GetUsername(), + "username", in.GetAuthn().GetDirectAuthn().GetUsername(), "vendor", in.Vendor.GetName(), "userCreds.Username", in.UserCreds.Username, "userCreds.UserRole", in.UserCreds.UserRole, @@ -136,7 +136,7 @@ func (b *BmcService) UpdateUser(ctx context.Context, in *v1.UpdateUserRequest) ( defer cancel() return "", t.UpdateUser(taskCtx) } - b.TaskRunner.Execute(ctx, l, "updating user", taskID, execFunc) + b.TaskRunner.Execute(ctx, l, "updating user", taskID, in.GetAuthn().GetDirectAuthn().GetHost().GetHost(), execFunc) return &v1.UpdateUserResponse{TaskId: taskID}, nil } @@ -148,7 +148,7 @@ func (b *BmcService) DeleteUser(ctx context.Context, in *v1.DeleteUserRequest) ( l = l.WithValues("taskID", taskID) l.Info( "start DeleteUser request", - "username", in.Authn.GetDirectAuthn().GetUsername(), + "username", in.GetAuthn().GetDirectAuthn().GetUsername(), "vendor", in.Vendor.GetName(), "userCreds.Username", in.Username, ) @@ -169,7 +169,7 @@ func (b *BmcService) DeleteUser(ctx context.Context, in *v1.DeleteUserRequest) ( defer cancel() return "", t.DeleteUser(taskCtx) } - b.TaskRunner.Execute(ctx, l, "deleting user", taskID, execFunc) + b.TaskRunner.Execute(ctx, l, "deleting user", taskID, in.GetAuthn().GetDirectAuthn().GetHost().GetHost(), execFunc) return &v1.DeleteUserResponse{TaskId: taskID}, nil } diff --git a/grpc/rpc/bmc_test.go b/grpc/rpc/bmc_test.go index 5f009c5..b1c3af9 100644 --- a/grpc/rpc/bmc_test.go +++ b/grpc/rpc/bmc_test.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/philippgille/gokv" @@ -19,9 +20,10 @@ import ( const tempIPMITool = "/tmp/ipmitool" var ( - ctx context.Context - taskRunner *taskrunner.Runner - bmcService BmcService + tr *taskrunner.Runner + bmcService BmcService + taskService TaskService + machineService MachineService ) func TestMain(m *testing.M) { @@ -32,7 +34,7 @@ func TestMain(m *testing.M) { } func setup() { - ctx = context.Background() + ctx := context.Background() f := freecache.NewStore(freecache.DefaultOptions) s := gokv.Store(f) repo := &persistence.GoKV{ @@ -40,14 +42,18 @@ func setup() { Ctx: ctx, } - taskRunner = &taskrunner.Runner{ - Repository: repo, - Ctx: ctx, - } + tr = taskrunner.NewRunner(repo, 100, time.Second) + tr.Start(ctx) bmcService = BmcService{ - TaskRunner: taskRunner, + TaskRunner: tr, UnimplementedBMCServer: v1.UnimplementedBMCServer{}, } + taskService = TaskService{ + TaskRunner: tr, + } + machineService = MachineService{ + TaskRunner: tr, + } _, err := exec.LookPath("ipmitool") if err != nil { err := os.WriteFile(tempIPMITool, []byte{}, 0o777) @@ -98,7 +104,7 @@ func TestConfigNetworkSource(t *testing.T) { for _, tc := range testCases { testCase := tc t.Run(testCase.name, func(t *testing.T) { - response, err := bmcService.NetworkSource(ctx, testCase.req) + response, err := bmcService.NetworkSource(context.Background(), testCase.req) if response != nil { t.Fatalf("response should be nil, got: %v", response) } @@ -155,7 +161,7 @@ func TestReset(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - response, err := bmcService.Reset(ctx, tc.in) + response, err := bmcService.Reset(context.Background(), tc.in) if err != nil { diff := cmp.Diff(tc.expectedErr.Error(), err.Error()) if diff != "" { @@ -218,7 +224,7 @@ func TestCreateUser(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - response, err := bmcService.CreateUser(ctx, tc.in) + response, err := bmcService.CreateUser(context.Background(), tc.in) if err != nil { diff := cmp.Diff(tc.expectedErr.Error(), err.Error()) if diff != "" { @@ -281,7 +287,7 @@ func TestUpdateUser(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - response, err := bmcService.UpdateUser(ctx, tc.in) + response, err := bmcService.UpdateUser(context.Background(), tc.in) if err != nil { diff := cmp.Diff(tc.expectedErr.Error(), err.Error()) if diff != "" { @@ -340,7 +346,7 @@ func TestDeleteUser(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - response, err := bmcService.DeleteUser(ctx, tc.in) + response, err := bmcService.DeleteUser(context.Background(), tc.in) if err != nil { diff := cmp.Diff(tc.expectedErr.Error(), err.Error()) if diff != "" { diff --git a/grpc/rpc/machine.go b/grpc/rpc/machine.go index ac475ce..69c1afa 100644 --- a/grpc/rpc/machine.go +++ b/grpc/rpc/machine.go @@ -31,7 +31,7 @@ func (m *MachineService) BootDevice(ctx context.Context, in *v1.DeviceRequest) ( l.Info( "start BootDevice request", - "username", in.Authn.GetDirectAuthn().GetUsername(), + "username", in.GetAuthn().GetDirectAuthn().GetUsername(), "vendor", in.Vendor.GetName(), "bootDevice", in.BootDevice.String(), "persistent", in.Persistent, @@ -54,7 +54,7 @@ func (m *MachineService) BootDevice(ctx context.Context, in *v1.DeviceRequest) ( defer cancel() return mbd.BootDeviceSet(taskCtx, in.BootDevice.String(), in.Persistent, in.EfiBoot) } - m.TaskRunner.Execute(ctx, l, "setting boot device", taskID, execFunc) + m.TaskRunner.Execute(ctx, l, "setting boot device", taskID, in.GetAuthn().GetDirectAuthn().GetHost().GetHost(), execFunc) return &v1.DeviceResponse{TaskId: taskID}, nil } @@ -63,10 +63,10 @@ func (m *MachineService) BootDevice(ctx context.Context, in *v1.DeviceRequest) ( func (m *MachineService) Power(ctx context.Context, in *v1.PowerRequest) (*v1.PowerResponse, error) { l := logging.ExtractLogr(ctx) taskID := xid.New().String() - l = l.WithValues("taskID", taskID, "bmcIP", in.Authn.GetDirectAuthn().GetHost().GetHost()) + l = l.WithValues("taskID", taskID, "bmcIP", in.GetAuthn().GetDirectAuthn().GetHost().GetHost()) l.Info( "start Power request", - "username", in.Authn.GetDirectAuthn().GetUsername(), + "username", in.GetAuthn().GetDirectAuthn().GetUsername(), "vendor", in.Vendor.GetName(), "powerAction", in.GetPowerAction().String(), "softTimeout", in.SoftTimeout, @@ -89,7 +89,7 @@ func (m *MachineService) Power(ctx context.Context, in *v1.PowerRequest) (*v1.Po defer cancel() return mp.PowerSet(taskCtx, in.PowerAction.String()) } - m.TaskRunner.Execute(ctx, l, "power action: "+in.GetPowerAction().String(), taskID, execFunc) + m.TaskRunner.Execute(ctx, l, "power action: "+in.GetPowerAction().String(), taskID, in.GetAuthn().GetDirectAuthn().GetHost().GetHost(), execFunc) return &v1.PowerResponse{TaskId: taskID}, nil } diff --git a/grpc/rpc/machine_test.go b/grpc/rpc/machine_test.go index 0fe8fe5..3c4104d 100644 --- a/grpc/rpc/machine_test.go +++ b/grpc/rpc/machine_test.go @@ -7,11 +7,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/onsi/gomega" - "github.com/philippgille/gokv" - "github.com/philippgille/gokv/freecache" v1 "github.com/tinkerbell/pbnj/api/v1" - "github.com/tinkerbell/pbnj/grpc/persistence" - "github.com/tinkerbell/pbnj/grpc/taskrunner" ) func TestDevice(t *testing.T) { @@ -58,22 +54,7 @@ func TestDevice(t *testing.T) { g := gomega.NewGomegaWithT(t) ctx := context.Background() - - f := freecache.NewStore(freecache.DefaultOptions) - s := gokv.Store(f) - repo := &persistence.GoKV{ - Store: s, - Ctx: ctx, - } - - taskRunner := &taskrunner.Runner{ - Repository: repo, - Ctx: ctx, - } - machineSvc := MachineService{ - TaskRunner: taskRunner, - } - response, err := machineSvc.BootDevice(ctx, testCase.req) + response, err := machineService.BootDevice(ctx, testCase.req) t.Log("Got : ", response) if err != nil { @@ -157,22 +138,7 @@ func TestPower(t *testing.T) { g := gomega.NewGomegaWithT(t) ctx := context.Background() - - f := freecache.NewStore(freecache.DefaultOptions) - s := gokv.Store(f) - repo := &persistence.GoKV{ - Store: s, - Ctx: ctx, - } - - taskRunner := &taskrunner.Runner{ - Repository: repo, - Ctx: ctx, - } - machineSvc := MachineService{ - TaskRunner: taskRunner, - } - response, err := machineSvc.Power(ctx, testCase.req) + response, err := machineService.Power(ctx, testCase.req) t.Log("Got response: ", response) t.Log("Got err: ", err) diff --git a/grpc/rpc/task_test.go b/grpc/rpc/task_test.go index d138415..f9b4faa 100644 --- a/grpc/rpc/task_test.go +++ b/grpc/rpc/task_test.go @@ -5,51 +5,40 @@ import ( "testing" "time" - "github.com/go-logr/logr" "github.com/onsi/gomega" - "github.com/philippgille/gokv" - "github.com/philippgille/gokv/freecache" - "github.com/rs/xid" v1 "github.com/tinkerbell/pbnj/api/v1" - "github.com/tinkerbell/pbnj/grpc/persistence" - "github.com/tinkerbell/pbnj/grpc/taskrunner" - "github.com/tinkerbell/pbnj/pkg/repository" ) func TestTaskFound(t *testing.T) { - // create a task - ctx := context.Background() - defaultError := &repository.Error{ - Code: 0, - Message: "", - Details: nil, - } - logger := logr.Discard() - f := freecache.NewStore(freecache.DefaultOptions) - s := gokv.Store(f) - repo := &persistence.GoKV{Store: s, Ctx: ctx} - - taskRunner := &taskrunner.Runner{ - Repository: repo, - Ctx: ctx, - } - taskID := xid.New().String() - taskRunner.Execute(ctx, logger, "test", taskID, func(s chan string) (string, error) { - return "doing cool stuff", defaultError - }) - - taskReq := &v1.StatusRequest{TaskId: taskID} - - taskSvc := TaskService{ - TaskRunner: taskRunner, + pr := &v1.PowerRequest{ + Authn: &v1.Authn{ + Authn: &v1.Authn_DirectAuthn{ + DirectAuthn: &v1.DirectAuthn{ + Host: &v1.Host{ + Host: "10.1.1.1", + }, + Username: "admin", + Password: "admin", + }, + }, + }, + Vendor: &v1.Vendor{ + Name: "", + }, + PowerAction: v1.PowerAction_POWER_ACTION_STATUS, + SoftTimeout: 0, + OffDuration: 0, } - - time.Sleep(10 * time.Millisecond) - taskResp, err := taskSvc.Status(ctx, taskReq) + resp, err := machineService.Power(context.Background(), pr) if err != nil { - t.Fatal(err) + t.Fatalf("expected no error, got: %v", err) } - if taskResp.Id != taskID { + + time.Sleep(time.Second * 3) + taskReq := &v1.StatusRequest{TaskId: resp.TaskId} + taskResp, _ := taskService.Status(context.Background(), taskReq) + t.Logf("Got response: %+v", taskResp) + if taskResp.Id != resp.TaskId { t.Fatalf("got: %+v", taskResp) } } @@ -76,19 +65,7 @@ func TestRecordNotFound(t *testing.T) { g := gomega.NewGomegaWithT(t) ctx := context.Background() - - f := freecache.NewStore(freecache.DefaultOptions) - s := gokv.Store(f) - repo := &persistence.GoKV{Store: s, Ctx: ctx} - - taskRunner := &taskrunner.Runner{ - Repository: repo, - Ctx: ctx, - } - taskSvc := TaskService{ - TaskRunner: taskRunner, - } - response, err := taskSvc.Status(ctx, testCase.req) + response, err := taskService.Status(ctx, testCase.req) t.Log("Got response: ", response) t.Log("Got err: ", err) diff --git a/grpc/server.go b/grpc/server.go index d512f0a..82fc72d 100644 --- a/grpc/server.go +++ b/grpc/server.go @@ -21,6 +21,7 @@ import ( "github.com/tinkerbell/pbnj/pkg/repository" "google.golang.org/grpc" "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/reflection" ) // Server options. @@ -35,6 +36,10 @@ type Server struct { // // for more information see https://github.com/bmc-toolbox/bmclib#bmc-connections skipRedfishVersions []string + // maxWorkers is the maximum number of concurrent workers that will be allowed to handle bmc tasks. + maxWorkers int + // workerIdleTimeout is the idle timeout for workers. If no tasks are received within the timeout, the worker will exit. + workerIdleTimeout time.Duration } // ServerOption for setting optional values. @@ -55,6 +60,18 @@ func WithSkipRedfishVersions(versions []string) ServerOption { return func(args *Server) { args.skipRedfishVersions = versions } } +// WithMaxWorkers sets the max number of of concurrent workers that handle bmc tasks.. +func WithMaxWorkers(max int) ServerOption { + return func(args *Server) { args.maxWorkers = max } +} + +// WithWorkerIdleTimeout sets the idle timeout for workers. +// If no tasks are received within the timeout, the worker will exit. +// New tasks will spawn a new worker if there isn't a worker running. +func WithWorkerIdleTimeout(t time.Duration) ServerOption { + return func(args *Server) { args.workerIdleTimeout = t } +} + // RunServer registers all services and runs the server. func RunServer(ctx context.Context, log logr.Logger, grpcServer *grpc.Server, port string, httpServer *http.Server, opts ...ServerOption) error { ctx, cancel := context.WithCancel(ctx) @@ -69,27 +86,27 @@ func RunServer(ctx context.Context, log logr.Logger, grpcServer *grpc.Server, po } defaultServer := &Server{ - Actions: repo, - bmcTimeout: oob.DefaultBMCTimeout, + Actions: repo, + bmcTimeout: oob.DefaultBMCTimeout, + maxWorkers: 1000, + workerIdleTimeout: time.Second * 30, } for _, opt := range opts { opt(defaultServer) } - taskRunner := &taskrunner.Runner{ - Repository: defaultServer.Actions, - Ctx: ctx, - } + tr := taskrunner.NewRunner(repo, defaultServer.maxWorkers, defaultServer.workerIdleTimeout) + tr.Start(ctx) ms := rpc.MachineService{ - TaskRunner: taskRunner, + TaskRunner: tr, Timeout: defaultServer.bmcTimeout, } v1.RegisterMachineServer(grpcServer, &ms) bs := rpc.BmcService{ - TaskRunner: taskRunner, + TaskRunner: tr, Timeout: defaultServer.bmcTimeout, SkipRedfishVersions: defaultServer.skipRedfishVersions, } @@ -99,7 +116,7 @@ func RunServer(ctx context.Context, log logr.Logger, grpcServer *grpc.Server, po v1.RegisterDiagnosticServer(grpcServer, &ds) ts := rpc.TaskService{ - TaskRunner: taskRunner, + TaskRunner: tr, } v1.RegisterTaskServer(grpcServer, &ts) @@ -113,10 +130,11 @@ func RunServer(ctx context.Context, log logr.Logger, grpcServer *grpc.Server, po return err } - httpServer.WithTaskRunner(taskRunner) + httpServer.WithTaskRunner(tr) + reflection.Register(grpcServer) go func() { - err := httpServer.Run() + err := httpServer.Run(ctx) if err != nil { log.Error(err, "failed to serve http") os.Exit(1) //nolint:revive // removing deep-exit requires a significant refactor diff --git a/grpc/taskrunner/manager.go b/grpc/taskrunner/manager.go new file mode 100644 index 0000000..ba6df9b --- /dev/null +++ b/grpc/taskrunner/manager.go @@ -0,0 +1,122 @@ +package taskrunner + +// copied and modified from https://github.com/zenthangplus/goccm + +import "sync/atomic" + +type concurrencyManager struct { + // The number of goroutines that are allowed to run concurrently + max int + + // The manager channel to coordinate the number of concurrent goroutines. + managerCh chan interface{} + + // The done channel indicates when a single goroutine has finished its job. + doneCh chan bool + + // This channel indicates when all goroutines have finished their job. + allDoneCh chan bool + + // The closed channel is closed which controller should close + closed chan bool + + // The running count allows we know the number of goroutines are running + runningCount atomic.Int32 +} + +// newManager concurrencyManager. +func newManager(maxGoRoutines int) *concurrencyManager { + // Initiate the manager object + c := concurrencyManager{ + max: maxGoRoutines, + managerCh: make(chan interface{}, maxGoRoutines), + doneCh: make(chan bool), + allDoneCh: make(chan bool), + closed: make(chan bool), + } + + // Fill the manager channel by placeholder values + for i := 0; i < c.max; i++ { + c.managerCh <- nil + } + + // Start the controller to collect all the jobs + go c.controller() + + return &c +} + +// Create the controller to collect all the jobs. +// When a goroutine is finished, we can release a slot for another goroutine. +func (c *concurrencyManager) controller() { + for { + // This will block until a goroutine is finished + <-c.doneCh + + // Say that another goroutine can now start + c.managerCh <- nil + + // When the closed flag is set, + // we need to close the manager if it doesn't have any running goroutine + if c.isClosed() && c.RunningCount() == 0 { + break + } + } + + // Say that all goroutines are finished, we can close the manager + c.allDoneCh <- true +} + +// Wait until a slot is available for the new goroutine. +// A goroutine have to start after this function. +func (c *concurrencyManager) Wait() { + // Try to receive from the manager channel. When we have something, + // it means a slot is available and we can start a new goroutine. + // Otherwise, it will block until a slot is available. + <-c.managerCh + + // Increase the running count to help we know how many goroutines are running. + c.runningCount.Add(1) +} + +// Done Mark a goroutine as finished. +func (c *concurrencyManager) Done() { + // Decrease the number of running count + c.runningCount.Add(-1) + c.doneCh <- true +} + +// Close the manager manually +// terminate if channel is already closed. +func (c *concurrencyManager) Close() { + // terminate if channel is already closed + select { + case <-c.closed: + return + default: + close(c.closed) + } +} + +func (c *concurrencyManager) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} + +// WaitAllDone Wait for all goroutines are done. +func (c *concurrencyManager) WaitAllDone() { + // Close the manager automatic + c.Close() + + // This will block until allDoneCh was marked + <-c.allDoneCh +} + +// RunningCount Returns the number of goroutines which are running. +func (c *concurrencyManager) RunningCount() int32 { + return c.runningCount.Load() +} diff --git a/grpc/taskrunner/run.go b/grpc/taskrunner/run.go new file mode 100644 index 0000000..38abded --- /dev/null +++ b/grpc/taskrunner/run.go @@ -0,0 +1,102 @@ +package taskrunner + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/tinkerbell/pbnj/pkg/metrics" +) + +type orchestrator struct { + workers sync.Map + manager *concurrencyManager + workerIdleTimeout time.Duration + fifoChan chan string + // perHostChan is a map of hostID to a channel of tasks. + perHostChan sync.Map + ingestChan chan Task +} + +// ingest take a task off the ingestion queue and puts it on the perID queue +// and adds the host ID to the fcfs queue. +func (r *Runner) ingest(ctx context.Context) { + // dequeue from ingestion queue + // enqueue to perID queue + // enqueue to fcfs queue + for { + select { + case <-ctx.Done(): + return + case t := <-r.orchestrator.ingestChan: + + // 2. enqueue to perID queue + ch := make(chan Task, 10) + q, exists := r.orchestrator.perHostChan.LoadOrStore(t.Host, ch) + v, ok := q.(chan Task) + if !ok { + fmt.Println("bad type: IngestQueue") + return + } + if exists { + close(ch) + } + v <- t + metrics.PerIDQueue.WithLabelValues(t.Host).Inc() + metrics.IngestionQueue.Dec() + metrics.NumPerIDEnqueued.Inc() + r.orchestrator.workers.Store(t.Host, false) + } + } +} + +// 1. dequeue from fcfs queue +// 2. dequeue from perID queue +// 3. if worker id exists, send task to worker. else continue. +// 4. if maxWorkers is reached, wait for available worker. else create worker and send task to worker. +func (r *Runner) orchestrate(ctx context.Context) { + // 1. dequeue from fcfs queue + // 2. start workers + for { + // time.Sleep(time.Second * 3) - this potential helps with ingestion + r.orchestrator.workers.Range(func(key, value interface{}) bool { + // if worker id exists in o.workers, then move on because the worker is already running. + if value.(bool) { //nolint: forcetypeassert // values are always certain. + return true + } + + // wait for a worker to become available + r.orchestrator.manager.Wait() + + r.orchestrator.workers.Store(key.(string), true) //nolint: forcetypeassert // values are always certain. + v, found := r.orchestrator.perHostChan.Load(key.(string)) + if !found { + return false + } + go r.worker(ctx, key.(string), v.(chan Task)) //nolint: forcetypeassert // values are always certain. + return true + }) + } +} + +func (r *Runner) worker(ctx context.Context, hostID string, q chan Task) { + defer r.orchestrator.manager.Done() + defer r.orchestrator.workers.Delete(hostID) + + for { + select { + case <-ctx.Done(): + // TODO: check queue length before returning maybe? + // For 175000 tasks, i found there would occasionally be 1 or 2 that didnt get processed. + // still seemed to be in the queue/chan. + return + case t := <-q: + r.process(ctx, t.Log, t.Description, t.ID, t.Action) + metrics.PerIDQueue.WithLabelValues(hostID).Dec() + case <-time.After(r.orchestrator.workerIdleTimeout): + // TODO: check queue length returning maybe? + return + } + } +} diff --git a/grpc/taskrunner/taskrunner.go b/grpc/taskrunner/taskrunner.go index b14d347..fce4975 100644 --- a/grpc/taskrunner/taskrunner.go +++ b/grpc/taskrunner/taskrunner.go @@ -5,11 +5,11 @@ import ( "net" "net/url" "sync" + "sync/atomic" "syscall" "time" "github.com/go-logr/logr" - "github.com/hashicorp/go-multierror" "github.com/pkg/errors" "github.com/tinkerbell/pbnj/pkg/metrics" @@ -18,57 +18,127 @@ import ( // Runner for executing a task. type Runner struct { - Repository repository.Actions - Ctx context.Context - active int - total int - counterMu sync.RWMutex + Repository repository.Actions + active atomic.Int32 + total atomic.Int32 + orchestrator *orchestrator +} + +type Task struct { + ID string `json:"id"` + Host string `json:"host"` + Description string `json:"description"` + Action func(chan string) (string, error) `json:"-"` + Log logr.Logger `json:"-"` +} + +// NewRunner returns a task runner that manages tasks, workers, queues, and persistence. +// +// maxIngestionWorkers is the maximum number of concurrent workers that will be allowed. +// These are the workers that handle ingesting tasks from RPC endpoints and writing them to the map of per Host ID queues. +// +// maxWorkers is the maximum number of concurrent workers that will be allowed to handle bmc tasks. +// +// workerIdleTimeout is the idle timeout for workers. If no tasks are received within the timeout, the worker will exit. +func NewRunner(repo repository.Actions, maxWorkers int, workerIdleTimeout time.Duration) *Runner { + o := &orchestrator{ + workers: sync.Map{}, + fifoChan: make(chan string, 10000), + // perHostChan is a map of hostID to a channel of tasks. + perHostChan: sync.Map{}, + manager: newManager(maxWorkers), + workerIdleTimeout: workerIdleTimeout, + ingestChan: make(chan Task, 10000), + } + + return &Runner{ + Repository: repo, + orchestrator: o, + } } // ActiveWorkers returns a count of currently active worker jobs. func (r *Runner) ActiveWorkers() int { - r.counterMu.RLock() - defer r.counterMu.RUnlock() - return r.active + return int(r.active.Load()) } // TotalWorkers returns a count total workers executed. func (r *Runner) TotalWorkers() int { - r.counterMu.RLock() - defer r.counterMu.RUnlock() - return r.total + return int(r.total.Load()) +} + +func (r *Runner) Start(ctx context.Context) { + go func() { + ticker := time.NewTicker(3 * time.Second) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + metrics.NumWorkers.Set(float64(r.orchestrator.manager.RunningCount())) + var size int + r.orchestrator.workers.Range(func(key, value interface{}) bool { + size++ + return true + }) + metrics.WorkerMap.Set(float64(size)) + } + } + }() + go r.ingest(ctx) + go r.orchestrate(ctx) } // Execute a task, update repository with status. -func (r *Runner) Execute(ctx context.Context, l logr.Logger, description, taskID string, action func(chan string) (string, error)) { - go r.worker(ctx, l, description, taskID, action) +func (r *Runner) Execute(_ context.Context, l logr.Logger, description, taskID, hostID string, action func(chan string) (string, error)) { + i := Task{ + ID: taskID, + Host: hostID, + Description: description, + Action: action, + Log: l, + } + + r.orchestrator.ingestChan <- i + metrics.IngestionQueue.Inc() + metrics.Ingested.Inc() +} + +func (r *Runner) updateMessages(ctx context.Context, taskID string, ch chan string) { + for { + select { + case <-ctx.Done(): + return + case msg := <-ch: + record, err := r.Repository.Get(taskID) + if err != nil { + return + } + record.Messages = append(record.Messages, msg) + if err := r.Repository.Update(taskID, record); err != nil { + return + } + } + } } // does the work, updates the repo record // TODO handle retrys, use a timeout. -func (r *Runner) worker(_ context.Context, logger logr.Logger, description, taskID string, action func(chan string) (string, error)) { +func (r *Runner) process(ctx context.Context, logger logr.Logger, description, taskID string, action func(chan string) (string, error)) { logger = logger.WithValues("taskID", taskID, "description", description) - r.counterMu.Lock() - r.active++ - r.total++ - r.counterMu.Unlock() + r.active.Add(1) + r.total.Add(1) defer func() { - r.counterMu.Lock() - r.active-- - r.counterMu.Unlock() + r.active.Add(-1) }() - metrics.TasksTotal.Inc() + defer metrics.TasksTotal.Inc() + defer metrics.TotalGauge.Inc() metrics.TasksActive.Inc() defer metrics.TasksActive.Dec() messagesChan := make(chan string) - actionACK := make(chan bool, 1) - actionSyn := make(chan bool, 1) defer close(messagesChan) - defer close(actionACK) - defer close(actionSyn) - repo := r.Repository sessionRecord := repository.Record{ ID: taskID, Description: description, @@ -80,59 +150,49 @@ func (r *Runner) worker(_ context.Context, logger logr.Logger, description, task Details: nil, }, } - - err := repo.Create(taskID, sessionRecord) + err := r.Repository.Create(taskID, sessionRecord) if err != nil { - // TODO how to handle unable to create record; ie network error, persistence error, etc? - logger.Error(err, "task complete", "complete", true) return } + cctx, done := context.WithCancel(ctx) + defer done() + go r.updateMessages(cctx, taskID, messagesChan) - go func() { - for { - select { - case msg := <-messagesChan: - currStatus, _ := repo.Get(taskID) - sessionRecord.Messages = append(currStatus.Messages, msg) //nolint:gocritic // apparently this is the right slice - _ = repo.Update(taskID, sessionRecord) - case <-actionSyn: - actionACK <- true - return - default: - } - time.Sleep(10 * time.Millisecond) - } - }() - - sessionRecord.Result, err = action(messagesChan) - actionSyn <- true - <-actionACK - sessionRecord.State = "complete" - sessionRecord.Complete = true - var finalErr error + resultRecord := repository.Record{ + State: "complete", + Complete: true, + Error: &repository.Error{ + Code: 0, + Message: "", + Details: nil, + }, + } + result, err := action(messagesChan) if err != nil { - finalErr = multierror.Append(finalErr, err) - sessionRecord.Result = "action failed" + resultRecord.Result = "action failed" re, ok := err.(*repository.Error) if ok { - sessionRecord.Error = re.StructuredError() + resultRecord.Error = re.StructuredError() } else { - sessionRecord.Error.Message = err.Error() + resultRecord.Error.Message = err.Error() } var foundErr *repository.Error if errors.As(err, &foundErr) { - sessionRecord.Error = foundErr.StructuredError() + resultRecord.Error = foundErr.StructuredError() } } - // TODO handle unable to update record; ie network error, persistence error, etc - if err := repo.Update(taskID, sessionRecord); err != nil { - finalErr = multierror.Append(finalErr, err) + record, err := r.Repository.Get(taskID) + if err != nil { + return } - if finalErr != nil { - logger.Error(finalErr, "task complete", "complete", true) - } else { - logger.Info("task complete", "complete", true) + record.Complete = resultRecord.Complete + record.State = resultRecord.State + record.Result = result + record.Error = resultRecord.Error + + if err := r.Repository.Update(taskID, record); err != nil { + logger.Error(err, "failed to update record") } } diff --git a/grpc/taskrunner/taskrunner_test.go b/grpc/taskrunner/taskrunner_test.go index 9b44330..e299db8 100644 --- a/grpc/taskrunner/taskrunner_test.go +++ b/grpc/taskrunner/taskrunner_test.go @@ -26,22 +26,20 @@ func TestRoundTrip(t *testing.T) { defer s.Close() repo := &persistence.GoKV{Store: s, Ctx: ctx} logger := logr.Discard() - runner := Runner{ - Repository: repo, - Ctx: ctx, - } + runner := NewRunner(repo, 100, time.Second) + runner.Start(ctx) + time.Sleep(time.Millisecond * 100) taskID := xid.New().String() - runner.Execute(ctx, logger, description, taskID, func(s chan string) (string, error) { - return "didnt do anything", defaultError - }) - if len(taskID) != 20 { t.Fatalf("expected id of length 20, got: %v (%v)", len(taskID), taskID) } + runner.Execute(ctx, logger, description, taskID, "123", func(s chan string) (string, error) { + return "didnt do anything", defaultError + }) // must be min of 3 because we sleep 2 seconds in worker function to allow final status messages to be written - time.Sleep(500 * time.Millisecond) + time.Sleep(time.Second * 2) record, err := runner.Status(ctx, taskID) if err != nil { t.Fatal(err) diff --git a/pkg/http/http.go b/pkg/http/http.go index f43ea76..c2f70e4 100644 --- a/pkg/http/http.go +++ b/pkg/http/http.go @@ -1,7 +1,9 @@ package http import ( + "context" "net/http" + "time" "github.com/go-logr/logr" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -33,8 +35,26 @@ func (h *Server) init() { h.mux.HandleFunc("/_/live", h.handleLive) } -func (h *Server) Run() error { - return http.ListenAndServe(h.address, h.mux) +func (h *Server) Run(ctx context.Context) error { + svr := &http.Server{ + Addr: h.address, + Handler: h.mux, + // Mitigate Slowloris attacks. 20 seconds is based on Apache's recommended 20-40 + // recommendation. Hegel doesn't really have many headers so 20s should be plenty of time. + // https://en.wikipedia.org/wiki/Slowloris_(computer_security) + ReadHeaderTimeout: 20 * time.Second, + } + + go func() { + <-ctx.Done() + _ = svr.Shutdown(ctx) + }() + + if err := svr.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return err + } + + return nil } func NewServer(addr string) *Server { diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 8f403e1..76aa4e8 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -6,9 +6,17 @@ import ( ) var ( - ActionDuration prometheus.ObserverVec - TasksTotal prometheus.Counter - TasksActive prometheus.Gauge + ActionDuration prometheus.ObserverVec + TasksTotal prometheus.Counter + TotalGauge prometheus.Gauge + TasksActive prometheus.Gauge + PerIDQueue prometheus.GaugeVec + IngestionQueue prometheus.Gauge + Ingested prometheus.Gauge + FIFOQueue prometheus.Gauge + NumWorkers prometheus.Gauge + NumPerIDEnqueued prometheus.Gauge + WorkerMap prometheus.Gauge ) func init() { @@ -29,13 +37,45 @@ func init() { initObserverLabels(ActionDuration, labelValues) TasksTotal = promauto.NewCounter(prometheus.CounterOpts{ - Name: "pbnj_tasks_total", + Name: "pbnj_tasks_processes", Help: "Total number of tasks executed.", }) TasksActive = promauto.NewGauge(prometheus.GaugeOpts{ Name: "pbnj_tasks_active", Help: "Number of tasks currently active.", }) + PerIDQueue = *promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "pbnj_per_id_queue", + Help: "Number of tasks in perID queue.", + }, []string{"host"}) + IngestionQueue = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "pbnj_ingestion_queue", + Help: "Number of tasks in ingestion queue.", + }) + Ingested = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "pbnj_ingested", + Help: "Number of tasks ingested.", + }) + FIFOQueue = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "pbnj_fifo_queue", + Help: "Number of tasks in FIFO queue.", + }) + TotalGauge = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "pbnj_total", + Help: "Total number of tasks.", + }) + NumWorkers = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "pbnj_num_workers", + Help: "Number of workers.", + }) + NumPerIDEnqueued = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "pbnj_num_per_id_enqueued", + Help: "Number of perID enqueued.", + }) + WorkerMap = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "pbnj_worker_map_size", + Help: "Worker map size.", + }) } func initObserverLabels(m prometheus.ObserverVec, l []prometheus.Labels) { diff --git a/pkg/task/task.go b/pkg/task/task.go index bcd63f3..8847b95 100644 --- a/pkg/task/task.go +++ b/pkg/task/task.go @@ -9,6 +9,6 @@ import ( // Task interface for doing BMC actions. type Task interface { - Execute(ctx context.Context, l logr.Logger, description, taskID string, action func(chan string) (string, error)) + Execute(ctx context.Context, l logr.Logger, description, taskID, host string, action func(chan string) (string, error)) Status(ctx context.Context, taskID string) (record repository.Record, err error) }