Skip to content

Commit dc07c98

Browse files
committed
unit tests
1 parent 68f6f46 commit dc07c98

File tree

2 files changed

+293
-2
lines changed

2 files changed

+293
-2
lines changed

packages/ai/src/microsoft/teams/ai/agent.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .chat_prompt import ChatPrompt, ChatSendResult
1010
from .function import Function
1111
from .memory import ListMemory, Memory
12-
from .message import Message
12+
from .message import Message, SystemMessage
1313

1414

1515
class Agent(ChatPrompt):
@@ -26,7 +26,8 @@ async def send(
2626
self,
2727
input: str | Message,
2828
*,
29+
system_message: SystemMessage | None = None,
2930
memory: Memory | None = None,
3031
on_chunk: Callable[[str], Awaitable[None]] | Callable[[str], None] | None = None,
3132
) -> ChatSendResult:
32-
return await super().send(input, memory=memory or self.memory, on_chunk=on_chunk)
33+
return await super().send(input, memory=memory or self.memory, system_message=system_message, on_chunk=on_chunk)

packages/ai/tests/test_chat_prompt.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
FunctionCall,
1515
ListMemory,
1616
Memory,
17+
Message,
1718
ModelMessage,
1819
SystemMessage,
1920
UserMessage,
2021
)
22+
from microsoft.teams.ai.plugin import BaseAIPlugin
2123
from pydantic import BaseModel
2224

25+
# pyright: basic
26+
2327

