Skip to content

Commit d38eb8e

Browse files
breezewishzeminzhou
authored andcommitted
ddl: Fix vector index for high dimensional vectors (pingcap#58717)
ref pingcap#54245
1 parent 89fc14a commit d38eb8e

File tree

5 files changed

+126
-0
lines changed

5 files changed

+126
-0
lines changed

pkg/ddl/index.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ func getIndexColumnLength(col *model.ColumnInfo, colLen int) (int, error) {
285285
}
286286

287287
switch col.GetType() {
288+
case mysql.TypeTiDBVectorFloat32:
289+
// Vector Index does not actually create KV index, so it has length of 0.
290+
// however 0 may cause some issues in other calculations, so we use 1 here.
291+
// 1 is also minimal enough anyway.
292+
return 1, nil
288293
case mysql.TypeBit:
289294
return (length + 7) >> 3, nil
290295
case mysql.TypeVarchar, mysql.TypeString, mysql.TypeVarString, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeBlob, mysql.TypeLongBlob:
@@ -2930,6 +2935,9 @@ func newCleanUpIndexWorker(id int, t table.PhysicalTable, decodeColMap map[int64
29302935
indexes := make([]table.Index, 0, len(t.Indices()))
29312936
rowDecoder := decoder.NewRowDecoder(t, t.WritableCols(), decodeColMap)
29322937
for _, index := range t.Indices() {
2938+
if index.Meta().IsTiFlashLocalIndex() {
2939+
continue
2940+
}
29332941
if index.Meta().Global {
29342942
indexes = append(indexes, index)
29352943
}

pkg/expression/integration_test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ go_test(
1212
deps = [
1313
"//pkg/config",
1414
"//pkg/domain",
15+
"//pkg/domain/infosync",
1516
"//pkg/errno",
1617
"//pkg/expression",
1718
"//pkg/kv",

pkg/expression/integration_test/integration_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
"github.com/pingcap/failpoint"
3434
"github.com/pingcap/tidb/pkg/config"
3535
"github.com/pingcap/tidb/pkg/domain"
36+
"github.com/pingcap/tidb/pkg/domain/infosync"
3637
"github.com/pingcap/tidb/pkg/errno"
3738
"github.com/pingcap/tidb/pkg/expression"
3839
"github.com/pingcap/tidb/pkg/kv"
@@ -61,6 +62,101 @@ import (
6162
"github.com/tikv/client-go/v2/oracle"
6263
)
6364

65+
func TestVectorLong(t *testing.T) {
66+
store := testkit.CreateMockStoreWithSchemaLease(t, 1*time.Second, mockstore.WithMockTiFlash(2))
67+
68+
tk := testkit.NewTestKit(t, store)
69+
70+
tiflash := infosync.NewMockTiFlash()
71+
infosync.SetMockTiFlash(tiflash)
72+
defer func() {
73+
tiflash.Lock()
74+
tiflash.StatusServer.Close()
75+
tiflash.Unlock()
76+
}()
77+
78+
genVec := func(d int, startValue int) string {
79+
vb := strings.Builder{}
80+
vb.WriteString("[")
81+
value := startValue
82+
for i := 0; i < d; i++ {
83+
if i > 0 {
84+
vb.WriteString(",")
85+
}
86+
vb.WriteString(strconv.FormatInt(int64(value), 10))
87+
value += 100
88+
}
89+
vb.WriteString("]")
90+
return vb.String()
91+
}
92+
93+
failpoint.Enable("github.com/pingcap/tidb/pkg/ddl/MockCheckVectorIndexProcess", `return(1)`)
94+
defer func() {
95+
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/ddl/MockCheckVectorIndexProcess"))
96+
}()
97+
98+
runWorkload := func() {
99+
tk.MustExec(fmt.Sprintf(`insert into t1 values (1, '%s')`, genVec(16383, 100)))
100+
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows("1 " + genVec(16383, 100)))
101+
tk.MustExec(fmt.Sprintf(`delete from t1 where vec > '%s'`, genVec(16383, 200)))
102+
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows("1 " + genVec(16383, 100)))
103+
tk.MustExec(fmt.Sprintf(`delete from t1 where vec > '%s'`, genVec(16383, 50)))
104+
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows())
105+
tk.MustExec(fmt.Sprintf(`insert into t1 values (1, '%s')`, genVec(16383, 100)))
106+
tk.MustExec(fmt.Sprintf(`insert into t1 values (2, '%s')`, genVec(16383, 200)))
107+
tk.MustExec(fmt.Sprintf(`insert into t1 values (3, '%s')`, genVec(16383, 300)))
108+
tk.MustQuery(fmt.Sprintf(`select id from t1 order by vec_l2_distance(vec, '%s') limit 2`, genVec(16383, 180))).Check(testkit.Rows(
109+
"2",
110+
"1",
111+
))
112+
tk.MustExec(fmt.Sprintf(`update t1 set vec = '%s' where id = 1`, genVec(16383, 500)))
113+
tk.MustQuery(`select * from t1 order by id`).Check(testkit.Rows(
114+
"1 "+genVec(16383, 500),
115+
"2 "+genVec(16383, 200),
116+
"3 "+genVec(16383, 300),
117+
))
118+
tk.MustQuery(fmt.Sprintf(`select id from t1 order by vec_l2_distance(vec, '%s') limit 2`, genVec(16383, 180))).Check(testkit.Rows(
119+
"2",
120+
"3",
121+
))
122+
}
123+
124+
tk.MustExec("use test")
125+
tk.MustExec(`
126+
create table t1 (
127+
id int primary key,
128+
vec vector(16383)
129+
)
130+
`)
131+
runWorkload()
132+
tk.MustExec("drop table t1")
133+
134+
tk.MustExec(`
135+
create table t1 (
136+
id int primary key,
137+
vec vector(16383),
138+
VECTOR INDEX ((vec_cosine_distance(vec)))
139+
)
140+
`)
141+
runWorkload()
142+
tk.MustExec("drop table if exists t1")
143+
tk.MustExec(`
144+
create table t1 (
145+
id int primary key,
146+
vec vector(16383)
147+
)
148+
`)
149+
tk.MustExec(`alter table t1 set tiflash replica 1`)
150+
tbl, _ := domain.GetDomain(tk.Session()).InfoSchema().TableByName(context.Background(), ast.NewCIStr("test"), ast.NewCIStr("t1"))
151+
tbl.Meta().TiFlashReplica = &model.TiFlashReplicaInfo{
152+
Count: 1,
153+
Available: true,
154+
}
155+
tk.MustExec(`alter table t1 add VECTOR INDEX ((vec_cosine_distance(vec)))`)
156+
runWorkload()
157+
tk.MustExec("drop table if exists t1")
158+
}
159+
64160
func TestVectorDefaultValue(t *testing.T) {
65161
store := testkit.CreateMockStore(t)
66162
tk := testkit.NewTestKit(t, store)

pkg/meta/model/index.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,12 @@ func (index *IndexInfo) IsPublic() bool {
143143
return index.State == StatePublic
144144
}
145145

146+
// IsTiFlashLocalIndex checks whether the index is a TiFlash local index.
147+
// For a TiFlash local index, no actual index data need to be written to KV layer.
148+
func (index *IndexInfo) IsTiFlashLocalIndex() bool {
149+
return index.VectorInfo != nil
150+
}
151+
146152
// FindIndexByColumns find IndexInfo in indices which is cover the specified columns.
147153
func FindIndexByColumns(tbInfo *TableInfo, indices []*IndexInfo, cols ...ast.CIStr) *IndexInfo {
148154
for _, index := range indices {

pkg/table/tables/tables.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,9 @@ func GetWritableIndexByName(idxName string, t table.Table) table.Index {
306306
if !IsIndexWritable(idx) {
307307
continue
308308
}
309+
if idx.Meta().IsTiFlashLocalIndex() {
310+
continue
311+
}
309312
if idxName == idx.Meta().Name.L {
310313
return idx
311314
}
@@ -547,6 +550,9 @@ func (t *TableCommon) rebuildUpdateRecordIndices(
547550
if t.meta.IsCommonHandle && idx.Meta().Primary {
548551
continue
549552
}
553+
if idx.Meta().IsTiFlashLocalIndex() {
554+
continue
555+
}
550556
for _, ic := range idx.Meta().Columns {
551557
if !touched[ic.Offset] {
552558
continue
@@ -566,6 +572,9 @@ func (t *TableCommon) rebuildUpdateRecordIndices(
566572
if !IsIndexWritable(idx) {
567573
continue
568574
}
575+
if idx.Meta().IsTiFlashLocalIndex() {
576+
continue
577+
}
569578
if t.meta.IsCommonHandle && idx.Meta().Primary {
570579
continue
571580
}
@@ -926,6 +935,9 @@ func (t *TableCommon) addIndices(sctx table.MutateContext, recordID kv.Handle, r
926935
if !IsIndexWritable(v) {
927936
continue
928937
}
938+
if v.Meta().IsTiFlashLocalIndex() {
939+
continue
940+
}
929941
if t.meta.IsCommonHandle && v.Meta().Primary {
930942
continue
931943
}
@@ -1185,6 +1197,9 @@ func (t *TableCommon) removeRowIndices(ctx table.MutateContext, txn kv.Transacti
11851197
if v.Meta().Primary && (t.Meta().IsCommonHandle || t.Meta().PKIsHandle) {
11861198
continue
11871199
}
1200+
if v.Meta().IsTiFlashLocalIndex() {
1201+
continue
1202+
}
11881203
var vals []types.Datum
11891204
if opt.HasIndexesLayout() {
11901205
vals, err = fetchIndexRow(v.Meta(), rec, nil, opt.GetIndexLayout(v.Meta().ID))

0 commit comments

Comments
 (0)