Skip to content

Commit e546304

Browse files
committed
Add MCP Client Plugin
1 parent ca9ea18 commit e546304

File tree

12 files changed

+672
-3
lines changed

12 files changed

+672
-3
lines changed

packages/ai/src/microsoft/teams/ai/agent.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from typing import Any, Awaitable, Callable
77

8+
from microsoft.teams.ai.plugin import AIPluginProtocol
9+
810
from .ai_model import AIModel
911
from .chat_prompt import ChatPrompt, ChatSendResult
1012
from .function import Function
@@ -18,8 +20,15 @@ class Agent(ChatPrompt):
1820
through the existence of the Agent.
1921
"""
2022

21-
def __init__(self, model: AIModel, *, memory: Memory | None = None, functions: list[Function[Any]] | None = None):
22-
super().__init__(model, functions=functions)
23+
def __init__(
24+
self,
25+
model: AIModel,
26+
*,
27+
memory: Memory | None = None,
28+
functions: list[Function[Any]] | None = None,
29+
plugins: list[AIPluginProtocol] | None = None,
30+
):
31+
super().__init__(model, functions=functions, plugins=plugins)
2332
self.memory = memory or ListMemory()
2433

2534
async def send(

packages/mcp/README.md

Whitespace-only changes.

packages/mcp/pyproject.toml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
[project]
2+
name = "microsoft.teams.mcp"
3+
version = "0.0.1-alpha.1"
4+
description = "library for handling mcp with teams ai library"
5+
authors = [{ name = "Microsoft", email = "[email protected]" }]
6+
readme = "README.md"
7+
requires-python = ">=3.12"
8+
repository = "https://github.com/microsoft/teams.py"
9+
keywords = ["microsoft", "teams", "ai", "bot", "agents"]
10+
license = "MIT"
11+
dependencies = [
12+
"mcp>=1.13.1",
13+
]
14+
15+
[build-system]
16+
requires = ["hatchling"]
17+
build-backend = "hatchling.build"
18+
19+
[tool.hatch.build.targets.wheel]
20+
packages = ["src/microsoft"]
21+
22+
[tool.hatch.build.targets.sdist]
23+
include = ["src"]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
"""
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
"""
5+
6+
from .ai_plugin import McpClientPlugin, McpClientPluginParams, McpToolDetails
7+
8+
__all__ = ["McpClientPlugin", "McpClientPluginParams", "McpToolDetails"]
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
"""
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
"""
5+
6+
import asyncio
7+
import json
8+
import logging
9+
import time
10+
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union
11+
12+
from microsoft.teams.ai.function import Function
13+
from microsoft.teams.ai.plugin import BaseAIPlugin
14+
from microsoft.teams.common.logging import ConsoleLogger
15+
from pydantic import BaseModel
16+
17+
from mcp import ClientSession
18+
from mcp.types import TextContent
19+
20+
from .transport import create_transport
21+
22+
23+
class McpToolDetails(BaseModel):
24+
"""Details of an MCP tool."""
25+
26+
name: str
27+
description: str
28+
input_schema: Dict[str, Any]
29+
30+
31+
class McpCachedValue:
32+
"""Cached value for MCP server data."""
33+
34+
def __init__(
35+
self,
36+
transport: Optional[str] = None,
37+
available_tools: Optional[List[McpToolDetails]] = None,
38+
last_attempted_fetch: Optional[float] = None,
39+
):
40+
self.transport = transport
41+
self.available_tools = available_tools or []
42+
self.last_attempted_fetch = last_attempted_fetch
43+
44+
45+
class McpClientPluginParams:
46+
"""Parameters for MCP client plugin configuration."""
47+
48+
def __init__(
49+
self,
50+
transport: Optional[str] = "streamable_http",
51+
available_tools: Optional[List[McpToolDetails]] = None,
52+
headers: Optional[Dict[str, Union[str, Callable[[], Union[str, Awaitable[str]]]]]] = None,
53+
skip_if_unavailable: Optional[bool] = True,
54+
refetch_timeout_ms: Optional[int] = None,
55+
):
56+
self.transport = transport
57+
self.available_tools = available_tools
58+
self.headers = headers
59+
self.skip_if_unavailable = skip_if_unavailable
60+
self.refetch_timeout_ms = refetch_timeout_ms
61+
62+
63+
class McpClientPlugin(BaseAIPlugin):
64+
"""MCP Client Plugin for Teams AI integration."""
65+
66+
def __init__(
67+
self,
68+
name: str = "mcp_client",
69+
version: str = "0.0.0",
70+
cache: Optional[Dict[str, McpCachedValue]] = None,
71+
logger: Optional[logging.Logger] = None,
72+
refetch_timeout_ms: int = 24 * 60 * 60 * 1000, # 1 day
73+
):
74+
super().__init__(name)
75+
76+
self._version = version
77+
self._cache: Dict[str, McpCachedValue] = cache or {}
78+
self._logger = logger or ConsoleLogger().create_logger(self.name)
79+
self._refetch_timeout_ms = refetch_timeout_ms
80+
81+
# Track MCP server URLs and their parameters
82+
self._mcp_server_params: Dict[str, McpClientPluginParams] = {}
83+
84+
@property
85+
def version(self) -> str:
86+
"""Get the plugin version."""
87+
return self._version
88+
89+
@property
90+
def cache(self) -> Dict[str, McpCachedValue]:
91+
"""Get the plugin cache."""
92+
return self._cache
93+
94+
@property
95+
def refetch_timeout_ms(self) -> int:
96+
"""Get the refetch timeout in milliseconds."""
97+
return self._refetch_timeout_ms
98+
99+
def add_mcp_server(self, url: str, params: Optional[McpClientPluginParams] = None) -> None:
100+
"""Add an MCP server to be used by this plugin."""
101+
self._mcp_server_params[url] = params or McpClientPluginParams()
102+
103+
# Update cache if tools are provided
104+
if params and params.available_tools:
105+
self._cache[url] = McpCachedValue(
106+
transport=params.transport,
107+
available_tools=params.available_tools,
108+
last_attempted_fetch=None,
109+
)
110+
111+
async def on_build_functions(self, functions: List[Function[BaseModel]]) -> List[Function[BaseModel]]:
112+
"""Build functions from MCP tools."""
113+
await self._fetch_tools_if_needed()
114+
115+
# Create functions from cached tools
116+
all_functions = list(functions)
117+
118+
for url, params in self._mcp_server_params.items():
119+
cached_data = self._cache.get(url)
120+
available_tools = cached_data.available_tools if cached_data else []
121+
122+
for tool in available_tools:
123+
# Create a function for each tool
124+
function = self._create_function_from_tool(url, tool, params)
125+
all_functions.append(function)
126+
127+
return all_functions
128+
129+
async def _fetch_tools_if_needed(self) -> None:
130+
"""Fetch tools from MCP servers if needed."""
131+
fetch_needed: List[Tuple[str, McpClientPluginParams]] = []
132+
current_time = time.time() * 1000 # Convert to milliseconds
133+
134+
for url, params in self._mcp_server_params.items():
135+
# Skip if tools are explicitly provided
136+
if params.available_tools:
137+
continue
138+
139+
cached_data = self._cache.get(url)
140+
should_fetch = (
141+
not cached_data
142+
or not cached_data.available_tools
143+
or not cached_data.last_attempted_fetch
144+
or (current_time - cached_data.last_attempted_fetch)
145+
> (params.refetch_timeout_ms or self._refetch_timeout_ms)
146+
)
147+
148+
if should_fetch:
149+
fetch_needed.append((url, params))
150+
151+
# Fetch tools in parallel
152+
if fetch_needed:
153+
tasks = [self._fetch_tools_from_server(url, params) for url, params in fetch_needed]
154+
results = await asyncio.gather(*tasks, return_exceptions=True)
155+
156+
for i, (url, params) in enumerate(fetch_needed):
157+
result = results[i]
158+
if isinstance(result, Exception):
159+
self._logger.error(f"Failed to fetch tools from {url}: {result}")
160+
if not params.skip_if_unavailable:
161+
raise result
162+
elif isinstance(result, list):
163+
# Update cache with fetched tools
164+
if url not in self._cache:
165+
self._cache[url] = McpCachedValue()
166+
self._cache[url].available_tools = result
167+
self._cache[url].last_attempted_fetch = current_time
168+
self._cache[url].transport = params.transport
169+
170+
self._logger.debug(f"Cached {len(result)} tools for {url}")
171+
172+
def _create_function_from_tool(
173+
self, url: str, tool: Union[McpToolDetails, Dict[str, Any]], plugin_params: McpClientPluginParams
174+
) -> Function[BaseModel]:
175+
"""Create a Teams AI function from an MCP tool."""
176+
if isinstance(tool, dict):
177+
tool_name = tool["name"]
178+
tool_description = tool["description"]
179+
else:
180+
tool_name = tool.name
181+
tool_description = tool.description
182+
183+
async def handler(params: BaseModel) -> str:
184+
"""Handle MCP tool call."""
185+
try:
186+
result = await self._call_mcp_tool(url, tool_name, params.model_dump(), plugin_params)
187+
return str(result)
188+
except Exception as e:
189+
self._logger.error(f"Error calling tool {tool_name} on {url}: {e}")
190+
raise
191+
192+
return Function(name=tool_name, description=tool_description, parameter_schema=BaseModel, handler=handler)
193+
194+
async def _fetch_tools_from_server(self, url: str, params: McpClientPluginParams) -> List[McpToolDetails]:
195+
"""Fetch tools from a specific MCP server."""
196+
transport_context = create_transport(url, params.transport or "streamable_http", params.headers)
197+
198+
async with transport_context as (read_stream, write_stream):
199+
async with ClientSession(read_stream, write_stream) as session:
200+
# Initialize the connection
201+
await session.initialize()
202+
203+
# List available tools
204+
tools_response = await session.list_tools()
205+
206+
# Convert MCP tools to our format
207+
tools: list[McpToolDetails] = []
208+
for tool in tools_response.tools:
209+
tools.append(
210+
McpToolDetails(
211+
name=tool.name, description=tool.description or "", input_schema=tool.inputSchema or {}
212+
)
213+
)
214+
215+
return tools
216+
217+
async def _call_mcp_tool(
218+
self, url: str, tool_name: str, arguments: Dict[str, Any], params: McpClientPluginParams
219+
) -> Optional[Union[str, List[str]]]:
220+
"""Call a specific tool on an MCP server."""
221+
transport_context = create_transport(url, params.transport or "streamable_http", params.headers)
222+
223+
async with transport_context as (read_stream, write_stream):
224+
async with ClientSession(read_stream, write_stream) as session:
225+
# Initialize the connection
226+
await session.initialize()
227+
228+
# Call the tool
229+
result = await session.call_tool(tool_name, arguments)
230+
231+
# Return the content from the result
232+
if result.content:
233+
if len(result.content) == 1:
234+
content_item = result.content[0]
235+
if isinstance(content_item, TextContent):
236+
return content_item.text
237+
else:
238+
return str(content_item)
239+
else:
240+
contents: list[str] = []
241+
for item in result.content:
242+
if isinstance(item, TextContent):
243+
contents.append(item.text)
244+
else:
245+
contents.append(json.dumps(item))
246+
return contents
247+
248+
return None
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
"""
5+
6+
import asyncio
7+
from contextlib import asynccontextmanager
8+
from typing import Awaitable, Callable, Dict, Optional, Union
9+
10+
from mcp.client.sse import sse_client
11+
from mcp.client.streamable_http import streamablehttp_client
12+
13+
ValueOrFactory = Union[str, Callable[[], Union[str, Awaitable[str]]]]
14+
15+
16+
@asynccontextmanager
17+
async def create_streamable_http_transport(
18+
url: str,
19+
headers: Optional[Dict[str, ValueOrFactory]] = None,
20+
):
21+
"""Create a streamable HTTP transport for MCP communication."""
22+
resolved_headers: Dict[str, str] = {}
23+
if headers:
24+
for key, value in headers.items():
25+
if callable(value):
26+
resolved_value = value()
27+
if asyncio.iscoroutine(resolved_value):
28+
resolved_value = await resolved_value
29+
resolved_headers[key] = str(resolved_value)
30+
else:
31+
resolved_headers[key] = str(value)
32+
33+
async with streamablehttp_client(url, headers=resolved_headers) as (read_stream, write_stream, _):
34+
yield read_stream, write_stream
35+
36+
37+
@asynccontextmanager
38+
async def create_sse_transport(
39+
url: str,
40+
headers: Optional[Dict[str, ValueOrFactory]] = None,
41+
):
42+
"""Create an SSE transport for MCP communication."""
43+
resolved_headers: Dict[str, str] = {}
44+
if headers:
45+
for key, value in headers.items():
46+
if callable(value):
47+
resolved_value = value()
48+
if asyncio.iscoroutine(resolved_value):
49+
resolved_value = await resolved_value
50+
resolved_headers[key] = str(resolved_value)
51+
else:
52+
resolved_headers[key] = str(value)
53+
54+
async with sse_client(url, headers=resolved_headers) as (read_stream, write_stream):
55+
yield read_stream, write_stream
56+
57+
58+
def create_transport(
59+
url: str,
60+
transport_type: str = "streamable_http",
61+
headers: Optional[Dict[str, ValueOrFactory]] = None,
62+
):
63+
"""Create the appropriate transport based on transport type."""
64+
if transport_type == "streamable_http":
65+
return create_streamable_http_transport(url, headers)
66+
elif transport_type == "sse":
67+
return create_sse_transport(url, headers)
68+
else:
69+
raise ValueError(f"Unsupported transport type: {transport_type}")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"microsoft-teams-graph" = { workspace = true }
88
"microsoft-teams-ai" = { workspace = true }
99
"microsoft-teams-openai" = { workspace = true }
10+
"microsoft-teams-mcp" = { workspace = true }
1011

1112
[tool.uv.workspace]
1213
members = ["packages/*", "tests/*"]

pyrightconfig.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"packages/graph/src",
1414
"packages/devtools/src",
1515
"packages/ai/src",
16-
"packages/openai/src"
16+
"packages/openai/src",
17+
"packages/mcp/src"
1718
],
1819
"typeCheckingMode": "strict",
1920
"executionEnvironments": [

tests/mcp/README.md

Whitespace-only changes.

0 commit comments

Comments
 (0)