Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Globalization;
using System.Text;
using System.Threading.Tasks;
using Azure.Identity;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.Configuration;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using SemanticKernel.IntegrationTests.TestSettings;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.AzureOpenAI;

public sealed class AzureOpenAIChatClientNoneFunctionChoiceBehaviorTests : BaseIntegrationTest
{
private readonly Kernel _kernel;
private readonly FakeFunctionFilter _autoFunctionInvocationFilter;
private readonly IChatClient _chatClient;

public AzureOpenAIChatClientNoneFunctionChoiceBehaviorTests()
{
this._autoFunctionInvocationFilter = new FakeFunctionFilter();

this._kernel = this.InitializeKernel();
this._kernel.AutoFunctionInvocationFilters.Add(this._autoFunctionInvocationFilter);
this._chatClient = this._kernel.GetRequiredService<IChatClient>();
}

[Fact]
public async Task SpecifiedInCodeInstructsConnectorNotToInvokeKernelFunctionAsync()
{
// Arrange
var plugin = this._kernel.CreatePluginFromType<DateTimeUtils>();
this._kernel.Plugins.Add(plugin);

var invokedFunctions = new List<string>();

this._autoFunctionInvocationFilter.RegisterFunctionInvocationHandler(async (context, next) =>
{
invokedFunctions.Add(context.Function.Name);
await next(context);
});

// Act
var settings = new AzureOpenAIPromptExecutionSettings() { FunctionChoiceBehavior = FunctionChoiceBehavior.None() };
var chatOptions = settings.ToChatOptions(this._kernel);

var messages = new List<ChatMessage>
{
new(ChatRole.User, "How many days until Christmas?")
};

var response = await this._chatClient.GetResponseAsync(messages, chatOptions);

// Assert
Assert.NotNull(response);

Assert.Empty(invokedFunctions);
}

[Fact]
public async Task SpecifiedInPromptInstructsConnectorNotToInvokeKernelFunctionAsync()
{
// Arrange
this._kernel.ImportPluginFromType<DateTimeUtils>();

var invokedFunctions = new List<string>();

this._autoFunctionInvocationFilter.RegisterFunctionInvocationHandler(async (context, next) =>
{
invokedFunctions.Add(context.Function.Name);
await next(context);
});

var promptTemplate = """"
template_format: semantic-kernel
template: How many days until Christmas?
execution_settings:
default:
temperature: 0.1
function_choice_behavior:
type: none
"""";

var promptFunction = KernelFunctionYaml.FromPromptYaml(promptTemplate);

// Act
var result = await this._kernel.InvokeAsync(promptFunction);

// Assert
Assert.NotNull(result);

Assert.Empty(invokedFunctions);
}

[Fact]
public async Task SpecifiedInCodeInstructsConnectorNotToInvokeKernelFunctionForStreamingAsync()
{
// Arrange
var plugin = this._kernel.CreatePluginFromType<DateTimeUtils>();
this._kernel.Plugins.Add(plugin);

var invokedFunctions = new List<string>();

this._autoFunctionInvocationFilter.RegisterFunctionInvocationHandler(async (context, next) =>
{
invokedFunctions.Add(context.Function.Name);
await next(context);
});

var settings = new AzureOpenAIPromptExecutionSettings { FunctionChoiceBehavior = FunctionChoiceBehavior.None() };
var chatOptions = settings.ToChatOptions(this._kernel);

var messages = new List<ChatMessage>
{
new(ChatRole.User, "How many days until Christmas?")
};

StringBuilder result = new();

// Act
await foreach (var update in this._chatClient.GetStreamingResponseAsync(messages, chatOptions))
{
foreach (var content in update.Contents)
{
if (content is Microsoft.Extensions.AI.TextContent textContent)
{
result.Append(textContent.Text);
}
}
}

// Assert
Assert.NotNull(result);

Assert.Empty(invokedFunctions);
}

[Fact]
public async Task SpecifiedInPromptInstructsConnectorNotToInvokeKernelFunctionForStreamingAsync()
{
// Arrange
this._kernel.ImportPluginFromType<DateTimeUtils>();

var invokedFunctions = new List<string>();

this._autoFunctionInvocationFilter.RegisterFunctionInvocationHandler(async (context, next) =>
{
invokedFunctions.Add(context.Function.Name);
await next(context);
});

var promptTemplate = """"
template_format: semantic-kernel
template: How many days until Christmas?
execution_settings:
default:
temperature: 0.1
function_choice_behavior:
type: none
"""";

var promptFunction = KernelFunctionYaml.FromPromptYaml(promptTemplate);

StringBuilder result = new();

// Act
await foreach (string update in promptFunction.InvokeStreamingAsync<string>(this._kernel))
{
result.Append(update);
}

// Assert
Assert.NotNull(result);

Assert.Empty(invokedFunctions);
}

private Kernel InitializeKernel()
{
var azureOpenAIConfiguration = this._configuration.GetSection("AzureOpenAI").Get<AzureOpenAIConfiguration>();
Assert.NotNull(azureOpenAIConfiguration);
Assert.NotNull(azureOpenAIConfiguration.ChatDeploymentName);
Assert.NotNull(azureOpenAIConfiguration.Endpoint);

var kernelBuilder = base.CreateKernelBuilder();

kernelBuilder.AddAzureOpenAIChatClient(
deploymentName: azureOpenAIConfiguration.ChatDeploymentName,
modelId: azureOpenAIConfiguration.ChatModelId,
endpoint: azureOpenAIConfiguration.Endpoint,
credentials: new AzureCliCredential());

return kernelBuilder.Build();
}

private readonly IConfigurationRoot _configuration = new ConfigurationBuilder()
.AddJsonFile(path: "testsettings.json", optional: false, reloadOnChange: true)
.AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true)
.AddEnvironmentVariables()
.AddUserSecrets<AzureOpenAIChatClientNoneFunctionChoiceBehaviorTests>()
.Build();

/// <summary>
/// A plugin that returns the current time.
/// </summary>
#pragma warning disable CA1812 // Avoid uninstantiated internal classes
private sealed class DateTimeUtils
#pragma warning restore CA1812 // Avoid uninstantiated internal classes
{
[KernelFunction]
[Description("Retrieves the current date.")]
public string GetCurrentDate() => DateTime.UtcNow.ToString("d", CultureInfo.InvariantCulture);
}

#region private

private sealed class FakeFunctionFilter : IAutoFunctionInvocationFilter
{
private Func<AutoFunctionInvocationContext, Func<AutoFunctionInvocationContext, Task>, Task>? _onFunctionInvocation;

public void RegisterFunctionInvocationHandler(Func<AutoFunctionInvocationContext, Func<AutoFunctionInvocationContext, Task>, Task> onFunctionInvocation)
{
this._onFunctionInvocation = onFunctionInvocation;
}

public Task OnAutoFunctionInvocationAsync(AutoFunctionInvocationContext context, Func<AutoFunctionInvocationContext, Task> next)
{
if (this._onFunctionInvocation is null)
{
return next(context);
}

return this._onFunctionInvocation?.Invoke(context, next) ?? Task.CompletedTask;
}
}

#endregion
}
Loading
Loading