Skip to content

Commit 5b7cdce

Browse files
committed
Add AI Plugins
1 parent 22eff9c commit 5b7cdce

File tree

2 files changed

+175
-8
lines changed

2 files changed

+175
-8
lines changed

packages/ai/src/microsoft/teams/ai/chat_prompt.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55

66
import inspect
77
from dataclasses import dataclass
8+
from inspect import isawaitable
89
from typing import Any, Awaitable, Callable, TypeVar
910

1011
from pydantic import BaseModel
1112

1213
from .ai_model import AIModel
13-
from .function import Function
14+
from .function import Function, FunctionHandler
1415
from .memory import Memory
15-
from .message import Message, ModelMessage, UserMessage
16+
from .message import Message, ModelMessage, SystemMessage, UserMessage
17+
from .plugin import AIPluginProtocol
1618

1719
T = TypeVar("T", bound=BaseModel)
1820

@@ -23,24 +25,74 @@ class ChatSendResult:
2325

2426

2527
class ChatPrompt:
26-
def __init__(self, model: AIModel, *, functions: list[Function[Any]] | None = None):
28+
def __init__(
29+
self,
30+
model: AIModel,
31+
*,
32+
functions: list[Function[Any]] | None = None,
33+
plugins: list[AIPluginProtocol] | None = None,
34+
):
2735
self.model = model
2836
self.functions: dict[str, Function[Any]] = {func.name: func for func in functions} if functions else {}
37+
self.plugins: list[AIPluginProtocol] = plugins or []
2938

3039
def with_function(self, function: Function[T]) -> "ChatPrompt":
3140
self.functions[function.name] = function
3241
return self
3342

43+
def with_plugin(self, plugin: AIPluginProtocol) -> "ChatPrompt":
44+
"""Add a plugin to the chat prompt."""
45+
self.plugins.append(plugin)
46+
return self
47+
3448
async def send(
3549
self,
3650
input: str | Message,
3751
*,
3852
memory: Memory | None = None,
3953
on_chunk: Callable[[str], Awaitable[None]] | Callable[[str], None] | None = None,
54+
system_message: SystemMessage | None = None,
4055
) -> ChatSendResult:
4156
if isinstance(input, str):
4257
input = UserMessage(content=input)
4358

59+
# Allow plugins to modify the input before sending
60+
current_input = input
61+
for plugin in self.plugins:
62+
plugin_result = await plugin.on_before_send(current_input)
63+
if plugin_result is not None:
64+
current_input = plugin_result
65+
66+
# Allow plugins to modify the system message
67+
current_system_message = system_message
68+
for plugin in self.plugins:
69+
plugin_result = await plugin.on_build_system_message(current_system_message)
70+
if plugin_result is not None:
71+
current_system_message = plugin_result
72+
73+
# Wrap functions with plugin hooks
74+
wrapped_functions: dict[str, Function[BaseModel]] | None = None
75+
if self.functions:
76+
wrapped_functions = {}
77+
for name, func in self.functions.items():
78+
wrapped_functions[name] = Function[BaseModel](
79+
name=func.name,
80+
description=func.description,
81+
parameter_schema=func.parameter_schema,
82+
handler=self._wrap_function_handler(func.handler, name),
83+
)
84+
85+
# Allow plugins to modify the functions before sending to model
86+
if wrapped_functions:
87+
functions_list = list(wrapped_functions.values())
88+
for plugin in self.plugins:
89+
plugin_result = await plugin.on_build_functions(functions_list)
90+
if plugin_result is not None:
91+
functions_list = plugin_result
92+
93+
# Convert back to dict for model
94+
wrapped_functions = {func.name: func for func in functions_list}
95+
4496
async def on_chunk_fn(chunk: str):
4597
if not on_chunk:
4698
return
@@ -49,10 +101,40 @@ async def on_chunk_fn(chunk: str):
49101
await res
50102

51103
response = await self.model.generate_text(
52-
input,
53-
memory=memory,
54-
functions=self.functions if self.functions else None,
55-
on_chunk=on_chunk_fn if on_chunk else None,
104+
current_input, memory=memory, functions=wrapped_functions, on_chunk=on_chunk_fn if on_chunk else None
56105
)
57106

