@@ -3,6 +3,8 @@ package main
3
3
import (
4
4
"context"
5
5
"crypto/sha1"
6
+ "crypto/tls"
7
+ "crypto/x509"
6
8
"database/sql"
7
9
sqldrv "database/sql/driver"
8
10
"encoding/hex"
@@ -18,6 +20,7 @@ import (
18
20
"github.com/pingcap/go-tpc/pkg/util"
19
21
"github.com/spf13/cobra"
20
22
_ "go.uber.org/automaxprocs"
23
+
21
24
// mysql package
22
25
"github.com/go-sql-driver/mysql"
23
26
// pg
@@ -47,15 +50,19 @@ var (
47
50
connParams string
48
51
outputStyle string
49
52
targets []string
53
+ sslCA string
54
+ sslCert string
55
+ sslKey string
50
56
51
57
globalDB * sql.DB
52
58
globalCtx context.Context
53
59
)
54
60
55
61
const (
56
- createDBDDL = "CREATE DATABASE "
57
- mysqlDriver = "mysql"
58
- pgDriver = "postgres"
62
+ createDBDDL = "CREATE DATABASE "
63
+ mysqlDriver = "mysql"
64
+ pgDriver = "postgres"
65
+ customTlsName = "custom"
59
66
)
60
67
61
68
type MuxDriver struct {
@@ -93,18 +100,31 @@ func newDB(targets []string, driver string, user string, password string, dbName
93
100
hash .Write ([]byte (password ))
94
101
hash .Write ([]byte (dbName ))
95
102
hash .Write ([]byte (connParams ))
103
+
104
+ if driver == mysqlDriver && (len (sslCA ) > 0 || len (sslCert ) > 0 || len (sslKey ) > 0 ) {
105
+ registerMysqlTLSConfig ()
106
+ }
107
+
96
108
for i , addr := range targets {
97
109
hash .Write ([]byte (addr ))
98
110
switch driver {
99
111
case mysqlDriver :
112
+ var tlsName string = "preferred"
113
+ if len (sslCA ) > 0 {
114
+ tlsName = customTlsName
115
+ }
100
116
// allow multiple statements in one query to allow q15 on the TPC-H
101
- dsn := fmt .Sprintf ("%s:%s@tcp(%s)/%s?multiStatements=true&tls=preferred " , user , password , addr , dbName )
117
+ dsn := fmt .Sprintf ("%s:%s@tcp(%s)/%s?multiStatements=true&tls=%s " , user , password , addr , dbName , tlsName )
102
118
if len (connParams ) > 0 {
103
119
dsn = dsn + "&" + connParams
104
120
}
105
121
names [i ] = dsn
106
122
drv = & mysql.MySQLDriver {}
107
123
case pgDriver :
124
+ if len (sslCA ) > 0 || len (sslKey ) > 0 || len (sslCert ) > 0 {
125
+ panic ("postgresql driver doesn't support TLS yet" )
126
+ }
127
+
108
128
dsn := fmt .Sprintf ("postgres://%s:%s@%s/%s" , user , password , addr , dbName )
109
129
if len (connParams ) > 0 {
110
130
dsn = dsn + "?" + connParams
@@ -150,9 +170,10 @@ func openDB() {
150
170
tmpDB , _ = newDB (targets , driver , user , password , "" , connParams )
151
171
defer tmpDB .Close ()
152
172
if _ , err := tmpDB .Exec (createDBDDL + dbName ); err != nil {
153
- panic (fmt .Errorf ("failed to create database, err %v\n " , err ))
173
+ panic (fmt .Errorf ("failed to create database, err %v" , err ))
154
174
}
155
175
} else {
176
+ fmt .Printf ("failed to ping db, err %v\n " , err )
156
177
globalDB = nil
157
178
}
158
179
} else {
@@ -209,6 +230,9 @@ func main() {
209
230
rootCmd .PersistentFlags ().StringVar (& outputStyle , "output" , util .OutputStylePlain , "output style, valid values can be { plain | table | json }" )
210
231
rootCmd .PersistentFlags ().StringSliceVar (& targets , "targets" , nil , "Target database addresses" )
211
232
rootCmd .PersistentFlags ().MarkHidden ("targets" )
233
+ rootCmd .PersistentFlags ().StringVar (& sslCA , "ssl-ca" , "" , "Path of file that contains list of trusted SSL CAs for connection" )
234
+ rootCmd .PersistentFlags ().StringVar (& sslCert , "ssl-cert" , "" , "Path of file that contains X509 certificate in PEM format for connection" )
235
+ rootCmd .PersistentFlags ().StringVar (& sslKey , "ssl-key" , "" , "Path of file that contains X509 key in PEM format for connection" )
212
236
213
237
cobra .EnablePrefixMatching = true
214
238
@@ -251,3 +275,43 @@ func main() {
251
275
252
276
cancel ()
253
277
}
278
+
279
+ // registerMysqlTLSConfig constructs a `*tls.Config` from the CA, certification and key
280
+ // paths, and register to mysql client.
281
+ func registerMysqlTLSConfig () {
282
+ // Load the client certificates from disk
283
+ var certificates []tls.Certificate
284
+ if len (sslCert ) != 0 && len (sslKey ) != 0 {
285
+ cert , err := tls .LoadX509KeyPair (sslCert , sslKey )
286
+ if err != nil {
287
+ panic (fmt .Errorf ("could not load client key pair, err %v" , err ))
288
+ }
289
+ certificates = []tls.Certificate {cert }
290
+ } else if len (sslCert ) > 0 || len (sslKey ) > 0 {
291
+ panic ("incomplete key pair configuration" )
292
+ }
293
+
294
+ // Create a certificate pool from CA
295
+ certPool := x509 .NewCertPool ()
296
+ ca , err := os .ReadFile (sslCA )
297
+ if err != nil {
298
+ panic (fmt .Errorf ("could not read CA certificate, err %v" , err ))
299
+ }
300
+
301
+ // Append the certificates from the CA
302
+ if ! certPool .AppendCertsFromPEM (ca ) {
303
+ panic ("failed to append CA certs" )
304
+ }
305
+
306
+ tlsConfig := & tls.Config {
307
+ MinVersion : tls .VersionTLS12 ,
308
+ Certificates : certificates ,
309
+ RootCAs : certPool ,
310
+ ClientCAs : certPool ,
311
+ }
312
+
313
+ err = mysql .RegisterTLSConfig (customTlsName , tlsConfig )
314
+ if err != nil {
315
+ panic (fmt .Errorf ("failed to register TLS config, err %v" , err ))
316
+ }
317
+ }
0 commit comments