Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,21 @@ private static IEnumerable<AsyncMethodGrouping> GetMethodsGroupedBySyntaxTree(Ge

private static string GenerateOverloads(AsyncMethodGrouping grouping, GenerationOptions options)
{
var usings = grouping.SyntaxTree.GetRoot() is CompilationUnitSyntax compilationUnit
? compilationUnit.Usings.ToString()
: string.Empty;

var overloads = new StringBuilder();
overloads.AppendLine("#nullable enable");
overloads.AppendLine(usings);
overloads.AppendLine("namespace System.Linq");
overloads.AppendLine("{");
overloads.AppendLine(" partial class AsyncEnumerable");
overloads.AppendLine(" {");

foreach (var method in grouping.Methods)
overloads.AppendLine(GenerateOverload(method, options));

overloads.AppendLine(" }");
overloads.AppendLine("}");

return overloads.ToString();
var compilationRoot = grouping.SyntaxTree.GetCompilationUnitRoot();
var namespaceDeclaration = compilationRoot.ChildNodes().OfType<NamespaceDeclarationSyntax>().Single();
var classDeclaration = namespaceDeclaration.ChildNodes().OfType<ClassDeclarationSyntax>().Single();

return CompilationUnit()
.WithUsings(List(compilationRoot.Usings.Select(@using => @using.WithoutTrivia())))
.AddMembers(NamespaceDeclaration(namespaceDeclaration.Name)
.AddMembers(ClassDeclaration(classDeclaration.Identifier)
.AddModifiers(Token(SyntaxKind.PartialKeyword))
.WithMembers(List(grouping.Methods.Select(method => GenerateOverload(method, options))))))
.NormalizeWhitespace()
.ToFullString();
}

private static string GenerateOverload(AsyncMethod method, GenerationOptions options)
private static MemberDeclarationSyntax GenerateOverload(AsyncMethod method, GenerationOptions options)
=> MethodDeclaration(method.Syntax.ReturnType, GetMethodName(method.Symbol, options))
.WithModifiers(TokenList(Token(SyntaxKind.PublicKeyword), Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(method.Syntax.TypeParameterList)
Expand All @@ -87,9 +80,7 @@ private static string GenerateOverload(AsyncMethod method, GenerationOptions opt
method.Syntax.ParameterList.Parameters
.Select(p => Argument(IdentifierName(p.Identifier))))))))
.WithSemicolonToken(Token(SyntaxKind.SemicolonToken))
.WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax))
.NormalizeWhitespace()
.ToFullString();
.WithLeadingTrivia(method.Syntax.GetLeadingTrivia().Where(t => t.GetStructure() is not DirectiveTriviaSyntax));

