Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for total count to ToBatchPageAsync. #7944

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using System.Collections.Concurrent;
using System.Collections.Immutable;
using System.Linq.Expressions;
using System.Reflection;
using HotChocolate.Pagination.Expressions;
using static HotChocolate.Pagination.Expressions.ExpressionHelpers;

Expand All @@ -11,6 +13,7 @@ namespace HotChocolate.Pagination;
public static class PagingQueryableExtensions
{
private static readonly AsyncLocal<InterceptorHolder> _interceptor = new();
private static readonly ConcurrentDictionary<(Type, Type), Expression> _countExpressionCache = new();

/// <summary>
/// Executes a query with paging and returns the selected page.
Expand Down Expand Up @@ -208,6 +211,49 @@ public static ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, T
where TKey : notnull
=> ToBatchPageAsync<TKey, TValue, TValue>(source, keySelector, t => t, arguments, cancellationToken);

/// <summary>
/// Executes a batch query with paging and returns the selected pages for each parent.
/// </summary>
/// <param name="source">
/// The queryable to be paged.
/// </param>
/// <param name="keySelector">
/// A function to select the key of the parent.
/// </param>
/// <param name="arguments">
/// The paging arguments.
/// </param>
/// <param name="includeTotalCount">
/// If set to <c>true</c> the total count will be included in the result.
/// </param>
/// <param name="cancellationToken">
/// The cancellation token.
/// </param>
/// <typeparam name="TKey">
/// The type of the parent key.
/// </typeparam>
/// <typeparam name="TValue">
/// The type of the items in the queryable.
/// </typeparam>
/// <returns></returns>
/// <exception cref="ArgumentException">
/// If the queryable does not have any keys specified.
/// </exception>
public static ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, TValue>(
this IQueryable<TValue> source,
Expression<Func<TValue, TKey>> keySelector,
PagingArguments arguments,
bool includeTotalCount,
CancellationToken cancellationToken = default)
where TKey : notnull
=> ToBatchPageAsync<TKey, TValue, TValue>(
source,
keySelector,
t => t,
arguments,
includeTotalCount: includeTotalCount,
cancellationToken);

/// <summary>
/// Executes a batch query with paging and returns the selected pages for each parent.
/// </summary>
Expand Down Expand Up @@ -239,11 +285,55 @@ public static ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, T
/// <exception cref="ArgumentException">
/// If the queryable does not have any keys specified.
/// </exception>
public static ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, TValue, TElement>(
this IQueryable<TElement> source,
Expression<Func<TElement, TKey>> keySelector,
Func<TElement, TValue> valueSelector,
PagingArguments arguments,
CancellationToken cancellationToken = default)
where TKey : notnull
=> ToBatchPageAsync(source, keySelector, valueSelector, arguments, includeTotalCount: false, cancellationToken);

/// <summary>
/// Executes a batch query with paging and returns the selected pages for each parent.
/// </summary>
/// <param name="source">
/// The queryable to be paged.
/// </param>
/// <param name="keySelector">
/// A function to select the key of the parent.
/// </param>
/// <param name="valueSelector">
/// A function to select the value of the items in the queryable.
/// </param>
/// <param name="arguments">
/// The paging arguments.
/// </param>
/// <param name="includeTotalCount">
/// If set to <c>true</c> the total count will be included in the result.
/// </param>
/// <param name="cancellationToken">
/// The cancellation token.
/// </param>
/// <typeparam name="TKey">
/// The type of the parent key.
/// </typeparam>
/// <typeparam name="TValue">
/// The type of the items in the queryable.
/// </typeparam>
/// <typeparam name="TElement">
/// The type of the items in the queryable.
/// </typeparam>
/// <returns></returns>
/// <exception cref="ArgumentException">
/// If the queryable does not have any keys specified.
/// </exception>
public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<TKey, TValue, TElement>(
this IQueryable<TElement> source,
Expression<Func<TElement, TKey>> keySelector,
Func<TElement, TValue> valueSelector,
PagingArguments arguments,
bool includeTotalCount,
CancellationToken cancellationToken = default)
where TKey : notnull
{
Expand All @@ -263,6 +353,12 @@ public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<T
nameof(arguments));
}

Dictionary<TKey, int>? counts = null;
if (includeTotalCount)
{
counts = await GetBatchCountsAsync(source, keySelector, cancellationToken);
}

source = QueryHelpers.EnsureOrderPropsAreSelected(source);

// we need to move the ordering into the select expression we are constructing
Expand Down Expand Up @@ -308,13 +404,67 @@ public static async ValueTask<Dictionary<TKey, Page<TValue>>> ToBatchPageAsync<T
builder.Add(valueSelector(item.Items[i]));
}

var page = CreatePage(builder.ToImmutable(), arguments, keys, item.Items.Count);
var totalCount = counts?.GetValueOrDefault(item.Key);
var page = CreatePage(builder.ToImmutable(), arguments, keys, item.Items.Count, totalCount);
map.Add(item.Key, page);
}

return map;
}

private static async Task<Dictionary<TKey, int>> GetBatchCountsAsync<TElement, TKey>(
IQueryable<TElement> source,
Expression<Func<TElement, TKey>> keySelector,
CancellationToken cancellationToken)
where TKey : notnull
{
var query = source
.GroupBy(keySelector)
.Select(GetOrCreateCountSelector<TElement, TKey>());

TryGetQueryInterceptor()?.OnBeforeExecute(query);

return await query.ToDictionaryAsync(t => t.Key, t => t.Count, cancellationToken);
}

