@@ -36,14 +36,18 @@ package oracles
36
36
37
37
import (
38
38
"context"
39
+ "fmt"
39
40
"math"
40
41
"sync"
41
42
"sync/atomic"
42
43
"testing"
43
44
"time"
44
45
46
+ "github.com/pingcap/failpoint"
45
47
"github.com/stretchr/testify/assert"
48
+ "github.com/stretchr/testify/require"
46
49
"github.com/tikv/client-go/v2/oracle"
50
+ "github.com/tikv/client-go/v2/util"
47
51
pd "github.com/tikv/pd/client"
48
52
)
49
53
@@ -374,10 +378,15 @@ func TestValidateReadTS(t *testing.T) {
374
378
// the fetching-from-PD path, and it can get the previous ts + 1, which can allow this validation to pass.
375
379
err = o .ValidateReadTS (ctx , ts + 1 , staleRead , opt )
376
380
assert .NoError (t , err )
377
- // It can't pass if the readTS is newer than previous ts + 2.
381
+ // It can also pass if the readTS is previous ts + 2, as it can perform a retry .
378
382
ts , err = o .GetTimestamp (ctx , opt )
379
383
assert .NoError (t , err )
380
384
err = o .ValidateReadTS (ctx , ts + 2 , staleRead , opt )
385
+ assert .NoError (t , err )
386
+ // As it retries at most once, it can't pass the check if the readTS is newer than previous ts + 3
387
+ ts , err = o .GetTimestamp (ctx , opt )
388
+ assert .NoError (t , err )
389
+ err = o .ValidateReadTS (ctx , ts + 3 , staleRead , opt )
381
390
assert .Error (t , err )
382
391
383
392
// Simulate other PD clients requests a timestamp.
@@ -412,6 +421,8 @@ func (c *MockPDClientWithPause) Resume() {
412
421
}
413
422
414
423
func TestValidateReadTSForStaleReadReusingGetTSResult (t * testing.T ) {
424
+ util .EnableFailpoints ()
425
+
415
426
pdClient := & MockPDClientWithPause {}
416
427
o , err := NewPdOracle (pdClient , & PDOracleOptions {
417
428
UpdateInterval : time .Second * 2 ,
@@ -420,6 +431,11 @@ func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
420
431
assert .NoError (t , err )
421
432
defer o .Close ()
422
433
434
+ assert .NoError (t , failpoint .Enable ("tikvclient/validateReadTSRetryGetTS" , `return("skip")` ))
435
+ defer func () {
436
+ assert .NoError (t , failpoint .Disable ("tikvclient/validateReadTSRetryGetTS" ))
437
+ }()
438
+
423
439
asyncValidate := func (ctx context.Context , readTS uint64 ) chan error {
424
440
ch := make (chan error , 1 )
425
441
go func () {
@@ -429,21 +445,21 @@ func TestValidateReadTSForStaleReadReusingGetTSResult(t *testing.T) {
429
445
return ch
430
446
}
431
447
432
- noResult := func (ch chan error ) {
448
+ noResult := func (ch chan error , additionalMsg ... interface {} ) {
433
449
select {
434
450
case <- ch :
435
- assert .FailNow (t , "a ValidateReadTS operation is not blocked while it's expected to be blocked" )
451
+ assert .FailNow (t , "a ValidateReadTS operation is not blocked while it's expected to be blocked" , additionalMsg ... )
436
452
default :
437
453
}
438
454
}
439
455
440
456
cancelIndices := []int {- 1 , - 1 , 0 , 1 }
441
- for i , ts := range []uint64 {100 , 200 , 300 , 400 } {
457
+ for caseIndex , ts := range []uint64 {100 , 200 , 300 , 400 } {
442
458
// Note: the ts is the result that the next GetTS will return. Any validation with readTS <= ts should pass, otherwise fail.
443
459
444
460
// We will cancel the cancelIndex-th validation call. This is for testing that canceling some of the calls
445
- // doesn't affect other calls that are waiting
446
- cancelIndex := cancelIndices [i ]
461
+ // doesn't affect other calls that are waiting.
462
+ cancelIndex := cancelIndices [caseIndex ]
447
463
448
464
pdClient .Pause ()
449
465
@@ -541,13 +557,18 @@ func TestValidateReadTSForNormalReadDoNotAffectUpdateInterval(t *testing.T) {
541
557
assert .NoError (t , err )
542
558
mustNoNotify ()
543
559
544
- // It loads `ts + 1` from the mock PD, and the check cannot pass .
560
+ // It loads `ts + 1` from the mock PD, and then retries `ts + 2` and passes .
545
561
err = o .ValidateReadTS (ctx , ts + 2 , false , opt )
562
+ assert .NoError (t , err )
563
+ mustNoNotify ()
564
+
565
+ // It loads `ts + 3` and `ts + 4` from the mock PD, and the check cannot pass.
566
+ err = o .ValidateReadTS (ctx , ts + 5 , false , opt )
546
567
assert .Error (t , err )
547
568
mustNoNotify ()
548
569
549
- // Do the check again. It loads `ts + 2 ` from the mock PD, and the check passes.
550
- err = o .ValidateReadTS (ctx , ts + 2 , false , opt )
570
+ // Do the check again. It loads `ts + 5 ` from the mock PD, and the check passes.
571
+ err = o .ValidateReadTS (ctx , ts + 5 , false , opt )
551
572
assert .NoError (t , err )
552
573
mustNoNotify ()
553
574
}
@@ -586,3 +607,68 @@ func TestSetLastTSAlwaysPushTS(t *testing.T) {
586
607
close (cancel )
587
608
wg .Wait ()
588
609
}
610
+
611
+ func TestValidateReadTSFromDifferentSource (t * testing.T ) {
612
+ // If a ts is fetched from a different client to the same cluster, the ts might not be cached by the low resolution
613
+ // ts. In this case, the validation should not be false positive.
614
+ util .EnableFailpoints ()
615
+ pdClient := MockPdClient {}
616
+ o , err := NewPdOracle (& pdClient , & PDOracleOptions {
617
+ UpdateInterval : time .Second * 2 ,
618
+ NoUpdateTS : true ,
619
+ })
620
+ assert .NoError (t , err )
621
+ defer o .Close ()
622
+
623
+ // Construct the situation that the low resolution ts is lower than the ts fetched from another client.
624
+ ts , err := o .GetTimestamp (context .Background (), & oracle.Option {TxnScope : oracle .GlobalTxnScope })
625
+ assert .NoError (t , err )
626
+ lowResolutionTS , err := o .GetLowResolutionTimestamp (context .Background (), & oracle.Option {TxnScope : oracle .GlobalTxnScope })
627
+ assert .NoError (t , err )
628
+ assert .Equal (t , ts , lowResolutionTS )
629
+
630
+ assert .NoError (t , failpoint .Enable ("tikvclient/getCurrentTSForValidationBeforeReturn" , "pause" ))
631
+ defer func () {
632
+ assert .NoError (t , failpoint .Disable ("tikvclient/getCurrentTSForValidationBeforeReturn" ))
633
+ }()
634
+
635
+ // Trigger getting ts from PD for validation, which causes a previously-started concurrent call. We block it during
636
+ // getting the ts by the failpoint. So that when the second call starts, it will reuse the same singleflight
637
+ // for getting the ts, which return a older ts to it.
638
+ firstResCh := make (chan error )
639
+ go func () {
640
+ firstResCh <- o .ValidateReadTS (context .Background (), ts + 1 , false , & oracle.Option {TxnScope : oracle .GlobalTxnScope })
641
+ }()
642
+
643
+ select {
644
+ case err = <- firstResCh :
645
+ assert .FailNow (t , fmt .Sprintf ("expected to be blocked, but got result: %v" , err ))
646
+ case <- time .After (time .Millisecond * 50 ):
647
+ }
648
+
649
+ pdClient .logicalTimestamp .Add (10 )
650
+ physical , logical , err := pdClient .GetTS (context .Background ())
651
+ assert .NoError (t , err )
652
+ // The next ts should be the previous `ts + 1 (fetched by the ValidateReadTS call) + 10 (advanced manually) + 1`.
653
+ nextTS := oracle .ComposeTS (physical , logical )
654
+ // The low resolution ts is not updated since the validation.
655
+ nextLowResolutionTS , err := o .GetLowResolutionTimestamp (context .Background (), & oracle.Option {TxnScope : oracle .GlobalTxnScope })
656
+ assert .NoError (t , err )
657
+ assert .Equal (t , ts + 1 , nextLowResolutionTS )
658
+ assert .Equal (t , nextTS - 11 , nextLowResolutionTS )
659
+
660
+ // The second check reuses the singleflight to get the ts and the result can be older than `nextTS`.
661
+ secondResCh := make (chan error )
662
+ go func () {
663
+ secondResCh <- o .ValidateReadTS (context .Background (), nextTS , false , & oracle.Option {TxnScope : oracle .GlobalTxnScope })
664
+ }()
665
+ select {
666
+ case err = <- firstResCh :
667
+ assert .FailNow (t , fmt .Sprintf ("expected to be blocked, but got result: %v" , err ))
668
+ case <- time .After (time .Millisecond * 50 ):
669
+ }
670
+
671
+ assert .NoError (t , failpoint .Disable ("tikvclient/getCurrentTSForValidationBeforeReturn" ))
672
+ require .NoError (t , <- firstResCh )
673
+ require .NoError (t , <- secondResCh )
674
+ }
0 commit comments