@@ -22,12 +22,25 @@ import (
22
22
"os"
23
23
"strconv"
24
24
"sync"
25
+ "time"
25
26
26
27
"github.com/go-ldap/ldap/v3"
27
28
"github.com/ngaut/pools"
28
29
"github.com/pingcap/errors"
30
+ "github.com/pingcap/tidb/pkg/util/intest"
31
+ "github.com/pingcap/tidb/pkg/util/logutil"
32
+ "go.uber.org/zap"
29
33
)
30
34
35
+ // ldapTimeout is set to 10s. It works on both the TCP/TLS dialing timeout, and the LDAP request timeout. For connection with TLS, the
36
+ // user may find that it fails after 2*ldapTimeout, because TiDB will try to connect through both `StartTLS` (from a normal TCP connection)
37
+ // and `TLS`, therefore the total time is 2*ldapTimeout.
38
+ var ldapTimeout = 10 * time .Second
39
+
40
+ // skipTLSForTest is used to skip trying to connect with TLS directly in tests. If it's set to false, connection will only try to
41
+ // use `StartTLS`
42
+ var skipTLSForTest = false
43
+
31
44
// ldapAuthImpl gives the internal utilities of authentication with LDAP.
32
45
// The getter and setter methods will lock the mutex inside, while all other methods don't, so all other method call
33
46
// should be protected by `impl.Lock()`.
@@ -115,10 +128,13 @@ func (impl *ldapAuthImpl) initializeCAPool() error {
115
128
}
116
129
117
130
func (impl * ldapAuthImpl ) tryConnectLDAPThroughStartTLS (address string ) (* ldap.Conn , error ) {
118
- ldapConnection , err := ldap .Dial ("tcp" , address )
131
+ ldapConnection , err := ldap .DialURL ("ldap://" + address , ldap .DialWithDialer (& net.Dialer {
132
+ Timeout : ldapTimeout ,
133
+ }))
119
134
if err != nil {
120
135
return nil , err
121
136
}
137
+ ldapConnection .SetTimeout (ldapTimeout )
122
138
123
139
err = ldapConnection .StartTLS (& tls.Config {
124
140
RootCAs : impl .caPool ,
@@ -129,17 +145,22 @@ func (impl *ldapAuthImpl) tryConnectLDAPThroughStartTLS(address string) (*ldap.C
129
145
130
146
return nil , err
131
147
}
148
+
132
149
return ldapConnection , nil
133
150
}
134
151
135
152
func (impl * ldapAuthImpl ) tryConnectLDAPThroughTLS (address string ) (* ldap.Conn , error ) {
136
- ldapConnection , err := ldap . DialTLS ( "tcp" , address , & tls.Config {
153
+ tlsConfig := & tls.Config {
137
154
RootCAs : impl .caPool ,
138
155
ServerName : impl .ldapServerHost ,
139
- })
156
+ }
157
+ ldapConnection , err := ldap .DialURL ("ldaps://" + address , ldap .DialWithTLSDialer (tlsConfig , & net.Dialer {
158
+ Timeout : ldapTimeout ,
159
+ }))
140
160
if err != nil {
141
161
return nil , err
142
162
}
163
+ ldapConnection .SetTimeout (ldapTimeout )
143
164
144
165
return ldapConnection , nil
145
166
}
@@ -152,6 +173,10 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) {
152
173
if impl .enableTLS {
153
174
ldapConnection , err := impl .tryConnectLDAPThroughStartTLS (address )
154
175
if err != nil {
176
+ if intest .InTest && skipTLSForTest {
177
+ return nil , err
178
+ }
179
+
155
180
ldapConnection , err = impl .tryConnectLDAPThroughTLS (address )
156
181
if err != nil {
157
182
return nil , errors .Wrap (err , "create ldap connection" )
@@ -160,15 +185,19 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) {
160
185
161
186
return ldapConnection , nil
162
187
}
163
- ldapConnection , err := ldap .Dial ("tcp" , address )
188
+ ldapConnection , err := ldap .DialURL ("ldap://" + address , ldap .DialWithDialer (& net.Dialer {
189
+ Timeout : ldapTimeout ,
190
+ }))
164
191
if err != nil {
165
192
return nil , errors .Wrap (err , "create ldap connection" )
166
193
}
194
+ ldapConnection .SetTimeout (ldapTimeout )
167
195
168
196
return ldapConnection , nil
169
197
}
170
198
171
199
const getConnectionMaxRetry = 10
200
+ const getConnectionRetryInterval = 500 * time .Millisecond
172
201
173
202
func (impl * ldapAuthImpl ) getConnection () (* ldap.Conn , error ) {
174
203
retryCount := 0
@@ -189,13 +218,19 @@ func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) {
189
218
Password : impl .bindRootPWD ,
190
219
})
191
220
if err != nil {
221
+ logutil .BgLogger ().Warn ("fail to use LDAP connection bind to anonymous user. Retrying" , zap .Error (err ),
222
+ zap .Duration ("backoff" , getConnectionRetryInterval ))
223
+
192
224
// fail to bind to anonymous user, just release this connection and try to get a new one
193
225
impl .ldapConnectionPool .Put (nil )
194
226
195
227
retryCount ++
196
228
if retryCount >= getConnectionMaxRetry {
197
229
return nil , errors .Wrap (err , "fail to bind to anonymous user" )
198
230
}
231
+ // Be careful that it's still holding the lock of the system variables, so it's not good to sleep here.
232
+ // TODO: refactor the `RWLock` to avoid the problem of holding the lock.
233
+ time .Sleep (getConnectionRetryInterval )
199
234
continue
200
235
}
201
236
@@ -208,12 +243,12 @@ func (impl *ldapAuthImpl) putConnection(conn *ldap.Conn) {
208
243
}
209
244
210
245
func (impl * ldapAuthImpl ) initializePool () {
211
- if impl .ldapConnectionPool != nil {
212
- impl .ldapConnectionPool .Close ()
213
- }
214
-
215
- // skip initialization when the variables are not correct
246
+ // skip re-initialization when the variables are not correct
216
247
if impl .initCapacity > 0 && impl .maxCapacity >= impl .initCapacity {
248
+ if impl .ldapConnectionPool != nil {
249
+ impl .ldapConnectionPool .Close ()
250
+ }
251
+
217
252
impl .ldapConnectionPool = pools .NewResourcePool (impl .connectionFactory , impl .initCapacity , impl .maxCapacity , 0 )
218
253
}
219
254
}
@@ -258,6 +293,7 @@ func (impl *ldapAuthImpl) SetLDAPServerHost(ldapServerHost string) {
258
293
259
294
if ldapServerHost != impl .ldapServerHost {
260
295
impl .ldapServerHost = ldapServerHost
296
+ impl .initializePool ()
261
297
}
262
298
}
263
299
@@ -268,6 +304,7 @@ func (impl *ldapAuthImpl) SetLDAPServerPort(ldapServerPort int) {
268
304
269
305
if ldapServerPort != impl .ldapServerPort {
270
306
impl .ldapServerPort = ldapServerPort
307
+ impl .initializePool ()
271
308
}
272
309
}
273
310
@@ -278,6 +315,7 @@ func (impl *ldapAuthImpl) SetEnableTLS(enableTLS bool) {
278
315
279
316
if enableTLS != impl .enableTLS {
280
317
impl .enableTLS = enableTLS
318
+ impl .initializePool ()
281
319
}
282
320
}
283
321
0 commit comments