Skip to content

Commit c019bad

Browse files
authored
ldap: add timeout and retry-backoff for ldap (#51927) (#52551)
close #51883
1 parent 710538b commit c019bad

File tree

3 files changed

+113
-9
lines changed

3 files changed

+113
-9
lines changed

pkg/privilege/privileges/ldap/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@ go_library(
1212
visibility = ["//visibility:public"],
1313
deps = [
1414
"//pkg/privilege/conn",
15+
"//pkg/util/intest",
16+
"//pkg/util/logutil",
1517
"@com_github_go_ldap_ldap_v3//:ldap",
1618
"@com_github_ngaut_pools//:pools",
1719
"@com_github_pingcap_errors//:errors",
20+
"@org_uber_go_zap//:zap",
1821
],
1922
)
2023

@@ -29,5 +32,6 @@ go_test(
2932
"test/ldap.key",
3033
],
3134
flaky = True,
35+
shard_count = 3,
3236
deps = ["@com_github_stretchr_testify//require"],
3337
)

pkg/privilege/privileges/ldap/ldap_common.go

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,25 @@ import (
2222
"os"
2323
"strconv"
2424
"sync"
25+
"time"
2526

2627
"github.com/go-ldap/ldap/v3"
2728
"github.com/ngaut/pools"
2829
"github.com/pingcap/errors"
30+
"github.com/pingcap/tidb/pkg/util/intest"
31+
"github.com/pingcap/tidb/pkg/util/logutil"
32+
"go.uber.org/zap"
2933
)
3034

35+
// ldapTimeout is set to 10s. It works on both the TCP/TLS dialing timeout, and the LDAP request timeout. For connection with TLS, the
36+
// user may find that it fails after 2*ldapTimeout, because TiDB will try to connect through both `StartTLS` (from a normal TCP connection)
37+
// and `TLS`, therefore the total time is 2*ldapTimeout.
38+
var ldapTimeout = 10 * time.Second
39+
40+
// skipTLSForTest is used to skip trying to connect with TLS directly in tests. If it's set to false, connection will only try to
41+
// use `StartTLS`
42+
var skipTLSForTest = false
43+
3144
// ldapAuthImpl gives the internal utilities of authentication with LDAP.
3245
// The getter and setter methods will lock the mutex inside, while all other methods don't, so all other method call
3346
// should be protected by `impl.Lock()`.
@@ -115,10 +128,13 @@ func (impl *ldapAuthImpl) initializeCAPool() error {
115128
}
116129

117130
func (impl *ldapAuthImpl) tryConnectLDAPThroughStartTLS(address string) (*ldap.Conn, error) {
118-
ldapConnection, err := ldap.Dial("tcp", address)
131+
ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer(&net.Dialer{
132+
Timeout: ldapTimeout,
133+
}))
119134
if err != nil {
120135
return nil, err
121136
}
137+
ldapConnection.SetTimeout(ldapTimeout)
122138

