using System; using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; using Wabbajack.RateLimiter; namespace Wabbajack.Common; public static class AsyncParallelExtensions { public static async Task PDoAll(this IEnumerable coll, Func mapFn) { var tasks = coll.Select(mapFn).ToList(); await Task.WhenAll(tasks); } public static async Task PDoAll(this IEnumerable coll, IResource limiter, Func mapFn) { var tasks = coll.Select(async x => { using var job = await limiter.Begin("", 0, CancellationToken.None); await mapFn(x); }).ToList(); await Task.WhenAll(tasks); } public static async IAsyncEnumerable PMapAll(this IEnumerable coll, Func> mapFn) { var tasks = coll.Select(async x => { return await Task.Run(() => mapFn(x)); }).ToList(); foreach (var itm in tasks) yield return await itm; } // Like PMapAll but don't keep defaults public static async IAsyncEnumerable PKeepAll(this IEnumerable coll, Func> mapFn) where TOut : class { var tasks = coll.Select(mapFn).ToList(); foreach (var itm in tasks) { var val = await itm; if (itm != default) yield return await itm; } } public static async IAsyncEnumerable PMapAll(this IEnumerable coll, IResource limiter, Func> mapFn) { var tasks = coll.Select(mapFn).ToList(); foreach (var itm in tasks) { using var job = await limiter.Begin("", 0, CancellationToken.None); yield return await itm; } } /// /// Faster version of PMapAll for when the function invocation will take a very small amount of time /// batches all the inputs into N groups and executes them all on one task, where N is the number of /// threads supported by the limiter /// /// /// /// /// /// /// /// public static async IAsyncEnumerable PMapAllBatched(this IEnumerable coll, IResource limiter, Func> mapFn) { var asList = coll.ToList(); var tasks = new List>>(); tasks.AddRange(Enumerable.Range(0, limiter.MaxTasks).Select(i => Task.Run(async () => { using var job = await limiter.Begin(limiter.Name, asList.Count / limiter.MaxTasks, CancellationToken.None); var list = new List(); for (var idx = i; idx < asList.Count; idx += limiter.MaxTasks) { job.ReportNoWait(1); list.Add(await mapFn(asList[idx])); } return list; }))); foreach (var result in tasks) { foreach (var itm in (await result)) { yield return itm; } } } public static async IAsyncEnumerable PKeepAll(this IEnumerable coll, IResource limiter, Func> mapFn) where TOut : class { var tasks = coll.Select(mapFn).ToList(); foreach (var itm in tasks) { using var job = await limiter.Begin("", 0, CancellationToken.None); var itmA = await itm; if (itmA != default) { yield return await itm; } } } public static async Task> ToList(this IAsyncEnumerable coll) { List lst = new(); await foreach (var itm in coll) lst.Add(itm); return lst; } /// /// Consumes a IAsyncEnumerable without doing anything with it /// /// /// public static async Task Sink(this IAsyncEnumerable coll) { long count = 0; await foreach (var itm in coll) { count++; } } public static async Task ToArray(this IAsyncEnumerable coll) { List lst = new(); await foreach (var itm in coll) lst.Add(itm); return lst.ToArray(); } public static async Task> ToReadOnlyCollection(this IAsyncEnumerable coll) { List lst = new(); await foreach (var itm in coll) lst.Add(itm); return lst; } public static async Task> ToHashSet(this IAsyncEnumerable coll, Predicate? filter = default) { HashSet lst = new(); if (filter == default) await foreach (var itm in coll) lst.Add(itm); else await foreach (var itm in coll.Where(filter)) lst.Add(itm); return lst; } public static async Task Do(this IAsyncEnumerable coll, Func fn) { await foreach (var itm in coll) await fn(itm); } public static async Task Do(this IAsyncEnumerable coll, Action fn) { await foreach (var itm in coll) fn(itm); } public static async Task> ToDictionary(this IAsyncEnumerable coll, Func kSelector) where TK : notnull { Dictionary dict = new(); await foreach (var itm in coll) dict.Add(kSelector(itm), itm); return dict; } public static async Task> ToDictionary(this IAsyncEnumerable coll, Func kSelector, Func vSelector) where TK : notnull { Dictionary dict = new(); await foreach (var itm in coll) dict.Add(kSelector(itm), vSelector(itm)); return dict; } public static async IAsyncEnumerable Where(this IAsyncEnumerable coll, Predicate p) { await foreach (var itm in coll) if (p(itm)) yield return itm; } public static async IAsyncEnumerable SelectAsync(this IEnumerable coll, Func> fn) { foreach (var itm in coll) yield return await fn(itm); } public static async IAsyncEnumerable SelectMany(this IEnumerable coll, Func>> fn) { foreach (var itm in coll) foreach (var inner in await fn(itm)) yield return inner; } public static async IAsyncEnumerable Select(this IAsyncEnumerable coll, Func> fn) { await foreach (var itm in coll) yield return await fn(itm); } public static async IAsyncEnumerable SelectMany(this IAsyncEnumerable coll, Func> fn) { await foreach (var itm in coll) foreach (var inner in fn(itm)) yield return inner; } }