Skip to content

Commit 6cd9f74

Browse files
db-willdveeden
andauthored
Add tls support for mysql client (#186)
* Add tls support for mysql client * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden <[email protected]> * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden <[email protected]> * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden <[email protected]> * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden <[email protected]> * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden <[email protected]> * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden <[email protected]> * Adjust code based on reviews * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden <[email protected]> * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden <[email protected]> * Update error message * Update cmd/go-tpc/main.go Co-authored-by: Daniël van Eeden <[email protected]> * Update cmd/go-tpc/main.go --------- Co-authored-by: Daniël van Eeden <[email protected]>
1 parent 01c0653 commit 6cd9f74

File tree

1 file changed

+69
-5
lines changed

1 file changed

+69
-5
lines changed

cmd/go-tpc/main.go

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package main
33
import (
44
"context"
55
"crypto/sha1"
6+
"crypto/tls"
7+
"crypto/x509"
68
"database/sql"
79
sqldrv "database/sql/driver"
810
"encoding/hex"
@@ -18,6 +20,7 @@ import (
1820
"github.com/pingcap/go-tpc/pkg/util"
1921
"github.com/spf13/cobra"
2022
_ "go.uber.org/automaxprocs"
23+
2124
// mysql package
2225
"github.com/go-sql-driver/mysql"
2326
// pg
@@ -47,15 +50,19 @@ var (
4750
connParams string
4851
outputStyle string
4952
targets []string
53+
sslCA string
54+
sslCert string
55+
sslKey string
5056

5157
globalDB *sql.DB
5258
globalCtx context.Context
5359
)
5460

5561
const (
56-
createDBDDL = "CREATE DATABASE "
57-
mysqlDriver = "mysql"
58-
pgDriver = "postgres"
62+
createDBDDL = "CREATE DATABASE "
63+
mysqlDriver = "mysql"
64+
pgDriver = "postgres"
65+
customTlsName = "custom"
5966
)
6067

6168
type MuxDriver struct {
@@ -93,18 +100,31 @@ func newDB(targets []string, driver string, user string, password string, dbName
93100
hash.Write([]byte(password))
94101
hash.Write([]byte(dbName))
95102
hash.Write([]byte(connParams))
103+
104+
if driver == mysqlDriver && (len(sslCA) > 0 || len(sslCert) > 0 || len(sslKey) > 0) {
105+
registerMysqlTLSConfig()
106+
}
107+
96108
for i, addr := range targets {
97109
hash.Write([]byte(addr))
98110
switch driver {
99111
case mysqlDriver:
112+
var tlsName string = "preferred"
113+
if len(sslCA) > 0 {
114+
tlsName = customTlsName
115+
}
100116
// allow multiple statements in one query to allow q15 on the TPC-H
101-
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?multiStatements=true&tls=preferred", user, password, addr, dbName)
117+
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?multiStatements=true&tls=%s", user, password, addr, dbName, tlsName)
102118
if len(connParams) > 0 {
103119
dsn = dsn + "&" + connParams
104120
}
105121
names[i] = dsn
106122
drv = &mysql.MySQLDriver{}
107123
case pgDriver:
124+
if len(sslCA) > 0 || len(sslKey) > 0 || len(sslCert) > 0 {
125+
panic("postgresql driver doesn't support TLS yet")
126+
}
127+
108128
dsn := fmt.Sprintf("postgres://%s:%s@%s/%s", user, password, addr, dbName)
109129
if len(connParams) > 0 {
110130
dsn = dsn + "?" + connParams
@@ -150,9 +170,10 @@ func openDB() {
150170
tmpDB, _ = newDB(targets, driver, user, password, "", connParams)
151171
defer tmpDB.Close()
152172
if _, err := tmpDB.Exec(createDBDDL + dbName); err != nil {
153-
panic(fmt.Errorf("failed to create database, err %v\n", err))
173+
panic(fmt.Errorf("failed to create database, err %v", err))
154174
}
155175
} else {
176+
fmt.Printf("failed to ping db, err %v\n", err)
156177
globalDB = nil
157178
}
158179
} else {
@@ -209,6 +230,9 @@ func main() {
209230
rootCmd.PersistentFlags().StringVar(&outputStyle, "output", util.OutputStylePlain, "output style, valid values can be { plain | table | json }")
210231
rootCmd.PersistentFlags().StringSliceVar(&targets, "targets", nil, "Target database addresses")
211232
rootCmd.PersistentFlags().MarkHidden("targets")
233+
rootCmd.PersistentFlags().StringVar(&sslCA, "ssl-ca", "", "Path of file that contains list of trusted SSL CAs for connection")
234+
rootCmd.PersistentFlags().StringVar(&sslCert, "ssl-cert", "", "Path of file that contains X509 certificate in PEM format for connection")
235+
rootCmd.PersistentFlags().StringVar(&sslKey, "ssl-key", "", "Path of file that contains X509 key in PEM format for connection")
212236

213237
cobra.EnablePrefixMatching = true
214238

@@ -251,3 +275,43 @@ func main() {
251275

252276
cancel()
253277
}
278+
279+
// registerMysqlTLSConfig constructs a `*tls.Config` from the CA, certification and key
280+
// paths, and register to mysql client.
281+
func registerMysqlTLSConfig() {
282+
// Load the client certificates from disk
283+
var certificates []tls.Certificate
284+
if len(sslCert) != 0 && len(sslKey) != 0 {
285+
cert, err := tls.LoadX509KeyPair(sslCert, sslKey)
286+
if err != nil {
287+
panic(fmt.Errorf("could not load client key pair, err %v", err))
288+
}
289+
certificates = []tls.Certificate{cert}
290+
} else if len(sslCert) > 0 || len(sslKey) > 0 {
291+
panic("incomplete key pair configuration")
292+
}
293+
294+
// Create a certificate pool from CA
295+
certPool := x509.NewCertPool()
296+
ca, err := os.ReadFile(sslCA)
297+
if err != nil {
298+
panic(fmt.Errorf("could not read CA certificate, err %v", err))
299+
}
300+
301+
// Append the certificates from the CA
302+
if !certPool.AppendCertsFromPEM(ca) {
303+
panic("failed to append CA certs")
304+
}
305+
306+
tlsConfig := &tls.Config{
307+
MinVersion: tls.VersionTLS12,
308+
Certificates: certificates,
309+
RootCAs: certPool,
310+
ClientCAs: certPool,
311+
}
312+
313+
err = mysql.RegisterTLSConfig(customTlsName, tlsConfig)
314+
if err != nil {
315+
panic(fmt.Errorf("failed to register TLS config, err %v", err))
316+
}
317+
}

0 commit comments

Comments
 (0)