58-
return ChatSendResult(response=response)
107+
# Allow plugins to modify the response after receiving
108+
current_response = response
109+
for plugin in self.plugins:
110+
plugin_result = await plugin.on_after_send(current_response)
111+
if plugin_result is not None:
112+
current_response = plugin_result
113+
114+
return ChatSendResult(response=current_response)
115+
116+
def _wrap_function_handler(
117+
self, original_handler: FunctionHandler[BaseModel], function_name: str
118+
) -> FunctionHandler[BaseModel]:
119+
"""Wrap a function handler with plugin before/after hooks."""
120+
121+
async def wrapped_handler(params: BaseModel) -> str:
122+
# Run before function call hooks
123+
for plugin in self.plugins:
124+
await plugin.on_before_function_call(function_name, params)
125+
126+
# Call the original function (could be sync or async)
127+
result = original_handler(params)
128+
if isawaitable(result):
129+
result = await result
130+
131+
# Run after function call hooks
132+
current_result = result
133+
for plugin in self.plugins:
134+
plugin_result = await plugin.on_after_function_call(function_name, params, current_result)
135+
if plugin_result is not None:
136+
current_result = plugin_result
137+
138+
return current_result
139+
140+
return wrapped_handler
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
Copyright (c) Microsoft Corporation. All rights reserved.
3+
Licensed under the MIT License.
4+
"""
5+
6+
from abc import abstractmethod
7+
from typing import Protocol, TypeVar, runtime_checkable
8+
9+
from pydantic import BaseModel
10+
11+
from .function import Function
12+
from .message import Message, ModelMessage, SystemMessage
13+
14+
T = TypeVar("T")
15+
16+
17+
@runtime_checkable
18+
class AIPluginProtocol(Protocol):
19+
"""Protocol defining the interface for AI plugins."""
20+
21+
@property
22+
@abstractmethod
23+
def name(self) -> str:
24+
"""Unique name of the plugin."""
25+
...
26+
27+
async def on_before_send(self, input: Message) -> Message | None:
28+
"""Modify input before sending to model."""
29+
...
30+
31+
async def on_after_send(self, response: ModelMessage) -> ModelMessage | None:
32+
"""Modify response after receiving from model."""
33+
...
34+
35+
async def on_before_function_call(self, function_name: str, args: BaseModel) -> None:
36+
"""Called before a function is executed."""
37+
...
38+
39+
async def on_after_function_call(self, function_name: str, args: BaseModel, result: str) -> str | None:
40+
"""Called after a function is executed."""
41+
...
42+
43+
async def on_build_functions(self, functions: list[Function[BaseModel]]) -> list[Function[BaseModel]] | None:
44+
"""Modify the functions array passed to the model."""
45+
...
46+
47+
async def on_build_system_message(self, system_message: SystemMessage | None) -> SystemMessage | None:
48+
"""Modify the system message before sending to model."""
49+
...
50+
51+
52+
class BaseAIPlugin:
53+
"""Base implementation of AIPlugin with no-op methods."""
54+
55+
def __init__(self, name: str):
56+
self._name = name
57+
58+
@property
59+
def name(self) -> str:
60+
"""Unique name of the plugin."""
61+
return self._name
62+
63+
async def on_before_send(self, input: Message) -> Message | None:
64+
"""Modify input before sending to model."""
65+
return input
66+
67+
async def on_after_send(self, response: ModelMessage) -> ModelMessage | None:
68+
"""Modify response after receiving from model."""
69+
return response
70+
71+
async def on_before_function_call(self, function_name: str, args: BaseModel) -> None:
72+
"""Called before a function is executed."""
73+
pass
74+
75+
async def on_after_function_call(self, function_name: str, args: BaseModel, result: str) -> str | None:
76+
"""Called after a function is executed."""
77+
return result
78+
79+
async def on_build_functions(self, functions: list[Function[BaseModel]]) -> list[Function[BaseModel]] | None:
80+
"""Modify the functions array passed to the model."""
81+
return functions
82+
83+
async def on_build_system_message(self, system_message: SystemMessage | None) -> SystemMessage | None:
84+
"""Modify the system message before sending to model."""
85+
return system_message

0 commit comments

Comments
 (0)