Skip to content

Commit 10a84d0

Browse files
authored
Add a retry when getting ts from PD for validating read ts (#1600)
Signed-off-by: MyonKeminta <[email protected]>
1 parent 5ac118b commit 10a84d0

File tree

2 files changed

+156
-35
lines changed

2 files changed

+156
-35
lines changed

oracle/oracles/pd.go

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import (
4747
"github.com/tikv/client-go/v2/internal/logutil"
4848
"github.com/tikv/client-go/v2/metrics"
4949
"github.com/tikv/client-go/v2/oracle"
50+
"github.com/tikv/client-go/v2/util"
5051
pd "github.com/tikv/pd/client"
5152
"github.com/tikv/pd/client/clients/tso"
5253
"go.uber.org/zap"
@@ -647,6 +648,7 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op
647648
// waiting for reusing the same result should not be canceled. So pass context.Background() instead of the
648649
// current ctx.
649650
res, err := o.GetTimestamp(context.Background(), opt)
651+
_, _ = util.EvalFailpoint("getCurrentTSForValidationBeforeReturn")
650652
return res, err
651653
})
652654
select {
@@ -660,42 +662,75 @@ func (o *pdOracle) getCurrentTSForValidation(ctx context.Context, opt *oracle.Op
660662
}
661663
}
662664

663-
func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) (errRet error) {
665+
func (o *pdOracle) ValidateReadTS(ctx context.Context, readTS uint64, isStaleRead bool, opt *oracle.Option) error {
664666
if readTS == math.MaxUint64 {
665667
if isStaleRead {
666668
return oracle.ErrLatestStaleRead{}
667669
}
668670
return nil
669671
}
670672

671-
latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope)
672-
// If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check.
673-
// But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function
674-
// loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls.
675-
if !exists || readTS > latestTSInfo.tso {
676-
currentTS, err := o.getCurrentTSForValidation(ctx, opt)
677-
if err != nil {
678-
return errors.Errorf("fail to validate read timestamp: %v", err)
679-
}
680-
if isStaleRead {
681-
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now())
682-
}
683-
if readTS > currentTS {
684-
return oracle.ErrFutureTSRead{
685-
ReadTS: readTS,
686-
CurrentTS: currentTS,
673+
retrying := false
674+
for {
675+
latestTSInfo, exists := o.getLastTSWithArrivalTS(opt.TxnScope)
676+
// If we fail to get latestTSInfo or the readTS exceeds it, get a timestamp from PD to double-check.
677+
// But we don't need to strictly fetch the latest TS. So if there are already concurrent calls to this function
678+
// loading the latest TS, we can just reuse the same result to avoid too many concurrent GetTS calls.
679+
if !exists || readTS > latestTSInfo.tso {
680+
currentTS, err := o.getCurrentTSForValidation(ctx, opt)
681+
if err != nil {
682+
return errors.Errorf("fail to validate read timestamp: %v", err)
683+
}
684+
if isStaleRead && !retrying {
685+
// Trigger the adjustment at most once in a single invocation.
686+
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, currentTS, time.Now())
687+
}
688+
if readTS > currentTS {
689+
// It's possible that the caller is checking a ts that's legal but not fetched from the current oracle
690+
// object. In this case, it's possible that:
691+
// * The ts is not be cached by the low resolution ts (so that readTS > latestTSInfo.TSO);
692+
// * ... and then the getCurrentTSForValidation (which uses a singleflight internally) reuse a
693+
// previously-started call and returns an older ts
694+
// so that it may cause the check false-positive.
695+
// To handle this case, we do not fail immediately when the check doesn't at once; instead, retry one
696+
// more time. In the retry:
697+
// * Considering that there can already be some other concurrent GetTimestamp operation that may have updated
698+
// the low resolution ts, so check it again. If it passes, then no need to get the next ts from PD,
699+
// which is slow.
700+
// * Then, call getCurrentTSForValidation and check again. As the current GetTimestamp operation
701+
// inside getCurrentTSForValidation must be started after finishing the previous one (while the
702+
// latter is finished after starting this invocation to ValidateReadTS), then we can conclude that
703+
// the next ts returned by getCurrentTSForValidation must be greater than any ts allocated by PD
704+
// before the current invocation to ValidateReadTS.
705+
skipRetry := false
706+
if val, err1 := util.EvalFailpoint("validateReadTSRetryGetTS"); err1 == nil {
707+
if str, ok := val.(string); ok {
708+
if str == "skip" {
709+
skipRetry = true
710+
}
711+
}
712+
}
713+
if !retrying && !skipRetry {
714+
retrying = true
715+
continue
716+
}
717+
return oracle.ErrFutureTSRead{
718+
ReadTS: readTS,
719+
CurrentTS: currentTS,
720+
}
721+
}
722+
} else if !retrying && isStaleRead {
723+
// Trigger the adjustment at most once in a single invocation.
724+
estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0)
725+
if err != nil {
726+
logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval",
727+
zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope))
728+
} else {
729+
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, estimatedCurrentTS, time.Now())
687730
}
688731
}
689-
} else if isStaleRead {
690-
estimatedCurrentTS, err := o.getStaleTimestampWithLastTS(latestTSInfo, 0)
691-
if err != nil {
692-
logutil.Logger(ctx).Warn("failed to estimate current ts by getSlateTimestamp for auto-adjusting update low resolution ts interval",
693-
zap.Error(err), zap.Uint64("readTS", readTS), zap.String("txnScope", opt.TxnScope))
694-
} else {
695-
o.adjustUpdateLowResolutionTSIntervalWithRequestedStaleness(readTS, estimatedCurrentTS, time.Now())
696-
}
732+
return nil
697733
}
698-
return nil
699734
}
700735

