Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Linq.Expressions;
using System.Text.RegularExpressions;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.VectorData;
Expand All @@ -29,14 +30,25 @@ namespace Microsoft.SemanticKernel.Data;
[Experimental("SKEXP0130")]
[RequiresDynamicCode("This API is not compatible with NativeAOT.")]
[RequiresUnreferencedCode("This API is not compatible with trimming.")]
public sealed class TextSearchStore<TKey> : ITextSearch, IDisposable
public sealed partial class TextSearchStore<TKey> : ITextSearch, IDisposable
where TKey : notnull
{
#if NET7_0_OR_GREATER
[GeneratedRegex(@"\p{L}+", RegexOptions.IgnoreCase, "en-US")]
private static partial Regex AnyLanguageWordRegex();
#else
private static readonly Regex s_anyLanguageWordRegex = new(@"\p{L}+", RegexOptions.Compiled);
private static Regex AnyLanguageWordRegex() => s_anyLanguageWordRegex;
#endif

private static readonly Func<string, ICollection<string>> s_defaultWordSegementer = text => ((IEnumerable<Match>)AnyLanguageWordRegex().Matches(text)).Select(x => x.Value).ToList();

private readonly VectorStore _vectorStore;
private readonly int _vectorDimensions;
private readonly TextSearchStoreOptions _options;
private readonly Func<string, ICollection<string>> _wordSegmenter;

private readonly Lazy<VectorStoreCollection<TKey, TextRagStorageDocument<TKey>>> _vectorStoreRecordCollection;
private readonly VectorStoreCollection<TKey, TextRagStorageDocument<TKey>> _vectorStoreRecordCollection;
private readonly SemaphoreSlim _collectionInitializationLock = new(1, 1);
private bool _collectionInitialized = false;
private bool _disposedValue;
Expand Down Expand Up @@ -74,6 +86,7 @@ public TextSearchStore(
this._vectorStore = vectorStore;
this._vectorDimensions = vectorDimensions;
this._options = options ?? new TextSearchStoreOptions();
this._wordSegmenter = this._options.WordSegementer ?? s_defaultWordSegementer;

// Create a definition so that we can use the dimensions provided at runtime.
VectorStoreCollectionDefinition ragDocumentDefinition = new()
Expand All @@ -83,15 +96,14 @@ public TextSearchStore(
new VectorStoreKeyProperty("Key", typeof(TKey)),
new VectorStoreDataProperty("Namespaces", typeof(List<string>)) { IsIndexed = true },
new VectorStoreDataProperty("SourceId", typeof(string)) { IsIndexed = true },
new VectorStoreDataProperty("Text", typeof(string)),
new VectorStoreDataProperty("Text", typeof(string)) { IsFullTextIndexed = true },
new VectorStoreDataProperty("SourceName", typeof(string)),
new VectorStoreDataProperty("SourceLink", typeof(string)),
new VectorStoreVectorProperty("TextEmbedding", typeof(string), vectorDimensions),
}
};

this._vectorStoreRecordCollection = new Lazy<VectorStoreCollection<TKey, TextRagStorageDocument<TKey>>>(() =>
this._vectorStore.GetCollection<TKey, TextRagStorageDocument<TKey>>(collectionName, ragDocumentDefinition));
this._vectorStoreRecordCollection = this._vectorStore.GetCollection<TKey, TextRagStorageDocument<TKey>>(collectionName, ragDocumentDefinition);
}

/// <summary>
Expand All @@ -114,11 +126,9 @@ public async Task UpsertTextAsync(IEnumerable<string> textChunks, CancellationTo
throw new ArgumentException("One of the provided text chunks is null.", nameof(textChunks));
}

var key = GenerateUniqueKey<TKey>(null);

return new TextRagStorageDocument<TKey>
{
Key = key,
Key = GenerateUniqueKey<TKey>(null),
Text = textChunk,
TextEmbedding = textChunk,
};
Expand Down Expand Up @@ -214,20 +224,41 @@ public async Task<KernelSearchResults<object>> GetSearchResultsAsync(string quer
/// <returns>The search results.</returns>
private async Task<IEnumerable<TextRagStorageDocument<TKey>>> SearchInternalAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default)
{
// Short circuit if the query is empty.
if (string.IsNullOrWhiteSpace(query))
{
return Enumerable.Empty<TextRagStorageDocument<TKey>>();
}

var vectorStoreRecordCollection = await this.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false);

// If the user has not opted out of hybrid search, check if the vector store supports it.
var hybridSearchCollection = this._options.UseHybridSearch ?? true ?
vectorStoreRecordCollection.GetService(typeof(IKeywordHybridSearchable<TextRagStorageDocument<TKey>>)) as IKeywordHybridSearchable<TextRagStorageDocument<TKey>> :
null;

// Optional filter to limit the search to a specific namespace.
Expression<Func<TextRagStorageDocument<TKey>, bool>>? filter = string.IsNullOrWhiteSpace(this._options.SearchNamespace) ? null : x => x.Namespaces.Contains(this._options.SearchNamespace);

// Generate the vector for the query and search.
var searchResult = vectorStoreRecordCollection.SearchAsync(
query,
searchOptions?.Top ?? 3,
options: new()
{
Filter = filter,
},
cancellationToken: cancellationToken);
// Execute a hybrid search if possible, otherwise perform a regular vector search.
var searchResult = hybridSearchCollection is null
? vectorStoreRecordCollection.SearchAsync(
query,
searchOptions?.Top ?? 3,
options: new()
{
Filter = filter,
},
cancellationToken: cancellationToken)
: hybridSearchCollection.HybridSearchAsync(
query,
this._wordSegmenter(query),
searchOptions?.Top ?? 3,
options: new()
{
Filter = filter,
},
cancellationToken: cancellationToken);