2428
class MockFunctionParams(BaseModel):
2529
value: str
@@ -61,6 +65,11 @@ async def generate_text(
6165
try:
6266
params = function.parameter_schema(value="test_input")
6367
result = function.handler(params)
68+
# Handle both sync and async results
69+
from inspect import isawaitable
70+
71+
if isawaitable(result):
72+
result = await result
6473
# In real implementation, function result would be added to memory
6574
# and conversation would continue recursively
6675
content += f" | Function result: {result}"
@@ -277,3 +286,284 @@ async def test_different_message_types(self, mock_model: MockAIModel) -> None:
277286
model_msg = ModelMessage(content="Model message", function_calls=None)
278287
result3 = await prompt.send(model_msg)
279288
assert result3.response.content == "GENERATED - Model message"
289+
290+
291+
class MockPlugin(BaseAIPlugin):
292+
"""Mock plugin for testing that tracks all hook calls"""
293+
294+
def __init__(self, name: str):
295+
super().__init__(name)
296+
self.before_send_called = False
297+
self.after_send_called = False
298+
self.before_function_called: list[tuple[str, BaseModel]] = []
299+
self.after_function_called: list[tuple[str, BaseModel, str]] = []
300+
self.build_functions_called = False
301+
self.build_system_message_called = False
302+
self.input_modifications: list[str] = []
303+
self.response_modifications: list[str] = []
304+
self.function_result_modifications: list[str] = []
305+
306+
async def on_before_send(self, input: Message) -> Message | None:
307+
self.before_send_called = True
308+
if self.input_modifications:
309+
modification = self.input_modifications.pop(0)
310+
if isinstance(input, UserMessage):
311+
return UserMessage(content=f"{modification}: {input.content}")
312+
elif isinstance(input, ModelMessage):
313+
return ModelMessage(
314+
content=f"{modification}: {input.content}" if input.content else modification,
315+
function_calls=input.function_calls,
316+
)
317+
return input
318+
319+
async def on_after_send(self, response: ModelMessage) -> ModelMessage | None:
320+
self.after_send_called = True
321+
if self.response_modifications:
322+
modification = self.response_modifications.pop(0)
323+
return ModelMessage(
324+
content=f"{modification}: {response.content}" if response.content else modification,
325+
function_calls=response.function_calls,
326+
)
327+
return response
328+
329+
async def on_before_function_call(self, function_name: str, args: BaseModel) -> None:
330+
self.before_function_called.append((function_name, args))
331+
332+
async def on_after_function_call(self, function_name: str, args: BaseModel, result: str) -> str | None:
333+
self.after_function_called.append((function_name, args, result))
334+
if self.function_result_modifications:
335+
modification = self.function_result_modifications.pop(0)
336+
return f"{modification}: {result}"
337+
return result
338+
339+
async def on_build_functions(self, functions: list[Function[BaseModel]]) -> list[Function[BaseModel]] | None:
340+
self.build_functions_called = True
341+
return functions
342+
343+
async def on_build_system_message(self, system_message: SystemMessage | None) -> SystemMessage | None:
344+
self.build_system_message_called = True
345+
if system_message is None:
346+
return SystemMessage(content="Plugin-generated system message")
347+
return SystemMessage(content=f"Plugin-modified: {system_message.content}")
348+
349+
350+
class TestChatPromptPlugins:
351+
"""Test suite for plugin functionality in ChatPrompt"""
352+
353+
def test_plugin_initialization_and_registration(self, mock_model: MockAIModel) -> None:
354+
"""Test plugin initialization and registration"""
355+
plugin1 = MockPlugin("plugin1")
356+
plugin2 = MockPlugin("plugin2")
357+
358+
# Test initialization with plugins
359+
prompt = ChatPrompt(mock_model, plugins=[plugin1])
360+
assert len(prompt.plugins) == 1
361+
assert prompt.plugins[0] is plugin1
362+
363+
# Test with_plugin method
364+
result = prompt.with_plugin(plugin2)
365+
assert result is prompt # Should return self for chaining
366+
assert len(prompt.plugins) == 2
367+
assert prompt.plugins[1] is plugin2
368+
369+
@pytest.mark.asyncio
370+
async def test_on_before_send_hook(self, mock_model: MockAIModel) -> None:
371+
"""Test that on_before_send hook can modify input messages"""
372+
plugin = MockPlugin("test_plugin")
373+
plugin.input_modifications = ["MODIFIED"]
374+
375+
prompt = ChatPrompt(mock_model, plugins=[plugin])
376+
result = await prompt.send("Original message")
377+
378+
assert plugin.before_send_called
379+
assert result.response.content == "GENERATED - MODIFIED: Original message"
380+
381+
@pytest.mark.asyncio
382+
async def test_on_after_send_hook(self, mock_model: MockAIModel) -> None:
383+
"""Test that on_after_send hook can modify response messages"""
384+
plugin = MockPlugin("test_plugin")
385+
plugin.response_modifications = ["RESPONSE_MODIFIED"]
386+
387+
prompt = ChatPrompt(mock_model, plugins=[plugin])
388+
result = await prompt.send("Test message")
389+
390+
assert plugin.after_send_called
391+
assert result.response.content == "RESPONSE_MODIFIED: GENERATED - Test message"
392+
393+
@pytest.mark.asyncio
394+
async def test_on_build_system_message_hook(self, mock_model: MockAIModel) -> None:
395+
"""Test that on_build_system_message hook is called"""
396+
plugin = MockPlugin("test_plugin")
397+
398+
prompt = ChatPrompt(mock_model, plugins=[plugin])
399+
400+
# Test with None system message
401+
await prompt.send("Test", system_message=None)
402+
assert plugin.build_system_message_called
403+
404+
# Reset and test with existing system message
405+
plugin.build_system_message_called = False
406+
system_msg = SystemMessage(content="Original system")
407+
await prompt.send("Test", system_message=system_msg)
408+
assert plugin.build_system_message_called
409+
410+
@pytest.mark.asyncio
411+
async def test_function_call_hooks(self, mock_function_handler: Mock) -> None:
412+
"""Test that function call hooks are properly executed"""
413+
plugin = MockPlugin("test_plugin")
414+
plugin.function_result_modifications = ["FUNCTION_MODIFIED"]
415+
416+
# Create a mock model that will call functions
417+
mock_model = MockAIModel(should_call_function=True)
418+
419+
# Create function with mock handler
420+
test_function = Function(
421+
name="test_function",
422+
description="A test function",
423+
parameter_schema=MockFunctionParams,
424+
handler=mock_function_handler,
425+
)
426+
427+
prompt = ChatPrompt(mock_model, functions=[test_function], plugins=[plugin])
428+
result = await prompt.send("Call the function")
429+
430+
# Verify before hook was called
431+
assert len(plugin.before_function_called) == 1
432+
assert plugin.before_function_called[0][0] == "test_function"
433+
assert isinstance(plugin.before_function_called[0][1], MockFunctionParams)
434+
435+
# Verify after hook was called and modified result
436+
assert len(plugin.after_function_called) == 1
437+
assert plugin.after_function_called[0][0] == "test_function"
438+
assert result.response.content is not None
439+
assert "FUNCTION_MODIFIED: Function executed successfully" in result.response.content
440+
441+
@pytest.mark.asyncio
442+
async def test_on_build_functions_hook(
443+
self, mock_model: MockAIModel, test_function: Function[MockFunctionParams]
444+
) -> None:
445+
"""Test that on_build_functions hook is called when functions are present"""
446+
plugin = MockPlugin("test_plugin")
447+
448+
prompt = ChatPrompt(mock_model, functions=[test_function], plugins=[plugin])
449+
await prompt.send("Test message")
450+
451+
assert plugin.build_functions_called
452+
453+
@pytest.mark.asyncio
454+
async def test_multiple_plugins_execution_order(self, mock_model: MockAIModel) -> None:
455+
"""Test that multiple plugins execute in correct order"""
456+
plugin1 = MockPlugin("plugin1")
457+
plugin1.input_modifications = ["FIRST"]
458+
plugin1.response_modifications = ["FIRST_RESP"]
459+
460+
plugin2 = MockPlugin("plugin2")
461+
plugin2.input_modifications = ["SECOND"]
462+
plugin2.response_modifications = ["SECOND_RESP"]
463+
464+
prompt = ChatPrompt(mock_model, plugins=[plugin1, plugin2])
465+
result = await prompt.send("Original")
466+
467+
# Both plugins should be called
468+
assert plugin1.before_send_called
469+
assert plugin2.before_send_called
470+
assert plugin1.after_send_called
471+
assert plugin2.after_send_called
472+
473+
# Input should be modified by both plugins in order
474+
assert result.response.content == "SECOND_RESP: FIRST_RESP: GENERATED - SECOND: FIRST: Original"
475+
476+
@pytest.mark.asyncio
477+
async def test_plugin_returns_none_preserves_original(self, mock_model: MockAIModel) -> None:
478+
"""Test that when plugin returns None, original values are preserved"""
479+
480+
class NoOpPlugin(BaseAIPlugin):
481+
def __init__(self):
482+
super().__init__("noop")
483+
484+
async def on_before_send(self, input: Message) -> Message | None:
485+
return None # Return None to preserve original
486+
487+
async def on_after_send(self, response: ModelMessage) -> ModelMessage | None:
488+
return None # Return None to preserve original
489+
490+
plugin = NoOpPlugin()
491+
prompt = ChatPrompt(mock_model, plugins=[plugin])
492+
result = await prompt.send("Test message")
493+
494+
# Should be unchanged since plugin returned None
495+
assert result.response.content == "GENERATED - Test message"
496+
497+
@pytest.mark.asyncio
498+
async def test_empty_plugin_list_maintains_compatibility(self, mock_model: MockAIModel) -> None:
499+
"""Test that ChatPrompt with no plugins behaves identically to original implementation"""
500+
prompt_with_plugins = ChatPrompt(mock_model, plugins=[])
501+
prompt_without_plugins = ChatPrompt(mock_model)
502+
503+
result_with = await prompt_with_plugins.send("Test message")
504+
result_without = await prompt_without_plugins.send("Test message")
505+
506+
assert result_with.response.content == result_without.response.content
507+
508+
@pytest.mark.asyncio
509+
async def test_plugin_with_async_function_handler(self, mock_function_handler: Mock) -> None:
510+
"""Test plugin hooks work correctly with async function handlers"""
511+
plugin = MockPlugin("async_test")
512+
plugin.function_result_modifications = ["ASYNC_MODIFIED"]
513+
514+
# Use the existing mock function handler (it's already set up correctly)
515+
mock_model = MockAIModel(should_call_function=True)
516+
517+
test_function = Function(
518+
name="test_function", # Use same name as MockAIModel expects
519+
description="A test function",
520+
parameter_schema=MockFunctionParams,
521+
handler=mock_function_handler,
522+
)
523+
524+
prompt = ChatPrompt(mock_model, functions=[test_function], plugins=[plugin])
525+
result = await prompt.send("Call the function")
526+
527+
# Verify function was called and result was modified by plugin
528+
assert len(plugin.before_function_called) == 1
529+
assert len(plugin.after_function_called) == 1
530+
assert result.response.content is not None
531+
assert "ASYNC_MODIFIED: Function executed successfully" in result.response.content
532+
533+
@pytest.mark.asyncio
534+
async def test_plugin_error_handling(self, mock_model: MockAIModel) -> None:
535+
"""Test that plugin errors don't break the chat flow"""
536+
537+
class FaultyPlugin(BaseAIPlugin):
538+
def __init__(self):
539+
super().__init__("faulty")
540+
541+
async def on_before_send(self, input: Message) -> Message | None:
542+
raise ValueError("Plugin error")
543+
544+
plugin = FaultyPlugin()
545+
prompt = ChatPrompt(mock_model, plugins=[plugin])
546+
547+
# Plugin error should propagate (this is expected behavior)
548+
with pytest.raises(ValueError, match="Plugin error"):
549+
await prompt.send("Test message")
550+
551+
@pytest.mark.asyncio
552+
async def test_base_plugin_default_implementations(self, mock_model: MockAIModel) -> None:
553+
"""Test that BaseAIPlugin provides working default implementations"""
554+
base_plugin = BaseAIPlugin("base")
555+
prompt = ChatPrompt(mock_model, plugins=[base_plugin])
556+
557+
# Should work without any issues using default implementations
558+
result = await prompt.send("Test with base plugin")
559+
assert result.response.content == "GENERATED - Test with base plugin"
560+
561+
# Test with functions too
562+
def handler(params: MockFunctionParams) -> str:
563+
return "Base plugin test"
564+
565+
test_function = Function("test", "test", MockFunctionParams, handler)
566+
prompt_with_func = ChatPrompt(mock_model, functions=[test_function], plugins=[base_plugin])
567+
568+
result2 = await prompt_with_func.send("Test with function")
569+
assert result2.response.content == "GENERATED - Test with function"

0 commit comments

Comments
 (0)