701736
// adjustUpdateLowResolutionTSIntervalWithRequestedStaleness triggers adjustments the update interval of low resolution

oracle/oracles/pd_test.go

Lines changed: 95 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,18 @@ package oracles
3636

3737
import (
3838
"context"
39+
"fmt"
3940
"math"
4041
"sync"
4142
"sync/atomic"
4243
"testing"
4344
"time"
4445

46+
"github.com/pingcap/failpoint"
4547
"github.com/stretchr/testify/assert"
48+
"github.com/stretchr/testify/require"
4649
"github.com/tikv/client-go/v2/oracle"
50+
"github.com/tikv/client-go/v2/util"
4751
pd "github.com/tikv/pd/client"
4852
)
4953

@@ -374,10 +378,15 @@ func TestValidateReadTS(t *testing.T) {
374378
// the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass.
375379
err = o.ValidateReadTS(ctx, ts+1, staleRead, opt)
376380
assert.NoError(t, err)
377-
// It can't pass if the readTS is newer than previous ts + 2.
381+
// It can also pass if the readTS is previous ts + 2, as it can perform a retry.
378382
ts, err = o.GetTimestamp(ctx, opt)
379383
assert.NoError(t, err)
380384
err = o.ValidateReadTS(ctx, ts+2, staleRead, opt)
385+
assert.NoError(t, err)
386+
// As it retries at most once, it can't pass the check if the readTS is newer than previous ts + 3
387+
ts, err = o.GetTimestamp(ctx, opt)
388+
assert.NoError(t, err)
389+
err = o.ValidateReadTS(ctx, ts+3, staleRead, opt)
381390
assert.Error(t, err)
382391

383392
// Simulate other PD clients requests a timestamp.
@@ -412,6 +421,8 @@ func (c *MockPDClientWithPause) Resume() {
412421
}
413422

414423
func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
424+
util.EnableFailpoints()
425+
415426
pdClient := &MockPDClientWithPause{}
416427
o, err := NewPdOracle(pdClient, &PDOracleOptions{
417428
UpdateInterval: time.Second * 2,
@@ -420,6 +431,11 @@ func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
420431
assert.NoError(t, err)
421432
defer o.Close()
422433

434+
assert.NoError(t, failpoint.Enable("tikvclient/validateReadTSRetryGetTS", `return("skip")`))
435+
defer func() {
436+
assert.NoError(t, failpoint.Disable("tikvclient/validateReadTSRetryGetTS"))
437+
}()
438+
423439
asyncValidate := func(ctx context.Context, readTS uint64) chan error {
424440
ch := make(chan error, 1)
425441
go func() {
@@ -429,21 +445,21 @@ func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
429445
return ch
430446
}
431447

432-
noResult := func(ch chan error) {
448+
noResult := func(ch chan error, additionalMsg ...interface{}) {
433449
select {
434450
case <-ch:
435-
assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked")
451+
assert.FailNow(t, "a ValidateReadTS operation is not blocked while it's expected to be blocked", additionalMsg...)
436452
default:
437453
}
438454
}
439455

440456
cancelIndices := []int{-1, -1, 0, 1}
441-
for i, ts := range []uint64{100, 200, 300, 400} {
457+
for caseIndex, ts := range []uint64{100, 200, 300, 400} {
442458
// Note: the ts is the result that the next GetTS will return. Any validation with readTS <= ts should pass, otherwise fail.
443459

444460
// We will cancel the cancelIndex-th validation call. This is for testing that canceling some of the calls
445-
// doesn't affect other calls that are waiting
446-
cancelIndex := cancelIndices[i]
461+
// doesn't affect other calls that are waiting.
462+
cancelIndex := cancelIndices[caseIndex]
447463

448464
pdClient.Pause()
449465

@@ -541,13 +557,18 @@ func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) {
541557
assert.NoError(t, err)
542558
mustNoNotify()
543559

544-
// It loads `ts + 1` from the mock PD, and the check cannot pass.
560+
// It loads `ts + 1` from the mock PD, and then retries `ts + 2` and passes.
545561
err = o.ValidateReadTS(ctx, ts+2, false, opt)
562+
assert.NoError(t, err)
563+
mustNoNotify()
564+
565+
// It loads `ts + 3` and `ts + 4` from the mock PD, and the check cannot pass.
566+
err = o.ValidateReadTS(ctx, ts+5, false, opt)
546567
assert.Error(t, err)
547568
mustNoNotify()
548569

549-
// Do the check again. It loads `ts + 2` from the mock PD, and the check passes.
550-
err = o.ValidateReadTS(ctx, ts+2, false, opt)
570+
// Do the check again. It loads `ts + 5` from the mock PD, and the check passes.
571+
err = o.ValidateReadTS(ctx, ts+5, false, opt)
551572
assert.NoError(t, err)
552573
mustNoNotify()
553574
}
@@ -586,3 +607,68 @@ func TestSetLastTSAlwaysPushTS(t *testing.T) {
586607
close(cancel)
587608
wg.Wait()
588609
}
610+
611+
func TestValidateReadTSFromDifferentSource(t *testing.T) {
612+
// If a ts is fetched from a different client to the same cluster, the ts might not be cached by the low resolution
613+
// ts. In this case, the validation should not be false positive.
614+
util.EnableFailpoints()
615+
pdClient := MockPdClient{}
616+
o, err := NewPdOracle(&pdClient, &PDOracleOptions{
617+
UpdateInterval: time.Second * 2,
618+
NoUpdateTS: true,
619+
})
620+
assert.NoError(t, err)
621+
defer o.Close()
622+
623+
// Construct the situation that the low resolution ts is lower than the ts fetched from another client.
624+
ts, err := o.GetTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
625+
assert.NoError(t, err)
626+
lowResolutionTS, err := o.GetLowResolutionTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
627+
assert.NoError(t, err)
628+
assert.Equal(t, ts, lowResolutionTS)
629+
630+
assert.NoError(t, failpoint.Enable("tikvclient/getCurrentTSForValidationBeforeReturn", "pause"))
631+
defer func() {
632+
assert.NoError(t, failpoint.Disable("tikvclient/getCurrentTSForValidationBeforeReturn"))
633+
}()
634+
635+
// Trigger getting ts from PD for validation, which causes a previously-started concurrent call. We block it during
636+
// getting the ts by the failpoint. So that when the second call starts, it will reuse the same singleflight
637+
// for getting the ts, which return a older ts to it.
638+
firstResCh := make(chan error)
639+
go func() {
640+
firstResCh <- o.ValidateReadTS(context.Background(), ts+1, false, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
641+
}()
642+
643+
select {
644+
case err = <-firstResCh:
645+
assert.FailNow(t, fmt.Sprintf("expected to be blocked, but got result: %v", err))
646+
case <-time.After(time.Millisecond * 50):
647+
}
648+
649+
pdClient.logicalTimestamp.Add(10)
650+
physical, logical, err := pdClient.GetTS(context.Background())
651+
assert.NoError(t, err)
652+
// The next ts should be the previous `ts + 1 (fetched by the ValidateReadTS call) + 10 (advanced manually) + 1`.
653+
nextTS := oracle.ComposeTS(physical, logical)
654+
// The low resolution ts is not updated since the validation.
655+
nextLowResolutionTS, err := o.GetLowResolutionTimestamp(context.Background(), &oracle.Option{TxnScope: oracle.GlobalTxnScope})
656+
assert.NoError(t, err)
657+
assert.Equal(t, ts+1, nextLowResolutionTS)
658+
assert.Equal(t, nextTS-11, nextLowResolutionTS)
659+
660+
// The second check reuses the singleflight to get the ts and the result can be older than `nextTS`.
661+
secondResCh := make(chan error)
662+
go func() {
663+
secondResCh <- o.ValidateReadTS(context.Background(), nextTS, false, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
664+
}()
665+
select {
666+
case err = <-firstResCh:
667+
assert.FailNow(t, fmt.Sprintf("expected to be blocked, but got result: %v", err))
668+
case <-time.After(time.Millisecond * 50):
669+
}
670+
671+
assert.NoError(t, failpoint.Disable("tikvclient/getCurrentTSForValidationBeforeReturn"))
672+
require.NoError(t, <-firstResCh)
673+
require.NoError(t, <-secondResCh)
674+
}

0 commit comments

Comments
 (0)