Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
"asyncmy",
"autobegin",
"autouse",
"BAAI",
"bindparam",
"checkfirst",
"cymysql",
"filterwarnings",
"fulltext",
"funcfilter",
"getenv",
"huggingface",
"ischema",
"jina",
"JINAAI",
Expand All @@ -44,6 +46,7 @@
"Pydantic",
"pymysql",
"pytest",
"Qwen",
"Rerank",
"reranked",
"reranker",
Expand Down
4 changes: 1 addition & 3 deletions pytidb/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from .base import BaseEmbeddingFunction
from .builtin import BuiltInEmbeddingFunction

EmbeddingFunction = BuiltInEmbeddingFunction
from .builtin import EmbeddingFunction

__all__ = ["BaseEmbeddingFunction", "EmbeddingFunction"]
31 changes: 17 additions & 14 deletions pytidb/embeddings/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytidb.embeddings.base import BaseEmbeddingFunction, EmbeddingSourceType
from pytidb.embeddings.dimensions import get_model_dimensions
from pytidb.embeddings.utils import (
deep_merge,
encode_local_file_to_base64,
encode_pil_image_to_base64,
parse_url_safely,
Expand All @@ -25,18 +26,23 @@
}


PROVIDER_DEFAULT_EMBED_PARAMS = {
"jina_ai": {
"task": "retrieval.passage",
"task@search": "retrieval.query",
},
}
EmbeddingInput = Union[str, Path, "Image"]


EmbeddingInput = Union[str, Path, "Image"]
def _convert_dimensions_param(provider: str, dimensions: int) -> dict[str, Any]:
if provider == "cohere":
return {"dimension": dimensions}
elif provider == "gemini":
return {"output_dimensionality": dimensions}
elif provider == "nvidia_nim":
# Notice: Nvidia NIM doesn't support dimensions parameter.
return {}
else:
# OpenAI, Jina AI follow the same convention.
return {"dimensions": dimensions}


