Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Support reused connection for better performance. #753

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ endif
test:
go test -count=1 -v $(shell go list ./... | grep -v "hub/test")

test-db:
go test -count=1 -timeout=6h -v ./database...

# Run Hub REST API tests.
test-api:
HUB_BASE_URL=$(HUB_BASE_URL) go test -count=1 -p=1 -v -failfast ./test/api/...
Expand Down
34 changes: 34 additions & 0 deletions database/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,40 @@ import (

var N, _ = env.GetInt("TEST_CONCURRENT", 10)

func TestDriver(t *testing.T) {
pid := os.Getpid()
Settings.DB.Path = fmt.Sprintf("/tmp/driver-%d.db", pid)
defer func() {
_ = os.Remove(Settings.DB.Path)
}()
db, err := Open(true)
if err != nil {
panic(err)
}
key := "driver"
m := &model.Setting{Key: key, Value: "Test"}
// insert.
err = db.Create(m).Error
if err != nil {
panic(err)
}
// update
err = db.Save(m).Error
if err != nil {
panic(err)
}
// select
err = db.First(m, m.ID).Error
if err != nil {
panic(err)
}
// delete
err = db.Delete(m).Error
if err != nil {
panic(err)
}
}

func TestConcurrent(t *testing.T) {
pid := os.Getpid()
Settings.DB.Path = fmt.Sprintf("/tmp/concurrent-%d.db", pid)
Expand Down
188 changes: 170 additions & 18 deletions database/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ import (
"github.com/mattn/go-sqlite3"
)

// Driver is a wrapper around the SQLite driver.
// The purpose is to prevent database locked errors using
// a mutex around write operations.
type Driver struct {
mutex sync.Mutex
wrapped driver.Driver
dsn string
}

// Open a connection.
func (d *Driver) Open(dsn string) (conn driver.Conn, err error) {
d.wrapped = &sqlite3.SQLiteDriver{}
conn, err = d.wrapped.Open(dsn)
Expand All @@ -28,48 +32,67 @@ func (d *Driver) Open(dsn string) (conn driver.Conn, err error) {
return
}

// OpenConnector opens a connection.
func (d *Driver) OpenConnector(dsn string) (dc driver.Connector, err error) {
d.dsn = dsn
dc = d
return
}

// Connect opens a connection.
func (d *Driver) Connect(context.Context) (conn driver.Conn, err error) {
conn, err = d.Open(d.dsn)
return
}

// Driver returns the underlying driver.
func (d *Driver) Driver() driver.Driver {
return d
}

// Conn is a DB connection.
type Conn struct {
mutex *sync.Mutex
wrapped driver.Conn
hasMutex bool
hasTx bool
}

// Ping the DB.
func (c *Conn) Ping(ctx context.Context) (err error) {
if p, cast := c.wrapped.(driver.Pinger); cast {
err = p.Ping(ctx)
}
return
}

// ResetSession reset the connection.
// - Reset the Tx.
// - Release the mutex.
func (c *Conn) ResetSession(ctx context.Context) (err error) {
defer func() {
c.hasTx = false
c.release()
}()
if p, cast := c.wrapped.(driver.SessionResetter); cast {
err = p.ResetSession(ctx)
}
return
}

// IsValid returns true when the connection is valid.
// When true, the connection may be reused by the sql package.
func (c *Conn) IsValid() (b bool) {
b = true
if p, cast := c.wrapped.(driver.Validator); cast {
b = p.IsValid()
}
return
}

// QueryContext execute a query with context.
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Rows, err error) {
defer c.release()
if c.needsMutex(query) {
c.acquire()
}
Expand All @@ -79,57 +102,90 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return
}

func (c *Conn) PrepareContext(ctx context.Context, query string) (s driver.Stmt, err error) {
// ExecContext executes an SQL/DDL statement with context.
func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
defer c.release()
if c.needsMutex(query) {
c.acquire()
}
if p, cast := c.wrapped.(driver.ConnPrepareContext); cast {
s, err = p.PrepareContext(ctx, query)
if p, cast := c.wrapped.(driver.ExecerContext); cast {
r, err = p.ExecContext(ctx, query, args)
}
return
}

func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (r driver.Result, err error) {
// Begin a transaction.
func (c *Conn) Begin() (tx driver.Tx, err error) {
c.acquire()
if p, cast := c.wrapped.(driver.ExecerContext); cast {
r, err = p.ExecContext(ctx, query, args)
tx, err = c.wrapped.Begin()
if err != nil {
return
}
tx = &Tx{
conn: c,
wrapped: tx,
}
c.hasTx = true
return
}

// BeginTx begins a transaction.
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
c.acquire()
if p, cast := c.wrapped.(driver.ConnBeginTx); cast {
tx, err = p.BeginTx(ctx, opts)
} else {
tx, err = c.wrapped.Begin()
}
tx = &Tx{
conn: c,
wrapped: tx,
}
c.hasTx = true
return
}

func (c *Conn) Prepare(query string) (s driver.Stmt, err error) {
// Prepare a statement.
func (c *Conn) Prepare(query string) (stmt driver.Stmt, err error) {
if c.needsMutex(query) {
c.acquire()
}
s, err = c.wrapped.Prepare(query)
stmt, err = c.wrapped.Prepare(query)
stmt = &Stmt{
conn: c,
wrapped: stmt,
query: query,
}
return
}

func (c *Conn) Close() (err error) {
err = c.wrapped.Close()
c.release()
// PrepareContext prepares a statement with context.
func (c *Conn) PrepareContext(ctx context.Context, query string) (stmt driver.Stmt, err error) {
if c.needsMutex(query) {
c.acquire()
}
if p, cast := c.wrapped.(driver.ConnPrepareContext); cast {
stmt, err = p.PrepareContext(ctx, query)
} else {
stmt, err = c.Prepare(query)
}
stmt = &Stmt{
conn: c,
wrapped: stmt,
query: query,
}
return
}

func (c *Conn) Begin() (tx driver.Tx, err error) {
c.acquire()
tx, err = c.wrapped.Begin()
if err != nil {
return
}
// Close the connection.
func (c *Conn) Close() (err error) {
err = c.wrapped.Close()
c.hasMutex = false
c.release()
return
}

// needsMutex returns true when the query should is a write operation.
func (c *Conn) needsMutex(query string) (matched bool) {
if query == "" {
return
Expand All @@ -144,16 +200,112 @@ func (c *Conn) needsMutex(query string) (matched bool) {
return
}

// acquire the mutex.
// Since Locks are not reentrant, the mutex is acquired
// only if this connection has not already acquired it.
func (c *Conn) acquire() {
if !c.hasMutex {
c.mutex.Lock()
c.hasMutex = true
}
}

// release the mutex.
// Released only when:
// - This connection has acquired it
// - Not in a transaction.
func (c *Conn) release() {
if c.hasMutex {
if c.hasMutex && !c.hasTx {
c.mutex.Unlock()
c.hasMutex = false
}
}

// Stmt is a SQL/DDL statement.
type Stmt struct {
wrapped driver.Stmt
conn *Conn
query string
}

// Close the statement.
func (s *Stmt) Close() (err error) {
defer s.conn.release()
err = s.wrapped.Close()
return
}

// NumInput returns the number of (query) input parameters.
func (s *Stmt) NumInput() (n int) {
n = s.wrapped.NumInput()
return
}

// Exec executes the statement.
func (s *Stmt) Exec(args []driver.Value) (r driver.Result, err error) {
r, err = s.wrapped.Exec(args)
return
}

// ExecContext executes the statement with context.
func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (r driver.Result, err error) {
if p, cast := s.wrapped.(driver.StmtExecContext); cast {
r, err = p.ExecContext(ctx, args)
} else {
r, err = s.Exec(s.values(args))
}
return
}

// Query executes a query.
func (s *Stmt) Query(args []driver.Value) (r driver.Rows, err error) {
r, err = s.wrapped.Query(args)
return
}

// QueryContext executes a query.
func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (r driver.Rows, err error) {
if p, cast := s.wrapped.(driver.StmtQueryContext); cast {
r, err = p.QueryContext(ctx, args)
} else {
r, err = s.Query(s.values(args))
}
return
}

// values converts named-values to values.
func (s *Stmt) values(named []driver.NamedValue) (out []driver.Value) {
for i := range named {
out = append(out, named[i].Value)
}
return
}

// Tx is a transaction.
type Tx struct {
wrapped driver.Tx
conn *Conn
}

// Commit the transaction.
// Releases the mutex.
func (t *Tx) Commit() (err error) {
defer func() {
t.conn.hasTx = false
t.conn.release()
}()
err = t.wrapped.Commit()
return
}

//
// Rollback the transaction.
// Releases the mutex.
func (t *Tx) Rollback() (err error) {
defer func() {
t.conn.hasTx = false
t.conn.release()
}()
err = t.wrapped.Rollback()
return
}
Loading