Skip to content

Commit 43ca269

Browse files
authored
Add AsIEmbeddingGenerator for Azure.AI.Inference ImageEmbeddingsClient (#6363)
1 parent 651c27c commit 43ca269

File tree

5 files changed

+294
-2
lines changed

5 files changed

+294
-2
lines changed

src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceEmbeddingGenerator.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ public AzureAIInferenceEmbeddingGenerator(
8989
public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(
9090
IEnumerable<string> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
9191
{
92+
_ = Throw.IfNull(values);
93+
9294
var azureAIOptions = ToAzureAIOptions(values, options, EmbeddingEncodingFormat.Base64);
9395

9496
var embeddings = (await _embeddingsClient.EmbedAsync(azureAIOptions, cancellationToken).ConfigureAwait(false)).Value;
@@ -118,7 +120,7 @@ void IDisposable.Dispose()
118120
// Nothing to dispose. Implementation required for the IEmbeddingGenerator interface.
119121
}
120122

121-
private static float[] ParseBase64Floats(BinaryData binaryData)
123+
internal static float[] ParseBase64Floats(BinaryData binaryData)
122124
{
123125
ReadOnlySpan<byte> base64 = binaryData.ToMemory().Span;
124126

@@ -161,7 +163,7 @@ static void ThrowInvalidData() =>
161163
throw new FormatException("The input is not a valid Base64 string of encoded floats.");
162164
}
163165

164-
/// <summary>Converts an extensions options instance to an OpenAI options instance.</summary>
166+
/// <summary>Converts an extensions options instance to an Azure.AI.Inference options instance.</summary>
165167
private EmbeddingsOptions ToAzureAIOptions(IEnumerable<string> inputs, EmbeddingGenerationOptions? options, EmbeddingEncodingFormat format)
166168
{
167169
EmbeddingsOptions result = new(inputs)

src/Libraries/Microsoft.Extensions.AI.AzureAIInference/AzureAIInferenceExtensions.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,13 @@ public static IChatClient AsIChatClient(
2424
public static IEmbeddingGenerator<string, Embedding<float>> AsIEmbeddingGenerator(
2525
this EmbeddingsClient embeddingsClient, string? defaultModelId = null, int? defaultModelDimensions = null) =>
2626
new AzureAIInferenceEmbeddingGenerator(embeddingsClient, defaultModelId, defaultModelDimensions);
27+
28+
/// <summary>Gets an <see cref="IEmbeddingGenerator{DataContent, Single}"/> for use with this <see cref="EmbeddingsClient"/>.</summary>
29+
/// <param name="imageEmbeddingsClient">The client.</param>
30+
/// <param name="defaultModelId">The ID of the model to use. If <see langword="null"/>, it can be provided per request via <see cref="ChatOptions.ModelId"/>.</param>
31+
/// <param name="defaultModelDimensions">The number of dimensions generated in each embedding.</param>
32+
/// <returns>An <see cref="IEmbeddingGenerator{DataContent, Embedding}"/> that can be used to generate embeddings via the <see cref="ImageEmbeddingsClient"/>.</returns>
33+
public static IEmbeddingGenerator<DataContent, Embedding<float>> AsIEmbeddingGenerator(
34+
this ImageEmbeddingsClient imageEmbeddingsClient, string? defaultModelId = null, int? defaultModelDimensions = null) =>
35+
new AzureAIInferenceImageEmbeddingGenerator(imageEmbeddingsClient, defaultModelId, defaultModelDimensions);
2736
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
using System.Reflection;
8+
using System.Text.Json;
9+
using System.Threading;
10+
using System.Threading.Tasks;
11+
using Azure.AI.Inference;
12+
using Microsoft.Shared.Diagnostics;
13+
14+
#pragma warning disable EA0002 // Use 'System.TimeProvider' to make the code easier to test
15+
#pragma warning disable S3011 // Reflection should not be used to increase accessibility of classes, methods, or fields
16+
#pragma warning disable S109 // Magic numbers should not be used
17+
18+
namespace Microsoft.Extensions.AI;
19+
20+
/// <summary>Represents an <see cref="IEmbeddingGenerator{DataContent, Embedding}"/> for an Azure.AI.Inference <see cref="ImageEmbeddingsClient"/>.</summary>
21+
internal sealed class AzureAIInferenceImageEmbeddingGenerator :
22+
IEmbeddingGenerator<DataContent, Embedding<float>>
23+
{
24+
/// <summary>Metadata about the embedding generator.</summary>
25+
private readonly EmbeddingGeneratorMetadata _metadata;
26+
27+
/// <summary>The underlying <see cref="ImageEmbeddingsClient" />.</summary>
28+
private readonly ImageEmbeddingsClient _imageEmbeddingsClient;
29+
30+
/// <summary>The number of dimensions produced by the generator.</summary>
31+
private readonly int? _dimensions;
32+
33+
/// <summary>Initializes a new instance of the <see cref="AzureAIInferenceImageEmbeddingGenerator"/> class.</summary>
34+
/// <param name="imageEmbeddingsClient">The underlying client.</param>
35+
/// <param name="defaultModelId">
36+
/// The ID of the model to use. This can also be overridden per request via <see cref="EmbeddingGenerationOptions.ModelId"/>.
37+
/// Either this parameter or <see cref="EmbeddingGenerationOptions.ModelId"/> must provide a valid model ID.
38+
/// </param>
39+
/// <param name="defaultModelDimensions">The number of dimensions to generate in each embedding.</param>
40+
/// <exception cref="ArgumentNullException"><paramref name="imageEmbeddingsClient"/> is <see langword="null"/>.</exception>
41+
/// <exception cref="ArgumentException"><paramref name="defaultModelId"/> is empty or composed entirely of whitespace.</exception>
42+
/// <exception cref="ArgumentOutOfRangeException"><paramref name="defaultModelDimensions"/> is not positive.</exception>
43+
public AzureAIInferenceImageEmbeddingGenerator(
44+
ImageEmbeddingsClient imageEmbeddingsClient, string? defaultModelId = null, int? defaultModelDimensions = null)
45+
{
46+
_ = Throw.IfNull(imageEmbeddingsClient);
47+
48+
if (defaultModelId is not null)
49+
{
50+
_ = Throw.IfNullOrWhitespace(defaultModelId);
51+
}
52+
53+
if (defaultModelDimensions is < 1)
54+
{
55+
Throw.ArgumentOutOfRangeException(nameof(defaultModelDimensions), "Value must be greater than 0.");
56+
}
57+
58+
_imageEmbeddingsClient = imageEmbeddingsClient;
59+
_dimensions = defaultModelDimensions;
60+
61+
// https://github.com/Azure/azure-sdk-for-net/issues/46278
62+
// The endpoint isn't currently exposed, so use reflection to get at it, temporarily. Once packages
63+
// implement the abstractions directly rather than providing adapters on top of the public APIs,
64+
// the package can provide such implementations separate from what's exposed in the public API.
65+
var providerUrl = typeof(ImageEmbeddingsClient).GetField("_endpoint", BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
66+
?.GetValue(imageEmbeddingsClient) as Uri;
67+
68+
_metadata = new EmbeddingGeneratorMetadata("az.ai.inference", providerUrl, defaultModelId, defaultModelDimensions);
69+
}
70+
71+
/// <inheritdoc />
72+
object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey)
73+
{
74+
_ = Throw.IfNull(serviceType);
75+
76+
return
77+
serviceKey is not null ? null :
78+
serviceType == typeof(ImageEmbeddingsClient) ? _imageEmbeddingsClient :
79+
serviceType == typeof(EmbeddingGeneratorMetadata) ? _metadata :
80+
serviceType.IsInstanceOfType(this) ? this :
81+
null;
82+
}
83+
84+
/// <inheritdoc />
85+
public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(
86+
IEnumerable<DataContent> values, EmbeddingGenerationOptions? options = null, CancellationToken cancellationToken = default)
87+
{
88+
_ = Throw.IfNull(values);
89+
90+
var azureAIOptions = ToAzureAIOptions(values, options, EmbeddingEncodingFormat.Base64);
91+
92+
var embeddings = (await _imageEmbeddingsClient.EmbedAsync(azureAIOptions, cancellationToken).ConfigureAwait(false)).Value;
93+
94+
GeneratedEmbeddings<Embedding<float>> result = new(embeddings.Data.Select(e =>
95+
new Embedding<float>(AzureAIInferenceEmbeddingGenerator.ParseBase64Floats(e.Embedding))
96+
{
97+
CreatedAt = DateTimeOffset.UtcNow,
98+
ModelId = embeddings.Model ?? azureAIOptions.Model,
99+
}));
100+
101+
if (embeddings.Usage is not null)
102+
{
103+
result.Usage = new()
104+
{
105+
InputTokenCount = embeddings.Usage.PromptTokens,
106+
TotalTokenCount = embeddings.Usage.TotalTokens
107+
};
108+
}
109+
110+
return result;
111+
}
112+
113+
/// <inheritdoc />
114+
void IDisposable.Dispose()
115+
{
116+
// Nothing to dispose. Implementation required for the IEmbeddingGenerator interface.
117+
}
118+
119+
/// <summary>Converts an extensions options instance to an Azure.AI.Inference options instance.</summary>
120+
private ImageEmbeddingsOptions ToAzureAIOptions(IEnumerable<DataContent> inputs, EmbeddingGenerationOptions? options, EmbeddingEncodingFormat format)
121+
{
122+
ImageEmbeddingsOptions result = new(inputs.Select(dc => new ImageEmbeddingInput(dc.Uri)))
123+
{
124+
Dimensions = options?.Dimensions ?? _dimensions,
125+
Model = options?.ModelId ?? _metadata.DefaultModelId,
126+
EncodingFormat = format,
127+
};
128+
129+
if (options?.AdditionalProperties is { } props)
130+
{
131+
foreach (var prop in props)
132+
{
133+
if (prop.Value is not null)
134+
{
135+
byte[] data = JsonSerializer.SerializeToUtf8Bytes(prop.Value, AIJsonUtilities.DefaultOptions.GetTypeInfo(typeof(object)));
136+
result.AdditionalProperties[prop.Key] = new BinaryData(data);
137+
}
138+
}
139+
}
140+
141+
return result;
142+
}
143+
}

0 commit comments

Comments
 (0)