class BuiltInEmbeddingFunction(BaseEmbeddingFunction):
class EmbeddingFunction(BaseEmbeddingFunction):
api_key: Optional[str] = Field(None, description="The API key for authentication.")
api_base: Optional[str] = Field(
None, description="The base URL of the model provider."
Expand Down Expand Up @@ -78,11 +84,8 @@ def __init__(
dimensions = get_model_dimensions(model_name)

provider = model_name.split("/")[0] if "/" in model_name else "openai"
server_embed_params = (
server_embed_params
if server_embed_params is not None
else PROVIDER_DEFAULT_EMBED_PARAMS.get(provider)
)
dimensions_param = _convert_dimensions_param(provider, dimensions)
_server_embed_params = deep_merge(dimensions_param, server_embed_params or {})

super().__init__(
model_name=model_name,
Expand All @@ -93,7 +96,7 @@ def __init__(
timeout=timeout,
caching=caching,
use_server=use_server,
server_embed_params=server_embed_params,
server_embed_params=_server_embed_params,
multimodal=multimodal,
**kwargs,
)
Expand Down
64 changes: 60 additions & 4 deletions pytidb/embeddings/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,72 @@
"openai/text-embedding-3-large": 3072,
"openai/text-embedding-ada-002": 1536,
# Cohere models
"cohere/embed-v4.0": 1536,
"cohere/embed-english-v3.0": 1024,
"cohere/embed-multilingual-v3.0": 1024,
# Jina AI models
"jina_ai/jina-embeddings-v4": 2048,
"jina_ai/jina-embeddings-v3": 1024,
"jina_ai/jina-clip-v2": 1024,
# TODO: remove these after jina_ai is released on prod.
"jina/jina-embeddings-v4": 2048,
"jina/jina-embeddings-v3": 1024,
"jina/jina-clip-v2": 1024,
# Gemini models
"gemini/gemini-embedding-001": 3072,
# Hugging Face models
"huggingface/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct": 1536,
"huggingface/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct": 3584,
"huggingface/Alibaba-NLP/gte-Qwen1.5-7B-instruct": 4096,
"huggingface/Alibaba-NLP/gte-Qwen2-1.5B-instruct": 8960,
"huggingface/Alibaba-NLP/gte-Qwen2-7B-instruct": 3584,
"huggingface/Alibaba-NLP/gte-multilingual-base": 768,
"huggingface/Alibaba-NLP/gte-modernbert-base": 768,
"huggingface/Alibaba-NLP/gte-base-en-v1.5": 768,
"huggingface/BAAI/bge-base-en": 768,
"huggingface/BAAI/bge-base-en-v1.5": 768,
"huggingface/BAAI/bge-base-zh": 768,
"huggingface/BAAI/bge-base-zh-v1.5": 768,
"huggingface/BAAI/bge-en-icl": 4096,
"huggingface/BAAI/bge-large-en": 1024,
"huggingface/BAAI/bge-large-en-v1.5": 1024,
"huggingface/BAAI/bge-large-zh": 1024,
"huggingface/BAAI/bge-large-zh-v1.5": 1024,
"huggingface/BAAI/bge-m3": 1024,
"huggingface/BAAI/bge-m3-unsupervised": 1024,
"huggingface/BAAI/bge-multilingual-gemma2": 3584,
"huggingface/BAAI/bge-small-en": 512,
"huggingface/BAAI/bge-small-en-v1.5": 512,
"huggingface/BAAI/bge-small-zh": 512,
"huggingface/BAAI/bge-small-zh-v1.5": 512,
"huggingface/Cohere/Cohere-embed-multilingual-v3.0": 1024,
"huggingface/Qwen/Qwen3-Embedding-0.6B": 1024,
"huggingface/Qwen/Qwen3-Embedding-4B": 2560,
"huggingface/Qwen/Qwen3-Embedding-8B": 4096,
"huggingface/Snowflake/snowflake-arctic-embed-l": 1024,
"huggingface/Snowflake/snowflake-arctic-embed-l-v2.0": 1024,
"huggingface/Snowflake/snowflake-arctic-embed-m": 768,
"huggingface/Snowflake/snowflake-arctic-embed-m-long": 768,
"huggingface/Snowflake/snowflake-arctic-embed-m-v1.5": 768,
"huggingface/Snowflake/snowflake-arctic-embed-m-v2.0": 768,
"huggingface/Snowflake/snowflake-arctic-embed-s": 384,
"huggingface/Snowflake/snowflake-arctic-embed-xs": 384,
"huggingface/intfloat/e5-base": 768,
"huggingface/intfloat/e5-base-v2": 768,
"huggingface/intfloat/e5-large": 1024,
"huggingface/intfloat/e5-large-v2": 1024,
"huggingface/intfloat/e5-mistral-7b-instruct": 4096,
"huggingface/intfloat/e5-small": 384,
"huggingface/intfloat/e5-small-v2": 384,
"huggingface/intfloat/multilingual-e5-base": 768,
"huggingface/intfloat/multilingual-e5-large": 1024,
"huggingface/intfloat/multilingual-e5-large-instruct": 1024,
"huggingface/intfloat/multilingual-e5-small": 384,
"huggingface/jinaai/jina-embedding-b-en-v1": 768,
"huggingface/jinaai/jina-embedding-s-en-v1": 512,
"huggingface/jinaai/jina-embeddings-v2-base-en": 768,
"huggingface/jinaai/jina-embeddings-v2-small-en": 512,
"huggingface/jinaai/jina-embeddings-v3": 1024,
"huggingface/jinaai/jina-embeddings-v4": 2048,
# Nvidia NIM models
"nvidia_nim/baai/bge-m3": 1024,
"nvidia_nim/nvidia/nv-embed-v1": 4096,
}

# Mapping of model aliases to their full names for backward compatibility
Expand Down
28 changes: 28 additions & 0 deletions pytidb/embeddings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import base64
from collections.abc import Mapping
import io
from pathlib import Path
from typing import Optional, TYPE_CHECKING, Union
Expand Down Expand Up @@ -153,3 +154,30 @@ def encode_pil_image_to_base64(
return base64.b64encode(buffer.getvalue()).decode("utf-8")
except Exception as e:
raise ValueError(f"Failed to encode PIL Image to base64: {str(e)}")


def deep_merge(*dicts: Optional[dict]) -> dict:
"""
Deeply merge one or more dictionaries into the first one (in-place).
Later dictionaries override earlier ones.
"""
if not dicts:
return {}

# Filter out None values and convert to list
valid_dicts = [d for d in dicts if d is not None]
if not valid_dicts:
return {}

def _deep_merge(d: dict, u: dict) -> dict:
for k, v in u.items():
if k in d and isinstance(d[k], Mapping) and isinstance(v, Mapping):
_deep_merge(d[k], v)
else:
d[k] = v
return d

result = valid_dicts[0].copy()
for u in valid_dicts[1:]:
_deep_merge(result, u)
return result
124 changes: 93 additions & 31 deletions tests/test_auto_embedding_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,95 @@
from pytidb.embeddings import EmbeddingFunction
from pytidb.schema import TableModel, Field


EMBEDDING_MODELS = [
{
"id": "openai",
"model_name": "openai/text-embedding-3-small",
"expected_similarity": 0.9,
},
{
"id": "jina_ai",
"model_name": "jina_ai/jina-embeddings-v3",
"expected_similarity": 0.7,
},
{
"id": "tidbcloud_free",
"model_name": "tidbcloud_free/amazon/titan-embed-text-v2",
"expected_similarity": 0.9,
},
{
"id": "cohere",
"model_name": "cohere/embed-v4.0",
"expected_similarity": 0.7,
},
{
"id": "gemini",
"model_name": "gemini/gemini-embedding-001",
"expected_similarity": 0.8,
},
{
"id": "huggingface",
"model_name": "huggingface/intfloat/multilingual-e5-large",
"expected_similarity": 0.9,
},
{
"id": "nvidia_nim",
"model_name": "nvidia_nim/baai/bge-m3",
"expected_similarity": 0.9,
},
# TODO: uncomment these after jina_ai and tidbcloud_free are released on prod.
# {
# "id": "jina_ai",
# "model_name": "jina_ai/jina-embeddings-v4",
# },
# {
# "id": "tidbcloud_free",
# "model_name": "tidbcloud_free/amazon/titan-embed-text-v2",
# },
]


def _should_skip(shared_client: TiDBClient, model_id: str) -> Optional[str]:
# Skip auto embedding tests if not connected to TiDB Serverless
if not shared_client.is_serverless:
return "Currently, Only TiDB Serverless supports auto embedding"

# Configure embedding provider based on model
if model_id == "openai":
if not os.getenv("OPENAI_API_KEY"):
return "OPENAI_API_KEY is not set"
shared_client.configure_embedding_provider(
"openai", os.getenv("OPENAI_API_KEY")
)
elif model_id == "jina_ai":
if not os.getenv("JINA_AI_API_KEY"):
return "JINA_AI_API_KEY is not set"
shared_client.configure_embedding_provider(
"jina_ai", os.getenv("JINA_AI_API_KEY")
)
elif model_id == "cohere":
if not os.getenv("COHERE_API_KEY"):
return "COHERE_API_KEY is not set"
shared_client.configure_embedding_provider(
"cohere", os.getenv("COHERE_API_KEY")
)
elif model_id == "gemini":
if not os.getenv("GEMINI_API_KEY"):
return "GEMINI_API_KEY is not set"
shared_client.configure_embedding_provider(
"gemini", os.getenv("GEMINI_API_KEY")
)
elif model_id == "huggingface":
if not os.getenv("HUGGINGFACE_API_KEY"):
return "HUGGINGFACE_API_KEY is not set"
shared_client.configure_embedding_provider(
"huggingface", os.getenv("HUGGINGFACE_API_KEY")
)
elif model_id == "nvidia_nim":
if not os.getenv("NVIDIA_NIM_API_KEY"):
return "NVIDIA_NIM_API_KEY is not set"
shared_client.configure_embedding_provider(
"nvidia_nim", os.getenv("NVIDIA_NIM_API_KEY")
)
elif model_id == "tidbcloud_free":
# tidbcloud_free doesn't need additional API key configuration
pass

return None


@pytest.fixture(
scope="module",
params=EMBEDDING_MODELS,
Expand All @@ -41,27 +113,15 @@ def text_embed(request):


def test_auto_embedding(shared_client: TiDBClient, text_embed: EmbeddingFunction):
# Skip auto embedding tests if not connected to TiDB Serverless
if not shared_client.is_serverless:
pytest.skip("Currently, Only TiDB Serverless supports auto embedding")

model_id = text_embed._model_config["id"]

# Configure embedding provider based on model
if model_id == "openai":
shared_client.configure_embedding_provider(
"openai", os.getenv("OPENAI_API_KEY")
)
elif model_id == "jina_ai":
shared_client.configure_embedding_provider(
"jina_ai", os.getenv("JINA_AI_API_KEY")
)
elif model_id == "tidbcloud_free":
# tidbcloud_free doesn't need additional API key configuration
pass
# Check if test should be skipped
skip_reason = _should_skip(shared_client, model_id)
if skip_reason:
pytest.skip(skip_reason)

class ChunkBase(TableModel, table=False):
__tablename__ = f"chunks_auto_embedding_with_{model_id}"
__tablename__ = f"chunks_with_{model_id}_auto_embedding"
id: int = Field(primary_key=True)
text: Optional[str] = Field()
text_vec: Optional[list[float]] = text_embed.VectorField(
Expand All @@ -71,7 +131,7 @@ class ChunkBase(TableModel, table=False):
user_id: int = Field()

Chunk = type(
f"ChunkAutoEmbeddingWith{model_id.capitalize()}", (ChunkBase,), {}, table=True
f"ChunkWith{model_id.capitalize()}AutoEmbedding", (ChunkBase,), {}, table=True
)
tbl = shared_client.create_table(schema=Chunk, if_exists="overwrite")

Expand All @@ -84,15 +144,15 @@ class ChunkBase(TableModel, table=False):
assert len(chunk.text_vec) == text_embed.dimensions

# Test bulk_insert with auto embedding (including empty text case)
chunks_via_model_instance = [
chunk_entities = [
Chunk(id=3, text="baz", user_id=2),
Chunk(id=4, text=None, user_id=2), # None will skip auto embedding.
]
chunks_via_dict = [
chunk_dicts = [
{"id": 5, "text": "qux", "user_id": 3},
{"id": 6, "text": None, "user_id": 3}, # None will skip auto embedding.
]
chunks = tbl.bulk_insert(chunks_via_model_instance + chunks_via_dict)
chunks = tbl.bulk_insert(chunk_entities + chunk_dicts)
for chunk in chunks:
if chunk.text is None:
assert chunk.text_vec is None
Expand All @@ -104,7 +164,9 @@ class ChunkBase(TableModel, table=False):
assert len(results) == 1
assert results[0].id == 2
assert results[0].text == "bar"
assert results[0].similarity_score >= 0.9
assert (
results[0].similarity_score >= text_embed._model_config["expected_similarity"]
)

# Test update with auto embedding, from empty to non-empty string
chunk = tbl.get(4)
Expand Down