Skip to content

Commit 954fb5b

Browse files
some refinement
1 parent 6985ad0 commit 954fb5b

File tree

5 files changed

+50
-22
lines changed

5 files changed

+50
-22
lines changed

python/semantic_kernel/agents/agent.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from abc import ABC, abstractmethod
66
from collections.abc import AsyncIterable, Iterable
7-
from typing import Annotated, Any, ClassVar
7+
from typing import TYPE_CHECKING, Annotated, Any, ClassVar
88

99
from pydantic import Field, model_validator
1010

@@ -14,7 +14,6 @@
1414
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
1515
from semantic_kernel.functions.kernel_arguments import KernelArguments
1616
from semantic_kernel.functions.kernel_function_decorator import kernel_function
17-
from semantic_kernel.functions.kernel_plugin import KernelPlugin
1817
from semantic_kernel.kernel import Kernel
1918
from semantic_kernel.kernel_pydantic import KernelBaseModel
2019
from semantic_kernel.prompt_template.kernel_prompt_template import KernelPromptTemplate
@@ -23,6 +22,9 @@
2322
from semantic_kernel.utils.naming import generate_random_ascii_name
2423
from semantic_kernel.utils.validation import AGENT_NAME_REGEX
2524

25+
if TYPE_CHECKING:
26+
pass
27+
2628
logger: logging.Logger = logging.getLogger(__name__)
2729

2830

@@ -55,11 +57,6 @@ class Agent(KernelBaseModel, ABC):
5557
name: str = Field(default_factory=lambda: f"agent_{generate_random_ascii_name()}", pattern=AGENT_NAME_REGEX)
5658
prompt_template: PromptTemplateBase | None = None
5759

58-
@staticmethod
59-
def _get_plugin_name(plugin: KernelPlugin | object) -> str:
60-
"""Helper method to get the plugin name."""
61-
return getattr(plugin, "name", plugin.__class__.__name__)
62-
6360
@model_validator(mode="before")
6461
@classmethod
6562
def _configure_plugins(cls, data: Any) -> Any:
@@ -69,23 +66,22 @@ def _configure_plugins(cls, data: Any) -> Any:
6966
if not kernel:
7067
kernel = Kernel()
7168
for plugin in plugins:
72-
name = Agent._get_plugin_name(plugin)
73-
kernel.add_plugin(plugin, plugin_name=name)
69+
kernel.add_plugin(plugin)
7470
data["kernel"] = kernel
7571
return data
7672

7773
def model_post_init(self, __context: Any) -> None:
7874
"""Post initialization."""
7975

8076
@kernel_function(name=self.name, description=self.description)
81-
async def _invoke_agent_as_function_inner(
77+
async def _as_function(
8278
task: Annotated[str, "The task to perform."],
8379
) -> Annotated[str, "The response from the agent."]:
8480
history = ChatHistory()
8581
history.add_user_message(task)
8682
return (await self.get_response(history=history)).content
8783

88-
setattr(self, "_as_function", _invoke_agent_as_function_inner)
84+
setattr(self, "_as_function", _as_function)
8985

9086
@abstractmethod
9187
async def get_response(self, *args, **kwargs) -> ChatMessageContent:

python/semantic_kernel/functions/kernel_function_extension.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def add_plugin(
8585
self.plugins[plugin.name] = plugin
8686
return self.plugins[plugin.name]
8787
if not plugin_name:
88-
plugin_name = plugin.name if hasattr(plugin, "name") else plugin.__class__.__name__
88+
plugin_name = getattr(plugin, "name", plugin.__class__.__name__)
8989
if not isinstance(plugin_name, str):
9090
raise TypeError("plugin_name must be a string.")
9191
if plugin:

python/tests/unit/agents/test_agent.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import pytest
99

10+
from semantic_kernel.kernel import Kernel
11+
1012
if sys.version_info >= (3, 12):
1113
from typing import override # pragma: no cover
1214
else:
@@ -33,7 +35,7 @@ class MockAgent(Agent):
3335

3436
channel_type: ClassVar[type[AgentChannel]] = MockChannel
3537

36-
def __init__(self, name: str = "Test-Agent", description: str = "A test agent", id: str = None):
38+
def __init__(self, name: str = "TestAgent", description: str = "A test agent", id: str = None):
3739
args = {
3840
"name": name,
3941
"description": description,
@@ -171,3 +173,21 @@ def test_merge_arguments_both_not_none():
171173

172174
assert merged["param1"] == "baseVal", "Should retain base param from agent"
173175
assert merged["param2"] == "override_param", "Should include param from override"
176+
177+
178+
def test_function_from_agent():
179+
agent = MockAgent()
180+
assert hasattr(agent, "_as_function")
181+
func = agent._as_function
182+
assert hasattr(func, "__kernel_function__")
183+
assert func.__kernel_function_description__ == agent.description
184+
assert func.__kernel_function_name__ == agent.name
185+
assert len(func.__kernel_function_parameters__) == 1
186+
187+
188+
def test_add_agent_as_plugin(kernel: Kernel):
189+
agent = MockAgent()
190+
kernel.add_plugin(agent)
191+
assert len(kernel.plugins) == 1
192+
assert len(kernel.plugins[agent.name].functions) == 1
193+
assert kernel.plugins[agent.name].functions[agent.name].parameters[0].name == "task"

python/tests/unit/kernel/test_kernel.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
import os
4+
from dataclasses import dataclass
45
from pathlib import Path
56
from typing import Union
67
from unittest.mock import AsyncMock, MagicMock, patch
@@ -479,9 +480,18 @@ def test_plugin_no_plugin(kernel: Kernel):
479480
kernel.add_plugin(plugin_name="test")
480481

481482

482-
def test_plugin_name_error(kernel: Kernel):
483-
with pytest.raises(ValueError):
484-
kernel.add_plugin(" ", None)
483+
def test_plugin_name_from_class_name(kernel: Kernel):
484+
kernel.add_plugin(" ", None)
485+
assert "str" in kernel.plugins
486+
487+
488+
def test_plugin_name_from_name_attribute(kernel: Kernel):
489+
@dataclass
490+
class TestPlugin:
491+
name: str = "test_plugin"
492+
493+
kernel.add_plugin(TestPlugin(), None)
494+
assert "test_plugin" in kernel.plugins
485495

486496

487497
def test_plugin_name_not_string_error(kernel: Kernel):

python/uv.lock

Lines changed: 8 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)