// Retrieve the documents from the search results.
var searchResponseDocs = await searchResult
Expand Down Expand Up @@ -281,12 +312,10 @@ private async Task<IEnumerable<TextRagStorageDocument<TKey>>> SearchInternalAsyn
/// <returns>The created collection.</returns>
private async Task<VectorStoreCollection<TKey, TextRagStorageDocument<TKey>>> EnsureCollectionExistsAsync(CancellationToken cancellationToken)
{
var vectorStoreRecordCollection = this._vectorStoreRecordCollection.Value;

// Return immediately if the collection is already created, no need to do any locking in this case.
if (this._collectionInitialized)
{
return vectorStoreRecordCollection;
return this._vectorStoreRecordCollection;
}

// Wait on a lock to ensure that only one thread can create the collection.
Expand All @@ -297,21 +326,21 @@ private async Task<VectorStoreCollection<TKey, TextRagStorageDocument<TKey>>> En
if (this._collectionInitialized)
{
this._collectionInitializationLock.Release();
return vectorStoreRecordCollection;
return this._vectorStoreRecordCollection;
}

// Only the winning thread should reach this point and create the collection.
try
{
await vectorStoreRecordCollection.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false);
await this._vectorStoreRecordCollection.EnsureCollectionExistsAsync(cancellationToken).ConfigureAwait(false);
this._collectionInitialized = true;
}
finally
{
this._collectionInitializationLock.Release();
}

return vectorStoreRecordCollection;
return this._vectorStoreRecordCollection;
}

/// <summary>
Expand All @@ -338,6 +367,7 @@ private void Dispose(bool disposing)
{
if (disposing)
{
this._vectorStoreRecordCollection.Dispose();
this._collectionInitializationLock.Dispose();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,23 @@ public sealed class TextSearchStoreOptions
/// </value>
public bool? UseSourceIdAsPrimaryKey { get; init; }

/// <summary>
/// Gets or sets a value indicating whether to use hybrid search if it is available for the provided vector store.
/// </summary>
/// <value>
/// Defaults to <c>true</c> if not set.
/// </value>
public bool? UseHybridSearch { get; init; }

/// <summary>
/// Gets or sets a word segmenter function to split search text into separate words for the purposes of hybrid search.
/// This will not be used if <see cref="UseHybridSearch"/> is set to <c>false</c>.
/// </summary>
/// <remarks>
/// Defaults to a simple text-character-based segmenter that splits the text by any character that is not a text character.
/// </remarks>
public Func<string, ICollection<string>>? WordSegementer { get; init; }

/// <summary>
/// Gets or sets an optional callback to load the source text using the source id or source link
/// if the source text is not persisted in the database.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ public class TextSearchStoreTests
{
private readonly Mock<VectorStore> _vectorStoreMock;
private readonly Mock<VectorStoreCollection<string, TextSearchStore<string>.TextRagStorageDocument<string>>> _recordCollectionMock;
private readonly Mock<IKeywordHybridSearchable<TextSearchStore<string>.TextRagStorageDocument<string>>> _keywordHybridSearchableMock;

public TextSearchStoreTests()
{
this._vectorStoreMock = new Mock<VectorStore>();
this._recordCollectionMock = new Mock<VectorStoreCollection<string, TextSearchStore<string>.TextRagStorageDocument<string>>>();
this._keywordHybridSearchableMock = new Mock<IKeywordHybridSearchable<TextSearchStore<string>.TextRagStorageDocument<string>>>();

this._vectorStoreMock
.Setup(v => v.GetCollection<string, TextSearchStore<string>.TextRagStorageDocument<string>>("testCollection", It.IsAny<VectorStoreCollectionDefinition>()))
Expand Down Expand Up @@ -243,6 +245,39 @@ public async Task SearchAsyncReturnsSearchResults()
Assert.Equal("Sample text", actualResultsList[0]);
}

[Fact]
public async Task SearchAsyncWithHybridReturnsSearchResults()
{
// Arrange
this._recordCollectionMock
.Setup(r => r.GetService(typeof(IKeywordHybridSearchable<TextSearchStore<string>.TextRagStorageDocument<string>>), null))
.Returns(this._keywordHybridSearchableMock.Object);

var mockResults = new List<VectorSearchResult<TextSearchStore<string>.TextRagStorageDocument<string>>>
{
new(new TextSearchStore<string>.TextRagStorageDocument<string> { Text = "Sample text" }, 0.9f)
};

this._keywordHybridSearchableMock
.Setup(r => r.HybridSearchAsync(
"query word1 wordtwo",
It.Is<ICollection<string>>(x => x.Contains("query") && x.Contains("word") && x.Contains("wordtwo")),
3,
It.IsAny<HybridSearchOptions<TextSearchStore<string>.TextRagStorageDocument<string>>>(),
It.IsAny<CancellationToken>()))
.Returns(mockResults.ToAsyncEnumerable());

using var store = new TextSearchStore<string>(this._vectorStoreMock.Object, "testCollection", 128);

// Act
var actualResults = await store.SearchAsync("query word1 wordtwo");

// Assert
var actualResultsList = await actualResults.Results.ToListAsync();
Assert.Single(actualResultsList);
Assert.Equal("Sample text", actualResultsList[0]);
}

[Fact]
public async Task SearchAsyncWithHydrationCallsCallbackAndReturnsSearchResults()
{
Expand Down
Loading