Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/ddl/backfilling_dist_scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func TestBackfillingSchedulerGlobalSortMode(t *testing.T) {
ext.(*ddl.LitBackfillScheduler).GlobalSort = true
sch.Extension = ext

taskID, err := mgr.CreateTask(ctx, task.Key, proto.Backfill, 1, "", 0, task.Meta)
taskID, err := mgr.CreateTask(ctx, task.Key, proto.Backfill, 1, "", 0, proto.ExtraParams{}, task.Meta)
require.NoError(t, err)
task.ID = taskID
execIDs := []string{":4000"}
Expand Down
8 changes: 8 additions & 0 deletions pkg/disttask/framework/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@
// └─────────►│cancelling├────┘
// └──────────┘
//
// Note: if ManualRecovery is enabled, when some subtask failed, the task will
// move to `awaiting-resolution` state, and manual operation is needed for the
// task to continue. This mechanism is used for debugging, some bug such as those
// on global-sort are harder to investigate without the intermediate files, or to
// manually recover from some error when importing large mount of data using
// global-sort where one round of import takes a lot of time, it might be more
// flexible and efficient than retrying the whole task.
//
// pause/resume state transition:
// as we don't know the state of the task before `paused`, so the state after
// `resuming` is always `running`.
Expand Down
2 changes: 1 addition & 1 deletion pkg/disttask/framework/handle/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func SubmitTask(ctx context.Context, taskKey string, taskType proto.TaskType, co
return nil, storage.ErrTaskAlreadyExists
}

taskID, err := taskManager.CreateTask(ctx, taskKey, taskType, concurrency, targetScope, maxNodeCnt, taskMeta)
taskID, err := taskManager.CreateTask(ctx, taskKey, taskType, concurrency, targetScope, maxNodeCnt, proto.ExtraParams{}, taskMeta)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/disttask/framework/integrationtests/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ go_test(
],
flaky = True,
race = "off",
shard_count = 23,
shard_count = 22,
deps = [
"//pkg/config",
"//pkg/ddl",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,84 @@
package integrationtests

import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/disttask/framework/storage"
"github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute"
"github.com/pingcap/tidb/pkg/disttask/framework/testutil"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/stretchr/testify/require"
)

func TestRetryErrOnNextSubtasksBatch(t *testing.T) {
func TestOnTaskError(t *testing.T) {
c := testutil.NewTestDXFContext(t, 2, 16, true)
registerExampleTask(t, c.MockCtrl, testutil.GetPlanErrSchedulerExt(c.MockCtrl, c.TestContext), c.TestContext, nil)
submitTaskAndCheckSuccessForBasic(c.Ctx, t, "key1", c.TestContext)
}

func TestPlanNotRetryableOnNextSubtasksBatchErr(t *testing.T) {
c := testutil.NewTestDXFContext(t, 2, 16, true)
t.Run("retryable error on OnNextSubtasksBatch", func(t *testing.T) {
registerExampleTask(t, c.MockCtrl, testutil.GetPlanErrSchedulerExt(c.MockCtrl, c.TestContext), c.TestContext, nil)
submitTaskAndCheckSuccessForBasic(c.Ctx, t, "key1", c.TestContext)
})

t.Run("non retryable error on OnNextSubtasksBatch", func(t *testing.T) {
registerExampleTask(t, c.MockCtrl, testutil.GetPlanNotRetryableErrSchedulerExt(c.MockCtrl), c.TestContext, nil)
task := testutil.SubmitAndWaitTask(c.Ctx, t, "key2-1", "", 1)
require.Equal(t, proto.TaskStateReverted, task.State)
registerExampleTask(t, c.MockCtrl, testutil.GetStepTwoPlanNotRetryableErrSchedulerExt(c.MockCtrl), c.TestContext, nil)
task = testutil.SubmitAndWaitTask(c.Ctx, t, "key2-2", "", 1)
require.Equal(t, proto.TaskStateReverted, task.State)
})

prepareForAwaitingResolutionTestFn := func(t *testing.T, taskKey string) int64 {
subtaskErrRetryable := atomic.Bool{}
executorExt := testutil.GetTaskExecutorExt(c.MockCtrl,
func(task *proto.Task) (execute.StepExecutor, error) {
return testutil.GetCommonStepExecutor(c.MockCtrl, task.Step, func(ctx context.Context, subtask *proto.Subtask) error {
if !subtaskErrRetryable.Load() {
return errors.New("non retryable subtask error")
}
return nil
}), nil
},
func(error) bool {
return subtaskErrRetryable.Load()
},
)
testutil.RegisterExampleTask(t, testutil.GetPlanErrSchedulerExt(c.MockCtrl, c.TestContext),
executorExt, testutil.GetCommonCleanUpRoutine(c.MockCtrl))
tm, err := storage.GetTaskManager()
require.NoError(t, err)
taskID, err := tm.CreateTask(c.Ctx, taskKey, proto.TaskTypeExample, 1, "",
2, proto.ExtraParams{ManualRecovery: true}, nil)
require.NoError(t, err)
require.Eventually(t, func() bool {
task, err := tm.GetTaskByID(c.Ctx, taskID)
require.NoError(t, err)
return task.State == proto.TaskStateAwaitingResolution
}, 10*time.Second, 100*time.Millisecond)
subtaskErrRetryable.Store(true)
return taskID
}

t.Run("task enter awaiting-resolution state if ManualRecovery set, success after manual recover", func(t *testing.T) {
taskKey := "key3-1"
taskID := prepareForAwaitingResolutionTestFn(t, taskKey)
tk := testkit.NewTestKit(t, c.Store)
tk.MustExec(fmt.Sprintf("update mysql.tidb_background_subtask set state='pending' where state='failed' and task_key= %d", taskID))
tk.MustExec(fmt.Sprintf("update mysql.tidb_global_task set state='running' where id = %d", taskID))
Comment on lines +85 to +86
Copy link
Preview

Copilot AI Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The task key is being used directly in the SQL query, which could lead to SQL injection vulnerabilities. Use parameterized queries to avoid potential issues.

Suggested change
tk.MustExec(fmt.Sprintf("update mysql.tidb_background_subtask set state='pending' where state='failed' and task_key= %d", taskID))
tk.MustExec(fmt.Sprintf("update mysql.tidb_global_task set state='running' where id = %d", taskID))
tk.MustExec("update mysql.tidb_background_subtask set state='pending' where state='failed' and task_key=?", taskID)
tk.MustExec("update mysql.tidb_global_task set state='running' where id=?", taskID)

Copilot uses AI. Check for mistakes.

Comment on lines +85 to +86
Copy link
Preview

Copilot AI Feb 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The task ID is being used directly in the SQL query, which could lead to SQL injection vulnerabilities. Use parameterized queries to avoid potential issues.

Suggested change
tk.MustExec(fmt.Sprintf("update mysql.tidb_background_subtask set state='pending' where state='failed' and task_key= %d", taskID))
tk.MustExec(fmt.Sprintf("update mysql.tidb_global_task set state='running' where id = %d", taskID))
tk.MustExec("update mysql.tidb_background_subtask set state='pending' where state='failed' and task_key= ?", taskID)
tk.MustExec("update mysql.tidb_global_task set state='running' where id = ?", taskID)

Copilot uses AI. Check for mistakes.

task := testutil.WaitTaskDone(c.Ctx, t, taskKey)
require.Equal(t, proto.TaskStateSucceed, task.State)
})

registerExampleTask(t, c.MockCtrl, testutil.GetPlanNotRetryableErrSchedulerExt(c.MockCtrl), c.TestContext, nil)
task := testutil.SubmitAndWaitTask(c.Ctx, t, "key1", "", 1)
require.Equal(t, proto.TaskStateReverted, task.State)
registerExampleTask(t, c.MockCtrl, testutil.GetStepTwoPlanNotRetryableErrSchedulerExt(c.MockCtrl), c.TestContext, nil)
task = testutil.SubmitAndWaitTask(c.Ctx, t, "key2", "", 1)
require.Equal(t, proto.TaskStateReverted, task.State)
t.Run("task enter awaiting-resolution state if ManualRecovery set, cancel also works", func(t *testing.T) {
taskKey := "key4-1"
taskID := prepareForAwaitingResolutionTestFn(t, taskKey)
require.NoError(t, c.TaskMgr.CancelTask(c.Ctx, taskID))
task := testutil.WaitTaskDone(c.Ctx, t, taskKey)
require.Equal(t, proto.TaskStateReverted, task.State)
})
}
14 changes: 14 additions & 0 deletions pkg/disttask/framework/mock/scheduler_mock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pkg/disttask/framework/planner/planner.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package planner

import (
"github.com/pingcap/tidb/pkg/config"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/disttask/framework/storage"
)

Expand Down Expand Up @@ -47,6 +48,7 @@ func (*Planner) Run(planCtx PlanCtx, plan LogicalPlan) (int64, error) {
planCtx.ThreadCnt,
config.GetGlobalConfig().Instance.TiDBServiceScope,
planCtx.MaxNodeCnt,
proto.ExtraParams{},
taskMeta,
)
}
33 changes: 22 additions & 11 deletions pkg/disttask/framework/proto/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@ import (

// see doc.go for more details.
const (
TaskStatePending TaskState = "pending"
TaskStateRunning TaskState = "running"
TaskStateSucceed TaskState = "succeed"
TaskStateFailed TaskState = "failed"
TaskStateReverting TaskState = "reverting"
TaskStateReverted TaskState = "reverted"
TaskStateCancelling TaskState = "cancelling"
TaskStatePausing TaskState = "pausing"
TaskStatePaused TaskState = "paused"
TaskStateResuming TaskState = "resuming"
TaskStateModifying TaskState = "modifying"
TaskStatePending TaskState = "pending"
TaskStateRunning TaskState = "running"
TaskStateSucceed TaskState = "succeed"
TaskStateFailed TaskState = "failed"
TaskStateReverting TaskState = "reverting"
TaskStateAwaitingResolution TaskState = "awaiting-resolution"
TaskStateReverted TaskState = "reverted"
TaskStateCancelling TaskState = "cancelling"
TaskStatePausing TaskState = "pausing"
TaskStatePaused TaskState = "paused"
TaskStateResuming TaskState = "resuming"
TaskStateModifying TaskState = "modifying"
)

type (
Expand Down Expand Up @@ -66,6 +67,15 @@ const (
// TODO: remove this limit later.
var MaxConcurrentTask = 16

// ExtraParams is the extra params of task.
// Note: only store params that's not used for filter or sort in this struct.
type ExtraParams struct {
// ManualRecovery indicates whether the task can be recovered manually.
// if enabled, the task will enter 'awaiting-resolution' state when it failed,
// then the user can recover the task manually or fail it if it's not recoverable.
ManualRecovery bool `json:"manual_recovery"`
}

// TaskBase contains the basic information of a task.
// we define this to avoid load task meta which might be very large into memory.
type TaskBase struct {
Expand All @@ -87,6 +97,7 @@ type TaskBase struct {
TargetScope string
CreateTime time.Time
MaxNodeCount int
ExtraParams
}

// IsDone checks if the task is done.
Expand Down
2 changes: 2 additions & 0 deletions pkg/disttask/framework/scheduler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ type TaskManager interface {
FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error
// RevertTask updates task state to reverting, and task error.
RevertTask(ctx context.Context, taskID int64, taskState proto.TaskState, taskErr error) error
// AwaitingResolveTask updates task state to awaiting-resolve, also set task err.
AwaitingResolveTask(ctx context.Context, taskID int64, taskState proto.TaskState, taskErr error) error
// RevertedTask updates task state to reverted.
RevertedTask(ctx context.Context, taskID int64) error
// PauseTask updated task state to pausing.
Expand Down
16 changes: 15 additions & 1 deletion pkg/disttask/framework/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func (s *BaseScheduler) onRunning() error {
if len(subTaskErrs) > 0 {
s.logger.Warn("subtasks encounter errors", zap.Errors("subtask-errs", subTaskErrs))
// we only store the first error as task error.
return s.revertTask(subTaskErrs[0])
return s.revertTaskOrManualRecover(subTaskErrs[0])
}
} else if s.isStepSucceed(cntByStates) {
return s.switch2NextStep()
Expand Down Expand Up @@ -586,6 +586,20 @@ func (s *BaseScheduler) revertTask(taskErr error) error {
return nil
}

func (s *BaseScheduler) revertTaskOrManualRecover(taskErr error) error {
task := s.getTaskClone()
if task.ManualRecovery {
if err := s.taskMgr.AwaitingResolveTask(s.ctx, task.ID, task.State, taskErr); err != nil {
return err
}
task.State = proto.TaskStateAwaitingResolution
task.Error = taskErr
s.task.Store(task)
return nil
}
return s.revertTask(taskErr)
}

// MockServerInfo exported for scheduler_test.go
var MockServerInfo atomic.Pointer[[]string]

Expand Down
2 changes: 1 addition & 1 deletion pkg/disttask/framework/scheduler/scheduler_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestCleanUpRoutine(t *testing.T) {
mockCleanupRoutine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes()
sch.Start()
defer sch.Stop()
taskID, err := mgr.CreateTask(ctx, "test", proto.TaskTypeExample, 1, "", 0, nil)
taskID, err := mgr.CreateTask(ctx, "test", proto.TaskTypeExample, 1, "", 0, proto.ExtraParams{}, nil)
require.NoError(t, err)

checkTaskRunningCnt := func() []*proto.Task {
Expand Down
10 changes: 5 additions & 5 deletions pkg/disttask/framework/scheduler/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ func TestTaskFailInManager(t *testing.T) {
defer schManager.Stop()

// unknown task type
taskID, err := mgr.CreateTask(ctx, "test", "test-type", 1, "", 0, nil)
taskID, err := mgr.CreateTask(ctx, "test", "test-type", 1, "", 0, proto.ExtraParams{}, nil)
require.NoError(t, err)
require.Eventually(t, func() bool {
task, err := mgr.GetTaskByID(ctx, taskID)
Expand All @@ -140,7 +140,7 @@ func TestTaskFailInManager(t *testing.T) {
}, time.Second*10, time.Millisecond*300)

// scheduler init error
taskID, err = mgr.CreateTask(ctx, "test2", proto.TaskTypeExample, 1, "", 0, nil)
taskID, err = mgr.CreateTask(ctx, "test2", proto.TaskTypeExample, 1, "", 0, proto.ExtraParams{}, nil)
require.NoError(t, err)
require.Eventually(t, func() bool {
task, err := mgr.GetTaskByID(ctx, taskID)
Expand Down Expand Up @@ -215,7 +215,7 @@ func checkSchedule(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel,
// Mock add tasks.
taskIDs := make([]int64, 0, taskCnt)
for i := 0; i < taskCnt; i++ {
taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", i), proto.TaskTypeExample, 0, "background", 0, nil)
taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", i), proto.TaskTypeExample, 0, "background", 0, proto.ExtraParams{}, nil)
require.NoError(t, err)
taskIDs = append(taskIDs, taskID)
}
Expand All @@ -225,7 +225,7 @@ func checkSchedule(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel,
checkSubtaskCnt(tasks, taskIDs)
// test parallelism control
if taskCnt == 1 {
taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", taskCnt), proto.TaskTypeExample, 0, "background", 0, nil)
taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", taskCnt), proto.TaskTypeExample, 0, "background", 0, proto.ExtraParams{}, nil)
require.NoError(t, err)
checkGetRunningTaskCnt(taskCnt)
// Clean the task.
Expand Down Expand Up @@ -460,7 +460,7 @@ func TestManagerScheduleLoop(t *testing.T) {
},
)
for i := 0; i < len(concurrencies); i++ {
_, err := taskMgr.CreateTask(ctx, fmt.Sprintf("key/%d", i), proto.TaskTypeExample, concurrencies[i], "", 0, []byte("{}"))
_, err := taskMgr.CreateTask(ctx, fmt.Sprintf("key/%d", i), proto.TaskTypeExample, concurrencies[i], "", 0, proto.ExtraParams{}, []byte("{}"))
require.NoError(t, err)
}
getRunningTaskKeys := func() []string {
Expand Down
Loading