diff --git a/integtests/test_errors.py b/integtests/test_errors.py index 14fbc964..934405c7 100644 --- a/integtests/test_errors.py +++ b/integtests/test_errors.py @@ -74,7 +74,7 @@ async def test_sql_api_invalid_query_error(self, mcp_context: Context): re.IGNORECASE ) with pytest.raises(ValueError, match=match): - await query_data('INVALID SQL SYNTAX HERE', mcp_context) + await query_data('INVALID SQL SYNTAX HERE', 'Invalid SQL query.', mcp_context) @pytest.mark.asyncio async def test_concurrent_error_handling(self, mcp_context: Context): diff --git a/integtests/tools/test_sql.py b/integtests/tools/test_sql.py index 2ea48968..a8ae627e 100644 --- a/integtests/tools/test_sql.py +++ b/integtests/tools/test_sql.py @@ -5,7 +5,7 @@ import pytest from mcp.server.fastmcp import Context -from keboola_mcp_server.tools.sql import get_sql_dialect, query_data +from keboola_mcp_server.tools.sql import QueryDataOutput, get_sql_dialect, query_data from keboola_mcp_server.tools.storage import get_table, list_buckets, list_tables LOG = logging.getLogger(__name__) @@ -26,14 +26,16 @@ async def test_query_data(mcp_context: Context): assert table.fully_qualified_name is not None, 'Table should have fully qualified name' sql_query = f'SELECT COUNT(*) as row_count FROM {table.fully_qualified_name}' - result = await query_data(sql_query=sql_query, ctx=mcp_context) + result = await query_data(sql_query=sql_query, query_name='Row Count Query', ctx=mcp_context) - # Verify result is CSV formatted string - assert isinstance(result, str) - assert len(result) > 0 + # Verify result is structured output + assert isinstance(result, QueryDataOutput) + assert result.query_name == 'Row Count Query' + assert isinstance(result.csv_data, str) + assert len(result.csv_data) > 0 # Parse the CSV to verify structure - csv_reader = csv.reader(StringIO(result)) + csv_reader = csv.reader(StringIO(result.csv_data)) rows = list(csv_reader) # Should have header and one data row @@ -52,4 +54,4 @@ async def test_query_data_invalid_query(mcp_context: Context): invalid_sql = 'INVALID SQL SYNTAX SELECT * FROM' with pytest.raises(ValueError, match='Failed to run SQL query'): - await query_data(sql_query=invalid_sql, ctx=mcp_context) + await query_data(sql_query=invalid_sql, query_name='Invalid Query Test', ctx=mcp_context) diff --git a/pyproject.toml b/pyproject.toml index 3cef9221..48806ce9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "keboola-mcp-server" -version = "1.13.1" +version = "1.14.0" description = "MCP server for interacting with Keboola Connection" readme = "README.md" requires-python = ">=3.10" diff --git a/src/keboola_mcp_server/tools/sql.py b/src/keboola_mcp_server/tools/sql.py index 2337faf3..fcf16eeb 100644 --- a/src/keboola_mcp_server/tools/sql.py +++ b/src/keboola_mcp_server/tools/sql.py @@ -5,7 +5,7 @@ from fastmcp import Context, FastMCP from fastmcp.tools import FunctionTool -from pydantic import Field +from pydantic import BaseModel, Field from keboola_mcp_server.errors import tool_errors from keboola_mcp_server.mcp import with_session_state @@ -14,6 +14,13 @@ LOG = logging.getLogger(__name__) +class QueryDataOutput(BaseModel): + """Output model for SQL query results.""" + + query_name: str = Field(description='The name of the executed query') + csv_data: str = Field(description='The retrieved data in CSV format') + + def add_sql_tools(mcp: FastMCP) -> None: """Add tools to the MCP server.""" mcp.add_tool(FunctionTool.from_function(query_data)) @@ -34,8 +41,18 @@ async def get_sql_dialect( @with_session_state() async def query_data( sql_query: Annotated[str, Field(description='SQL SELECT query to run.')], + query_name: Annotated[ + str, + Field( + description=( + 'A concise, human-readable name for this query based on its purpose and what data it retrieves. ' + 'Use normal words with spaces (e.g., "Customer Orders Last Month", "Top Selling Products", ' + '"User Activity Summary").' + ) + ), + ], ctx: Context, -) -> Annotated[str, Field(description='The retrieved data in a CSV format.')]: +) -> Annotated[QueryDataOutput, Field(description='The query results with name and CSV data.')]: """ Executes an SQL SELECT query to get the data from the underlying database. * When constructing the SQL SELECT query make sure to check the SQL dialect @@ -62,7 +79,10 @@ async def query_data( writer.writeheader() writer.writerows(data.rows) - return output.getvalue() + return QueryDataOutput( + query_name=query_name, + csv_data=output.getvalue() + ) else: raise ValueError(f'Failed to run SQL query, error: {result.message}') diff --git a/tests/tools/test_sql.py b/tests/tools/test_sql.py index fcc9f8bb..2bdbf498 100644 --- a/tests/tools/test_sql.py +++ b/tests/tools/test_sql.py @@ -6,7 +6,7 @@ from pydantic import TypeAdapter from keboola_mcp_server.client import KeboolaClient -from keboola_mcp_server.tools.sql import get_sql_dialect, query_data +from keboola_mcp_server.tools.sql import QueryDataOutput, get_sql_dialect, query_data from keboola_mcp_server.workspace import ( QueryResult, SqlSelectData, @@ -17,15 +17,17 @@ @pytest.mark.asyncio @pytest.mark.parametrize( - ('query', 'result', 'expected'), + ('query', 'query_name', 'result', 'expected_csv'), [ ( 'select 1;', + 'Simple Count Query', QueryResult(status='ok', data=SqlSelectData(columns=['a'], rows=[{'a': 1}])), 'a\r\n1\r\n', # CSV ), ( 'select id, name, email from user;', + 'User Details List', QueryResult( status='ok', data=SqlSelectData( @@ -40,18 +42,23 @@ ), ( 'create table foo (id integer, name varchar);', + 'Create Table Operation', QueryResult(status='ok', message='1 table created'), 'message\r\n1 table created\r\n', # CSV ), ], ) -async def test_query_data(query: str, result: QueryResult, expected: str, mcp_context_client: Context, mocker): +async def test_query_data( + query: str, query_name: str, result: QueryResult, expected_csv: str, mcp_context_client: Context, mocker +): workspace_manager = mocker.AsyncMock(WorkspaceManager) workspace_manager.execute_query.return_value = result mcp_context_client.session.state[WorkspaceManager.STATE_KEY] = workspace_manager - result = await query_data(query, mcp_context_client) - assert result == expected + result = await query_data(query, query_name, mcp_context_client) + assert isinstance(result, QueryDataOutput) + assert result.query_name == query_name + assert result.csv_data == expected_csv @pytest.mark.asyncio diff --git a/uv.lock b/uv.lock index ca5c477c..e5cd64fe 100644 --- a/uv.lock +++ b/uv.lock @@ -902,7 +902,7 @@ wheels = [ [[package]] name = "keboola-mcp-server" -version = "1.13.1" +version = "1.14.0" source = { editable = "." } dependencies = [ { name = "fastmcp" },