Skip to content

Commit

Permalink
mysql: pass TLS config directly to MySQL's config (#3348)
Browse files Browse the repository at this point in the history
  • Loading branch information
giautm authored Dec 28, 2023
1 parent 97fe9d0 commit 92114ef
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 46 deletions.
19 changes: 5 additions & 14 deletions mysql/awsmysql/awsmysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import (
"database/sql/driver"
"fmt"
"net/url"
"sync/atomic"

"contrib.go.opencensus.io/integrations/ocsql"
"github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -112,20 +111,14 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
c.sem <- struct{}{} // release
return nil, fmt.Errorf("connect RDS: %v", err)
}
// TODO(light): Avoid global registry once https://github.com/go-sql-driver/mysql/issues/771 is fixed.
tlsConfigName := fmt.Sprintf(
"gocloud.dev/mysql/awsmysql/%d",
atomic.AddUint32(&tlsConfigCounter, 1),
)
err = mysql.RegisterTLSConfig(tlsConfigName, &tls.Config{
RootCAs: certPool,
})
cfg, err := mysql.ParseDSN(c.dsn)
if err != nil {
c.sem <- struct{}{} // release
return nil, fmt.Errorf("connect RDS: register TLS: %v", err)
return nil, fmt.Errorf("connect RDS: parse DSN: %v", err)
}
cfg.TLS = &tls.Config{
RootCAs: certPool,
}
cfg, _ := mysql.ParseDSN(c.dsn)
cfg.TLSConfig = tlsConfigName
c.dsn = cfg.FormatDSN()
close(c.ready)
// Don't release sem: make it block forever, so this case won't be run again.
Expand All @@ -141,8 +134,6 @@ func (c *connector) Driver() driver.Driver {
return ocsql.Wrap(mysql.MySQLDriver{}, c.traceOpts...)
}

var tlsConfigCounter uint32

// A CertPoolProvider obtains a certificate pool that contains the RDS CA certificate.
type CertPoolProvider = rds.CertPoolProvider

Expand Down
22 changes: 1 addition & 21 deletions mysql/azuremysql/azuremysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"fmt"
"net/url"
"strings"
"sync"

"contrib.go.opencensus.io/integrations/ocsql"
"github.com/go-sql-driver/mysql"
Expand Down Expand Up @@ -106,26 +105,12 @@ func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
c.sem <- struct{}{} // release
return nil, fmt.Errorf("connect Azure MySql: %v", err)
}

// TODO(light): Avoid global registry once https://github.com/go-sql-driver/mysql/issues/771 is fixed.
tlsConfigCounter.mu.Lock()
tlsConfigNum := tlsConfigCounter.n
tlsConfigCounter.n++
tlsConfigCounter.mu.Unlock()
tlsConfigName := fmt.Sprintf("gocloud.dev/mysql/azuremysql/%d", tlsConfigNum)
err = mysql.RegisterTLSConfig(tlsConfigName, &tls.Config{
RootCAs: certPool,
})
if err != nil {
c.sem <- struct{}{} // release
return nil, fmt.Errorf("connect Azure MySql: register TLS: %v", err)
}
cfg := &mysql.Config{
Net: "tcp",
Addr: c.addr,
User: c.user,
Passwd: c.password,
TLSConfig: tlsConfigName,
TLS: &tls.Config{RootCAs: certPool},
AllowCleartextPasswords: true,
AllowNativePasswords: true,
DBName: c.dbName,
Expand All @@ -145,11 +130,6 @@ func (c *connector) Driver() driver.Driver {
return ocsql.Wrap(mysql.MySQLDriver{}, c.traceOpts...)
}

var tlsConfigCounter struct {
mu sync.Mutex
n int
}

// A CertPoolProvider obtains a certificate pool that contains the Azure CA certificate.
type CertPoolProvider = azuredb.CertPoolProvider

Expand Down
16 changes: 5 additions & 11 deletions mysql/gcpmysql/gcpmysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"net/url"
"strings"
"sync"
"sync/atomic"

"contrib.go.opencensus.io/integrations/ocsql"
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/proxy"
Expand Down Expand Up @@ -97,12 +98,8 @@ func (uo *URLOpener) OpenMySQLURL(ctx context.Context, u *url.URL) (*sql.DB, err
if uo.CertSource == nil {
return nil, fmt.Errorf("gcpmysql: URLOpener CertSource is nil")
}
// TODO(light): Avoid global registry once https://github.com/go-sql-driver/mysql/issues/771 is fixed.
dialerCounter.mu.Lock()
dialerNum := dialerCounter.n
dialerCounter.mu.Unlock()
dialerName := fmt.Sprintf("gocloud.dev/mysql/gcpmysql/%d", dialerNum)

dialerName := fmt.Sprintf("gocloud.dev/mysql/gcpmysql/%d",
atomic.AddUint32(&dialerCounter, 1))
cfg, err := configFromURL(u, dialerName)
if err != nil {
return nil, fmt.Errorf("gcpmysql: open config %v", err)
Expand All @@ -112,7 +109,7 @@ func (uo *URLOpener) OpenMySQLURL(ctx context.Context, u *url.URL) (*sql.DB, err
Port: 3307,
Certs: uo.CertSource,
}
mysql.RegisterDial(dialerName, client.Dial)
mysql.RegisterDialContext(dialerName, client.DialContext)

db := sql.OpenDB(connector{cfg.FormatDSN(), uo.TraceOpts})
return db, nil
Expand Down Expand Up @@ -161,10 +158,7 @@ func instanceFromURL(u *url.URL) (instance, db string, _ error) {
return parts[0] + ":" + parts[1] + ":" + parts[2], parts[3], nil
}

var dialerCounter struct {
mu sync.Mutex
n int
}
var dialerCounter uint32

type connector struct {
dsn string
Expand Down

0 comments on commit 92114ef

Please sign in to comment.