Skip to content

Commit 44c12e6

Browse files
committed
Add AI Plugins
1 parent ad8cd5c commit 44c12e6

File tree

2 files changed

+175
-5
lines changed

2 files changed

+175
-5
lines changed

packages/ai/src/microsoft/teams/ai/chat_prompt.py

Lines changed: 90 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
"""
55

66
from dataclasses import dataclass
7+
from inspect import isawaitable
78
from typing import Any, Awaitable, Callable, TypeVar
89

910
from pydantic import BaseModel
1011

1112
from .ai_model import AIModel
12-
from .function import Function
13+
from .function import Function, FunctionHandler
1314
from .memory import Memory
14-
from .message import Message, ModelMessage, UserMessage
15+
from .message import Message, ModelMessage, SystemMessage, UserMessage
16+
from .plugin import AIPluginProtocol
1517

1618
T = TypeVar("T", bound=BaseModel)
1719

@@ -22,26 +24,109 @@ class ChatSendResult:
2224

2325

2426
class ChatPrompt:
25-
def __init__(self, model: AIModel, *, functions: list[Function[Any]] | None = None):
27+
def __init__(
28+
self,
29+
model: AIModel,
30+
*,
31+
functions: list[Function[Any]] | None = None,
32+
plugins: list[AIPluginProtocol] | None = None,
33+
):
2634
self.model = model
2735
self.functions: dict[str, Function[Any]] = {func.name: func for func in functions} if functions else {}
36+
self.plugins: list[AIPluginProtocol] = plugins or []
2837

2938
def with_function(self, function: Function[T]) -> "ChatPrompt":
3039
self.functions[function.name] = function
3140
return self
3241

42+
def with_plugin(self, plugin: AIPluginProtocol) -> "ChatPrompt":
43+
"""Add a plugin to the chat prompt."""
44+
self.plugins.append(plugin)
45+
return self
46+
3347
async def send(
3448
self,
3549
input: str | Message,
3650
*,
3751
memory: Memory | None = None,
52+
system_message: SystemMessage | None = None,
3853
on_chunk: Callable[[str], Awaitable[None]] | None = None,
3954
) -> ChatSendResult:
4055
if isinstance(input, str):
4156
input = UserMessage(content=input)
4257

58+
# Allow plugins to modify the input before sending
59+
current_input = input
60+
for plugin in self.plugins:
61+
plugin_result = await plugin.on_before_send(current_input)
62+
if plugin_result is not None:
63+
current_input = plugin_result
64+
65+
# Allow plugins to modify the system message
66+
current_system_message = system_message
67+
for plugin in self.plugins:
68+
plugin_result = await plugin.on_build_system_message(current_system_message)
69+
if plugin_result is not None:
70+
current_system_message = plugin_result
71+
72+
# Wrap functions with plugin hooks
73+
wrapped_functions: dict[str, Function[BaseModel]] | None = None
74+
if self.functions:
75+
wrapped_functions = {}
76+
for name, func in self.functions.items():
77+
wrapped_functions[name] = Function[BaseModel](
78+
name=func.name,
79+
description=func.description,
80+
parameter_schema=func.parameter_schema,
81+
handler=self._wrap_function_handler(func.handler, name),
82+
)
83+
84+
# Allow plugins to modify the functions before sending to model
85+
if wrapped_functions:
86+
functions_list = list(wrapped_functions.values())
87+
for plugin in self.plugins:
88+
plugin_result = await plugin.on_build_functions(functions_list)
89+
if plugin_result is not None:
90+
functions_list = plugin_result
91+
92+
# Convert back to dict for model
93+
wrapped_functions = {func.name: func for func in functions_list}
94+
4395
response = await self.model.generate_text(
44-
input, memory=memory, functions=self.functions if self.functions else None, on_chunk=on_chunk
96+
current_input, memory=memory, functions=wrapped_functions, on_chunk=on_chunk
4597
)
4698

47-
return ChatSendResult(response=response)
99+
# Allow plugins to modify the response after receiving
100+
current_response = response
101+
for plugin in self.plugins:
102+
plugin_result = await plugin.on_after_send(current_response)
103+
if plugin_result is not None:
104+
current_response = plugin_result
105+
106+
return ChatSendResult(response=current_response)
107+
108+
def _wrap_function_handler(
109+
self, original_handler: FunctionHandler[BaseModel], function_name: str
110+
) -> FunctionHandler[BaseModel]:
111+
"""Wrap a function handler with plugin before/after hooks."""
112+
113+
async def wrapped_handler(params: BaseModel) -> str:
114+
# Run before function call hooks
115+
for plugin in self.plugins:
116+
await plugin.on_before_function_call(function_name, params)
117+
118+
# Call the original function (could be sync or async)
119+
result = original_handler(params)
120+
if isawaitable(result):
121+
result = await result
122+
123+
# Run after function call hooks
124+
current_result = result
125+
for plugin in self.plugins:
126+
plugin_result = await plugin.on_after_function_call(function_name, params, current_result)
127+
if plugin_result is not None:
128+
current_result = plugin_result
129+
130+
return current_result
131+
132+
return wrapped_handler
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
"""
5+
6+
from abc import abstractmethod
7+
from typing import Protocol, TypeVar, runtime_checkable
8+
9+
from pydantic import BaseModel
10+
11+
from .function import Function
12+
from .message import Message, ModelMessage, SystemMessage
13+
14+
T = TypeVar("T")
15+
16+
17+
@runtime_checkable
18+
class AIPluginProtocol(Protocol):
19+
"""Protocol defining the interface for AI plugins."""
20+
21+
@property
22+
@abstractmethod
23+
def name(self) -> str:
24+
"""Unique name of the plugin."""
25+
...
26+
27+
async def on_before_send(self, input: Message) -> Message | None:
28+
"""Modify input before sending to model."""
29+
...
30+
31+
async def on_after_send(self, response: ModelMessage) -> ModelMessage | None:
32+
"""Modify response after receiving from model."""
33+
...
34+
35+
async def on_before_function_call(self, function_name: str, args: BaseModel) -> None:
36+
"""Called before a function is executed."""
37+
...
38+
39+
async def on_after_function_call(self, function_name: str, args: BaseModel, result: str) -> str | None:
40+
"""Called after a function is executed."""
41+
...
42+
43+
async def on_build_functions(self, functions: list[Function[BaseModel]]) -> list[Function[BaseModel]] | None:
44+
"""Modify the functions array passed to the model."""
45+
...
46+
47+
async def on_build_system_message(self, system_message: SystemMessage | None) -> SystemMessage | None:
48+
"""Modify the system message before sending to model."""
49+
...
50+
51+
52+
class BaseAIPlugin:
53+
"""Base implementation of AIPlugin with no-op methods."""
54+
55+
def __init__(self, name: str):
56+
self._name = name
57+
58+
@property
59+
def name(self) -> str:
60+
"""Unique name of the plugin."""
61+
return self._name
62+
63+
async def on_before_send(self, input: Message) -> Message | None:
64+
"""Modify input before sending to model."""
65+
return input
66+
67+
async def on_after_send(self, response: ModelMessage) -> ModelMessage | None:
68+
"""Modify response after receiving from model."""
69+
return response
70+
71+
async def on_before_function_call(self, function_name: str, args: BaseModel) -> None:
72+
"""Called before a function is executed."""
73+
pass
74+
75+
async def on_after_function_call(self, function_name: str, args: BaseModel, result: str) -> str | None:
76+
"""Called after a function is executed."""
77+
return result
78+
79+
async def on_build_functions(self, functions: list[Function[BaseModel]]) -> list[Function[BaseModel]] | None:
80+
"""Modify the functions array passed to the model."""
81+
return functions
82+
83+
async def on_build_system_message(self, system_message: SystemMessage | None) -> SystemMessage | None:
84+
"""Modify the system message before sending to model."""
85+
return system_message

0 commit comments

Comments
 (0)