diff --git a/br/pkg/task/BUILD.bazel b/br/pkg/task/BUILD.bazel index 0dabd8646fc6d..424388f57e573 100644 --- a/br/pkg/task/BUILD.bazel +++ b/br/pkg/task/BUILD.bazel @@ -114,7 +114,7 @@ go_test( ], embed = [":task"], flaky = True, - shard_count = 38, + shard_count = 39, deps = [ "//br/pkg/backup", "//br/pkg/config", diff --git a/br/pkg/task/common.go b/br/pkg/task/common.go index 93f9b3a08c2d9..ebb53968d5fea 100644 --- a/br/pkg/task/common.go +++ b/br/pkg/task/common.go @@ -28,7 +28,6 @@ import ( "github.com/pingcap/tidb/br/pkg/metautil" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/br/pkg/utils" - "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/sessionctx/variable" filter "github.com/pingcap/tidb/pkg/util/table-filter" "github.com/spf13/cobra" @@ -115,9 +114,7 @@ const ( ) const ( - // Once TableInfoVersion updated. BR need to check compatibility with - // new TableInfoVersion. both snapshot restore and pitr need to be checked. - CURRENT_BACKUP_SUPPORT_TABLE_INFO_VERSION = model.TableInfoVersion5 + cipherKeyNonHexErrorMsg = "cipher key must be a valid hexadecimal string" ) // FullBackupType type when doing full backup or restore @@ -464,34 +461,52 @@ func GetCipherKeyContent(cipherKey, cipherKeyFile string) ([]byte, error) { return nil, errors.Trace(err) } - // if cipher-key is valid, convert the hexadecimal string to bytes + var hexString string + + // Check if cipher-key is provided directly if len(cipherKey) > 0 { - return hex.DecodeString(cipherKey) + hexString = cipherKey + } else { + // Read content from cipher-file + content, err := os.ReadFile(cipherKeyFile) + if err != nil { + return nil, errors.Annotate(err, "failed to read cipher file") + } + hexString = string(bytes.TrimSuffix(content, []byte("\n"))) } - // convert the content(as hexadecimal string) from cipher-file to bytes - content, err := os.ReadFile(cipherKeyFile) + // Attempt to decode the hex string + decodedKey, err := hex.DecodeString(hexString) if err != nil { - return nil, errors.Annotate(err, "failed to read cipher file") + return nil, errors.Annotate(berrors.ErrInvalidArgument, cipherKeyNonHexErrorMsg) } - content = bytes.TrimSuffix(content, []byte("\n")) - return hex.DecodeString(string(content)) + return decodedKey, nil } -func checkCipherKeyMatch(cipher *backuppb.CipherInfo) bool { +func checkCipherKeyMatch(cipher *backuppb.CipherInfo) error { switch cipher.CipherType { case encryptionpb.EncryptionMethod_PLAINTEXT: - return true + return nil case encryptionpb.EncryptionMethod_AES128_CTR: - return len(cipher.CipherKey) == crypterAES128KeyLen + if len(cipher.CipherKey) != crypterAES128KeyLen { + return errors.Annotatef(berrors.ErrInvalidArgument, "AES-128 key length mismatch: expected %d, got %d", + crypterAES128KeyLen, len(cipher.CipherKey)) + } case encryptionpb.EncryptionMethod_AES192_CTR: - return len(cipher.CipherKey) == crypterAES192KeyLen + if len(cipher.CipherKey) != crypterAES192KeyLen { + return errors.Annotatef(berrors.ErrInvalidArgument, "AES-192 key length mismatch: expected %d, got %d", + crypterAES192KeyLen, len(cipher.CipherKey)) + } case encryptionpb.EncryptionMethod_AES256_CTR: - return len(cipher.CipherKey) == crypterAES256KeyLen + if len(cipher.CipherKey) != crypterAES256KeyLen { + return errors.Annotatef(berrors.ErrInvalidArgument, "AES-256 key length mismatch: expected %d, got %d", + crypterAES256KeyLen, len(cipher.CipherKey)) + } default: - return false + return errors.Errorf("Unknown encryption method: %v", cipher.CipherType) } + return nil } func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error { @@ -524,8 +539,9 @@ func (cfg *Config) parseCipherInfo(flags *pflag.FlagSet) error { return errors.Trace(err) } - if !checkCipherKeyMatch(&cfg.CipherInfo) { - return errors.Annotate(berrors.ErrInvalidArgument, "crypter method and key length not match") + err = checkCipherKeyMatch(&cfg.CipherInfo) + if err != nil { + return errors.Trace(err) } return nil @@ -561,8 +577,9 @@ func (cfg *Config) parseLogBackupCipherInfo(flags *pflag.FlagSet) (bool, error) return false, errors.Trace(err) } - if !checkCipherKeyMatch(&cfg.CipherInfo) { - return false, errors.Annotate(berrors.ErrInvalidArgument, "log backup encryption method and key length not match") + err = checkCipherKeyMatch(&cfg.CipherInfo) + if err != nil { + return false, errors.Trace(err) } return true, nil diff --git a/br/pkg/task/common_test.go b/br/pkg/task/common_test.go index c4433da574109..5979ef6eebeeb 100644 --- a/br/pkg/task/common_test.go +++ b/br/pkg/task/common_test.go @@ -3,7 +3,6 @@ package task import ( - "encoding/hex" "fmt" "testing" @@ -70,57 +69,89 @@ func TestStripingPDURL(t *testing.T) { func TestCheckCipherKeyMatch(t *testing.T) { cases := []struct { - CipherType encryptionpb.EncryptionMethod - CipherKey string - ok bool + name string + cipherInfo *backup.CipherInfo + expectErr bool + errMsg string }{ { - CipherType: encryptionpb.EncryptionMethod_PLAINTEXT, - ok: true, + name: "PLAINTEXT", + cipherInfo: &backup.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_PLAINTEXT, + }, + expectErr: false, }, { - CipherType: encryptionpb.EncryptionMethod_UNKNOWN, - ok: false, + name: "UNKNOWN", + cipherInfo: &backup.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_UNKNOWN, + }, + expectErr: true, + errMsg: "Unknown encryption method: UNKNOWN", }, { - CipherType: encryptionpb.EncryptionMethod_AES128_CTR, - CipherKey: "0123456789abcdef0123456789abcdef", - ok: true, + name: "AES128_CTR valid", + cipherInfo: &backup.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_AES128_CTR, + CipherKey: make([]byte, crypterAES128KeyLen), + }, + expectErr: false, }, { - CipherType: encryptionpb.EncryptionMethod_AES128_CTR, - CipherKey: "0123456789abcdef0123456789abcd", - ok: false, + name: "AES128_CTR invalid length", + cipherInfo: &backup.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_AES128_CTR, + CipherKey: make([]byte, crypterAES128KeyLen-1), + }, + expectErr: true, + errMsg: fmt.Sprintf("AES-128 key length mismatch: expected %d, got %d", crypterAES128KeyLen, crypterAES128KeyLen-1), }, { - CipherType: encryptionpb.EncryptionMethod_AES192_CTR, - CipherKey: "0123456789abcdef0123456789abcdef0123456789abcdef", - ok: true, + name: "AES192_CTR valid", + cipherInfo: &backup.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_AES192_CTR, + CipherKey: make([]byte, crypterAES192KeyLen), + }, + expectErr: false, }, { - CipherType: encryptionpb.EncryptionMethod_AES192_CTR, - CipherKey: "0123456789abcdef0123456789abcdef0123456789abcdefff", - ok: false, + name: "AES192_CTR invalid length", + cipherInfo: &backup.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_AES192_CTR, + CipherKey: make([]byte, crypterAES192KeyLen+1), + }, + expectErr: true, + errMsg: fmt.Sprintf("AES-192 key length mismatch: expected %d, got %d", crypterAES192KeyLen, crypterAES192KeyLen+1), }, { - CipherType: encryptionpb.EncryptionMethod_AES256_CTR, - CipherKey: "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", - ok: true, + name: "AES256_CTR valid", + cipherInfo: &backup.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_AES256_CTR, + CipherKey: make([]byte, crypterAES256KeyLen), + }, + expectErr: false, }, { - CipherType: encryptionpb.EncryptionMethod_AES256_CTR, - CipherKey: "", - ok: false, + name: "AES256_CTR invalid length", + cipherInfo: &backup.CipherInfo{ + CipherType: encryptionpb.EncryptionMethod_AES256_CTR, + CipherKey: make([]byte, 0), + }, + expectErr: true, + errMsg: fmt.Sprintf("AES-256 key length mismatch: expected %d, got %d", crypterAES256KeyLen, 0), }, } for _, c := range cases { - cipherKey, err := hex.DecodeString(c.CipherKey) - require.NoError(t, err) - require.Equal(t, c.ok, checkCipherKeyMatch(&backup.CipherInfo{ - CipherType: c.CipherType, - CipherKey: cipherKey, - })) + t.Run(c.name, func(t *testing.T) { + err := checkCipherKeyMatch(c.cipherInfo) + if c.expectErr { + require.Error(t, err) + require.Contains(t, err.Error(), c.errMsg) + } else { + require.NoError(t, err) + } + }) } } @@ -162,6 +193,13 @@ func TestCheckCipherKey(t *testing.T) { } } +func TestGetCipherKey(t *testing.T) { + nonHexKey := "this is not a hex string" + _, err := GetCipherKeyContent(nonHexKey, "") + require.Error(t, err) + require.Contains(t, err.Error(), cipherKeyNonHexErrorMsg) +} + func must[T any](t T, err error) T { if err != nil { panic(err)