private static INamedTypeSymbol GetAsyncOverloadAttributeSymbol(GeneratorExecutionContext context)
=> context.Compilation.GetTypeByMetadataName("System.Linq.GenerateAsyncOverloadAttribute") ?? throw new InvalidOperationException();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Tests
{
public class ToImmutableArray : AsyncEnumerableTests
{
[Fact]
public async Task ToImmutableArray_Null()
{
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableArrayAsyncEnumerableExtensions.ToImmutableArrayAsync<int>(default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableArrayAsyncEnumerableExtensions.ToImmutableArrayAsync<int>(default, CancellationToken.None).AsTask());
}

[Fact]
public async Task ToImmutableArray_IAsyncIListProvider_Simple()
{
var xs = new[] { 42, 25, 39 };
var res = xs.ToAsyncEnumerable().ToImmutableArrayAsync();
Assert.True((await res).SequenceEqual(xs));
}

[Fact]
public async Task ToImmutableArray_IAsyncIListProvider_Empty1()
{
var xs = new int[0];
var res = xs.ToAsyncEnumerable().ToImmutableArrayAsync();
Assert.True((await res).SequenceEqual(xs));
}

[Fact]
public async Task ToImmutableArray_IAsyncIListProvider_Empty2()
{
var xs = new HashSet<int>();
var res = xs.ToAsyncEnumerable().ToImmutableArrayAsync();
Assert.True((await res).SequenceEqual(xs));
}

[Fact]
public async Task ToImmutableArray_Empty()
{
var xs = AsyncEnumerable.Empty<int>();
var res = xs.ToImmutableArrayAsync();
Assert.True((await res).Length == 0);
}

[Fact]
public async Task ToImmutableArray_Throw()
{
var ex = new Exception("Bang!");
var res = Throw<int>(ex).ToImmutableArrayAsync();
await AssertThrowsAsync(res, ex);
}

[Fact]
public async Task ToImmutableArray_Query()
{
var xs = await AsyncEnumerable.Range(5, 50).Take(10).ToImmutableArrayAsync();
var ex = new[] { 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 };

Assert.True(ex.SequenceEqual(xs));
}

[Fact]
public async Task ToImmutableArray_Set()
{
var res = new[] { 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 };
var xs = new HashSet<int>(res);

var arr = await xs.ToAsyncEnumerable().ToImmutableArrayAsync();

Assert.True(res.SequenceEqual(arr));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Tests
{
public class ToImmutableDictionary : AsyncEnumerableTests
{
[Fact]
public async Task ToImmutableDictionary_Null()
{
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int>(default, x => 0).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default(Func<int, int>)).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int>(default, x => 0, EqualityComparer<int>.Default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default, EqualityComparer<int>.Default).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(default, x => 0, x => 0).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, default, x => 0).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, x => 0, default).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(default, x => 0, x => 0, EqualityComparer<int>.Default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default, x => 0, EqualityComparer<int>.Default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, x => 0, default, EqualityComparer<int>.Default).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int>(default, x => 0, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default(Func<int, int>), CancellationToken.None).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int>(default, x => 0, EqualityComparer<int>.Default, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default, EqualityComparer<int>.Default, CancellationToken.None).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(default, x => 0, x => 0, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, default, x => 0, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, x => 0, default, CancellationToken.None).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(default, x => 0, x => 0, EqualityComparer<int>.Default, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync(Return42, default, x => 0, EqualityComparer<int>.Default, CancellationToken.None).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableDictionaryAsyncEnumerableExtensions.ToImmutableDictionaryAsync<int, int, int>(Return42, x => 0, default, EqualityComparer<int>.Default, CancellationToken.None).AsTask());
}

[Fact]
public async Task ToImmutableDictionary1Async()
{
var xs = new[] { 1, 4 }.ToAsyncEnumerable();
var res = await xs.ToImmutableDictionaryAsync(x => x % 2);
Assert.True(res[0] == 4);
Assert.True(res[1] == 1);
}

[Fact]
public async Task ToImmutableDictionary2Async()
{
var xs = new[] { 1, 4, 2 }.ToAsyncEnumerable();
await AssertThrowsAsync<ArgumentException>(xs.ToImmutableDictionaryAsync(x => x % 2).AsTask());
}

[Fact]
public async Task ToImmutableDictionary3Async()
{
var xs = new[] { 1, 4 }.ToAsyncEnumerable();
var res = await xs.ToImmutableDictionaryAsync(x => x % 2, x => x + 1);
Assert.True(res[0] == 5);
Assert.True(res[1] == 2);
}

[Fact]
public async Task ToImmutableDictionary4Async()
{
var xs = new[] { 1, 4, 2 }.ToAsyncEnumerable();
await AssertThrowsAsync<ArgumentException>(xs.ToImmutableDictionaryAsync(x => x % 2, x => x + 1).AsTask());
}

[Fact]
public async Task ToImmutableDictionary5Async()
{
var xs = new[] { 1, 4 }.ToAsyncEnumerable();
var res = await xs.ToImmutableDictionaryAsync(x => x % 2, new Eq());
Assert.True(res[0] == 4);
Assert.True(res[1] == 1);
}

[Fact]
public async Task ToImmutableDictionary6Async()
{
var xs = new[] { 1, 4, 2 }.ToAsyncEnumerable();
await AssertThrowsAsync<ArgumentException>(xs.ToImmutableDictionaryAsync(x => x % 2, new Eq()).AsTask());
}

[Fact]
public async Task ToImmutableDictionary7Async()
{
var xs = new[] { 1, 4 }.ToAsyncEnumerable();
var res = await xs.ToImmutableDictionaryAsync(x => x % 2, x => x, new Eq());
Assert.True(res[0] == 4);
Assert.True(res[1] == 1);
}

private sealed class Eq : IEqualityComparer<int>
{
public bool Equals(int x, int y) => EqualityComparer<int>.Default.Equals(Math.Abs(x), Math.Abs(y));

public int GetHashCode(int obj) => EqualityComparer<int>.Default.GetHashCode(Math.Abs(obj));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Tests
{
public class ToImmutableHashSet : AsyncEnumerableTests
{
[Fact]
public async Task ToImmutableHashSet_Null()
{
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableHashSetAsyncEnumerableExtensions.ToImmutableHashSetAsync<int>(default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableHashSetAsyncEnumerableExtensions.ToImmutableHashSetAsync<int>(default, CancellationToken.None).AsTask());

await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableHashSetAsyncEnumerableExtensions.ToImmutableHashSetAsync(default, EqualityComparer<int>.Default, CancellationToken.None).AsTask());
}

[Fact]
public async Task ToImmutableHashSet_Simple()
{
var xs = new[] { 1, 2, 1, 2, 3, 4, 1, 2, 3, 4 };
var res = xs.ToAsyncEnumerable().ToImmutableHashSetAsync();
Assert.True((await res).OrderBy(x => x).SequenceEqual(new[] { 1, 2, 3, 4 }));
}

[Fact]
public async Task ToImmutableHashSet_Comparer()
{
var xs = new[] { 1, 12, 11, 2, 3, 14, 1, 12, 13, 4 };
var res = xs.ToAsyncEnumerable().ToImmutableHashSetAsync(new Eq());
Assert.True((await res).OrderBy(x => x).SequenceEqual(new[] { 1, 3, 12, 14 }));
}

private class Eq : IEqualityComparer<int>
{
public bool Equals(int x, int y) => x % 10 == y % 10;

public int GetHashCode(int obj) => obj % 10;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Immutable;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Tests
{
public class ToImmutableList : AsyncEnumerableTests
{
[Fact]
public async Task ToImmutableList_Null()
{
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableListAsyncEnumerableExtensions.ToImmutableListAsync<int>(default).AsTask());
await Assert.ThrowsAsync<ArgumentNullException>(() => ImmutableListAsyncEnumerableExtensions.ToImmutableListAsync<int>(default, CancellationToken.None).AsTask());
}

[Fact]
public async Task ToImmutableList_Simple()
{
var xs = new[] { 42, 25, 39 };
var res = xs.ToAsyncEnumerable().ToImmutableListAsync();
Assert.True((await res).SequenceEqual(xs));
}

[Fact]
public async Task ToImmutableList_Empty()
{
var xs = AsyncEnumerable.Empty<int>();
var res = xs.ToImmutableListAsync();
Assert.True((await res).Count == 0);
}

[Fact]
public async Task ToImmutableList_Throw()
{
var ex = new Exception("Bang!");
var res = Throw<int>(ex).ToImmutableListAsync();
await AssertThrowsAsync(res, ex);
}
}
}
Loading