Skip to content

Commit ef8aa2a

Browse files
authored
.Net: Remove SqlDataReaderDictionary in SQL Server provider (#12241)
We currently write records to the database by first serializing them to a Dictionary, and then writing that Dictionary out as SqlParameters on the command. Similarly, we read records by allocating and populating a Dictionary, only to then iterate over that Dictionary and populate the .NET instance from it. This PR does away with the intermediate Dictionary representation. This optimizes serialization by removing the double-copy, reduces allocations, opens the way for boxing-free serialization, and also just simplifies the code (less conversions and stuff to follow).
1 parent afb1e15 commit ef8aa2a

File tree

8 files changed

+180
-239
lines changed

8 files changed

+180
-239
lines changed

dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlDataReaderDictionary.cs

Lines changed: 0 additions & 141 deletions
This file was deleted.

dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCollection.cs

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -259,20 +259,18 @@ public override async Task DeleteAsync(IEnumerable<TKey> keys, CancellationToken
259259
key,
260260
includeVectors);
261261

262-
using SqlDataReader reader = await connection.ExecuteWithErrorHandlingAsync(
262+
return await connection.ExecuteWithErrorHandlingAsync(
263263
this._collectionMetadata,
264264
operationName: "Get",
265265
async () =>
266266
{
267-
SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
267+
using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false);
268268
await reader.ReadAsync(cancellationToken).ConfigureAwait(false);
269-
return reader;
269+
return reader.HasRows
270+
? this._mapper.MapFromStorageToDataModel(reader, includeVectors)
271+
: null;
270272
},
271273
cancellationToken).ConfigureAwait(false);
272-
273-
return reader.HasRows
274-
? this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, this._model.VectorProperties), includeVectors)
275-
: default;
276274
}
277275

278276
/// <inheritdoc/>
@@ -320,12 +318,22 @@ public override async IAsyncEnumerable<TRecord> GetAsync(IEnumerable<TKey> keys,
320318
() => command.ExecuteReaderAsync(cancellationToken),
321319
cancellationToken).ConfigureAwait(false);
322320