123139
err = ldapConnection.StartTLS(&tls.Config{
124140
RootCAs: impl.caPool,
@@ -129,17 +145,22 @@ func (impl *ldapAuthImpl) tryConnectLDAPThroughStartTLS(address string) (*ldap.C
129145

130146
return nil, err
131147
}
148+
132149
return ldapConnection, nil
133150
}
134151

135152
func (impl *ldapAuthImpl) tryConnectLDAPThroughTLS(address string) (*ldap.Conn, error) {
136-
ldapConnection, err := ldap.DialTLS("tcp", address, &tls.Config{
153+
tlsConfig := &tls.Config{
137154
RootCAs: impl.caPool,
138155
ServerName: impl.ldapServerHost,
139-
})
156+
}
157+
ldapConnection, err := ldap.DialURL("ldaps://"+address, ldap.DialWithTLSDialer(tlsConfig, &net.Dialer{
158+
Timeout: ldapTimeout,
159+
}))
140160
if err != nil {
141161
return nil, err
142162
}
163+
ldapConnection.SetTimeout(ldapTimeout)
143164

144165
return ldapConnection, nil
145166
}
@@ -152,6 +173,10 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) {
152173
if impl.enableTLS {
153174
ldapConnection, err := impl.tryConnectLDAPThroughStartTLS(address)
154175
if err != nil {
176+
if intest.InTest && skipTLSForTest {
177+
return nil, err
178+
}
179+
155180
ldapConnection, err = impl.tryConnectLDAPThroughTLS(address)
156181
if err != nil {
157182
return nil, errors.Wrap(err, "create ldap connection")
@@ -160,15 +185,19 @@ func (impl *ldapAuthImpl) connectionFactory() (pools.Resource, error) {
160185

161186
return ldapConnection, nil
162187
}
163-
ldapConnection, err := ldap.Dial("tcp", address)
188+
ldapConnection, err := ldap.DialURL("ldap://"+address, ldap.DialWithDialer(&net.Dialer{
189+
Timeout: ldapTimeout,
190+
}))
164191
if err != nil {
165192
return nil, errors.Wrap(err, "create ldap connection")
166193
}
194+
ldapConnection.SetTimeout(ldapTimeout)
167195

168196
return ldapConnection, nil
169197
}
170198

171199
const getConnectionMaxRetry = 10
200+
const getConnectionRetryInterval = 500 * time.Millisecond
172201

173202
func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) {
174203
retryCount := 0
@@ -189,13 +218,19 @@ func (impl *ldapAuthImpl) getConnection() (*ldap.Conn, error) {
189218
Password: impl.bindRootPWD,
190219
})
191220
if err != nil {
221+
logutil.BgLogger().Warn("fail to use LDAP connection bind to anonymous user. Retrying", zap.Error(err),
222+
zap.Duration("backoff", getConnectionRetryInterval))
223+
192224
// fail to bind to anonymous user, just release this connection and try to get a new one
193225
impl.ldapConnectionPool.Put(nil)
194226

195227
retryCount++
196228
if retryCount >= getConnectionMaxRetry {
197229
return nil, errors.Wrap(err, "fail to bind to anonymous user")
198230
}
231+
// Be careful that it's still holding the lock of the system variables, so it's not good to sleep here.
232+
// TODO: refactor the `RWLock` to avoid the problem of holding the lock.
233+
time.Sleep(getConnectionRetryInterval)
199234
continue
200235
}
201236

@@ -208,12 +243,12 @@ func (impl *ldapAuthImpl) putConnection(conn *ldap.Conn) {
208243
}
209244

210245
func (impl *ldapAuthImpl) initializePool() {
211-
if impl.ldapConnectionPool != nil {
212-
impl.ldapConnectionPool.Close()
213-
}
214-
215-
// skip initialization when the variables are not correct
246+
// skip re-initialization when the variables are not correct
216247
if impl.initCapacity > 0 && impl.maxCapacity >= impl.initCapacity {
248+
if impl.ldapConnectionPool != nil {
249+
impl.ldapConnectionPool.Close()
250+
}
251+
217252
impl.ldapConnectionPool = pools.NewResourcePool(impl.connectionFactory, impl.initCapacity, impl.maxCapacity, 0)
218253
}
219254
}
@@ -258,6 +293,7 @@ func (impl *ldapAuthImpl) SetLDAPServerHost(ldapServerHost string) {
258293

259294
if ldapServerHost != impl.ldapServerHost {
260295
impl.ldapServerHost = ldapServerHost
296+
impl.initializePool()
261297
}
262298
}
263299

@@ -268,6 +304,7 @@ func (impl *ldapAuthImpl) SetLDAPServerPort(ldapServerPort int) {
268304

269305
if ldapServerPort != impl.ldapServerPort {
270306
impl.ldapServerPort = ldapServerPort
307+
impl.initializePool()
271308
}
272309
}
273310

@@ -278,6 +315,7 @@ func (impl *ldapAuthImpl) SetEnableTLS(enableTLS bool) {
278315

279316
if enableTLS != impl.enableTLS {
280317
impl.enableTLS = enableTLS
318+
impl.initializePool()
281319
}
282320
}
283321

pkg/privilege/privileges/ldap/ldap_common_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"net"
2525
"sync"
2626
"testing"
27+
"time"
2728

2829
"github.com/stretchr/testify/require"
2930
)
@@ -108,3 +109,64 @@ func TestConnectThrough636(t *testing.T) {
108109
require.NoError(t, err)
109110
defer conn.Close()
110111
}
112+
113+
func TestLDAPStartTLSTimeout(t *testing.T) {
114+
originalTimeout := ldapTimeout
115+
ldapTimeout = time.Second * 2
116+
skipTLSForTest = true
117+
defer func() {
118+
ldapTimeout = originalTimeout
119+
skipTLSForTest = false
120+
}()
121+
122+
var ln net.Listener
123+
startListen := make(chan struct{})
124+
afterTimeout := make(chan struct{})
125+
defer close(afterTimeout)
126+
127+
// this test only tests whether the LDAP with LTS enabled will fallback from StartTLS
128+
randomTLSServicePort := rand.Int()%10000 + 10000
129+
serverWg := &sync.WaitGroup{}
130+
serverWg.Add(1)
131+
go func() {
132+
var err error
133+
defer close(startListen)
134+
defer serverWg.Done()
135+
136+
ln, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", randomTLSServicePort))
137+
require.NoError(t, err)
138+
startListen <- struct{}{}
139+
140+
conn, err := ln.Accept()
141+
require.NoError(t, err)
142+
143+
<-afterTimeout
144+
require.NoError(t, conn.Close())
145+
146+
// close the server
147+
require.NoError(t, ln.Close())
148+
}()
149+
150+
<-startListen
151+
defer func() {
152+
serverWg.Wait()
153+
}()
154+
155+
impl := &ldapAuthImpl{}
156+
impl.SetEnableTLS(true)
157+
impl.SetLDAPServerHost("localhost")
158+
impl.SetLDAPServerPort(randomTLSServicePort)
159+
160+
impl.caPool = x509.NewCertPool()
161+
require.True(t, impl.caPool.AppendCertsFromPEM(tlsCAStr))
162+
impl.SetInitCapacity(1)
163+
impl.SetMaxCapacity(1)
164+
165+
now := time.Now()
166+
_, err := impl.connectionFactory()
167+
afterTimeout <- struct{}{}
168+
dur := time.Since(now)
169+
require.Greater(t, dur, 2*time.Second)
170+
require.Less(t, dur, 3*time.Second)
171+
require.ErrorContains(t, err, "connection timed out")
172+
}

0 commit comments

Comments
 (0)