5
5
6
6
import inspect
7
7
from dataclasses import dataclass
8
+ from inspect import isawaitable
8
9
from typing import Any , Awaitable , Callable , TypeVar
9
10
10
11
from pydantic import BaseModel
11
12
12
13
from .ai_model import AIModel
13
- from .function import Function
14
+ from .function import Function , FunctionHandler
14
15
from .memory import Memory
15
- from .message import Message , ModelMessage , UserMessage
16
+ from .message import Message , ModelMessage , SystemMessage , UserMessage
17
+ from .plugin import AIPluginProtocol
16
18
17
19
T = TypeVar ("T" , bound = BaseModel )
18
20
@@ -23,24 +25,74 @@ class ChatSendResult:
23
25
24
26
25
27
class ChatPrompt :
26
- def __init__ (self , model : AIModel , * , functions : list [Function [Any ]] | None = None ):
28
+ def __init__ (
29
+ self ,
30
+ model : AIModel ,
31
+ * ,
32
+ functions : list [Function [Any ]] | None = None ,
33
+ plugins : list [AIPluginProtocol ] | None = None ,
34
+ ):
27
35
self .model = model
28
36
self .functions : dict [str , Function [Any ]] = {func .name : func for func in functions } if functions else {}
37
+ self .plugins : list [AIPluginProtocol ] = plugins or []
29
38
30
39
def with_function (self , function : Function [T ]) -> "ChatPrompt" :
31
40
self .functions [function .name ] = function
32
41
return self
33
42
43
+ def with_plugin (self , plugin : AIPluginProtocol ) -> "ChatPrompt" :
44
+ """Add a plugin to the chat prompt."""
45
+ self .plugins .append (plugin )
46
+ return self
47
+
34
48
async def send (
35
49
self ,
36
50
input : str | Message ,
37
51
* ,
38
52
memory : Memory | None = None ,
39
53
on_chunk : Callable [[str ], Awaitable [None ]] | Callable [[str ], None ] | None = None ,
54
+ system_message : SystemMessage | None = None ,
40
55
) -> ChatSendResult :
41
56
if isinstance (input , str ):
42
57
input = UserMessage (content = input )
43
58
59
+ # Allow plugins to modify the input before sending
60
+ current_input = input
61
+ for plugin in self .plugins :
62
+ plugin_result = await plugin .on_before_send (current_input )
63
+ if plugin_result is not None :
64
+ current_input = plugin_result
65
+
66
+ # Allow plugins to modify the system message
67
+ current_system_message = system_message
68
+ for plugin in self .plugins :
69
+ plugin_result = await plugin .on_build_system_message (current_system_message )
70
+ if plugin_result is not None :
71
+ current_system_message = plugin_result
72
+
73
+ # Wrap functions with plugin hooks
74
+ wrapped_functions : dict [str , Function [BaseModel ]] | None = None
75
+ if self .functions :
76
+ wrapped_functions = {}
77
+ for name , func in self .functions .items ():
78
+ wrapped_functions [name ] = Function [BaseModel ](
79
+ name = func .name ,
80
+ description = func .description ,
81
+ parameter_schema = func .parameter_schema ,
82
+ handler = self ._wrap_function_handler (func .handler , name ),
83
+ )
84
+
85
+ # Allow plugins to modify the functions before sending to model
86
+ if wrapped_functions :
87
+ functions_list = list (wrapped_functions .values ())
88
+ for plugin in self .plugins :
89
+ plugin_result = await plugin .on_build_functions (functions_list )
90
+ if plugin_result is not None :
91
+ functions_list = plugin_result
92
+
93
+ # Convert back to dict for model
94
+ wrapped_functions = {func .name : func for func in functions_list }
95
+
44
96
async def on_chunk_fn (chunk : str ):
45
97
if not on_chunk :
46
98
return
@@ -49,10 +101,40 @@ async def on_chunk_fn(chunk: str):
49
101
await res
50
102
51
103
response = await self .model .generate_text (
52
- input ,
53
- memory = memory ,
54
- functions = self .functions if self .functions else None ,
55
- on_chunk = on_chunk_fn if on_chunk else None ,
104
+ current_input , memory = memory , functions = wrapped_functions , on_chunk = on_chunk_fn if on_chunk else None
56
105
)
57
106
58
- return ChatSendResult (response = response )
107
+ # Allow plugins to modify the response after receiving
108
+ current_response = response
109
+ for plugin in self .plugins :
110
+ plugin_result = await plugin .on_after_send (current_response )
111
+ if plugin_result is not None :
112
+ current_response = plugin_result
113
+
114
+ return ChatSendResult (response = current_response )
115
+
116
+ def _wrap_function_handler (
117
+ self , original_handler : FunctionHandler [BaseModel ], function_name : str
118
+ ) -> FunctionHandler [BaseModel ]:
119
+ """Wrap a function handler with plugin before/after hooks."""
120
+
121
+ async def wrapped_handler (params : BaseModel ) -> str :
122
+ # Run before function call hooks
123
+ for plugin in self .plugins :
124
+ await plugin .on_before_function_call (function_name , params )
125
+
126
+ # Call the original function (could be sync or async)
127
+ result = original_handler (params )
128
+ if isawaitable (result ):
129
+ result = await result
130
+
131
+ # Run after function call hooks
132
+ current_result = result
133
+ for plugin in self .plugins :
134
+ plugin_result = await plugin .on_after_function_call (function_name , params , current_result )
135
+ if plugin_result is not None :
136
+ current_result = plugin_result
137
+
138
+ return current_result
139
+
140
+ return wrapped_handler
0 commit comments