|
14 | 14 | FunctionCall,
|
15 | 15 | ListMemory,
|
16 | 16 | Memory,
|
| 17 | + Message, |
17 | 18 | ModelMessage,
|
18 | 19 | SystemMessage,
|
19 | 20 | UserMessage,
|
20 | 21 | )
|
| 22 | +from microsoft.teams.ai.plugin import BaseAIPlugin |
21 | 23 | from pydantic import BaseModel
|
22 | 24 |
|
| 25 | +# pyright: basic |
| 26 | + |
23 | 27 |
|
24 | 28 | class MockFunctionParams(BaseModel):
|
25 | 29 | value: str
|
@@ -61,6 +65,11 @@ async def generate_text(
|
61 | 65 | try:
|
62 | 66 | params = function.parameter_schema(value="test_input")
|
63 | 67 | result = function.handler(params)
|
| 68 | + # Handle both sync and async results |
| 69 | + from inspect import isawaitable |
| 70 | + |
| 71 | + if isawaitable(result): |
| 72 | + result = await result |
64 | 73 | # In real implementation, function result would be added to memory
|
65 | 74 | # and conversation would continue recursively
|
66 | 75 | content += f" | Function result: {result}"
|
@@ -277,3 +286,284 @@ async def test_different_message_types(self, mock_model: MockAIModel) -> None:
|
277 | 286 | model_msg = ModelMessage(content="Model message", function_calls=None)
|
278 | 287 | result3 = await prompt.send(model_msg)
|
279 | 288 | assert result3.response.content == "GENERATED - Model message"
|
| 289 | + |
| 290 | + |
| 291 | +class MockPlugin(BaseAIPlugin): |
| 292 | + """Mock plugin for testing that tracks all hook calls""" |
| 293 | + |
| 294 | + def __init__(self, name: str): |
| 295 | + super().__init__(name) |
| 296 | + self.before_send_called = False |
| 297 | + self.after_send_called = False |
| 298 | + self.before_function_called: list[tuple[str, BaseModel]] = [] |
| 299 | + self.after_function_called: list[tuple[str, BaseModel, str]] = [] |
| 300 | + self.build_functions_called = False |
| 301 | + self.build_system_message_called = False |
| 302 | + self.input_modifications: list[str] = [] |
| 303 | + self.response_modifications: list[str] = [] |
| 304 | + self.function_result_modifications: list[str] = [] |
| 305 | + |
| 306 | + async def on_before_send(self, input: Message) -> Message | None: |
| 307 | + self.before_send_called = True |
| 308 | + if self.input_modifications: |
| 309 | + modification = self.input_modifications.pop(0) |
| 310 | + if isinstance(input, UserMessage): |
| 311 | + return UserMessage(content=f"{modification}: {input.content}") |
| 312 | + elif isinstance(input, ModelMessage): |
| 313 | + return ModelMessage( |
| 314 | + content=f"{modification}: {input.content}" if input.content else modification, |
| 315 | + function_calls=input.function_calls, |
| 316 | + ) |
| 317 | + return input |
| 318 | + |
| 319 | + async def on_after_send(self, response: ModelMessage) -> ModelMessage | None: |
| 320 | + self.after_send_called = True |
| 321 | + if self.response_modifications: |
| 322 | + modification = self.response_modifications.pop(0) |
| 323 | + return ModelMessage( |
| 324 | + content=f"{modification}: {response.content}" if response.content else modification, |
| 325 | + function_calls=response.function_calls, |
| 326 | + ) |
| 327 | + return response |
| 328 | + |
| 329 | + async def on_before_function_call(self, function_name: str, args: BaseModel) -> None: |
| 330 | + self.before_function_called.append((function_name, args)) |
| 331 | + |
| 332 | + async def on_after_function_call(self, function_name: str, args: BaseModel, result: str) -> str | None: |
| 333 | + self.after_function_called.append((function_name, args, result)) |
| 334 | + if self.function_result_modifications: |
| 335 | + modification = self.function_result_modifications.pop(0) |
| 336 | + return f"{modification}: {result}" |
| 337 | + return result |
| 338 | + |
| 339 | + async def on_build_functions(self, functions: list[Function[BaseModel]]) -> list[Function[BaseModel]] | None: |
| 340 | + self.build_functions_called = True |
| 341 | + return functions |
| 342 | + |
| 343 | + async def on_build_system_message(self, system_message: SystemMessage | None) -> SystemMessage | None: |
| 344 | + self.build_system_message_called = True |
| 345 | + if system_message is None: |
| 346 | + return SystemMessage(content="Plugin-generated system message") |
| 347 | + return SystemMessage(content=f"Plugin-modified: {system_message.content}") |
| 348 | + |
| 349 | + |
| 350 | +class TestChatPromptPlugins: |
| 351 | + """Test suite for plugin functionality in ChatPrompt""" |
| 352 | + |
| 353 | + def test_plugin_initialization_and_registration(self, mock_model: MockAIModel) -> None: |
| 354 | + """Test plugin initialization and registration""" |
| 355 | + plugin1 = MockPlugin("plugin1") |
| 356 | + plugin2 = MockPlugin("plugin2") |
| 357 | + |
| 358 | + # Test initialization with plugins |
| 359 | + prompt = ChatPrompt(mock_model, plugins=[plugin1]) |
| 360 | + assert len(prompt.plugins) == 1 |
| 361 | + assert prompt.plugins[0] is plugin1 |
| 362 | + |
| 363 | + # Test with_plugin method |
| 364 | + result = prompt.with_plugin(plugin2) |
| 365 | + assert result is prompt # Should return self for chaining |
| 366 | + assert len(prompt.plugins) == 2 |
| 367 | + assert prompt.plugins[1] is plugin2 |
| 368 | + |
| 369 | + @pytest.mark.asyncio |
| 370 | + async def test_on_before_send_hook(self, mock_model: MockAIModel) -> None: |
| 371 | + """Test that on_before_send hook can modify input messages""" |
| 372 | + plugin = MockPlugin("test_plugin") |
| 373 | + plugin.input_modifications = ["MODIFIED"] |
| 374 | + |
| 375 | + prompt = ChatPrompt(mock_model, plugins=[plugin]) |
| 376 | + result = await prompt.send("Original message") |
| 377 | + |
| 378 | + assert plugin.before_send_called |
| 379 | + assert result.response.content == "GENERATED - MODIFIED: Original message" |
| 380 | + |
| 381 | + @pytest.mark.asyncio |
| 382 | + async def test_on_after_send_hook(self, mock_model: MockAIModel) -> None: |
| 383 | + """Test that on_after_send hook can modify response messages""" |
| 384 | + plugin = MockPlugin("test_plugin") |
| 385 | + plugin.response_modifications = ["RESPONSE_MODIFIED"] |
| 386 | + |
| 387 | + prompt = ChatPrompt(mock_model, plugins=[plugin]) |
| 388 | + result = await prompt.send("Test message") |
| 389 | + |
| 390 | + assert plugin.after_send_called |
| 391 | + assert result.response.content == "RESPONSE_MODIFIED: GENERATED - Test message" |
| 392 | + |
| 393 | + @pytest.mark.asyncio |
| 394 | + async def test_on_build_system_message_hook(self, mock_model: MockAIModel) -> None: |
| 395 | + """Test that on_build_system_message hook is called""" |
| 396 | + plugin = MockPlugin("test_plugin") |
| 397 | + |
| 398 | + prompt = ChatPrompt(mock_model, plugins=[plugin]) |
| 399 | + |
| 400 | + # Test with None system message |
| 401 | + await prompt.send("Test", system_message=None) |
| 402 | + assert plugin.build_system_message_called |
| 403 | + |
| 404 | + # Reset and test with existing system message |
| 405 | + plugin.build_system_message_called = False |
| 406 | + system_msg = SystemMessage(content="Original system") |
| 407 | + await prompt.send("Test", system_message=system_msg) |
| 408 | + assert plugin.build_system_message_called |
| 409 | + |
| 410 | + @pytest.mark.asyncio |
| 411 | + async def test_function_call_hooks(self, mock_function_handler: Mock) -> None: |
| 412 | + """Test that function call hooks are properly executed""" |
| 413 | + plugin = MockPlugin("test_plugin") |
| 414 | + plugin.function_result_modifications = ["FUNCTION_MODIFIED"] |
| 415 | + |
| 416 | + # Create a mock model that will call functions |
| 417 | + mock_model = MockAIModel(should_call_function=True) |
| 418 | + |
| 419 | + # Create function with mock handler |
| 420 | + test_function = Function( |
| 421 | + name="test_function", |
| 422 | + description="A test function", |
| 423 | + parameter_schema=MockFunctionParams, |
| 424 | + handler=mock_function_handler, |
| 425 | + ) |
| 426 | + |
| 427 | + prompt = ChatPrompt(mock_model, functions=[test_function], plugins=[plugin]) |
| 428 | + result = await prompt.send("Call the function") |
| 429 | + |
| 430 | + # Verify before hook was called |
| 431 | + assert len(plugin.before_function_called) == 1 |
| 432 | + assert plugin.before_function_called[0][0] == "test_function" |
| 433 | + assert isinstance(plugin.before_function_called[0][1], MockFunctionParams) |
| 434 | + |
| 435 | + # Verify after hook was called and modified result |
| 436 | + assert len(plugin.after_function_called) == 1 |
| 437 | + assert plugin.after_function_called[0][0] == "test_function" |
| 438 | + assert result.response.content is not None |
| 439 | + assert "FUNCTION_MODIFIED: Function executed successfully" in result.response.content |
| 440 | + |
| 441 | + @pytest.mark.asyncio |
| 442 | + async def test_on_build_functions_hook( |
| 443 | + self, mock_model: MockAIModel, test_function: Function[MockFunctionParams] |
| 444 | + ) -> None: |
| 445 | + """Test that on_build_functions hook is called when functions are present""" |
| 446 | + plugin = MockPlugin("test_plugin") |
| 447 | + |
| 448 | + prompt = ChatPrompt(mock_model, functions=[test_function], plugins=[plugin]) |
| 449 | + await prompt.send("Test message") |
| 450 | + |
| 451 | + assert plugin.build_functions_called |
| 452 | + |
| 453 | + @pytest.mark.asyncio |
| 454 | + async def test_multiple_plugins_execution_order(self, mock_model: MockAIModel) -> None: |
| 455 | + """Test that multiple plugins execute in correct order""" |
| 456 | + plugin1 = MockPlugin("plugin1") |
| 457 | + plugin1.input_modifications = ["FIRST"] |
| 458 | + plugin1.response_modifications = ["FIRST_RESP"] |
| 459 | + |
| 460 | + plugin2 = MockPlugin("plugin2") |
| 461 | + plugin2.input_modifications = ["SECOND"] |
| 462 | + plugin2.response_modifications = ["SECOND_RESP"] |
| 463 | + |
| 464 | + prompt = ChatPrompt(mock_model, plugins=[plugin1, plugin2]) |
| 465 | + result = await prompt.send("Original") |
| 466 | + |
| 467 | + # Both plugins should be called |
| 468 | + assert plugin1.before_send_called |
| 469 | + assert plugin2.before_send_called |
| 470 | + assert plugin1.after_send_called |
| 471 | + assert plugin2.after_send_called |
| 472 | + |
| 473 | + # Input should be modified by both plugins in order |
| 474 | + assert result.response.content == "SECOND_RESP: FIRST_RESP: GENERATED - SECOND: FIRST: Original" |
| 475 | + |
| 476 | + @pytest.mark.asyncio |
| 477 | + async def test_plugin_returns_none_preserves_original(self, mock_model: MockAIModel) -> None: |
| 478 | + """Test that when plugin returns None, original values are preserved""" |
| 479 | + |
| 480 | + class NoOpPlugin(BaseAIPlugin): |
| 481 | + def __init__(self): |
| 482 | + super().__init__("noop") |
| 483 | + |
| 484 | + async def on_before_send(self, input: Message) -> Message | None: |
| 485 | + return None # Return None to preserve original |
| 486 | + |
| 487 | + async def on_after_send(self, response: ModelMessage) -> ModelMessage | None: |
| 488 | + return None # Return None to preserve original |
| 489 | + |
| 490 | + plugin = NoOpPlugin() |
| 491 | + prompt = ChatPrompt(mock_model, plugins=[plugin]) |
| 492 | + result = await prompt.send("Test message") |
| 493 | + |
| 494 | + # Should be unchanged since plugin returned None |
| 495 | + assert result.response.content == "GENERATED - Test message" |
| 496 | + |
| 497 | + @pytest.mark.asyncio |
| 498 | + async def test_empty_plugin_list_maintains_compatibility(self, mock_model: MockAIModel) -> None: |
| 499 | + """Test that ChatPrompt with no plugins behaves identically to original implementation""" |
| 500 | + prompt_with_plugins = ChatPrompt(mock_model, plugins=[]) |
| 501 | + prompt_without_plugins = ChatPrompt(mock_model) |
| 502 | + |
| 503 | + result_with = await prompt_with_plugins.send("Test message") |
| 504 | + result_without = await prompt_without_plugins.send("Test message") |
| 505 | + |
| 506 | + assert result_with.response.content == result_without.response.content |
| 507 | + |
| 508 | + @pytest.mark.asyncio |
| 509 | + async def test_plugin_with_async_function_handler(self, mock_function_handler: Mock) -> None: |
| 510 | + """Test plugin hooks work correctly with async function handlers""" |
| 511 | + plugin = MockPlugin("async_test") |
| 512 | + plugin.function_result_modifications = ["ASYNC_MODIFIED"] |
| 513 | + |
| 514 | + # Use the existing mock function handler (it's already set up correctly) |
| 515 | + mock_model = MockAIModel(should_call_function=True) |
| 516 | + |
| 517 | + test_function = Function( |
| 518 | + name="test_function", # Use same name as MockAIModel expects |
| 519 | + description="A test function", |
| 520 | + parameter_schema=MockFunctionParams, |
| 521 | + handler=mock_function_handler, |
| 522 | + ) |
| 523 | + |
| 524 | + prompt = ChatPrompt(mock_model, functions=[test_function], plugins=[plugin]) |
| 525 | + result = await prompt.send("Call the function") |
| 526 | + |
| 527 | + # Verify function was called and result was modified by plugin |
| 528 | + assert len(plugin.before_function_called) == 1 |
| 529 | + assert len(plugin.after_function_called) == 1 |
| 530 | + assert result.response.content is not None |
| 531 | + assert "ASYNC_MODIFIED: Function executed successfully" in result.response.content |
| 532 | + |
| 533 | + @pytest.mark.asyncio |
| 534 | + async def test_plugin_error_handling(self, mock_model: MockAIModel) -> None: |
| 535 | + """Test that plugin errors don't break the chat flow""" |
| 536 | + |
| 537 | + class FaultyPlugin(BaseAIPlugin): |
| 538 | + def __init__(self): |
| 539 | + super().__init__("faulty") |
| 540 | + |
| 541 | + async def on_before_send(self, input: Message) -> Message | None: |
| 542 | + raise ValueError("Plugin error") |
| 543 | + |
| 544 | + plugin = FaultyPlugin() |
| 545 | + prompt = ChatPrompt(mock_model, plugins=[plugin]) |
| 546 | + |
| 547 | + # Plugin error should propagate (this is expected behavior) |
| 548 | + with pytest.raises(ValueError, match="Plugin error"): |
| 549 | + await prompt.send("Test message") |
| 550 | + |
| 551 | + @pytest.mark.asyncio |
| 552 | + async def test_base_plugin_default_implementations(self, mock_model: MockAIModel) -> None: |
| 553 | + """Test that BaseAIPlugin provides working default implementations""" |
| 554 | + base_plugin = BaseAIPlugin("base") |
| 555 | + prompt = ChatPrompt(mock_model, plugins=[base_plugin]) |
| 556 | + |
| 557 | + # Should work without any issues using default implementations |
| 558 | + result = await prompt.send("Test with base plugin") |
| 559 | + assert result.response.content == "GENERATED - Test with base plugin" |
| 560 | + |
| 561 | + # Test with functions too |
| 562 | + def handler(params: MockFunctionParams) -> str: |
| 563 | + return "Base plugin test" |
| 564 | + |
| 565 | + test_function = Function("test", "test", MockFunctionParams, handler) |
| 566 | + prompt_with_func = ChatPrompt(mock_model, functions=[test_function], plugins=[base_plugin]) |
| 567 | + |
| 568 | + result2 = await prompt_with_func.send("Test with function") |
| 569 | + assert result2.response.content == "GENERATED - Test with function" |
0 commit comments