From 6cd9f74be874a66920d87d957c8038b49d45e45e Mon Sep 17 00:00:00 2001 From: will <87208113+db-will@users.noreply.github.com> Date: Thu, 9 Jan 2025 08:20:10 -0500 Subject: [PATCH] Add tls support for mysql client (#186) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add tls support for mysql client * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden * Adjust code based on reviews * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden * Update error message * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden * Update cmd/go-tpc/main.go --------- Co-authored-by: Daniël van Eeden --- cmd/go-tpc/main.go | 74 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 69 insertions(+), 5 deletions(-) diff --git a/cmd/go-tpc/main.go b/cmd/go-tpc/main.go index 4b4134b..c412558 100644 --- a/cmd/go-tpc/main.go +++ b/cmd/go-tpc/main.go @@ -3,6 +3,8 @@ package main import ( "context" "crypto/sha1" + "crypto/tls" + "crypto/x509" "database/sql" sqldrv "database/sql/driver" "encoding/hex" @@ -18,6 +20,7 @@ import ( "github.com/pingcap/go-tpc/pkg/util" "github.com/spf13/cobra" _ "go.uber.org/automaxprocs" + // mysql package "github.com/go-sql-driver/mysql" // pg @@ -47,15 +50,19 @@ var ( connParams string outputStyle string targets []string + sslCA string + sslCert string + sslKey string globalDB *sql.DB globalCtx context.Context ) const ( - createDBDDL = "CREATE DATABASE " - mysqlDriver = "mysql" - pgDriver = "postgres" + createDBDDL = "CREATE DATABASE " + mysqlDriver = "mysql" + pgDriver = "postgres" + customTlsName = "custom" ) type MuxDriver struct { @@ -93,18 +100,31 @@ func newDB(targets []string, driver string, user string, password string, dbName hash.Write([]byte(password)) hash.Write([]byte(dbName)) hash.Write([]byte(connParams)) + + if driver == mysqlDriver && (len(sslCA) > 0 || len(sslCert) > 0 || len(sslKey) > 0) { + registerMysqlTLSConfig() + } + for i, addr := range targets { hash.Write([]byte(addr)) switch driver { case mysqlDriver: + var tlsName string = "preferred" + if len(sslCA) > 0 { + tlsName = customTlsName + } // allow multiple statements in one query to allow q15 on the TPC-H - dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?multiStatements=true&tls=preferred", user, password, addr, dbName) + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?multiStatements=true&tls=%s", user, password, addr, dbName, tlsName) if len(connParams) > 0 { dsn = dsn + "&" + connParams } names[i] = dsn drv = &mysql.MySQLDriver{} case pgDriver: + if len(sslCA) > 0 || len(sslKey) > 0 || len(sslCert) > 0 { + panic("postgresql driver doesn't support TLS yet") + } + dsn := fmt.Sprintf("postgres://%s:%s@%s/%s", user, password, addr, dbName) if len(connParams) > 0 { dsn = dsn + "?" + connParams @@ -150,9 +170,10 @@ func openDB() { tmpDB, _ = newDB(targets, driver, user, password, "", connParams) defer tmpDB.Close() if _, err := tmpDB.Exec(createDBDDL + dbName); err != nil { - panic(fmt.Errorf("failed to create database, err %v\n", err)) + panic(fmt.Errorf("failed to create database, err %v", err)) } } else { + fmt.Printf("failed to ping db, err %v\n", err) globalDB = nil } } else { @@ -209,6 +230,9 @@ func main() { rootCmd.PersistentFlags().StringVar(&outputStyle, "output", util.OutputStylePlain, "output style, valid values can be { plain | table | json }") rootCmd.PersistentFlags().StringSliceVar(&targets, "targets", nil, "Target database addresses") rootCmd.PersistentFlags().MarkHidden("targets") + rootCmd.PersistentFlags().StringVar(&sslCA, "ssl-ca", "", "Path of file that contains list of trusted SSL CAs for connection") + rootCmd.PersistentFlags().StringVar(&sslCert, "ssl-cert", "", "Path of file that contains X509 certificate in PEM format for connection") + rootCmd.PersistentFlags().StringVar(&sslKey, "ssl-key", "", "Path of file that contains X509 key in PEM format for connection") cobra.EnablePrefixMatching = true @@ -251,3 +275,43 @@ func main() { cancel() } + +// registerMysqlTLSConfig constructs a `*tls.Config` from the CA, certification and key +// paths, and register to mysql client. +func registerMysqlTLSConfig() { + // Load the client certificates from disk + var certificates []tls.Certificate + if len(sslCert) != 0 && len(sslKey) != 0 { + cert, err := tls.LoadX509KeyPair(sslCert, sslKey) + if err != nil { + panic(fmt.Errorf("could not load client key pair, err %v", err)) + } + certificates = []tls.Certificate{cert} + } else if len(sslCert) > 0 || len(sslKey) > 0 { + panic("incomplete key pair configuration") + } + + // Create a certificate pool from CA + certPool := x509.NewCertPool() + ca, err := os.ReadFile(sslCA) + if err != nil { + panic(fmt.Errorf("could not read CA certificate, err %v", err)) + } + + // Append the certificates from the CA + if !certPool.AppendCertsFromPEM(ca) { + panic("failed to append CA certs") + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: certificates, + RootCAs: certPool, + ClientCAs: certPool, + } + + err = mysql.RegisterTLSConfig(customTlsName, tlsConfig) + if err != nil { + panic(fmt.Errorf("failed to register TLS config, err %v", err)) + } +}