323-
while (await reader.ReadWithErrorHandlingAsync(
324-
this._collectionMetadata,
325-
"GetBatch",
326-
cancellationToken).ConfigureAwait(false))
321+
while (true)
327322
{
328-
yield return this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, this._model.VectorProperties), includeVectors);
323+
TRecord? record = await VectorStoreErrorHandler.RunOperationAsync<TRecord?, SqlException>(
324+
this._collectionMetadata,
325+
"GetBatch",
326+
async () => await reader.ReadAsync(cancellationToken).ConfigureAwait(false)
327+
? this._mapper.MapFromStorageToDataModel(reader, includeVectors)
328+
: null)
329+
.ConfigureAwait(false);
330+
331+
if (record is null)
332+
{
333+
break;
334+
}
335+
336+
yield return record;
329337
}
330338
} while (command.Parameters.Count == SqlServerConstants.MaxParameterCount);
331339
}
@@ -335,7 +343,7 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
335343
{
336344
Verify.NotNull(record);
337345

338-
IReadOnlyList<Embedding>?[]? generatedEmbeddings = null;
346+
Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>? generatedEmbeddings = null;
339347

340348
var vectorPropertyCount = this._model.VectorProperties.Count;
341349
for (var i = 0; i < vectorPropertyCount; i++)
@@ -354,8 +362,8 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
354362
// and generate embeddings for them in a single batch. That's some more complexity though.
355363
if (vectorProperty.TryGenerateEmbedding<TRecord, Embedding<float>>(record, cancellationToken, out var floatTask))
356364
{
357-
generatedEmbeddings ??= new IReadOnlyList<Embedding>?[vectorPropertyCount];
358-
generatedEmbeddings[i] = [await floatTask.ConfigureAwait(false)];
365+
generatedEmbeddings ??= new Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>(vectorPropertyCount);
366+
generatedEmbeddings[vectorProperty] = [await floatTask.ConfigureAwait(false)];
359367
}
360368
else
361369
{
@@ -370,7 +378,8 @@ public override async Task UpsertAsync(TRecord record, CancellationToken cancell
370378
this._schema,
371379
this.Name,
372380
this._model,
373-
this._mapper.MapFromDataToStorageModel(record, recordIndex: 0, generatedEmbeddings));
381+
record,
382+
generatedEmbeddings);
374383

375384
await connection.ExecuteWithErrorHandlingAsync(
376385
this._collectionMetadata,
@@ -393,7 +402,7 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
393402
IReadOnlyList<TRecord>? recordsList = null;
394403

395404
// If an embedding generator is defined, invoke it once per property for all records.
396-
IReadOnlyList<Embedding>?[]? generatedEmbeddings = null;
405+
Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>? generatedEmbeddings = null;
397406

398407
var vectorPropertyCount = this._model.VectorProperties.Count;
399408
for (var i = 0; i < vectorPropertyCount; i++)
@@ -426,8 +435,8 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
426435
// and generate embeddings for them in a single batch. That's some more complexity though.
427436
if (vectorProperty.TryGenerateEmbeddings<TRecord, Embedding<float>>(records, cancellationToken, out var floatTask))
428437
{
429-
generatedEmbeddings ??= new IReadOnlyList<Embedding>?[vectorPropertyCount];
430-
generatedEmbeddings[i] = (IReadOnlyList<Embedding<float>>)await floatTask.ConfigureAwait(false);
438+
generatedEmbeddings ??= new Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>(vectorPropertyCount);
439+
generatedEmbeddings[vectorProperty] = (IReadOnlyList<Embedding<float>>)await floatTask.ConfigureAwait(false);
431440
}
432441
else
433442
{
@@ -459,9 +468,9 @@ public override async Task UpsertAsync(IEnumerable<TRecord> records, Cancellatio
459468
this._schema,
460469
this.Name,
461470
this._model,
462-
records.Skip(taken)
463-
.Take(SqlServerConstants.MaxParameterCount / parametersPerRecord)
464-
.Select((r, i) => this._mapper.MapFromDataToStorageModel(r, taken + i, generatedEmbeddings))))
471+
records.Skip(taken).Take(SqlServerConstants.MaxParameterCount / parametersPerRecord),
472+
firstRecordIndex: taken,
473+
generatedEmbeddings))
465474
{
466475
break; // records is empty
467476
}
@@ -613,7 +622,7 @@ private async IAsyncEnumerable<VectorSearchResult<TRecord>> ReadVectorSearchResu
613622
}
614623

615624
yield return new VectorSearchResult<TRecord>(
616-
this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, vectorProperties), includeVectors),
625+
this._mapper.MapFromStorageToDataModel(reader, includeVectors),
617626
reader.GetDouble(scoreIndex));
618627
}
619628
}
@@ -655,7 +664,7 @@ public override async IAsyncEnumerable<TRecord> GetAsync(Expression<Func<TRecord
655664
operationName: "GetAsync",
656665
cancellationToken).ConfigureAwait(false))
657666
{
658-
yield return this._mapper.MapFromStorageToDataModel(new SqlDataReaderDictionary(reader, vectorProperties), options.IncludeVectors);
667+
yield return this._mapper.MapFromStorageToDataModel(reader, options.IncludeVectors);
659668
}
660669
}
661670
}

dotnet/src/Connectors/Connectors.Memory.SqlServer/SqlServerCommandBuilder.cs

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Text;
77
using System.Text.Json;
88
using Microsoft.Data.SqlClient;
9+
using Microsoft.Extensions.AI;
910
using Microsoft.Extensions.VectorData;
1011
using Microsoft.Extensions.VectorData.ProviderServices;
1112

