diff --git a/.vscode/settings.json b/.vscode/settings.json index 7b160ba..bcc72de 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -13,6 +13,7 @@ "asyncmy", "autobegin", "autouse", + "BAAI", "bindparam", "checkfirst", "cymysql", @@ -20,6 +21,7 @@ "fulltext", "funcfilter", "getenv", + "huggingface", "ischema", "jina", "JINAAI", @@ -44,6 +46,7 @@ "Pydantic", "pymysql", "pytest", + "Qwen", "Rerank", "reranked", "reranker", diff --git a/pytidb/embeddings/__init__.py b/pytidb/embeddings/__init__.py index cee8718..5b650eb 100644 --- a/pytidb/embeddings/__init__.py +++ b/pytidb/embeddings/__init__.py @@ -1,6 +1,4 @@ from .base import BaseEmbeddingFunction -from .builtin import BuiltInEmbeddingFunction - -EmbeddingFunction = BuiltInEmbeddingFunction +from .builtin import EmbeddingFunction __all__ = ["BaseEmbeddingFunction", "EmbeddingFunction"] diff --git a/pytidb/embeddings/builtin.py b/pytidb/embeddings/builtin.py index d1d8ba4..533d67e 100644 --- a/pytidb/embeddings/builtin.py +++ b/pytidb/embeddings/builtin.py @@ -5,6 +5,7 @@ from pytidb.embeddings.base import BaseEmbeddingFunction, EmbeddingSourceType from pytidb.embeddings.dimensions import get_model_dimensions from pytidb.embeddings.utils import ( + deep_merge, encode_local_file_to_base64, encode_pil_image_to_base64, parse_url_safely, @@ -25,18 +26,23 @@ } -PROVIDER_DEFAULT_EMBED_PARAMS = { - "jina_ai": { - "task": "retrieval.passage", - "task@search": "retrieval.query", - }, -} +EmbeddingInput = Union[str, Path, "Image"] -EmbeddingInput = Union[str, Path, "Image"] +def _convert_dimensions_param(provider: str, dimensions: int) -> dict[str, Any]: + if provider == "cohere": + return {"dimension": dimensions} + elif provider == "gemini": + return {"output_dimensionality": dimensions} + elif provider == "nvidia_nim": + # Notice: Nvidia NIM doesn't support dimensions parameter. + return {} + else: + # OpenAI, Jina AI follow the same convention. + return {"dimensions": dimensions} -class BuiltInEmbeddingFunction(BaseEmbeddingFunction): +class EmbeddingFunction(BaseEmbeddingFunction): api_key: Optional[str] = Field(None, description="The API key for authentication.") api_base: Optional[str] = Field( None, description="The base URL of the model provider." @@ -78,11 +84,8 @@ def __init__( dimensions = get_model_dimensions(model_name) provider = model_name.split("/")[0] if "/" in model_name else "openai" - server_embed_params = ( - server_embed_params - if server_embed_params is not None - else PROVIDER_DEFAULT_EMBED_PARAMS.get(provider) - ) + dimensions_param = _convert_dimensions_param(provider, dimensions) + _server_embed_params = deep_merge(dimensions_param, server_embed_params or {}) super().__init__( model_name=model_name, @@ -93,7 +96,7 @@ def __init__( timeout=timeout, caching=caching, use_server=use_server, - server_embed_params=server_embed_params, + server_embed_params=_server_embed_params, multimodal=multimodal, **kwargs, ) diff --git a/pytidb/embeddings/dimensions.py b/pytidb/embeddings/dimensions.py index a941f74..b6d81c4 100644 --- a/pytidb/embeddings/dimensions.py +++ b/pytidb/embeddings/dimensions.py @@ -19,16 +19,72 @@ "openai/text-embedding-3-large": 3072, "openai/text-embedding-ada-002": 1536, # Cohere models + "cohere/embed-v4.0": 1536, "cohere/embed-english-v3.0": 1024, "cohere/embed-multilingual-v3.0": 1024, # Jina AI models "jina_ai/jina-embeddings-v4": 2048, "jina_ai/jina-embeddings-v3": 1024, "jina_ai/jina-clip-v2": 1024, - # TODO: remove these after jina_ai is released on prod. - "jina/jina-embeddings-v4": 2048, - "jina/jina-embeddings-v3": 1024, - "jina/jina-clip-v2": 1024, + # Gemini models + "gemini/gemini-embedding-001": 3072, + # Hugging Face models + "huggingface/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct": 1536, + "huggingface/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct": 3584, + "huggingface/Alibaba-NLP/gte-Qwen1.5-7B-instruct": 4096, + "huggingface/Alibaba-NLP/gte-Qwen2-1.5B-instruct": 8960, + "huggingface/Alibaba-NLP/gte-Qwen2-7B-instruct": 3584, + "huggingface/Alibaba-NLP/gte-multilingual-base": 768, + "huggingface/Alibaba-NLP/gte-modernbert-base": 768, + "huggingface/Alibaba-NLP/gte-base-en-v1.5": 768, + "huggingface/BAAI/bge-base-en": 768, + "huggingface/BAAI/bge-base-en-v1.5": 768, + "huggingface/BAAI/bge-base-zh": 768, + "huggingface/BAAI/bge-base-zh-v1.5": 768, + "huggingface/BAAI/bge-en-icl": 4096, + "huggingface/BAAI/bge-large-en": 1024, + "huggingface/BAAI/bge-large-en-v1.5": 1024, + "huggingface/BAAI/bge-large-zh": 1024, + "huggingface/BAAI/bge-large-zh-v1.5": 1024, + "huggingface/BAAI/bge-m3": 1024, + "huggingface/BAAI/bge-m3-unsupervised": 1024, + "huggingface/BAAI/bge-multilingual-gemma2": 3584, + "huggingface/BAAI/bge-small-en": 512, + "huggingface/BAAI/bge-small-en-v1.5": 512, + "huggingface/BAAI/bge-small-zh": 512, + "huggingface/BAAI/bge-small-zh-v1.5": 512, + "huggingface/Cohere/Cohere-embed-multilingual-v3.0": 1024, + "huggingface/Qwen/Qwen3-Embedding-0.6B": 1024, + "huggingface/Qwen/Qwen3-Embedding-4B": 2560, + "huggingface/Qwen/Qwen3-Embedding-8B": 4096, + "huggingface/Snowflake/snowflake-arctic-embed-l": 1024, + "huggingface/Snowflake/snowflake-arctic-embed-l-v2.0": 1024, + "huggingface/Snowflake/snowflake-arctic-embed-m": 768, + "huggingface/Snowflake/snowflake-arctic-embed-m-long": 768, + "huggingface/Snowflake/snowflake-arctic-embed-m-v1.5": 768, + "huggingface/Snowflake/snowflake-arctic-embed-m-v2.0": 768, + "huggingface/Snowflake/snowflake-arctic-embed-s": 384, + "huggingface/Snowflake/snowflake-arctic-embed-xs": 384, + "huggingface/intfloat/e5-base": 768, + "huggingface/intfloat/e5-base-v2": 768, + "huggingface/intfloat/e5-large": 1024, + "huggingface/intfloat/e5-large-v2": 1024, + "huggingface/intfloat/e5-mistral-7b-instruct": 4096, + "huggingface/intfloat/e5-small": 384, + "huggingface/intfloat/e5-small-v2": 384, + "huggingface/intfloat/multilingual-e5-base": 768, + "huggingface/intfloat/multilingual-e5-large": 1024, + "huggingface/intfloat/multilingual-e5-large-instruct": 1024, + "huggingface/intfloat/multilingual-e5-small": 384, + "huggingface/jinaai/jina-embedding-b-en-v1": 768, + "huggingface/jinaai/jina-embedding-s-en-v1": 512, + "huggingface/jinaai/jina-embeddings-v2-base-en": 768, + "huggingface/jinaai/jina-embeddings-v2-small-en": 512, + "huggingface/jinaai/jina-embeddings-v3": 1024, + "huggingface/jinaai/jina-embeddings-v4": 2048, + # Nvidia NIM models + "nvidia_nim/baai/bge-m3": 1024, + "nvidia_nim/nvidia/nv-embed-v1": 4096, } # Mapping of model aliases to their full names for backward compatibility diff --git a/pytidb/embeddings/utils.py b/pytidb/embeddings/utils.py index 7912b5c..dff6036 100644 --- a/pytidb/embeddings/utils.py +++ b/pytidb/embeddings/utils.py @@ -3,6 +3,7 @@ """ import base64 +from collections.abc import Mapping import io from pathlib import Path from typing import Optional, TYPE_CHECKING, Union @@ -153,3 +154,30 @@ def encode_pil_image_to_base64( return base64.b64encode(buffer.getvalue()).decode("utf-8") except Exception as e: raise ValueError(f"Failed to encode PIL Image to base64: {str(e)}") + + +def deep_merge(*dicts: Optional[dict]) -> dict: + """ + Deeply merge one or more dictionaries into the first one (in-place). + Later dictionaries override earlier ones. + """ + if not dicts: + return {} + + # Filter out None values and convert to list + valid_dicts = [d for d in dicts if d is not None] + if not valid_dicts: + return {} + + def _deep_merge(d: dict, u: dict) -> dict: + for k, v in u.items(): + if k in d and isinstance(d[k], Mapping) and isinstance(v, Mapping): + _deep_merge(d[k], v) + else: + d[k] = v + return d + + result = valid_dicts[0].copy() + for u in valid_dicts[1:]: + _deep_merge(result, u) + return result diff --git a/tests/test_auto_embedding_server.py b/tests/test_auto_embedding_server.py index 0918a17..e76bab6 100644 --- a/tests/test_auto_embedding_server.py +++ b/tests/test_auto_embedding_server.py @@ -5,23 +5,95 @@ from pytidb.embeddings import EmbeddingFunction from pytidb.schema import TableModel, Field + EMBEDDING_MODELS = [ { "id": "openai", "model_name": "openai/text-embedding-3-small", + "expected_similarity": 0.9, + }, + { + "id": "jina_ai", + "model_name": "jina_ai/jina-embeddings-v3", + "expected_similarity": 0.7, + }, + { + "id": "tidbcloud_free", + "model_name": "tidbcloud_free/amazon/titan-embed-text-v2", + "expected_similarity": 0.9, + }, + { + "id": "cohere", + "model_name": "cohere/embed-v4.0", + "expected_similarity": 0.7, + }, + { + "id": "gemini", + "model_name": "gemini/gemini-embedding-001", + "expected_similarity": 0.8, + }, + { + "id": "huggingface", + "model_name": "huggingface/intfloat/multilingual-e5-large", + "expected_similarity": 0.9, + }, + { + "id": "nvidia_nim", + "model_name": "nvidia_nim/baai/bge-m3", + "expected_similarity": 0.9, }, - # TODO: uncomment these after jina_ai and tidbcloud_free are released on prod. - # { - # "id": "jina_ai", - # "model_name": "jina_ai/jina-embeddings-v4", - # }, - # { - # "id": "tidbcloud_free", - # "model_name": "tidbcloud_free/amazon/titan-embed-text-v2", - # }, ] +def _should_skip(shared_client: TiDBClient, model_id: str) -> Optional[str]: + # Skip auto embedding tests if not connected to TiDB Serverless + if not shared_client.is_serverless: + return "Currently, Only TiDB Serverless supports auto embedding" + + # Configure embedding provider based on model + if model_id == "openai": + if not os.getenv("OPENAI_API_KEY"): + return "OPENAI_API_KEY is not set" + shared_client.configure_embedding_provider( + "openai", os.getenv("OPENAI_API_KEY") + ) + elif model_id == "jina_ai": + if not os.getenv("JINA_AI_API_KEY"): + return "JINA_AI_API_KEY is not set" + shared_client.configure_embedding_provider( + "jina_ai", os.getenv("JINA_AI_API_KEY") + ) + elif model_id == "cohere": + if not os.getenv("COHERE_API_KEY"): + return "COHERE_API_KEY is not set" + shared_client.configure_embedding_provider( + "cohere", os.getenv("COHERE_API_KEY") + ) + elif model_id == "gemini": + if not os.getenv("GEMINI_API_KEY"): + return "GEMINI_API_KEY is not set" + shared_client.configure_embedding_provider( + "gemini", os.getenv("GEMINI_API_KEY") + ) + elif model_id == "huggingface": + if not os.getenv("HUGGINGFACE_API_KEY"): + return "HUGGINGFACE_API_KEY is not set" + shared_client.configure_embedding_provider( + "huggingface", os.getenv("HUGGINGFACE_API_KEY") + ) + elif model_id == "nvidia_nim": + if not os.getenv("NVIDIA_NIM_API_KEY"): + return "NVIDIA_NIM_API_KEY is not set" + shared_client.configure_embedding_provider( + "nvidia_nim", os.getenv("NVIDIA_NIM_API_KEY") + ) + elif model_id == "tidbcloud_free": + # tidbcloud_free doesn't need additional API key configuration + pass + + return None + + @pytest.fixture( scope="module", params=EMBEDDING_MODELS, @@ -41,27 +113,15 @@ def text_embed(request): def test_auto_embedding(shared_client: TiDBClient, text_embed: EmbeddingFunction): - # Skip auto embedding tests if not connected to TiDB Serverless - if not shared_client.is_serverless: - pytest.skip("Currently, Only TiDB Serverless supports auto embedding") - model_id = text_embed._model_config["id"] - # Configure embedding provider based on model - if model_id == "openai": - shared_client.configure_embedding_provider( - "openai", os.getenv("OPENAI_API_KEY") - ) - elif model_id == "jina_ai": - shared_client.configure_embedding_provider( - "jina_ai", os.getenv("JINA_AI_API_KEY") - ) - elif model_id == "tidbcloud_free": - # tidbcloud_free doesn't need additional API key configuration - pass + # Check if test should be skipped + skip_reason = _should_skip(shared_client, model_id) + if skip_reason: + pytest.skip(skip_reason) class ChunkBase(TableModel, table=False): - __tablename__ = f"chunks_auto_embedding_with_{model_id}" + __tablename__ = f"chunks_with_{model_id}_auto_embedding" id: int = Field(primary_key=True) text: Optional[str] = Field() text_vec: Optional[list[float]] = text_embed.VectorField( @@ -71,7 +131,7 @@ class ChunkBase(TableModel, table=False): user_id: int = Field() Chunk = type( - f"ChunkAutoEmbeddingWith{model_id.capitalize()}", (ChunkBase,), {}, table=True + f"ChunkWith{model_id.capitalize()}AutoEmbedding", (ChunkBase,), {}, table=True ) tbl = shared_client.create_table(schema=Chunk, if_exists="overwrite") @@ -84,15 +144,15 @@ class ChunkBase(TableModel, table=False): assert len(chunk.text_vec) == text_embed.dimensions # Test bulk_insert with auto embedding (including empty text case) - chunks_via_model_instance = [ + chunk_entities = [ Chunk(id=3, text="baz", user_id=2), Chunk(id=4, text=None, user_id=2), # None will skip auto embedding. ] - chunks_via_dict = [ + chunk_dicts = [ {"id": 5, "text": "qux", "user_id": 3}, {"id": 6, "text": None, "user_id": 3}, # None will skip auto embedding. ] - chunks = tbl.bulk_insert(chunks_via_model_instance + chunks_via_dict) + chunks = tbl.bulk_insert(chunk_entities + chunk_dicts) for chunk in chunks: if chunk.text is None: assert chunk.text_vec is None @@ -104,7 +164,9 @@ class ChunkBase(TableModel, table=False): assert len(results) == 1 assert results[0].id == 2 assert results[0].text == "bar" - assert results[0].similarity_score >= 0.9 + assert ( + results[0].similarity_score >= text_embed._model_config["expected_similarity"] + ) # Test update with auto embedding, from empty to non-empty string chunk = tbl.get(4)