diff --git a/bindings/mysql/mysql.go b/bindings/mysql/mysql.go index c68ac1da8d..5c15f59910 100644 --- a/bindings/mysql/mysql.go +++ b/bindings/mysql/mysql.go @@ -20,11 +20,13 @@ import ( "database/sql" "database/sql/driver" "encoding/json" + "encoding/pem" "errors" "fmt" "os" "reflect" "strconv" + "strings" "sync/atomic" "time" @@ -125,11 +127,26 @@ func (m *Mysql) Init(ctx context.Context, md bindings.Metadata) error { // meta.PemContents supersedes meta.PemPath if both are provided. if meta.PemContents != "" { + // Reformat the PEM to standard format + meta.PemContents = reformatPEM(meta.PemContents) pemContents = []byte(meta.PemContents) } else if meta.PemPath != "" { pemContents, err = os.ReadFile(meta.PemPath) if err != nil { - return fmt.Errorf("unable to read pem file: %w", err) + return fmt.Errorf("unable to read PEM file: %w", err) + } + } + + // Decode PEM contents and parse certificate to ensure it's valid. + if len(pemContents) != 0 { + block, _ := pem.Decode(pemContents) + if block == nil { + return errors.New("failed to decode PEM") + } + + _, err = x509.ParseCertificate(block.Bytes) + if err != nil { + return fmt.Errorf("failed to parse PEM contents: %w", err) } } @@ -296,6 +313,7 @@ func initDB(url string, pemContents []byte) (*sql.DB, error) { if len(pemContents) != 0 { rootCertPool := x509.NewCertPool() + ok := rootCertPool.AppendCertsFromPEM(pemContents) if !ok { return nil, errors.New("failed to append PEM") @@ -323,6 +341,32 @@ func initDB(url string, pemContents []byte) (*sql.DB, error) { return db, nil } +// Helper function to reformat a single-line PEM into standard PEM format +func reformatPEM(pemStr string) string { + // Ensure headers and footers are on their own lines + pemStr = strings.ReplaceAll(pemStr, "-----BEGIN CERTIFICATE-----", "\n-----BEGIN CERTIFICATE-----\n") + pemStr = strings.ReplaceAll(pemStr, "-----END CERTIFICATE-----", "\n-----END CERTIFICATE-----") + + // Split into base64-encoded content and reformat into 64-character lines + lines := strings.Split(pemStr, "\n") + if len(lines) >= 3 { + encodedContent := lines[1] + lines[1] = strings.Join(chunkString(encodedContent, 64), "\n") + } + return strings.Join(lines, "\n") +} + +// Helper function to split a string into chunks of a given size +func chunkString(s string, chunkSize int) []string { + var chunks []string + for len(s) > chunkSize { + chunks = append(chunks, s[:chunkSize]) + s = s[chunkSize:] + } + chunks = append(chunks, s) + return chunks +} + func (m *Mysql) jsonify(rows *sql.Rows) ([]byte, error) { columnTypes, err := rows.ColumnTypes() if err != nil {