Skip to content

Commit f997732

Browse files
committed
Merge AI-1172-global-search-tests to AI-1172-refactor
2 parents 51147ff + eded2ae commit f997732

File tree

8 files changed

+152
-126
lines changed

8 files changed

+152
-126
lines changed

integtests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def _create_tables(storage_client: SyncStorageClient) -> list[TableDef]:
174174
ConfigDef(
175175
component_id='ex-generic-v2',
176176
configuration_id=None,
177-
internal_id='config1',
177+
internal_id='test_config1',
178178
),
179179
]
180180

File renamed without changes.

integtests/tools/test_search.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from integtests.conftest import BucketDef, ConfigDef, TableDef
77
from keboola_mcp_server.client import KeboolaClient, SuggestedComponent
8-
from keboola_mcp_server.tools.search import GlobalSearchOutput, find_component_id, global_search
8+
from keboola_mcp_server.tools.search import GlobalSearchOutput, find_component_id, find_ids_by_name
99

1010
LOG = logging.getLogger(__name__)
1111

@@ -19,7 +19,7 @@ async def test_global_search_end_to_end(
1919
configs: list[ConfigDef],
2020
) -> None:
2121
"""
22-
Test the global_search tool end-to-end by searching for entities that exist in the test project.
22+
Test the global_search tool end-to-end by searching for items that exist in the test project.
2323
This verifies that the search returns expected results for buckets, tables, and configurations.
2424
"""
2525