@@ -122,7 +123,8 @@ internal static SqlCommand MergeIntoSingle(
122123
string? schema,
123124
string tableName,
124125
CollectionModel model,
125-
IDictionary<string, object?> record)
126+
object record,
127+
Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>? generatedEmbeddings)
126128
{
127129
SqlCommand command = connection.CreateCommand();
128130
StringBuilder sb = new(200);
@@ -134,8 +136,13 @@ internal static SqlCommand MergeIntoSingle(
134136

135137
foreach (var property in model.Properties)
136138
{
137-
sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(',');
138-
command.AddParameter(property, paramName, record[property.StorageName]);
139+
sb.AppendParameterName(property, ref paramIndex, out var paramName).Append(',');
140+
141+
var value = property is VectorPropertyModel vectorProperty && generatedEmbeddings?.TryGetValue(vectorProperty, out var ge) == true
142+
? ge[0]
143+
: property.GetValueAsObject(record);
144+
145+
command.AddParameter(property, paramName, value);
139146
}
140147

141148
sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis
@@ -174,7 +181,9 @@ internal static bool MergeIntoMany(
174181
string? schema,
175182
string tableName,
176183
CollectionModel model,
177-
IEnumerable<IDictionary<string, object?>> records)
184+
IEnumerable<object> records,
185+
int firstRecordIndex,
186+
Dictionary<VectorPropertyModel, IReadOnlyList<Embedding>>? generatedEmbeddings)
178187
{
179188
StringBuilder sb = new(200);
180189
// The DECLARE statement creates a table variable to store the keys of the inserted rows.
@@ -189,11 +198,18 @@ internal static bool MergeIntoMany(
189198
foreach (var record in records)
190199
{
191200
sb.Append('(');
201+
192202
foreach (var property in model.Properties)
193203
{
194-
sb.AppendParameterName(property, ref paramIndex, out string paramName).Append(',');
195-
command.AddParameter(property, paramName, record[property.StorageName]);
204+
sb.AppendParameterName(property, ref paramIndex, out var paramName).Append(',');
205+
206+
var value = property is VectorPropertyModel vectorProperty && generatedEmbeddings?.TryGetValue(vectorProperty, out var ge) == true
207+
? ge[firstRecordIndex + rowIndex]
208+
: property.GetValueAsObject(record);
209+
210+
command.AddParameter(property, paramName, value);
196211
}
212+
197213
sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis
198214
sb.AppendLine(",");
199215
rowIndex++;
@@ -585,14 +601,19 @@ private static void AddParameter(this SqlCommand command, PropertyModel? propert
585601
command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = DBNull.Value;
586602
break;
587603
case null:
588-
case ReadOnlyMemory<float> vector when vector.Length == 0:
589604
command.Parameters.AddWithValue(name, DBNull.Value);
590605
break;
591606
case byte[] buffer:
592607
command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = buffer;
593608
break;
594609
case ReadOnlyMemory<float> vector:
595-
command.Parameters.AddWithValue(name, JsonSerializer.Serialize(vector));
610+
command.Parameters.AddWithValue(name, JsonSerializer.Serialize(vector, SqlServerJsonSerializerContext.Default.ReadOnlyMemorySingle));
611+
break;
612+
case Embedding<float> { Vector: var vector }:
613+
command.Parameters.AddWithValue(name, JsonSerializer.Serialize(vector, SqlServerJsonSerializerContext.Default.ReadOnlyMemorySingle));
614+
break;
615+
case float[] vectorArray:
616+
command.Parameters.AddWithValue(name, JsonSerializer.Serialize(vectorArray, SqlServerJsonSerializerContext.Default.SingleArray));
596617
break;
597618
default:
598619
command.Parameters.AddWithValue(name, value);
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
3+
using System;
4+
using System.Text.Json.Serialization;
5+
6+
namespace Microsoft.SemanticKernel.Connectors.SqlServer;
7+
8+
// Note: this is temporary - SQL Server will switch away from using JSON arrays to represent embeddings in the future.
9+
[JsonSerializable(typeof(float[]))]
10+
[JsonSerializable(typeof(ReadOnlyMemory<float>))]
11+
internal partial class SqlServerJsonSerializerContext : JsonSerializerContext;

0 commit comments

Comments
 (0)