Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/keboola_mcp_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,7 @@ async def lifespan(app: Starlette):
async with sse_app.lifespan(app):
yield

app = Starlette(
middleware=[Middleware(ForwardSlashMiddleware)],
lifespan=lifespan
)
app = Starlette(middleware=[Middleware(ForwardSlashMiddleware)], lifespan=lifespan)
app.mount('/mcp', http_app)
app.mount('/sse', sse_app) # serves /sse/ and /messages
custom_routes.add_to_starlette(app)
Expand Down
185 changes: 116 additions & 69 deletions src/keboola_mcp_server/client.py

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion src/keboola_mcp_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class Config:
"""The URL to the Storage API."""
storage_token: Optional[str] = field(default=None, metadata={'aliases': ['storage_api_token']})
"""The token to access the storage API using the MCP tools."""
branch_id: Optional[str] = None
"""The branch ID to access the storage API using the MCP tools."""
workspace_schema: Optional[str] = None
"""Workspace schema to access the buckets, tables and execute sql queries."""
accept_secrets_in_url: Optional[bool] = None
Expand Down Expand Up @@ -45,6 +47,9 @@ def __post_init__(self) -> None:
value = f'https://{value}'
object.__setattr__(self, f.name, value)

if self.branch_id is not None and self.branch_id.lower() in ['', 'none', 'null', 'default', 'production']:
object.__setattr__(self, 'branch_id', None)

@staticmethod
def _normalize(name: str) -> str:
"""Removes dashes and underscores from the input string and turns it into lowercase."""
Expand Down Expand Up @@ -134,7 +139,7 @@ class MetadataField:
# expected value: 'true'
UPDATED_BY_MCP_PREFIX = 'KBC.MCP.updatedBy.version.'

# Brnach filtering works only for "fake development branches"
# Branch filtering works only for "fake development branches"
FAKE_DEVELOPMENT_BRANCH = 'KBC.createdBy.branch.id'

# Data type metadata fields
Expand Down
2 changes: 1 addition & 1 deletion src/keboola_mcp_server/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, base_url: str, project_id: str):

@classmethod
async def from_client(cls, client: KeboolaClient) -> 'ProjectLinksManager':
base_url = client.storage_client.base_api_url
base_url = client.storage_api_url
project_id = await client.storage_client.project_id()
return cls(base_url=base_url, project_id=project_id)

