forked from bitcomplete/sqltestutil
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatabase_suite.go
89 lines (77 loc) · 2.25 KB
/
database_suite.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package sqltestutil
import (
"context"
"fmt"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/suite"
)
// Suite is a testify suite [1] that provides a database connection for running
// tests against a SQL database. For each test that is run, a new transaction is
// started, and then rolled back at the end of the test so that each test can
// operate on a clean slate. Here's an example of how to use it:
//
// type ExampleTestSuite struct {
// sqltestutil.Suite
// }
//
// func (s *ExampleTestSuite) TestExample() {
// _, err := s.Tx().Exec("INSERT INTO foo (bar) VALUES (?)", "baz")
// s.Assert().NoError(err)
// }
//
// func TestExampleTestSuite(t *testing.T) {
// suite.Run(t, &ExampleTestSuite{
// Suite: sqltestutil.Suite{
// Context: context.Background(),
// DriverName: "pgx",
// DataSourceName: "postgres://localhost:5432/example",
// },
// })
// }
//
// [1]: https://pkg.go.dev/github.com/stretchr/[email protected]/suite#Suite
type Suite struct {
suite.Suite
// Context is a required field for constructing a Suite, and is used for
// database operations within a suite. It's public because it's convenient to
// have access to it in tests.
context.Context
// DriverName is a required field for constructing a Suite, and is used to
// connect to the underlying SQL database.
DriverName string
// DataSourceName is a required field for constructing a Suite, and is used to
// connect to the underlying SQL database.
DataSourceName string
db *sqlx.DB
tx *sqlx.Tx
}
// DB returns the underlying SQL connection.
func (s *Suite) DB() *sqlx.DB {
return s.db
}
// Tx returns the transaction for the current test.
func (s *Suite) Tx() *sqlx.Tx {
if s.tx == nil {
var err error
s.tx, err = s.db.BeginTxx(s.Context, nil)
s.Require().NoError(err)
}
return s.tx
}
func (s *Suite) TearDownTest() {
if s.tx != nil {
err := s.tx.Rollback()
s.Require().NoError(err)
s.tx = nil
}
}
func (s *Suite) SetupSuite() {
db, err := sqlx.Open(s.DriverName, s.DataSourceName)
s.Require().NoError(err)
s.db = db
}
func (s *Suite) TearDownSuite() {
if err := s.db.Close(); err != nil {
fmt.Println("error in database close:", err)
}
}