@@ -28,48 +28,55 @@ async def test_global_search_end_to_end(
2828
LOG.warning('Global search is not available. Please enable it in the project settings.')
2929
pytest.skip('Global search is not available. Please enable it in the project settings.')
3030

31-
# Search for test entities by name prefix 'test' which should match our test data
32-
result = await global_search(
33-
ctx=mcp_context, name_prefixes=['test'], entity_types=tuple(), limit=50, offset=0 # Search all types
31+
# Search for test items by name prefix 'test' which should match our test data
32+
result = await find_ids_by_name(
33+
ctx=mcp_context, name_prefixes=['test'], item_types=tuple(), limit=50, offset=0 # Search all types
3434
)
3535

3636
# Verify the result structure
3737
assert isinstance(result, GlobalSearchOutput)
3838
assert isinstance(result.counts, dict)
39-
assert isinstance(result.type_groups, list)
39+
assert isinstance(result.groups, dict)
4040
assert 'total' in result.counts
4141

4242
# Verify we found some results
43-
assert result.counts['total'] > 0, 'Should find at least some test entities'
43+
assert result.counts['total'] > 0, 'Should find at least some test items'
4444

4545
# Create sets of expected IDs for verification
4646
expected_bucket_ids = {bucket.bucket_id for bucket in buckets}
4747
expected_table_ids = {table.table_id for table in tables}
4848
expected_config_ids = {config.configuration_id for config in configs if config.configuration_id}
4949

5050
# Check that we can find test buckets
51-
bucket_groups = [group for group in result.type_groups if group.group_type == 'bucket']
52-
if bucket_groups:
53-
bucket_group = bucket_groups[0]
54-
found_bucket_ids = {item.id for item in bucket_group.group_items}
55-
# At least some test buckets should be found
56-
assert found_bucket_ids.intersection(expected_bucket_ids), 'Should find at least one test bucket'
51+
bucket_groups = [group for group in result.groups.values() if group.type == 'bucket']
52+
assert len(bucket_groups) == 1
53+
bucket_group = bucket_groups[0]
54+
found_bucket_ids = {item.id for item in bucket_group.items}
55+
# At least some test buckets should be found
56+
assert found_bucket_ids.intersection(expected_bucket_ids), 'Should find at least one test bucket'
5757

5858
# Check that we can find test tables
59-
table_groups = [group for group in result.type_groups if group.group_type == 'table']
60-
if table_groups:
61-
table_group = table_groups[0]
62-
found_table_ids = {item.id for item in table_group.group_items}
63-
# At least some test tables should be found
64-
assert found_table_ids.intersection(expected_table_ids), 'Should find at least one test table'
59+
table_groups = [group for group in result.groups.values() if group.type == 'table']
60+
assert len(table_groups) == 1
61+
table_group = table_groups[0]
62+
found_table_ids = {item.id for item in table_group.items}
63+
# At least some test tables should be found
64+
assert found_table_ids.intersection(expected_table_ids), 'Should find at least one test table'
6565

6666
# Check that we can find test configurations
67-
config_groups = [group for group in result.type_groups if group.group_type == 'configuration']
68-
if config_groups:
69-
config_group = config_groups[0]
70-
found_config_ids = {item.id for item in config_group.group_items}
71-
# At least some test configurations should be found
72-
assert found_config_ids.intersection(expected_config_ids), 'Should find at least one test configuration'
67+
config_groups = [group for group in result.groups.values() if group.type == 'configuration']
68+
assert len(config_groups) == 1
69+
config_group = config_groups[0]
70+
found_config_ids = {item.id for item in config_group.items}
71+
# At least some test configurations should be found
72+
assert found_config_ids.intersection(expected_config_ids), 'Should find at least one test configuration'
73+
74+
config_groups = [group for group in result.groups.values() if group.type == 'configuration']
75+
assert len(config_groups) == 1
76+
config_group = config_groups[0]
77+
found_config_ids = {item.id for item in config_group.items}
78+
# At least some test configurations should be found
79+
assert found_config_ids.intersection(expected_config_ids), 'Should find at least one test configuration'
7380

7481

7582
@pytest.mark.asyncio
@@ -82,7 +89,5 @@ async def test_find_component_id(mcp_context: Context):
8289

8390
assert isinstance(result, list)
8491
assert len(result) > 0
92+
assert all(isinstance(component, SuggestedComponent) for component in result)
8593
assert generic_extractor_id in [component.component_id for component in result]
86-
87-
for component in result:
88-
assert isinstance(component, SuggestedComponent)

src/keboola_mcp_server/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
# Project features that can be checked with the is_enabled method
2222
ProjectFeature = Literal['global-search']
2323
# Input types for the global search endpoint parameters
24-
GlobalSearchBranchType = Literal['production', 'development']
25-
GlobalSearchType = Literal[
24+
BranchType = Literal['production', 'development']
25+
ItemType = Literal[
2626
'flow',
2727
'bucket',
2828
'table',
@@ -362,7 +362,7 @@ class GlobalSearchResponse(BaseModel):
362362
class Item(BaseModel):
363363
id: str = Field(description='The id of the item.')
364364
name: str = Field(description='The name of the item.')
365-
type: GlobalSearchType = Field(description='The type of the item.')
365+
type: ItemType = Field(description='The type of the item.')
366366
full_path: dict[str, Any] = Field(
367367
description=(
368368
'The full path of the item containing project, branch and other information depending on the '
@@ -878,7 +878,7 @@ async def global_search(
878878
query: str,
879879
limit: int = 100,
880880
offset: int = 0,
881-
types: Sequence[GlobalSearchType] = tuple(),
881+
types: Sequence[ItemType] = tuple(),
882882
) -> GlobalSearchResponse:
883883
"""
884884
Searches for items in the storage. It allows you to search for entities by name across all projects within an

src/keboola_mcp_server/tools/doc.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111

1212
LOG = logging.getLogger(__name__)
1313

14-
MAX_GLOBAL_SEARCH_LIMIT = 100
15-
DEFAULT_GLOBAL_SEARCH_LIMIT = 50
16-
1714

1815
def add_doc_tools(mcp: FastMCP) -> None:
1916
"""Add tools to the MCP server."""

src/keboola_mcp_server/tools/search.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from fastmcp.tools import FunctionTool
88
from pydantic import BaseModel, Field
99

10-
from keboola_mcp_server.client import GlobalSearchResponse, GlobalSearchType, KeboolaClient, SuggestedComponent
10+
from keboola_mcp_server.client import GlobalSearchResponse, ItemType, KeboolaClient, SuggestedComponent
1111
from keboola_mcp_server.errors import tool_errors
1212
from keboola_mcp_server.mcp import with_session_state
1313

@@ -20,8 +20,8 @@
2020
def add_search_tools(mcp: FastMCP) -> None:
2121
"""Add tools to the MCP server."""
2222
search_tools = [
23-
global_search,
2423
find_component_id,
24+
find_ids_by_name,
2525
]
2626
for tool in search_tools:
2727
LOG.info(f'Adding tool {tool.__name__} to the MCP server.')
@@ -30,19 +30,19 @@ def add_search_tools(mcp: FastMCP) -> None:
3030
LOG.info('Search tools initialized.')
3131

3232

33-
class GlobalSearchItemsGroup(BaseModel):
33+
class ItemsGroup(BaseModel):
3434
"""Group of items of the same type found in the global search."""
3535

36-
class GroupItem(BaseModel):
36+
class Item(BaseModel):
3737
"""An item corresponding to its group type found in the global search."""
3838

3939
name: str = Field(description='The name of the item.')
4040
id: str = Field(description='The id of the item.')
41-
created: datetime = Field(description='The date and time the entity was created.')
41+
created: datetime = Field(description='The date and time the item was created.')
4242
additional_info: dict[str, Any] = Field(description='Additional information about the item.')
4343

4444
@classmethod
45-
def from_api_response(cls, item: GlobalSearchResponse.Item) -> 'GlobalSearchItemsGroup.GroupItem':
45+
def from_api_response(cls, item: GlobalSearchResponse.Item) -> 'ItemsGroup.Item':
4646
"""Creates an Item from the item API response."""
4747
add_info = {}
4848
if item.type == 'table':
@@ -60,54 +60,51 @@ def from_api_response(cls, item: GlobalSearchResponse.Item) -> 'GlobalSearchItem
6060
add_info['configuration_name'] = configuration_info['name']
6161
return cls.model_construct(name=item.name, id=item.id, created=item.created, additional_info=add_info)
6262

63-
group_type: GlobalSearchType = Field(description='The type of the items in the group.')
64-
group_count: int = Field(description='Number of items in the group.')
65-
group_items: list[GroupItem] = Field(
63+
type: ItemType = Field(description='The type of the items in the group.')
64+
count: int = Field(description='Number of items in the group.')
65+
items: list[Item] = Field(
6666
description=('List of items for the type found in the global search, sorted by relevance and creation time.')
6767
)
6868

6969
@classmethod
70-
def from_api_response(
71-
cls, group_type: GlobalSearchType, group_items: list[GlobalSearchResponse.Item]
72-
) -> 'GlobalSearchItemsGroup':
73-
"""Creates a GlobalSearchItemsGroupedByType from the API response items and a type."""
70+
def from_api_response(cls, type: ItemType, items: list[GlobalSearchResponse.Item]) -> 'ItemsGroup':
71+
"""Creates a ItemsGroup from the API response items and a type."""
7472
# filter the items by the given type to be sure
75-
group_items = [item for item in group_items if item.type == group_type]
73+
items = [item for item in items if item.type == type]
7674
return cls.model_construct(
77-
group_type=group_type,
78-
group_count=len(group_items),
79-
group_items=[GlobalSearchItemsGroup.GroupItem.from_api_response(item) for item in group_items],
75+
type=type,
76+
count=len(items),
77+
items=[ItemsGroup.Item.from_api_response(item) for item in items],
8078
)
8179

8280

8381
class GlobalSearchOutput(BaseModel):
8482
"""A result of a global search query for multiple name substrings."""
8583

86-
counts: dict[str, int] = Field(description='Number of items found for each type.')
87-
type_groups: list[GlobalSearchItemsGroup] = Field(description='List of results grouped by type.')
84+
counts: dict[str, int] = Field(description='Number of items in total and for each type.')
85+
groups: dict[ItemType, ItemsGroup] = Field(description='Search results.')
8886

8987
@classmethod
9088
def from_api_responses(cls, response: GlobalSearchResponse) -> 'GlobalSearchOutput':
91-
"""Creates a GlobalSearchResult from the API responses."""
92-
items_by_type = defaultdict(list)
89+
"""Creates a GlobalSearchOutput from the API responses."""
90+
items_by_type: defaultdict[ItemType, list[GlobalSearchResponse.Item]] = defaultdict(list)
9391
for item in response.items:
9492
items_by_type[item.type].append(item)
9593
return cls.model_construct(
9694
counts=response.by_type, # contains counts for "total", and for each found type.
97-
type_groups=[
98-
GlobalSearchItemsGroup.from_api_response(group_type=type, group_items=items)
99-
for type, items in sorted(items_by_type.items(), key=lambda x: x[0])
100-
],
95+
groups={
96+
type: ItemsGroup.from_api_response(type=type, items=items) for type, items in items_by_type.items()
97+
},
10198
)
10299

103100

104101
@tool_errors()
105102
@with_session_state()
106-
async def global_search(
103+
async def find_ids_by_name(
107104
ctx: Context,
108-
name_prefixes: Annotated[list[str], Field(description='Name prefixes to look for inside entity name.')],
109-
entity_types: Annotated[
110-
Sequence[GlobalSearchType], Field(description='Optional list of keboola object types to search for.')
105+
name_prefixes: Annotated[list[str], Field(description='Name prefixes to match against item names.')],
106+
item_types: Annotated[
107+
Sequence[ItemType], Field(description='Optional list of keboola item types to filter by.')
111108
] = tuple(),
112109
limit: Annotated[
113110
int,
@@ -116,14 +113,17 @@ async def global_search(
116113
f'{MAX_GLOBAL_SEARCH_LIMIT}).'
117114
),
118115
] = DEFAULT_GLOBAL_SEARCH_LIMIT,
119-
offset: Annotated[int, Field(description='How many matching items to skip, pagination.')] = 0,
120-
) -> Annotated[GlobalSearchOutput, Field(description='Search results ordered by relevance, then creation time.')]:
116+
offset: Annotated[int, Field(description='Number of matching items to skip, pagination.')] = 0,
117+
) -> Annotated[
118+
GlobalSearchOutput,
119+
Field(description='Search results grouped by item type, ordered by relevance and creation time.'),
120+
]:
121121
"""
122-
Searches for Keboola entities by each name prefix in the production branch of the current project, potentially
123-
narrowed down by entity type, limited and paginated. Results are ordered by relevance, then creation time.
122+
Searches for Keboola items in the production branch of the current project whose names match the given prefixes,
123+
potentially narrowed down by item type, limited and paginated. Results are ordered by relevance, then creation time.
124124
125125
Considerations:
126-
- The search is purely name-based, and an entity is returned when its name or any word in the name starts with any
126+
- The search is purely name-based, and an item is returned when its name or any word in the name starts with any
127127
of the "name_prefixes" parameter.
128128
"""
129129

@@ -144,7 +144,7 @@ async def global_search(
144144
# separately.
145145
joined_prefixes = ' '.join(name_prefixes)
146146
response = await client.storage_client.global_search(
147-
query=joined_prefixes, types=entity_types, limit=limit, offset=offset
147+
query=joined_prefixes, types=item_types, limit=limit, offset=offset
148148
)
149149
return GlobalSearchOutput.from_api_responses(response)
150150

tests/test_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ async def test_list_tools(self):
2626
'create_sql_transformation',
2727
'docs_query',
2828
'find_component_id',
29+
'find_ids_by_name',
2930
'get_bucket',
3031
'get_component',
3132
'get_config',
@@ -36,7 +37,6 @@ async def test_list_tools(self):
3637
'get_project_info',
3738
'get_sql_dialect',
3839
'get_table',
39-
'global_search',
4040
'list_buckets',
4141
'list_configs',
4242
'list_flows',

0 commit comments

Comments
 (0)