Skip to content

Commit

Permalink
chore: add linter for correct usage of CommitOrRollback (#2306)
Browse files Browse the repository at this point in the history
  • Loading branch information
alecthomas authored Aug 9, 2024
1 parent bdc2dae commit 1d4bf29
Show file tree
Hide file tree
Showing 12 changed files with 263 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ jobs:
uses: ./.github/actions/build-cache
- name: golangci-lint
run: golangci-lint run --new-from-rev="$(git merge-base origin/main HEAD)" --out-format github-actions ./...
- name: lint-commit-or-rollback
run: lint-commit-or-rollback ./backend/... 2>&1 | to-annotation
- name: go-check-sumtype
shell: bash
run: go-check-sumtype ./... 2>&1 | to-annotation
Expand Down
1 change: 1 addition & 0 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ lint-frontend: build-frontend
# Lint the backend
lint-backend:
@golangci-lint run --new-from-rev=$(git merge-base origin/main HEAD) ./...
@lint-commit-or-rollback ./backend/...

lint-scripts:
@shellcheck -f gcc -e SC2016 $(find scripts -type f -not -path scripts/tests) | to-annotation
Expand Down
2 changes: 1 addition & 1 deletion backend/controller/dal/dal.go
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ func (p *postgresClaim) Rollback(ctx context.Context) error {
func (p *postgresClaim) Runner() Runner { return p.runner }

// SetDeploymentReplicas activates the given deployment.
func (d *DAL) SetDeploymentReplicas(ctx context.Context, key model.DeploymentKey, minReplicas int) error {
func (d *DAL) SetDeploymentReplicas(ctx context.Context, key model.DeploymentKey, minReplicas int) (err error) {
// Start the transaction
tx, err := d.db.Begin(ctx)
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions cmd/lint-commit-or-rollback/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module github.com/tbdeng/pfi/cmd/lint-commit-or-rollback

go 1.22.0

require golang.org/x/tools v0.18.0

require (
github.com/alecthomas/repr v0.4.0 // indirect
github.com/hexops/gotextdiff v1.0.3 // indirect
)

require (
github.com/alecthomas/assert/v2 v2.6.0
golang.org/x/mod v0.15.0 // indirect
)
12 changes: 12 additions & 0 deletions cmd/lint-commit-or-rollback/go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

94 changes: 94 additions & 0 deletions cmd/lint-commit-or-rollback/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package main

import (
"go/ast"
"go/token"

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/singlechecker"
)

var Analyzer = &analysis.Analyzer{
Name: "commitorrollback",
Doc: "Detects misues of dal.TX.CommitOrRollback",
Run: run,
}

// Detect that any use of dal.TX.CommitOrRollback is in a defer statement and
// takes a reference to a named error return parameter.
//
// ie. Must be in the following form
//
// func myFunc() (err error) {
// // ...
// defer tx.CommitOrRollback(&err)
// }
func run(pass *analysis.Pass) (interface{}, error) {
var inspect func(n ast.Node) bool
funcStack := []*ast.FuncType{}
inspect = func(n ast.Node) bool {
switch n := n.(type) {
case nil:
return false

case *ast.FuncLit:
funcStack = append(funcStack, n.Type)
ast.Inspect(n.Body, inspect)
funcStack = funcStack[:len(funcStack)-1]
return false

case *ast.FuncDecl:
funcStack = append(funcStack, n.Type)
ast.Inspect(n.Body, inspect)
funcStack = funcStack[:len(funcStack)-1]
return false

case *ast.CallExpr:
sel, ok := n.Fun.(*ast.SelectorExpr)
if !ok {
return true
}
x, ok := sel.X.(*ast.Ident)
if !ok || x.Name != "tx" || sel.Sel.Name != "CommitOrRollback" || len(n.Args) != 2 {
return true
}
arg0, ok := n.Args[1].(*ast.UnaryExpr)
if !ok || arg0.Op != token.AND {
return true
}
arg0Ident, ok := arg0.X.(*ast.Ident)
if !ok {
return true
}
funcDecl := funcStack[len(funcStack)-1]
funcPos := pass.Fset.Position(funcDecl.Func)
if funcDecl.Results == nil {
pass.Reportf(arg0.Pos(), "defer tx.CommitOrRollback(ctx, &err) should be deferred with a named error return parameter but the function at %s has no named return parameters", funcPos)
return true
}
for _, field := range funcDecl.Results.List {
if result, ok := field.Type.(*ast.Ident); ok && result.Name == "error" {
if len(field.Names) == 0 {
pass.Reportf(arg0.Pos(), "defer tx.CommitOrRollback(ctx, &err) should be deferred with a reference to a named error return parameter, but the function at %s has no named return parameters", funcPos)
}
for _, name := range field.Names {
if name.Name != arg0Ident.Name {
namePos := pass.Fset.Position(name.NamePos)
pass.Reportf(arg0.Pos(), "defer tx.CommitOrRollback(&ctx, %s) should be deferred with the named error return parameter here %s", arg0Ident.Name, namePos)
}
}
}
}
}
return true
}
for _, file := range pass.Files {
funcStack = []*ast.FuncType{}
ast.Inspect(file, inspect)
}
return nil, nil
}

func main() {
singlechecker.Main(Analyzer)
}
25 changes: 25 additions & 0 deletions cmd/lint-commit-or-rollback/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package main

import (
"os"
"os/exec"
"strings"
"testing"

"github.com/alecthomas/assert/v2"
)

func TestLinter(t *testing.T) {
pwd, err := os.Getwd()
assert.NoError(t, err)
cmd := exec.Command("lint-commit-or-rollback", ".")
cmd.Dir = "testdata"
output, err := cmd.CombinedOutput()
assert.Error(t, err)
expected := `
` + pwd + `/testdata/main.go:35:29: defer tx.CommitOrRollback(&err) should be deferred with a reference to a named error return parameter, but the function at /Users/alec/dev/pfi/cmd/lint-commit-or-rollback/testdata/main.go:29:6 has no named return parameters
` + pwd + `/testdata/main.go:44:28: defer tx.CommitOrRollback(&err) should be deferred with a reference to a named error return parameter, but the function at /Users/alec/dev/pfi/cmd/lint-commit-or-rollback/testdata/main.go:28:1 has no named return parameters
` + pwd + `/testdata/main.go:55:29: defer tx.CommitOrRollback(&err) should be deferred with a reference to a named error return parameter, but the function at /Users/alec/dev/pfi/cmd/lint-commit-or-rollback/testdata/main.go:49:6 has no named return parameters
`
assert.Equal(t, strings.TrimSpace(expected), strings.TrimSpace(string(output)))
}
7 changes: 7 additions & 0 deletions cmd/lint-commit-or-rollback/testdata/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module main

go 1.22.2

replace pfi/backend => ../../../backend

require pfi/backend v0.0.0-00010101000000-000000000000
34 changes: 34 additions & 0 deletions cmd/lint-commit-or-rollback/testdata/go.sum

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 70 additions & 0 deletions cmd/lint-commit-or-rollback/testdata/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package main

import (
"context"
"database/sql"

libdal "pfi/backend/libs/dal"
)

type Tx struct {
*libdal.Tx[Tx]
*DAL
}

// DAL is the data access layer for the IDV module.
type DAL struct {
db libdal.Conn
*libdal.DAL[Tx]
}

// NewDAL creates a new Data Access Layer instance.
func NewDAL(conn *sql.DB) *DAL {
return &DAL{db: conn, DAL: libdal.New(conn, func(tx *sql.Tx, t *libdal.Tx[Tx]) *Tx {
return &Tx{DAL: &DAL{db: tx, DAL: t.DAL}, Tx: t}
})}
}

func failure() error {
_ = func() error {
dal := DAL{}
tx, err := dal.Begin(context.Background())
if err != nil {
return err
}
defer tx.CommitOrRollback(&err) // Should error
return nil
}

dal := DAL{}
tx, err := dal.Begin(context.Background())
if err != nil {
return err
}
defer tx.CommitOrRollback(&err) // Should error
return nil
}

func success() (err error) {
_ = func() error {
dal := DAL{}
tx, err := dal.Begin(context.Background())
if err != nil {
return err
}
defer tx.CommitOrRollback(&err) // Should error
return nil
}
dal := DAL{}
tx, err := dal.Begin(context.Background())
if err != nil {
return err
}
defer tx.CommitOrRollback(&err) // Should NOT error
return nil
}

func main() {
_ = failure()
_ = success()
}
2 changes: 1 addition & 1 deletion scripts/ftl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ ftldir="$(dirname "$0")/.."
name="$(basename "$0")"
dest="${ftldir}/build/devel"
mkdir -p "$dest"
(cd "${ftldir}" && ./bin/go build -ldflags="-s -w -buildid=" -o "$dest/${name}" "./cmd/${name}") && exec "$dest/${name}" "$@"
(cd "${ftldir}/cmd/${name}" && "${ftldir}/bin/go" build -ldflags="-s -w -buildid=" -o "$dest/${name}" ./) && exec "$dest/${name}" "$@"
1 change: 1 addition & 0 deletions scripts/lint-commit-or-rollback

0 comments on commit 1d4bf29

Please sign in to comment.