private static Expression<Func<IGrouping<TKey, TElement>, CountResult<TKey>>> GetOrCreateCountSelector<TElement, TKey>()
{
return (Expression<Func<IGrouping<TKey, TElement>, CountResult<TKey>>>)
_countExpressionCache.GetOrAdd(
(typeof(TKey), typeof(TElement)),
static _ =>
{
var groupingType = typeof(IGrouping<,>).MakeGenericType(typeof(TKey), typeof(TElement));
var param = Expression.Parameter(groupingType, "g");
var keyProperty = Expression.Property(param, nameof(IGrouping<TKey, TElement>.Key));
var countMethod = typeof(Enumerable)
.GetMethods(BindingFlags.Static | BindingFlags.Public)
.First(m => m.Name == nameof(Enumerable.Count) && m.GetParameters().Length == 1)
.MakeGenericMethod(typeof(TElement));
var countCall = Expression.Call(countMethod, param);

var resultCtor = typeof(CountResult<TKey>).GetConstructor(Type.EmptyTypes)!;
var newExpr = Expression.New(resultCtor);

var bindings = new List<MemberBinding>
{
Expression.Bind(typeof(CountResult<TKey>).GetProperty(nameof(CountResult<TKey>.Key))!,
keyProperty),
Expression.Bind(typeof(CountResult<TKey>).GetProperty(nameof(CountResult<TKey>.Count))!,
countCall)
};

var body = Expression.MemberInit(newExpr, bindings);
return Expression.Lambda<Func<IGrouping<TKey, TElement>, CountResult<TKey>>>(body, param);
});
}

private class CountResult<TKey>
{
public required TKey Key { get; set; }
public required int Count { get; set; }
}

private static Page<T> CreatePage<T>(
ImmutableArray<T> items,
PagingArguments arguments,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,61 @@ public async Task Query_Owner_Animals()

var operationResult = result.ExpectOperationResult();

#if NET9_0_OR_GREATER
await Snapshot.Create("NET_9_0")
#else
await Snapshot.Create()
#endif
.AddQueries(queries)
.Add(operationResult.WithExtensions(ImmutableDictionary<string, object?>.Empty))
.MatchMarkdownAsync();
}

[Fact]
public async Task Query_Owner_Animals_With_TotalCount()
{
var connectionString = CreateConnectionString();
await SeedAsync(connectionString);

var queries = new List<QueryInfo>();
using var capture = new CapturePagingQueryInterceptor(queries);

var result = await new ServiceCollection()
.AddScoped(_ => new AnimalContext(connectionString))
.AddGraphQL()
.AddQueryType<Query>()
.AddTypeExtension(typeof(OwnerExtensionsWithTotalCount))
.AddDataLoader<AnimalsByOwnerWithCountDataLoader>()
.AddObjectType<Cat>()
.AddObjectType<Dog>()
.AddPagingArguments()
.ModifyRequestOptions(o => o.IncludeExceptionDetails = true)
.ModifyPagingOptions(o => o.IncludeTotalCount = true)
.ExecuteRequestAsync(
OperationRequestBuilder.New()
.SetDocument(
"""
{
owners(first: 10) {
nodes {
id
name
pets(first: 10) {
nodes {
__typename
id
name
}
totalCount
}
}
}
}
""")
.Build());

var operationResult = result.ExpectOperationResult();

#if NET9_0_OR_GREATER
await Snapshot.Create("NET_9_0")
#else
Expand Down Expand Up @@ -314,6 +369,24 @@ public static async Task<Connection<Animal>> GetPetsAsync(
.ToConnectionAsync();
}

[ExtendObjectType<Owner>]
public static class OwnerExtensionsWithTotalCount
{
[BindMember(nameof(Owner.Pets))]
[UsePaging]
public static async Task<Connection<Animal>> GetPetsAsync(
[Parent("Id")] Owner owner,
PagingArguments pagingArgs,
AnimalsByOwnerWithCountDataLoader animalsByOwner,
ISelection selection,
CancellationToken cancellationToken)
=> await animalsByOwner
.WithPagingArguments(pagingArgs)
.Select(selection)
.LoadAsync(owner.Id, cancellationToken)
.ToConnectionAsync();
}

public sealed class AnimalsByOwnerDataLoader
: StatefulBatchDataLoader<int, Page<Animal>>
{
Expand Down Expand Up @@ -352,6 +425,46 @@ protected override async Task<IReadOnlyDictionary<int, Page<Animal>>> LoadBatchA
cancellationToken);
}
}

public sealed class AnimalsByOwnerWithCountDataLoader
: StatefulBatchDataLoader<int, Page<Animal>>
{
private readonly IServiceProvider _services;

public AnimalsByOwnerWithCountDataLoader(
IServiceProvider services,
IBatchScheduler batchScheduler,
DataLoaderOptions options)
: base(batchScheduler, options)
{
_services = services;
}

protected override async Task<IReadOnlyDictionary<int, Page<Animal>>> LoadBatchAsync(
IReadOnlyList<int> keys,
DataLoaderFetchContext<Page<Animal>> context,
CancellationToken cancellationToken)
{
var pagingArgs = context.GetPagingArguments();
// var selector = context.GetSelector();

await using var scope = _services.CreateAsyncScope();
var dbContext = scope.ServiceProvider.GetRequiredService<AnimalContext>();

return await dbContext.Owners
.Where(t => keys.Contains(t.Id))
.SelectMany(t => t.Pets)
.OrderBy(t => t.Name)
.ThenBy(t => t.Id)
// selections do not work when inheritance is used for nested batching.
// .Select(selector, t => t.OwnerId)
.ToBatchPageAsync(
t => t.OwnerId,
pagingArgs,
includeTotalCount: true,
cancellationToken);
}
}
}

file static class Extensions
Expand Down
Loading
Loading