4
4
"""
5
5
6
6
from dataclasses import dataclass
7
+ from inspect import isawaitable
7
8
from typing import Any , Awaitable , Callable , TypeVar
8
9
9
10
from pydantic import BaseModel
10
11
11
12
from .ai_model import AIModel
12
- from .function import Function
13
+ from .function import Function , FunctionHandler
13
14
from .memory import Memory
14
- from .message import Message , ModelMessage , UserMessage
15
+ from .message import Message , ModelMessage , SystemMessage , UserMessage
16
+ from .plugin import AIPluginProtocol
15
17
16
18
T = TypeVar ("T" , bound = BaseModel )
17
19
@@ -22,26 +24,109 @@ class ChatSendResult:
22
24
23
25
24
26
class ChatPrompt :
25
- def __init__ (self , model : AIModel , * , functions : list [Function [Any ]] | None = None ):
27
+ def __init__ (
28
+ self ,
29
+ model : AIModel ,
30
+ * ,
31
+ functions : list [Function [Any ]] | None = None ,
32
+ plugins : list [AIPluginProtocol ] | None = None ,
33
+ ):
26
34
self .model = model
27
35
self .functions : dict [str , Function [Any ]] = {func .name : func for func in functions } if functions else {}
36
+ self .plugins : list [AIPluginProtocol ] = plugins or []
28
37
29
38
def with_function (self , function : Function [T ]) -> "ChatPrompt" :
30
39
self .functions [function .name ] = function
31
40
return self
32
41
42
+ def with_plugin (self , plugin : AIPluginProtocol ) -> "ChatPrompt" :
43
+ """Add a plugin to the chat prompt."""
44
+ self .plugins .append (plugin )
45
+ return self
46
+
33
47
async def send (
34
48
self ,
35
49
input : str | Message ,
36
50
* ,
37
51
memory : Memory | None = None ,
52
+ system_message : SystemMessage | None = None ,
38
53
on_chunk : Callable [[str ], Awaitable [None ]] | None = None ,
39
54
) -> ChatSendResult :
40
55
if isinstance (input , str ):
41
56
input = UserMessage (content = input )
42
57
58
+ # Allow plugins to modify the input before sending
59
+ current_input = input
60
+ for plugin in self .plugins :
61
+ plugin_result = await plugin .on_before_send (current_input )
62
+ if plugin_result is not None :
63
+ current_input = plugin_result
64
+
65
+ # Allow plugins to modify the system message
66
+ current_system_message = system_message
67
+ for plugin in self .plugins :
68
+ plugin_result = await plugin .on_build_system_message (current_system_message )
69
+ if plugin_result is not None :
70
+ current_system_message = plugin_result
71
+
72
+ # Wrap functions with plugin hooks
73
+ wrapped_functions : dict [str , Function [BaseModel ]] | None = None
74
+ if self .functions :
75
+ wrapped_functions = {}
76
+ for name , func in self .functions .items ():
77
+ wrapped_functions [name ] = Function [BaseModel ](
78
+ name = func .name ,
79
+ description = func .description ,
80
+ parameter_schema = func .parameter_schema ,
81
+ handler = self ._wrap_function_handler (func .handler , name ),
82
+ )
83
+
84
+ # Allow plugins to modify the functions before sending to model
85
+ if wrapped_functions :
86
+ functions_list = list (wrapped_functions .values ())
87
+ for plugin in self .plugins :
88
+ plugin_result = await plugin .on_build_functions (functions_list )
89
+ if plugin_result is not None :
90
+ functions_list = plugin_result
91
+
92
+ # Convert back to dict for model
93
+ wrapped_functions = {func .name : func for func in functions_list }
94
+
43
95
response = await self .model .generate_text (
44
- input , memory = memory , functions = self . functions if self . functions else None , on_chunk = on_chunk
96
+ current_input , memory = memory , functions = wrapped_functions , on_chunk = on_chunk
45
97
)
46
98
47
- return ChatSendResult (response = response )
99
+ # Allow plugins to modify the response after receiving
100
+ current_response = response
101
+ for plugin in self .plugins :
102
+ plugin_result = await plugin .on_after_send (current_response )
103
+ if plugin_result is not None :
104
+ current_response = plugin_result
105
+
106
+ return ChatSendResult (response = current_response )
107
+
108
+ def _wrap_function_handler (
109
+ self , original_handler : FunctionHandler [BaseModel ], function_name : str
110
+ ) -> FunctionHandler [BaseModel ]:
111
+ """Wrap a function handler with plugin before/after hooks."""
112
+
113
+ async def wrapped_handler (params : BaseModel ) -> str :
114
+ # Run before function call hooks
115
+ for plugin in self .plugins :
116
+ await plugin .on_before_function_call (function_name , params )
117
+
118
+ # Call the original function (could be sync or async)
119
+ result = original_handler (params )
120
+ if isawaitable (result ):
121
+ result = await result
122
+
123
+ # Run after function call hooks
124
+ current_result = result
125
+ for plugin in self .plugins :
126
+ plugin_result = await plugin .on_after_function_call (function_name , params , current_result )
127
+ if plugin_result is not None :
128
+ current_result = plugin_result
129
+
130
+ return current_result
131
+
132
+ return wrapped_handler
0 commit comments