Skip to content

Commit d7418f3

Browse files
authored
Merge pull request #199 from keboola/AI-1172-client-typed-response
Ai 1172 Typed response for global search client method Make typed response in the global search client method Add integ tests for global search.
2 parents d29e098 + 6bd0591 commit d7418f3

File tree

3 files changed

+89
-5
lines changed

3 files changed

+89
-5
lines changed

integtests/test_client.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import logging
2+
3+
import pytest
4+
5+
from integtests.conftest import ProjectDef, TableDef
6+
from keboola_mcp_server.client import AsyncStorageClient, GlobalSearchResponse, KeboolaClient
7+
8+
LOG = logging.getLogger(__name__)
9+
10+
11+
class TestAsyncStorageClient:
12+
13+
@pytest.fixture
14+
def storage_client(self, keboola_client: KeboolaClient, keboola_project: ProjectDef) -> AsyncStorageClient:
15+
return keboola_client.storage_client
16+
17+
@pytest.mark.asyncio
18+
async def test_global_search(self, storage_client: AsyncStorageClient):
19+
not_existing_id = 'not-existing-id'
20+
ret = await storage_client.global_search(query=not_existing_id)
21+
assert isinstance(ret, GlobalSearchResponse)
22+
assert ret.all == 0
23+
assert ret.items == []
24+
assert ret.by_type == {'total': 0}
25+
assert ret.by_project == {}
26+
27+
@pytest.mark.asyncio
28+
async def test_global_search_with_results(self, storage_client: AsyncStorageClient, tables: list[TableDef]):
29+
search_for_name = 'test'
30+
is_global_search_enabled = await storage_client.is_enabled('global-search')
31+
if not is_global_search_enabled:
32+
LOG.warning('Global search is not enabled in the project. Skipping test. Please enable it in the project.')
33+
pytest.skip('Global search is not enabled in the project. Skipping test.')
34+
35+
ret = await storage_client.global_search(query=search_for_name, types=['table'])
36+
assert isinstance(ret, GlobalSearchResponse)
37+
assert ret.all == len(tables)
38+
assert len(ret.items) == len(tables)
39+
assert all(isinstance(item, GlobalSearchResponse.GlobalSearchResponseItem) for item in ret.items)
40+
assert all(item.type == 'table' for item in ret.items)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "keboola-mcp-server"
7-
version = "1.9.0"
7+
version = "1.9.1"
88
description = "MCP server for interacting with Keboola Connection"
99
readme = "README.md"
1010
requires-python = ">=3.10"

src/keboola_mcp_server/client.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import importlib.metadata
44
import logging
55
import os
6+
from datetime import datetime
67
from typing import Any, Iterable, Literal, Mapping, Optional, Union, cast
78

89
import httpx
9-
from pydantic import BaseModel, Field
10+
from pydantic import BaseModel, Field, field_validator
1011

1112
LOG = logging.getLogger(__name__)
1213

@@ -355,6 +356,48 @@ async def delete(
355356
return await self.raw_client.delete(endpoint=endpoint)
356357

357358

359+
class GlobalSearchResponse(BaseModel):
360+
"""The SAPI global search response."""
361+
362+
class GlobalSearchResponseItem(BaseModel):
363+
id: str = Field(description='The id of the item.')
364+
name: str = Field(description='The name of the item.')
365+
type: GlobalSearchTypes = Field(description='The type of the item.')
366+
full_path: dict[str, Any] = Field(
367+
description=(
368+
'The full path of the item containing project, branch and other information depending on the '
369+
'type of the item.'
370+
),
371+
alias='fullPath',
372+
)
373+
component_id: Optional[str] = Field(
374+
default=None, description='The id of the component the item belongs to.', alias='componentId'
375+
)
376+
organization_id: int = Field(
377+
description='The id of the organization the item belongs to.', alias='organizationId'
378+
)
379+
project_id: int = Field(description='The id of the project the item belongs to.', alias='projectId')
380+
project_name: str = Field(description='The name of the project the item belongs to.', alias='projectName')
381+
created: datetime = Field(description='The date and time the item was created in ISO format.')
382+
383+
all: int = Field(description='Total number of found results.')
384+
items: list[GlobalSearchResponseItem] = Field(
385+
description='List of search results containing the items of the GlobalSearchType.'
386+
)
387+
by_type: dict[str, int] = Field(
388+
description='Mapping of found types to the number of corresponding results.', alias='byType'
389+
)
390+
by_project: dict[str, str] = Field(description='Mapping of project id to project name.', alias='byProject')
391+
392+
@field_validator('by_type', 'by_project', mode='before')
393+
@classmethod
394+
def validate_dict_fields(cls, current_value: Any) -> Any:
395+
# If the value is empty-list/None, return an empty dictionary, otherwise return the value
396+
if not current_value:
397+
return dict()
398+
return current_value
399+
400+
358401
class AsyncStorageClient(KeboolaServiceClient):
359402

360403
def __init__(self, raw_client: RawKeboolaClient, branch_id: str = 'default') -> None:
@@ -838,7 +881,7 @@ async def global_search(
838881
limit: int = 100,
839882
offset: int = 0,
840883
types: list[GlobalSearchTypes] | None = None,
841-
) -> JsonDict:
884+
) -> GlobalSearchResponse:
842885
"""
843886
Searches for items in the storage. It allows you to search for entities by name across all projects within an
844887
organization, even those you do not have direct access to. The search is conducted only through entity names to
@@ -849,7 +892,7 @@ async def global_search(
849892
:param offset: The offset to start from, pagination parameter.
850893
:param types: The types of items to search for.
851894
"""
852-
params : dict[str, Any] = {
895+
params: dict[str, Any] = {
853896
'query': query,
854897
'projectIds[]': [await self.project_id()],
855898
'branchTypes[]': 'production',
@@ -858,7 +901,8 @@ async def global_search(
858901
'offset': offset,
859902
}
860903
params = {k: v for k, v in params.items() if v}
861-
return cast(JsonDict, await self.get(endpoint='global-search', params=params))
904+
raw_resp = await self.get(endpoint='global-search', params=params)
905+
return GlobalSearchResponse.model_validate(raw_resp)
862906

863907
async def table_detail(self, table_id: str) -> JsonDict:
864908
"""

0 commit comments

Comments
 (0)