wabbajack/Wabbajack.Common/AsyncParallelExtensions.cs
2022-10-03 22:43:21 -06:00

267 lines
8.5 KiB
C#

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Wabbajack.RateLimiter;
namespace Wabbajack.Common;
public static class AsyncParallelExtensions
{
public static async Task PDoAll<TIn>(this IEnumerable<TIn> coll, Func<TIn, Task> mapFn)
{
var tasks = coll.Select(mapFn).ToList();
await Task.WhenAll(tasks);
}
public static async Task PDoAll<TIn, TJob>(this IEnumerable<TIn> coll, IResource<TJob> limiter,
Func<TIn, Task> mapFn)
{
var tasks = coll.Select(async x =>
{
using var job = await limiter.Begin("", 0, CancellationToken.None);
await mapFn(x);
}).ToList();
await Task.WhenAll(tasks);
}
public static async IAsyncEnumerable<TOut> PMapAll<TIn, TOut>(this IEnumerable<TIn> coll,
Func<TIn, Task<TOut>> mapFn)
{
var tasks = coll.Select(async x =>
{
return await Task.Run(() => mapFn(x));
}).ToList();
foreach (var itm in tasks) yield return await itm;
}
// Like PMapAll but don't keep defaults
public static async IAsyncEnumerable<TOut> PKeepAll<TIn, TOut>(this IEnumerable<TIn> coll,
Func<TIn, Task<TOut>> mapFn)
where TOut : class
{
var tasks = coll.Select(mapFn).ToList();
foreach (var itm in tasks)
{
var val = await itm;
if (itm != default)
yield return await itm;
}
}
public static async IAsyncEnumerable<TOut> PMapAll<TIn, TJob, TOut>(this IEnumerable<TIn> coll,
IResource<TJob> limiter, Func<TIn, Task<TOut>> mapFn)
{
var tasks = coll.Select(mapFn).ToList();
foreach (var itm in tasks)
{
using var job = await limiter.Begin("", 0, CancellationToken.None);
yield return await itm;
}
}
/// <summary>
/// Faster version of PMapAll for when the function invocation will take a very small amount of time
/// batches all the inputs into N groups and executes them all on one task, where N is the number of
/// threads supported by the limiter
/// </summary>
/// <param name="coll"></param>
/// <param name="limiter"></param>
/// <param name="mapFn"></param>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TJob"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <returns></returns>
public static async IAsyncEnumerable<TOut> PMapAllBatched<TIn, TJob, TOut>(this IEnumerable<TIn> coll,
IResource<TJob> limiter, Func<TIn, Task<TOut>> mapFn)
{
var asList = coll.ToList();
var tasks = new List<Task<List<TOut>>>();
tasks.AddRange(Enumerable.Range(0, limiter.MaxTasks).Select(i => Task.Run(async () =>
{
using var job = await limiter.Begin(limiter.Name, asList.Count / limiter.MaxTasks, CancellationToken.None);
var list = new List<TOut>();
for (var idx = i; idx < asList.Count; idx += limiter.MaxTasks)
{
job.ReportNoWait(1);
list.Add(await mapFn(asList[idx]));
}
return list;
})));
foreach (var result in tasks)
{
foreach (var itm in (await result))
{
yield return itm;
}
}
}
/// <summary>
/// Faster version of PMapAll for when the function invocation will take a very small amount of time
/// batches all the inputs into N groups and executes them all on one task, where N is the number of
/// threads supported by the limiter
/// </summary>
/// <param name="coll"></param>
/// <param name="limiter"></param>
/// <param name="mapFn"></param>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TJob"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <returns></returns>
public static async Task PDoAllBatched<TIn, TJob, TOut>(this IEnumerable<TIn> coll,
IResource<TJob> limiter, Func<TIn, Task> mapFn)
{
var asList = coll.ToList();
var tasks = new List<Task>();
tasks.AddRange(Enumerable.Range(0, limiter.MaxTasks).Select(i => Task.Run(async () =>
{
using var job = await limiter.Begin(limiter.Name, asList.Count / limiter.MaxTasks, CancellationToken.None);
for (var idx = i; idx < asList.Count; idx += limiter.MaxTasks)
{
job.ReportNoWait(1);
await mapFn(asList[idx]);
}
})));
await Task.WhenAll(tasks);
}
public static async IAsyncEnumerable<TOut> PKeepAll<TIn, TJob, TOut>(this IEnumerable<TIn> coll,
IResource<TJob> limiter, Func<TIn, Task<TOut>> mapFn)
where TOut : class
{
var tasks = coll.Select(mapFn).ToList();
foreach (var itm in tasks)
{
using var job = await limiter.Begin("", 0, CancellationToken.None);
var itmA = await itm;
if (itmA != default)
{
yield return await itm;
}
}
}
public static async Task<List<T>> ToList<T>(this IAsyncEnumerable<T> coll)
{
List<T> lst = new();
await foreach (var itm in coll) lst.Add(itm);
return lst;
}
/// <summary>
/// Consumes a IAsyncEnumerable without doing anything with it
/// </summary>
/// <param name="coll"></param>
/// <typeparam name="T"></typeparam>
public static async Task Sink<T>(this IAsyncEnumerable<T> coll)
{
long count = 0;
await foreach (var itm in coll)
{
count++;
}
}
public static async Task<T[]> ToArray<T>(this IAsyncEnumerable<T> coll)
{
List<T> lst = new();
await foreach (var itm in coll) lst.Add(itm);
return lst.ToArray();
}
public static async Task<IReadOnlyCollection<T>> ToReadOnlyCollection<T>(this IAsyncEnumerable<T> coll)
{
List<T> lst = new();
await foreach (var itm in coll) lst.Add(itm);
return lst;
}
public static async Task<HashSet<T>> ToHashSet<T>(this IAsyncEnumerable<T> coll, Predicate<T>? filter = default)
{
HashSet<T> lst = new();
if (filter == default)
await foreach (var itm in coll)
lst.Add(itm);
else
await foreach (var itm in coll.Where(filter))
lst.Add(itm);
return lst;
}
public static async Task Do<T>(this IAsyncEnumerable<T> coll, Func<T, Task> fn)
{
await foreach (var itm in coll) await fn(itm);
}
public static async Task Do<T>(this IAsyncEnumerable<T> coll, Action<T> fn)
{
await foreach (var itm in coll) fn(itm);
}
public static async Task<IDictionary<TK, T>> ToDictionary<T, TK>(this IAsyncEnumerable<T> coll,
Func<T, TK> kSelector)
where TK : notnull
{
Dictionary<TK, T> dict = new();
await foreach (var itm in coll) dict.Add(kSelector(itm), itm);
return dict;
}
public static async Task<IDictionary<TK, TV>> ToDictionary<T, TK, TV>(this IAsyncEnumerable<T> coll,
Func<T, TK> kSelector, Func<T, TV> vSelector)
where TK : notnull
{
Dictionary<TK, TV> dict = new();
await foreach (var itm in coll) dict.Add(kSelector(itm), vSelector(itm));
return dict;
}
public static async IAsyncEnumerable<T> Where<T>(this IAsyncEnumerable<T> coll, Predicate<T> p)
{
await foreach (var itm in coll)
if (p(itm))
yield return itm;
}
public static async IAsyncEnumerable<TOut> SelectAsync<TIn, TOut>(this IEnumerable<TIn> coll,
Func<TIn, ValueTask<TOut>> fn)
{
foreach (var itm in coll)
yield return await fn(itm);
}
public static async IAsyncEnumerable<TOut> SelectMany<TIn, TOut>(this IEnumerable<TIn> coll,
Func<TIn, ValueTask<IEnumerable<TOut>>> fn)
{
foreach (var itm in coll)
foreach (var inner in await fn(itm))
yield return inner;
}
public static async IAsyncEnumerable<TOut> Select<TIn, TOut>(this IAsyncEnumerable<TIn> coll,
Func<TIn, ValueTask<TOut>> fn)
{
await foreach (var itm in coll)
yield return await fn(itm);
}
public static async IAsyncEnumerable<TOut> SelectMany<TIn, TOut>(this IAsyncEnumerable<TIn> coll,
Func<TIn, IEnumerable<TOut>> fn)
{
await foreach (var itm in coll)
foreach (var inner in fn(itm))
yield return inner;
}
}