diff --git a/pkg/ddl/backfilling_dist_scheduler_test.go b/pkg/ddl/backfilling_dist_scheduler_test.go index 4f471bb33d3e2..09eae5ac17041 100644 --- a/pkg/ddl/backfilling_dist_scheduler_test.go +++ b/pkg/ddl/backfilling_dist_scheduler_test.go @@ -158,7 +158,7 @@ func TestBackfillingSchedulerGlobalSortMode(t *testing.T) { ext.(*ddl.LitBackfillScheduler).GlobalSort = true sch.Extension = ext - taskID, err := mgr.CreateTask(ctx, task.Key, proto.Backfill, 1, "", 0, task.Meta) + taskID, err := mgr.CreateTask(ctx, task.Key, proto.Backfill, 1, "", 0, proto.ExtraParams{}, task.Meta) require.NoError(t, err) task.ID = taskID execIDs := []string{":4000"} diff --git a/pkg/disttask/framework/doc.go b/pkg/disttask/framework/doc.go index 435588ed271f3..7172c21d9c5c1 100644 --- a/pkg/disttask/framework/doc.go +++ b/pkg/disttask/framework/doc.go @@ -125,6 +125,14 @@ // └─────────►│cancelling├────┘ // └──────────┘ // +// Note: if ManualRecovery is enabled, when some subtask failed, the task will +// move to `awaiting-resolution` state, and manual operation is needed for the +// task to continue. This mechanism is used for debugging, some bug such as those +// on global-sort are harder to investigate without the intermediate files, or to +// manually recover from some error when importing large mount of data using +// global-sort where one round of import takes a lot of time, it might be more +// flexible and efficient than retrying the whole task. +// // pause/resume state transition: // as we don't know the state of the task before `paused`, so the state after // `resuming` is always `running`. diff --git a/pkg/disttask/framework/handle/handle.go b/pkg/disttask/framework/handle/handle.go index a362444f3db98..82cf36e9f5047 100644 --- a/pkg/disttask/framework/handle/handle.go +++ b/pkg/disttask/framework/handle/handle.go @@ -69,7 +69,7 @@ func SubmitTask(ctx context.Context, taskKey string, taskType proto.TaskType, co return nil, storage.ErrTaskAlreadyExists } - taskID, err := taskManager.CreateTask(ctx, taskKey, taskType, concurrency, targetScope, maxNodeCnt, taskMeta) + taskID, err := taskManager.CreateTask(ctx, taskKey, taskType, concurrency, targetScope, maxNodeCnt, proto.ExtraParams{}, taskMeta) if err != nil { return nil, err } diff --git a/pkg/disttask/framework/integrationtests/BUILD.bazel b/pkg/disttask/framework/integrationtests/BUILD.bazel index 768ef148116d6..5290394c61708 100644 --- a/pkg/disttask/framework/integrationtests/BUILD.bazel +++ b/pkg/disttask/framework/integrationtests/BUILD.bazel @@ -17,7 +17,7 @@ go_test( ], flaky = True, race = "off", - shard_count = 23, + shard_count = 22, deps = [ "//pkg/config", "//pkg/ddl", diff --git a/pkg/disttask/framework/integrationtests/framework_err_handling_test.go b/pkg/disttask/framework/integrationtests/framework_err_handling_test.go index 31bff5cd1d1a7..26db8c5712c20 100644 --- a/pkg/disttask/framework/integrationtests/framework_err_handling_test.go +++ b/pkg/disttask/framework/integrationtests/framework_err_handling_test.go @@ -15,26 +15,84 @@ package integrationtests import ( + "context" + "fmt" + "sync/atomic" "testing" + "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/storage" + "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute" "github.com/pingcap/tidb/pkg/disttask/framework/testutil" + "github.com/pingcap/tidb/pkg/testkit" "github.com/stretchr/testify/require" ) -func TestRetryErrOnNextSubtasksBatch(t *testing.T) { +func TestOnTaskError(t *testing.T) { c := testutil.NewTestDXFContext(t, 2, 16, true) - registerExampleTask(t, c.MockCtrl, testutil.GetPlanErrSchedulerExt(c.MockCtrl, c.TestContext), c.TestContext, nil) - submitTaskAndCheckSuccessForBasic(c.Ctx, t, "key1", c.TestContext) -} -func TestPlanNotRetryableOnNextSubtasksBatchErr(t *testing.T) { - c := testutil.NewTestDXFContext(t, 2, 16, true) + t.Run("retryable error on OnNextSubtasksBatch", func(t *testing.T) { + registerExampleTask(t, c.MockCtrl, testutil.GetPlanErrSchedulerExt(c.MockCtrl, c.TestContext), c.TestContext, nil) + submitTaskAndCheckSuccessForBasic(c.Ctx, t, "key1", c.TestContext) + }) + + t.Run("non retryable error on OnNextSubtasksBatch", func(t *testing.T) { + registerExampleTask(t, c.MockCtrl, testutil.GetPlanNotRetryableErrSchedulerExt(c.MockCtrl), c.TestContext, nil) + task := testutil.SubmitAndWaitTask(c.Ctx, t, "key2-1", "", 1) + require.Equal(t, proto.TaskStateReverted, task.State) + registerExampleTask(t, c.MockCtrl, testutil.GetStepTwoPlanNotRetryableErrSchedulerExt(c.MockCtrl), c.TestContext, nil) + task = testutil.SubmitAndWaitTask(c.Ctx, t, "key2-2", "", 1) + require.Equal(t, proto.TaskStateReverted, task.State) + }) + + prepareForAwaitingResolutionTestFn := func(t *testing.T, taskKey string) int64 { + subtaskErrRetryable := atomic.Bool{} + executorExt := testutil.GetTaskExecutorExt(c.MockCtrl, + func(task *proto.Task) (execute.StepExecutor, error) { + return testutil.GetCommonStepExecutor(c.MockCtrl, task.Step, func(ctx context.Context, subtask *proto.Subtask) error { + if !subtaskErrRetryable.Load() { + return errors.New("non retryable subtask error") + } + return nil + }), nil + }, + func(error) bool { + return subtaskErrRetryable.Load() + }, + ) + testutil.RegisterExampleTask(t, testutil.GetPlanErrSchedulerExt(c.MockCtrl, c.TestContext), + executorExt, testutil.GetCommonCleanUpRoutine(c.MockCtrl)) + tm, err := storage.GetTaskManager() + require.NoError(t, err) + taskID, err := tm.CreateTask(c.Ctx, taskKey, proto.TaskTypeExample, 1, "", + 2, proto.ExtraParams{ManualRecovery: true}, nil) + require.NoError(t, err) + require.Eventually(t, func() bool { + task, err := tm.GetTaskByID(c.Ctx, taskID) + require.NoError(t, err) + return task.State == proto.TaskStateAwaitingResolution + }, 10*time.Second, 100*time.Millisecond) + subtaskErrRetryable.Store(true) + return taskID + } + + t.Run("task enter awaiting-resolution state if ManualRecovery set, success after manual recover", func(t *testing.T) { + taskKey := "key3-1" + taskID := prepareForAwaitingResolutionTestFn(t, taskKey) + tk := testkit.NewTestKit(t, c.Store) + tk.MustExec(fmt.Sprintf("update mysql.tidb_background_subtask set state='pending' where state='failed' and task_key= %d", taskID)) + tk.MustExec(fmt.Sprintf("update mysql.tidb_global_task set state='running' where id = %d", taskID)) + task := testutil.WaitTaskDone(c.Ctx, t, taskKey) + require.Equal(t, proto.TaskStateSucceed, task.State) + }) - registerExampleTask(t, c.MockCtrl, testutil.GetPlanNotRetryableErrSchedulerExt(c.MockCtrl), c.TestContext, nil) - task := testutil.SubmitAndWaitTask(c.Ctx, t, "key1", "", 1) - require.Equal(t, proto.TaskStateReverted, task.State) - registerExampleTask(t, c.MockCtrl, testutil.GetStepTwoPlanNotRetryableErrSchedulerExt(c.MockCtrl), c.TestContext, nil) - task = testutil.SubmitAndWaitTask(c.Ctx, t, "key2", "", 1) - require.Equal(t, proto.TaskStateReverted, task.State) + t.Run("task enter awaiting-resolution state if ManualRecovery set, cancel also works", func(t *testing.T) { + taskKey := "key4-1" + taskID := prepareForAwaitingResolutionTestFn(t, taskKey) + require.NoError(t, c.TaskMgr.CancelTask(c.Ctx, taskID)) + task := testutil.WaitTaskDone(c.Ctx, t, taskKey) + require.Equal(t, proto.TaskStateReverted, task.State) + }) } diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index 5932d1b3c4221..85c6f46ff4eea 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -268,6 +268,20 @@ func (m *MockTaskManager) ISGOMOCK() struct{} { return struct{}{} } +// AwaitingResolveTask mocks base method. +func (m *MockTaskManager) AwaitingResolveTask(arg0 context.Context, arg1 int64, arg2 proto.TaskState, arg3 error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AwaitingResolveTask", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(error) + return ret0 +} + +// AwaitingResolveTask indicates an expected call of AwaitingResolveTask. +func (mr *MockTaskManagerMockRecorder) AwaitingResolveTask(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AwaitingResolveTask", reflect.TypeOf((*MockTaskManager)(nil).AwaitingResolveTask), arg0, arg1, arg2, arg3) +} + // CancelTask mocks base method. func (m *MockTaskManager) CancelTask(arg0 context.Context, arg1 int64) error { m.ctrl.T.Helper() diff --git a/pkg/disttask/framework/planner/planner.go b/pkg/disttask/framework/planner/planner.go index 850d3d4a58dd3..fb25bd554d61d 100644 --- a/pkg/disttask/framework/planner/planner.go +++ b/pkg/disttask/framework/planner/planner.go @@ -16,6 +16,7 @@ package planner import ( "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/storage" ) @@ -47,6 +48,7 @@ func (*Planner) Run(planCtx PlanCtx, plan LogicalPlan) (int64, error) { planCtx.ThreadCnt, config.GetGlobalConfig().Instance.TiDBServiceScope, planCtx.MaxNodeCnt, + proto.ExtraParams{}, taskMeta, ) } diff --git a/pkg/disttask/framework/proto/task.go b/pkg/disttask/framework/proto/task.go index 14264d7082c9d..0709f77891256 100644 --- a/pkg/disttask/framework/proto/task.go +++ b/pkg/disttask/framework/proto/task.go @@ -22,17 +22,18 @@ import ( // see doc.go for more details. const ( - TaskStatePending TaskState = "pending" - TaskStateRunning TaskState = "running" - TaskStateSucceed TaskState = "succeed" - TaskStateFailed TaskState = "failed" - TaskStateReverting TaskState = "reverting" - TaskStateReverted TaskState = "reverted" - TaskStateCancelling TaskState = "cancelling" - TaskStatePausing TaskState = "pausing" - TaskStatePaused TaskState = "paused" - TaskStateResuming TaskState = "resuming" - TaskStateModifying TaskState = "modifying" + TaskStatePending TaskState = "pending" + TaskStateRunning TaskState = "running" + TaskStateSucceed TaskState = "succeed" + TaskStateFailed TaskState = "failed" + TaskStateReverting TaskState = "reverting" + TaskStateAwaitingResolution TaskState = "awaiting-resolution" + TaskStateReverted TaskState = "reverted" + TaskStateCancelling TaskState = "cancelling" + TaskStatePausing TaskState = "pausing" + TaskStatePaused TaskState = "paused" + TaskStateResuming TaskState = "resuming" + TaskStateModifying TaskState = "modifying" ) type ( @@ -66,6 +67,15 @@ const ( // TODO: remove this limit later. var MaxConcurrentTask = 16 +// ExtraParams is the extra params of task. +// Note: only store params that's not used for filter or sort in this struct. +type ExtraParams struct { + // ManualRecovery indicates whether the task can be recovered manually. + // if enabled, the task will enter 'awaiting-resolution' state when it failed, + // then the user can recover the task manually or fail it if it's not recoverable. + ManualRecovery bool `json:"manual_recovery"` +} + // TaskBase contains the basic information of a task. // we define this to avoid load task meta which might be very large into memory. type TaskBase struct { @@ -87,6 +97,7 @@ type TaskBase struct { TargetScope string CreateTime time.Time MaxNodeCount int + ExtraParams } // IsDone checks if the task is done. diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index a36615facad15..464a23bc53194 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -45,6 +45,8 @@ type TaskManager interface { FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error // RevertTask updates task state to reverting, and task error. RevertTask(ctx context.Context, taskID int64, taskState proto.TaskState, taskErr error) error + // AwaitingResolveTask updates task state to awaiting-resolve, also set task err. + AwaitingResolveTask(ctx context.Context, taskID int64, taskState proto.TaskState, taskErr error) error // RevertedTask updates task state to reverted. RevertedTask(ctx context.Context, taskID int64) error // PauseTask updated task state to pausing. diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index 72f99e081cd79..411660d32177c 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -388,7 +388,7 @@ func (s *BaseScheduler) onRunning() error { if len(subTaskErrs) > 0 { s.logger.Warn("subtasks encounter errors", zap.Errors("subtask-errs", subTaskErrs)) // we only store the first error as task error. - return s.revertTask(subTaskErrs[0]) + return s.revertTaskOrManualRecover(subTaskErrs[0]) } } else if s.isStepSucceed(cntByStates) { return s.switch2NextStep() @@ -586,6 +586,20 @@ func (s *BaseScheduler) revertTask(taskErr error) error { return nil } +func (s *BaseScheduler) revertTaskOrManualRecover(taskErr error) error { + task := s.getTaskClone() + if task.ManualRecovery { + if err := s.taskMgr.AwaitingResolveTask(s.ctx, task.ID, task.State, taskErr); err != nil { + return err + } + task.State = proto.TaskStateAwaitingResolution + task.Error = taskErr + s.task.Store(task) + return nil + } + return s.revertTask(taskErr) +} + // MockServerInfo exported for scheduler_test.go var MockServerInfo atomic.Pointer[[]string] diff --git a/pkg/disttask/framework/scheduler/scheduler_manager_test.go b/pkg/disttask/framework/scheduler/scheduler_manager_test.go index f9d4d0fcdc8df..fcc65e557676d 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager_test.go @@ -46,7 +46,7 @@ func TestCleanUpRoutine(t *testing.T) { mockCleanupRoutine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() sch.Start() defer sch.Stop() - taskID, err := mgr.CreateTask(ctx, "test", proto.TaskTypeExample, 1, "", 0, nil) + taskID, err := mgr.CreateTask(ctx, "test", proto.TaskTypeExample, 1, "", 0, proto.ExtraParams{}, nil) require.NoError(t, err) checkTaskRunningCnt := func() []*proto.Task { diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index e7ba6784490ca..e4f160ec1fbb9 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -130,7 +130,7 @@ func TestTaskFailInManager(t *testing.T) { defer schManager.Stop() // unknown task type - taskID, err := mgr.CreateTask(ctx, "test", "test-type", 1, "", 0, nil) + taskID, err := mgr.CreateTask(ctx, "test", "test-type", 1, "", 0, proto.ExtraParams{}, nil) require.NoError(t, err) require.Eventually(t, func() bool { task, err := mgr.GetTaskByID(ctx, taskID) @@ -140,7 +140,7 @@ func TestTaskFailInManager(t *testing.T) { }, time.Second*10, time.Millisecond*300) // scheduler init error - taskID, err = mgr.CreateTask(ctx, "test2", proto.TaskTypeExample, 1, "", 0, nil) + taskID, err = mgr.CreateTask(ctx, "test2", proto.TaskTypeExample, 1, "", 0, proto.ExtraParams{}, nil) require.NoError(t, err) require.Eventually(t, func() bool { task, err := mgr.GetTaskByID(ctx, taskID) @@ -215,7 +215,7 @@ func checkSchedule(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, // Mock add tasks. taskIDs := make([]int64, 0, taskCnt) for i := 0; i < taskCnt; i++ { - taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", i), proto.TaskTypeExample, 0, "background", 0, nil) + taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", i), proto.TaskTypeExample, 0, "background", 0, proto.ExtraParams{}, nil) require.NoError(t, err) taskIDs = append(taskIDs, taskID) } @@ -225,7 +225,7 @@ func checkSchedule(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, checkSubtaskCnt(tasks, taskIDs) // test parallelism control if taskCnt == 1 { - taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", taskCnt), proto.TaskTypeExample, 0, "background", 0, nil) + taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", taskCnt), proto.TaskTypeExample, 0, "background", 0, proto.ExtraParams{}, nil) require.NoError(t, err) checkGetRunningTaskCnt(taskCnt) // Clean the task. @@ -460,7 +460,7 @@ func TestManagerScheduleLoop(t *testing.T) { }, ) for i := 0; i < len(concurrencies); i++ { - _, err := taskMgr.CreateTask(ctx, fmt.Sprintf("key/%d", i), proto.TaskTypeExample, concurrencies[i], "", 0, []byte("{}")) + _, err := taskMgr.CreateTask(ctx, fmt.Sprintf("key/%d", i), proto.TaskTypeExample, concurrencies[i], "", 0, proto.ExtraParams{}, []byte("{}")) require.NoError(t, err) } getRunningTaskKeys := func() []string { diff --git a/pkg/disttask/framework/storage/converter.go b/pkg/disttask/framework/storage/converter.go index dbc55aa64110c..45aa54183c769 100644 --- a/pkg/disttask/framework/storage/converter.go +++ b/pkg/disttask/framework/storage/converter.go @@ -27,17 +27,27 @@ import ( ) func row2TaskBasic(r chunk.Row) *proto.TaskBase { + createTime, _ := r.GetTime(7).GoTime(time.Local) + extraParams := proto.ExtraParams{} + if !r.IsNull(10) { + str := r.GetJSON(10).String() + if err := json.Unmarshal([]byte(str), &extraParams); err != nil { + logutil.BgLogger().Error("unmarshal task extra params", zap.Error(err)) + } + } task := &proto.TaskBase{ - ID: r.GetInt64(0), - Key: r.GetString(1), - Type: proto.TaskType(r.GetString(2)), - State: proto.TaskState(r.GetString(3)), - Step: proto.Step(r.GetInt64(4)), - Priority: int(r.GetInt64(5)), - Concurrency: int(r.GetInt64(6)), - TargetScope: r.GetString(8), + ID: r.GetInt64(0), + Key: r.GetString(1), + Type: proto.TaskType(r.GetString(2)), + State: proto.TaskState(r.GetString(3)), + Step: proto.Step(r.GetInt64(4)), + Priority: int(r.GetInt64(5)), + Concurrency: int(r.GetInt64(6)), + CreateTime: createTime, + TargetScope: r.GetString(8), + MaxNodeCount: int(r.GetInt64(9)), + ExtraParams: extraParams, } - task.CreateTime, _ = r.GetTime(7).GoTime(time.Local) return task } @@ -46,18 +56,18 @@ func Row2Task(r chunk.Row) *proto.Task { taskBase := row2TaskBasic(r) task := &proto.Task{TaskBase: *taskBase} var startTime, updateTime time.Time - if !r.IsNull(9) { - startTime, _ = r.GetTime(9).GoTime(time.Local) + if !r.IsNull(11) { + startTime, _ = r.GetTime(11).GoTime(time.Local) } - if !r.IsNull(10) { - updateTime, _ = r.GetTime(10).GoTime(time.Local) + if !r.IsNull(12) { + updateTime, _ = r.GetTime(12).GoTime(time.Local) } task.StartTime = startTime task.StateUpdateTime = updateTime - task.Meta = r.GetBytes(11) - task.SchedulerID = r.GetString(12) - if !r.IsNull(13) { - errBytes := r.GetBytes(13) + task.Meta = r.GetBytes(13) + task.SchedulerID = r.GetString(14) + if !r.IsNull(15) { + errBytes := r.GetBytes(15) stdErr := errors.Normalize("") err := stdErr.UnmarshalJSON(errBytes) if err != nil { @@ -67,14 +77,12 @@ func Row2Task(r chunk.Row) *proto.Task { task.Error = stdErr } } - if !r.IsNull(14) { - str := r.GetJSON(14).String() + if !r.IsNull(16) { + str := r.GetJSON(16).String() if err := json.Unmarshal([]byte(str), &task.ModifyParam); err != nil { logutil.BgLogger().Error("unmarshal task modify param", zap.Error(err)) } } - maxNodeCnt := r.GetInt64(15) - task.MaxNodeCount = int(maxNodeCnt) return task } diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index f538929c623ad..7c73578f4e4c6 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -46,15 +46,16 @@ func TestTaskTable(t *testing.T) { require.NoError(t, gm.InitMeta(ctx, ":4000", "")) - _, err := gm.CreateTask(ctx, "key1", "test", 999, "", 0, []byte("test")) + _, err := gm.CreateTask(ctx, "key1", "test", 999, "", 0, proto.ExtraParams{}, []byte("test")) require.ErrorContains(t, err, "task concurrency(999) larger than cpu count") timeBeforeCreate := time.Unix(time.Now().Unix(), 0) - id, err := gm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) + id, err := gm.CreateTask(ctx, "key1", "test", 4, "aaa", + 12, proto.ExtraParams{ManualRecovery: true}, []byte("testmeta")) require.NoError(t, err) require.Equal(t, int64(1), id) - task, err := testutil.GetOneTask(ctx, gm) + task, err := gm.GetTaskByID(ctx, id) require.NoError(t, err) require.Equal(t, int64(1), task.ID) require.Equal(t, "key1", task.Key) @@ -63,7 +64,10 @@ func TestTaskTable(t *testing.T) { require.Equal(t, proto.NormalPriority, task.Priority) require.Equal(t, 4, task.Concurrency) require.Equal(t, proto.StepInit, task.Step) - require.Equal(t, []byte("test"), task.Meta) + require.Equal(t, "aaa", task.TargetScope) + require.Equal(t, 12, task.MaxNodeCount) + require.Equal(t, proto.ExtraParams{ManualRecovery: true}, task.ExtraParams) + require.Equal(t, []byte("testmeta"), task.Meta) require.GreaterOrEqual(t, task.CreateTime, timeBeforeCreate) require.Zero(t, task.StartTime) require.Zero(t, task.StateUpdateTime) @@ -99,11 +103,11 @@ func TestTaskTable(t *testing.T) { require.Equal(t, task.State, task6.State) // test cannot insert task with dup key - _, err = gm.CreateTask(ctx, "key1", "test2", 4, "", 0, []byte("test2")) + _, err = gm.CreateTask(ctx, "key1", "test2", 4, "", 0, proto.ExtraParams{}, []byte("test2")) require.EqualError(t, err, "[kv:1062]Duplicate entry 'key1' for key 'tidb_global_task.task_key'") // test cancel task - id, err = gm.CreateTask(ctx, "key2", "test", 4, "", 0, []byte("test")) + id, err = gm.CreateTask(ctx, "key2", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) cancelling, err := testutil.IsTaskCancelling(ctx, gm, id) @@ -115,7 +119,7 @@ func TestTaskTable(t *testing.T) { require.NoError(t, err) require.True(t, cancelling) - id, err = gm.CreateTask(ctx, "key-fail", "test2", 4, "", 0, []byte("test2")) + id, err = gm.CreateTask(ctx, "key-fail", "test2", 4, "", 0, proto.ExtraParams{}, []byte("test2")) require.NoError(t, err) // state not right, update nothing require.NoError(t, gm.FailTask(ctx, id, proto.TaskStateRunning, errors.New("test error"))) @@ -135,7 +139,7 @@ func TestTaskTable(t *testing.T) { require.GreaterOrEqual(t, endTime, curTime) // succeed a pending task, no effect - id, err = gm.CreateTask(ctx, "key-success", "test", 4, "", 0, []byte("test")) + id, err = gm.CreateTask(ctx, "key-success", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) require.NoError(t, gm.SucceedTask(ctx, id)) task, err = gm.GetTaskByID(ctx, id) @@ -154,7 +158,7 @@ func TestTaskTable(t *testing.T) { require.GreaterOrEqual(t, task.StateUpdateTime, startTime) // reverted a pending task, no effect - id, err = gm.CreateTask(ctx, "key-reverted", "test", 4, "", 0, []byte("test")) + id, err = gm.CreateTask(ctx, "key-reverted", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) require.NoError(t, gm.RevertedTask(ctx, id)) task, err = gm.GetTaskByID(ctx, id) @@ -178,7 +182,7 @@ func TestTaskTable(t *testing.T) { require.Equal(t, proto.TaskStateReverted, task.State) // paused - id, err = gm.CreateTask(ctx, "key-paused", "test", 4, "", 0, []byte("test")) + id, err = gm.CreateTask(ctx, "key-paused", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) require.NoError(t, gm.PausedTask(ctx, id)) task, err = gm.GetTaskByID(ctx, id) @@ -229,7 +233,7 @@ func TestSwitchTaskStep(t *testing.T) { tk := testkit.NewTestKit(t, store) require.NoError(t, tm.InitMeta(ctx, ":4000", "")) - taskID, err := tm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) + taskID, err := tm.CreateTask(ctx, "key1", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) task, err := tm.GetTaskByID(ctx, taskID) require.NoError(t, err) @@ -281,7 +285,7 @@ func TestSwitchTaskStepInBatch(t *testing.T) { require.NoError(t, tm.InitMeta(ctx, ":4000", "")) // normal flow prepare := func(taskKey string) (*proto.Task, []*proto.Subtask) { - taskID, err := tm.CreateTask(ctx, taskKey, "test", 4, "", 0, []byte("test")) + taskID, err := tm.CreateTask(ctx, taskKey, "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) task, err := tm.GetTaskByID(ctx, taskID) require.NoError(t, err) @@ -373,7 +377,7 @@ func TestGetTopUnfinishedTasks(t *testing.T) { } for i, state := range taskStates { taskKey := fmt.Sprintf("key/%d", i) - _, err := gm.CreateTask(ctx, taskKey, "test", 4, "", 0, []byte("test")) + _, err := gm.CreateTask(ctx, taskKey, "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { _, err := se.GetSQLExecutor().ExecuteInternal(ctx, ` @@ -446,7 +450,7 @@ func TestGetUsedSlotsOnNodes(t *testing.T) { func TestGetActiveSubtasks(t *testing.T) { _, tm, ctx := testutil.InitTableTest(t) require.NoError(t, tm.InitMeta(ctx, ":4000", "")) - id, err := tm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) + id, err := tm.CreateTask(ctx, "key1", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) task, err := tm.GetTaskByID(ctx, id) @@ -478,7 +482,7 @@ func TestSubTaskTable(t *testing.T) { _, sm, ctx := testutil.InitTableTest(t) timeBeforeCreate := time.Unix(time.Now().Unix(), 0) require.NoError(t, sm.InitMeta(ctx, ":4000", "")) - id, err := sm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) + id, err := sm.CreateTask(ctx, "key1", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) err = sm.SwitchTaskStep( @@ -634,7 +638,7 @@ func TestSubTaskTable(t *testing.T) { func TestBothTaskAndSubTaskTable(t *testing.T) { _, sm, ctx := testutil.InitTableTest(t) require.NoError(t, sm.InitMeta(ctx, ":4000", "")) - id, err := sm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) + id, err := sm.CreateTask(ctx, "key1", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) @@ -843,9 +847,9 @@ func TestTaskHistoryTable(t *testing.T) { _, gm, ctx := testutil.InitTableTest(t) require.NoError(t, gm.InitMeta(ctx, ":4000", "")) - _, err := gm.CreateTask(ctx, "1", proto.TaskTypeExample, 1, "", 0, nil) + _, err := gm.CreateTask(ctx, "1", proto.TaskTypeExample, 1, "", 0, proto.ExtraParams{}, nil) require.NoError(t, err) - taskID, err := gm.CreateTask(ctx, "2", proto.TaskTypeExample, 1, "", 0, nil) + taskID, err := gm.CreateTask(ctx, "2", proto.TaskTypeExample, 1, "", 0, proto.ExtraParams{}, nil) require.NoError(t, err) tasks, err := gm.GetTasksInStates(ctx, proto.TaskStatePending) @@ -878,7 +882,7 @@ func TestTaskHistoryTable(t *testing.T) { require.NotNil(t, task) // task with fail transfer - _, err = gm.CreateTask(ctx, "3", proto.TaskTypeExample, 1, "", 0, nil) + _, err = gm.CreateTask(ctx, "3", proto.TaskTypeExample, 1, "", 0, proto.ExtraParams{}, nil) require.NoError(t, err) tasks, err = gm.GetTasksInStates(ctx, proto.TaskStatePending) require.NoError(t, err) @@ -1135,7 +1139,7 @@ func TestGetActiveTaskExecInfo(t *testing.T) { taskStates := []proto.TaskState{proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStateReverting, proto.TaskStatePausing} tasks := make([]*proto.Task, 0, len(taskStates)) for i, expectedState := range taskStates { - taskID, err := tm.CreateTask(ctx, fmt.Sprintf("key-%d", i), proto.TaskTypeExample, 8, "", 0, []byte("")) + taskID, err := tm.CreateTask(ctx, fmt.Sprintf("key-%d", i), proto.TaskTypeExample, 8, "", 0, proto.ExtraParams{}, []byte("")) require.NoError(t, err) task, err := tm.GetTaskByID(ctx, taskID) require.NoError(t, err) diff --git a/pkg/disttask/framework/storage/task_state.go b/pkg/disttask/framework/storage/task_state.go index 8e4c472791c14..6b208d656c877 100644 --- a/pkg/disttask/framework/storage/task_state.go +++ b/pkg/disttask/framework/storage/task_state.go @@ -31,8 +31,9 @@ func (mgr *TaskManager) CancelTask(ctx context.Context, taskID int64) error { `update mysql.tidb_global_task set state = %?, state_update_time = CURRENT_TIMESTAMP() - where id = %? and state in (%?, %?)`, + where id = %? and state in (%?, %?, %?)`, proto.TaskStateCancelling, taskID, proto.TaskStatePending, proto.TaskStateRunning, + proto.TaskStateAwaitingResolution, ) return err } @@ -43,8 +44,10 @@ func (*TaskManager) CancelTaskByKeySession(ctx context.Context, se sessionctx.Co `update mysql.tidb_global_task set state = %?, state_update_time = CURRENT_TIMESTAMP() - where task_key = %? and state in (%?, %?)`, - proto.TaskStateCancelling, taskKey, proto.TaskStatePending, proto.TaskStateRunning) + where task_key = %? and state in (%?, %?, %?)`, + proto.TaskStateCancelling, taskKey, proto.TaskStatePending, proto.TaskStateRunning, + proto.TaskStateAwaitingResolution, + ) return err } @@ -64,17 +67,26 @@ func (mgr *TaskManager) FailTask(ctx context.Context, taskID int64, currentState // RevertTask implements the scheduler.TaskManager interface. func (mgr *TaskManager) RevertTask(ctx context.Context, taskID int64, taskState proto.TaskState, taskErr error) error { + return mgr.transitTaskStateOnErr(ctx, taskID, taskState, proto.TaskStateReverting, taskErr) +} + +func (mgr *TaskManager) transitTaskStateOnErr(ctx context.Context, taskID int64, currState, targetState proto.TaskState, taskErr error) error { _, err := mgr.ExecuteSQLWithNewSession(ctx, ` update mysql.tidb_global_task set state = %?, error = %?, state_update_time = CURRENT_TIMESTAMP() where id = %? and state = %?`, - proto.TaskStateReverting, serializeErr(taskErr), taskID, taskState, + targetState, serializeErr(taskErr), taskID, currState, ) return err } +// AwaitingResolveTask implements the scheduler.TaskManager interface. +func (mgr *TaskManager) AwaitingResolveTask(ctx context.Context, taskID int64, taskState proto.TaskState, taskErr error) error { + return mgr.transitTaskStateOnErr(ctx, taskID, taskState, proto.TaskStateAwaitingResolution, taskErr) +} + // RevertedTask implements the scheduler.TaskManager interface. func (mgr *TaskManager) RevertedTask(ctx context.Context, taskID int64) error { _, err := mgr.ExecuteSQLWithNewSession(ctx, diff --git a/pkg/disttask/framework/storage/task_state_test.go b/pkg/disttask/framework/storage/task_state_test.go index 3afa557d8c0a5..ea734e2ee2704 100644 --- a/pkg/disttask/framework/storage/task_state_test.go +++ b/pkg/disttask/framework/storage/task_state_test.go @@ -40,7 +40,7 @@ func TestTaskState(t *testing.T) { require.NoError(t, gm.InitMeta(ctx, ":4000", "")) // 1. cancel task - id, err := gm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) + id, err := gm.CreateTask(ctx, "key1", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(1), id) TODO: unstable for infoschema v2 require.NoError(t, gm.CancelTask(ctx, id)) @@ -49,7 +49,7 @@ func TestTaskState(t *testing.T) { checkTaskStateStep(t, task, proto.TaskStateCancelling, proto.StepInit) // 2. cancel task by key session - id, err = gm.CreateTask(ctx, "key2", "test", 4, "", 0, []byte("test")) + id, err = gm.CreateTask(ctx, "key2", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(2), id) TODO: unstable for infoschema v2 require.NoError(t, gm.WithNewTxn(ctx, func(se sessionctx.Context) error { @@ -61,7 +61,7 @@ func TestTaskState(t *testing.T) { checkTaskStateStep(t, task, proto.TaskStateCancelling, proto.StepInit) // 3. fail task - id, err = gm.CreateTask(ctx, "key3", "test", 4, "", 0, []byte("test")) + id, err = gm.CreateTask(ctx, "key3", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(3), id) TODO: unstable for infoschema v2 failedErr := errors.New("test err") @@ -72,7 +72,7 @@ func TestTaskState(t *testing.T) { require.ErrorContains(t, task.Error, "test err") // 4. Reverted task - id, err = gm.CreateTask(ctx, "key4", "test", 4, "", 0, []byte("test")) + id, err = gm.CreateTask(ctx, "key4", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(4), id) TODO: unstable for infoschema v2 task, err = gm.GetTaskByID(ctx, id) @@ -90,7 +90,7 @@ func TestTaskState(t *testing.T) { checkTaskStateStep(t, task, proto.TaskStateReverted, proto.StepInit) // 5. pause task - id, err = gm.CreateTask(ctx, "key5", "test", 4, "", 0, []byte("test")) + id, err = gm.CreateTask(ctx, "key5", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(5), id) TODO: unstable for infoschema v2 found, err := gm.PauseTask(ctx, "key5") @@ -119,7 +119,7 @@ func TestTaskState(t *testing.T) { require.Equal(t, proto.TaskStateRunning, task.State) // 8. succeed task - id, err = gm.CreateTask(ctx, "key6", "test", 4, "", 0, []byte("test")) + id, err = gm.CreateTask(ctx, "key6", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(6), id) TODO: unstable for infoschema v2 task, err = gm.GetTaskByID(ctx, id) @@ -139,7 +139,7 @@ func TestModifyTask(t *testing.T) { _, gm, ctx := testutil.InitTableTest(t) require.NoError(t, gm.InitMeta(ctx, ":4000", "")) - id, err := gm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) + id, err := gm.CreateTask(ctx, "key1", "test", 4, "", 0, proto.ExtraParams{}, []byte("test")) require.NoError(t, err) require.ErrorIs(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 65f76474c6f1e..f07ab42987b3f 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -16,6 +16,7 @@ package storage import ( "context" + "encoding/json" goerrors "errors" "strconv" "strings" @@ -37,12 +38,12 @@ import ( const ( defaultSubtaskKeepDays = 14 - basicTaskColumns = `t.id, t.task_key, t.type, t.state, t.step, t.priority, t.concurrency, t.create_time, t.target_scope` + basicTaskColumns = `t.id, t.task_key, t.type, t.state, t.step, t.priority, t.concurrency, t.create_time, t.target_scope, t.max_node_count, t.extra_params` // TaskColumns is the columns for task. // TODO: dispatcher_id will update to scheduler_id later - TaskColumns = basicTaskColumns + `, t.start_time, t.state_update_time, t.meta, t.dispatcher_id, t.error, t.modify_params, t.max_node_count` + TaskColumns = basicTaskColumns + `, t.start_time, t.state_update_time, t.meta, t.dispatcher_id, t.error, t.modify_params` // InsertTaskColumns is the columns used in insert task. - InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time, target_scope, max_node_count` + InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time, target_scope, max_node_count, extra_params` basicSubtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, ordinal, start_time` // SubtaskColumns is the columns for subtask. SubtaskColumns = basicSubtaskColumns + `, state_update_time, meta, summary` @@ -219,11 +220,12 @@ func (mgr *TaskManager) CreateTask( concurrency int, targetScope string, maxNodeCnt int, + extraParams proto.ExtraParams, meta []byte, ) (taskID int64, err error) { err = mgr.WithNewSession(func(se sessionctx.Context) error { var err2 error - taskID, err2 = mgr.CreateTaskWithSession(ctx, se, key, tp, concurrency, targetScope, maxNodeCnt, meta) + taskID, err2 = mgr.CreateTaskWithSession(ctx, se, key, tp, concurrency, targetScope, maxNodeCnt, extraParams, meta) return err2 }) return @@ -238,6 +240,7 @@ func (mgr *TaskManager) CreateTaskWithSession( concurrency int, targetScope string, maxNodeCount int, + extraParams proto.ExtraParams, meta []byte, ) (taskID int64, err error) { cpuCount, err := mgr.getCPUCountOfNode(ctx, se) @@ -247,10 +250,15 @@ func (mgr *TaskManager) CreateTaskWithSession( if concurrency > cpuCount { return 0, errors.Errorf("task concurrency(%d) larger than cpu count(%d) of managed node", concurrency, cpuCount) } + extraParamBytes, err := json.Marshal(extraParams) + if err != nil { + return 0, errors.Trace(err) + } _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` insert into mysql.tidb_global_task(`+InsertTaskColumns+`) - values (%?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), %?, %?)`, - key, tp, proto.TaskStatePending, proto.NormalPriority, concurrency, proto.StepInit, meta, targetScope, maxNodeCount) + values (%?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), %?, %?, %?)`, + key, tp, proto.TaskStatePending, proto.NormalPriority, concurrency, + proto.StepInit, meta, targetScope, maxNodeCount, json.RawMessage(extraParamBytes)) if err != nil { return 0, err } diff --git a/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go b/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go index 055932e87a894..25522a34d559b 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go +++ b/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go @@ -36,7 +36,7 @@ import ( ) func runOneTask(ctx context.Context, t *testing.T, mgr *storage.TaskManager, taskKey string, subtaskCnt int) { - taskID, err := mgr.CreateTask(ctx, taskKey, proto.TaskTypeExample, 1, "", 0, nil) + taskID, err := mgr.CreateTask(ctx, taskKey, proto.TaskTypeExample, 1, "", 0, proto.ExtraParams{}, nil) require.NoError(t, err) task, err := mgr.GetTaskByID(ctx, taskID) require.NoError(t, err) diff --git a/pkg/disttask/framework/testutil/disttest_util.go b/pkg/disttask/framework/testutil/disttest_util.go index 7f7db281101f2..838e3ba7fd19e 100644 --- a/pkg/disttask/framework/testutil/disttest_util.go +++ b/pkg/disttask/framework/testutil/disttest_util.go @@ -32,10 +32,15 @@ import ( // GetCommonTaskExecutorExt returns a common task executor extension. func GetCommonTaskExecutorExt(ctrl *gomock.Controller, getStepExecFn func(*proto.Task) (execute.StepExecutor, error)) *mock.MockExtension { + return GetTaskExecutorExt(ctrl, getStepExecFn, func(error) bool { return false }) +} + +// GetTaskExecutorExt returns a task executor extension. +func GetTaskExecutorExt(ctrl *gomock.Controller, getStepExecFn func(*proto.Task) (execute.StepExecutor, error), isRetryableErrorFn func(error) bool) *mock.MockExtension { executorExt := mock.NewMockExtension(ctrl) executorExt.EXPECT().IsIdempotent(gomock.Any()).Return(true).AnyTimes() executorExt.EXPECT().GetStepExecutor(gomock.Any()).DoAndReturn(getStepExecFn).AnyTimes() - executorExt.EXPECT().IsRetryableError(gomock.Any()).Return(false).AnyTimes() + executorExt.EXPECT().IsRetryableError(gomock.Any()).DoAndReturn(isRetryableErrorFn).AnyTimes() return executorExt } diff --git a/pkg/disttask/importinto/job_testkit_test.go b/pkg/disttask/importinto/job_testkit_test.go index 888b00d71a63e..5c6de824e272f 100644 --- a/pkg/disttask/importinto/job_testkit_test.go +++ b/pkg/disttask/importinto/job_testkit_test.go @@ -53,7 +53,7 @@ func TestGetTaskImportedRows(t *testing.T) { } bytes, err := json.Marshal(taskMeta) require.NoError(t, err) - taskID, err := manager.CreateTask(ctx, importinto.TaskKey(111), proto.ImportInto, 1, "", 0, bytes) + taskID, err := manager.CreateTask(ctx, importinto.TaskKey(111), proto.ImportInto, 1, "", 0, proto.ExtraParams{}, bytes) require.NoError(t, err) importStepMetas := []*importinto.ImportStepMeta{ { @@ -85,7 +85,7 @@ func TestGetTaskImportedRows(t *testing.T) { } bytes, err = json.Marshal(taskMeta) require.NoError(t, err) - taskID, err = manager.CreateTask(ctx, importinto.TaskKey(222), proto.ImportInto, 1, "", 0, bytes) + taskID, err = manager.CreateTask(ctx, importinto.TaskKey(222), proto.ImportInto, 1, "", 0, proto.ExtraParams{}, bytes) require.NoError(t, err) ingestStepMetas := []*importinto.WriteIngestStepMeta{ { diff --git a/pkg/disttask/importinto/scheduler_testkit_test.go b/pkg/disttask/importinto/scheduler_testkit_test.go index 7e6ab15b90bab..69c099603c99f 100644 --- a/pkg/disttask/importinto/scheduler_testkit_test.go +++ b/pkg/disttask/importinto/scheduler_testkit_test.go @@ -86,7 +86,7 @@ func TestSchedulerExtLocalSort(t *testing.T) { require.NoError(t, err) taskMeta, err := json.Marshal(task) require.NoError(t, err) - taskID, err := manager.CreateTask(ctx, importinto.TaskKey(jobID), proto.ImportInto, 1, "", 0, taskMeta) + taskID, err := manager.CreateTask(ctx, importinto.TaskKey(jobID), proto.ImportInto, 1, "", 0, proto.ExtraParams{}, taskMeta) require.NoError(t, err) task.ID = taskID @@ -144,6 +144,7 @@ func TestSchedulerExtLocalSort(t *testing.T) { require.NoError(t, err) task.Meta = bs require.NoError(t, importer.StartJob(ctx, conn, jobID, importer.JobStepImporting)) + task.State = proto.TaskStateReverting task.Error = errors.New("met error") require.NoError(t, ext.OnDone(ctx, d, task)) require.NoError(t, err) @@ -160,6 +161,7 @@ func TestSchedulerExtLocalSort(t *testing.T) { require.NoError(t, err) task.Meta = bs require.NoError(t, importer.StartJob(ctx, conn, jobID, importer.JobStepImporting)) + task.State = proto.TaskStateReverting task.Error = errors.New("cancelled by user") require.NoError(t, ext.OnDone(ctx, d, task)) require.NoError(t, err) @@ -229,7 +231,7 @@ func TestSchedulerExtGlobalSort(t *testing.T) { require.NoError(t, err) taskMeta, err := json.Marshal(task) require.NoError(t, err) - taskID, err := manager.CreateTask(ctx, importinto.TaskKey(jobID), proto.ImportInto, 1, "", 0, taskMeta) + taskID, err := manager.CreateTask(ctx, importinto.TaskKey(jobID), proto.ImportInto, 1, "", 0, proto.ExtraParams{}, taskMeta) require.NoError(t, err) task.ID = taskID diff --git a/pkg/session/bootstrap.go b/pkg/session/bootstrap.go index ee9ffca4f3ea4..9e1188c90f1b1 100644 --- a/pkg/session/bootstrap.go +++ b/pkg/session/bootstrap.go @@ -601,6 +601,7 @@ const ( error BLOB, modify_params json, max_node_count INT DEFAULT 0, + extra_params json, key(state), UNIQUE KEY task_key(task_key) );` @@ -624,6 +625,7 @@ const ( error BLOB, modify_params json, max_node_count INT DEFAULT 0, + extra_params json, key(state), UNIQUE KEY task_key(task_key) );` @@ -1264,6 +1266,7 @@ const ( version242 = 242 // Add max_node_count column to tidb_global_task and tidb_global_task_history. + // Add extra_params to tidb_global_task and tidb_global_task_history. version243 = 243 ) @@ -3373,6 +3376,8 @@ func upgradeToVer243(s sessiontypes.Session, ver int64) { } doReentrantDDL(s, "ALTER TABLE mysql.tidb_global_task ADD COLUMN max_node_count INT DEFAULT 0 AFTER `modify_params`;", infoschema.ErrColumnExists) doReentrantDDL(s, "ALTER TABLE mysql.tidb_global_task_history ADD COLUMN max_node_count INT DEFAULT 0 AFTER `modify_params`;", infoschema.ErrColumnExists) + doReentrantDDL(s, "ALTER TABLE mysql.tidb_global_task ADD COLUMN extra_params json AFTER max_node_count;", infoschema.ErrColumnExists) + doReentrantDDL(s, "ALTER TABLE mysql.tidb_global_task_history ADD COLUMN extra_params json AFTER max_node_count;", infoschema.ErrColumnExists) } // initGlobalVariableIfNotExists initialize a global variable with specific val if it does not exist.