Expand Down
7 changes: 6 additions & 1 deletion src/keboola_mcp_server/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,12 @@ def _create_session_state(config: Config) -> dict[str, Any]:
raise ValueError('Storage API token is not provided.')
if not config.storage_api_url:
raise ValueError('Storage API URL is not provided.')
client = KeboolaClient(config.storage_token, config.storage_api_url, bearer_token=config.bearer_token)
client = KeboolaClient(
storage_api_url=config.storage_api_url,
storage_api_token=config.storage_token,
bearer_token=config.bearer_token,
branch_id=config.branch_id,
)
state[KeboolaClient.STATE_KEY] = client
LOG.info('Successfully initialized Storage API client.')
except Exception as e:
Expand Down
5 changes: 1 addition & 4 deletions src/keboola_mcp_server/tools/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,8 @@ async def create_oauth_url(
# Extract the token from response
sapi_token = token_response['token']

# Get the storage API URL from client
storage_api_url = client.storage_client.base_api_url

# Generate OAuth URL
query_params = urlencode({'token': sapi_token, 'sapiUrl': storage_api_url})
query_params = urlencode({'token': sapi_token, 'sapiUrl': client.storage_api_url})
fragment = f'/{component_id}/{config_id}'

oauth_url = urlunsplit(
Expand Down
40 changes: 29 additions & 11 deletions src/keboola_mcp_server/tools/storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Storage-related tools for the MCP server (buckets, tables, etc.)."""

import logging
from collections import defaultdict
from datetime import datetime
from typing import Annotated, Any, Literal, Optional, cast

Expand Down Expand Up @@ -68,10 +69,9 @@ def add_storage_tools(mcp: KeboolaMcpServer) -> None:

def _extract_description(values: dict[str, Any]) -> Optional[str]:
"""Extracts the description from values or metadata."""
if description := values.get('description'):
return description
else:
return get_metadata_property(values.get('metadata', []), MetadataField.DESCRIPTION)
if not (description := values.get('description')):
description = get_metadata_property(values.get('metadata', []), MetadataField.DESCRIPTION)
return description or None


class BucketDetail(BaseModel):
Expand All @@ -83,7 +83,7 @@ class BucketDetail(BaseModel):
serialization_alias='displayName',
)
description: Optional[str] = Field(None, description='Description of the bucket.')
stage: Optional[str] = Field(None, description='Stage of the bucket (in for input stage, out for output stage).')
stage: str = Field(description='Stage of the bucket (in for input stage, out for output stage).')
created: str = Field(description='Creation timestamp of the bucket.')
data_size_bytes: Optional[int] = Field(
None,
Expand Down Expand Up @@ -213,14 +213,32 @@ async def list_buckets(ctx: Context) -> ListBucketsOutput:
links_manager = await ProjectLinksManager.from_client(client)

raw_bucket_data = await client.storage_client.bucket_list(include=['metadata'])
production_branch_raw_buckets = [
bucket
for bucket in raw_bucket_data
if not (any(meta.get('key') == MetadataField.FAKE_DEVELOPMENT_BRANCH for meta in bucket.get('metadata', [])))
] # filter out buckets from "Fake development branches"

# group by buckets by the ID of a branch that they belong to
buckets_by_branch: dict[str, list[JsonDict]] = defaultdict(list)
for bucket in raw_bucket_data:
bucket_branch_id = get_metadata_property(bucket.get('metadata', []), MetadataField.FAKE_DEVELOPMENT_BRANCH)
bucket_branch_id = bucket_branch_id or '__PROD__'
buckets_by_branch[bucket_branch_id].append(bucket)

buckets: list[JsonDict] = []
if client.branch_id:
# add the dev branch buckets and collect the IDs of their production branch equivalents
id_prefix = f'c-{client.branch_id}-'
hidden_bucket_ids: set[str] = set()
for b in buckets_by_branch.get(client.branch_id, []):
buckets.append(b)
hidden_bucket_ids.add(b['id'].replace(id_prefix, 'c-'))

# add the production branch buckets that are not "shaded" by their dev branch equivalent
for b in buckets_by_branch.get('__PROD__', []):
if b['id'] not in hidden_bucket_ids:
buckets.append(b)
else:
buckets += buckets_by_branch.get('__PROD__', [])

return ListBucketsOutput(
buckets=[BucketDetail.model_validate(bucket) for bucket in production_branch_raw_buckets],
buckets=[BucketDetail.model_validate(bucket) for bucket in buckets],
links=[links_manager.get_bucket_dashboard_link()],
)

Expand Down
4 changes: 3 additions & 1 deletion src/keboola_mcp_server/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def from_state(cls, state: Mapping[str, Any]) -> 'WorkspaceManager':
return instance

def __init__(self, client: KeboolaClient, workspace_schema: str | None = None):
self._client = client
# We use the read-only workspace with access to all project data which lives in the production branch.
# Hence we need KeboolaClient bound to the production/default branch.
self._client = client.with_branch_id(None)
self._workspace_schema = workspace_schema
self._workspace: _Workspace | None = None
self._table_fqn_cache: dict[str, TableFqn] = {}
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
def keboola_client(mocker) -> KeboolaClient:
"""Creates mocked `KeboolaClient` instance with mocked sub-clients."""
client = mocker.MagicMock(KeboolaClient)
client.storage_api_url = 'https://connection.test.keboola.com'
client.branch_id = None
client.with_branch_id.return_value = client

# Mock API clients
client.storage_client = mocker.MagicMock(AsyncStorageClient)
client.storage_client.base_api_url = 'test://api.keboola.com'
client.storage_client.branch_id = 'default'
client.storage_client.project_id.return_value = '69420'
client.jobs_queue_client = mocker.MagicMock(JobsQueueClient)
client.ai_service_client = mocker.MagicMock(AIServiceClient)
Expand Down
10 changes: 8 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,28 @@ def test_from_dict(self, d: Mapping[str, str], expected: Config) -> None:
{'storage_token': None, 'workspace_schema': 'bar'},
Config(workspace_schema='bar'),
),
(Config(branch_id='foo'), {'branch-id': ''}, Config()),
(Config(branch_id='foo'), {'branch-id': 'none'}, Config()),
(Config(branch_id='foo'), {'branch-id': 'Null'}, Config()),
(Config(branch_id='foo'), {'branch-id': 'Default'}, Config()),
(Config(branch_id='foo'), {'branch-id': 'pRoDuCtIoN'}, Config()),
],
)
def test_replace_by(self, orig: Config, d: Mapping[str, str], expected: Config) -> None:
assert orig.replace_by(d) == expected

def test_defaults(self) -> None:
config = Config()
assert config.storage_token is None
assert config.storage_api_url is None
assert config.storage_token is None
assert config.branch_id is None
assert config.workspace_schema is None
assert config.accept_secrets_in_url is None

def test_no_token_password_in_repr(self) -> None:
config = Config(storage_token='foo')
assert str(config) == (
"Config(storage_api_url=None, storage_token='****', workspace_schema=None, "
"Config(storage_api_url=None, storage_token='****', branch_id=None, workspace_schema=None, "
'accept_secrets_in_url=None, oauth_client_id=None, oauth_client_secret=None, '
'oauth_server_url=None, oauth_scope=None, mcp_server_url=None, '
'jwt_secret=None, bearer_token=None)'
Expand Down
4 changes: 2 additions & 2 deletions tests/tools/components/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,14 @@ async def test_get_config(
type='ui-detail',
title=f'Configuration: {mock_configuration["name"]}',
url=(
f'test://api.keboola.com/admin/projects/69420/components/'
f'https://connection.test.keboola.com/admin/projects/69420/components/'
f'{mock_component["id"]}/{mock_configuration["id"]}'
),
),
Link(
type='ui-dashboard',
title=f'{mock_component["id"]} Configurations Dashboard',
url=f'test://api.keboola.com/admin/projects/69420/components/{mock_component["id"]}',
url=f'https://connection.test.keboola.com/admin/projects/69420/components/{mock_component["id"]}',
),
}

Expand Down
10 changes: 8 additions & 2 deletions tests/tools/test_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,14 @@ async def test_run_job(
assert job_detail.config_id == configuration_id
assert job_detail.result == {}
assert set(job_detail.links) == {
Link(type='ui-detail', title='Job: 123', url='test://api.keboola.com/admin/projects/69420/queue/123'),
Link(type='ui-dashboard', title='Jobs in the project', url='test://api.keboola.com/admin/projects/69420/queue'),
Link(
type='ui-detail', title='Job: 123', url='https://connection.test.keboola.com/admin/projects/69420/queue/123'
),
Link(
type='ui-dashboard',
title='Jobs in the project',
url='https://connection.test.keboola.com/admin/projects/69420/queue',
),
}

keboola_client.jobs_queue_client.create_job.assert_called_once_with(
Expand Down
3 changes: 1 addition & 2 deletions tests/tools/test_oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def test_create_oauth_url_success(mcp_context_client: Context, mock_token_
# Mock the storage client's token_create method to return the token response
keboola_client = KeboolaClient.from_state(mcp_context_client.session.state)
keboola_client.storage_client.token_create.return_value = mock_token_response
keboola_client.storage_client.base_api_url = 'https://connection.test.keboola.com'
keboola_client.storage_api_url = 'https://connection.test.keboola.com'

component_id = 'keboola.ex-google-analytics-v4'
config_id = 'config-123'
Expand Down Expand Up @@ -70,7 +70,6 @@ async def test_create_oauth_url_different_components(
# Mock the storage client
keboola_client = KeboolaClient.from_state(mcp_context_client.session.state)
keboola_client.storage_client.token_create.return_value = mock_token_response
keboola_client.storage_client.base_api_url = 'https://connection.test.keboola.com'

result = await create_oauth_url(component_id=component_id, config_id=config_id, ctx=mcp_context_client)

Expand Down
1 change: 0 additions & 1 deletion tests/tools/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ async def test_get_project_info(mocker: MockerFixture, mcp_context_client: Conte
keboola_client = KeboolaClient.from_state(mcp_context_client.session.state)
keboola_client.storage_client.verify_token = mocker.AsyncMock(return_value=token_data)
keboola_client.storage_client.branch_metadata_get = mocker.AsyncMock(return_value=metadata)
keboola_client.storage_client.base_api_url = 'https://connection.test.keboola.com'
workspace_manager = WorkspaceManager.from_state(mcp_context_client.session.state)
workspace_manager.get_sql_dialect = mocker.AsyncMock(return_value='Snowflake')

Expand Down
Loading