Fix channel closing, add unordered pipeline

This commit is contained in:
Timothy Baldridge 2019-11-09 14:29:55 -07:00
parent 67dfaa3581
commit 081dea2368
10 changed files with 277 additions and 14 deletions

View File

@ -13,7 +13,7 @@ namespace Wabbajack.Common.CSP
public abstract bool IsClosed { get; }
public abstract void Close();
public abstract (AsyncResult, bool) Put(TIn val, Handler<Action<bool>> handler);
public abstract (AsyncResult, TOut) Take(Handler<Action<TOut>> handler);
public abstract (AsyncResult, TOut) Take(Handler<Action<bool, TOut>> handler);
private Task<(bool, TOut)> _take_cancelled_task;

View File

@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Subjects;
using System.Text;
using System.Threading.Tasks;
@ -25,10 +26,19 @@ namespace Wabbajack.Common.CSP
/// <returns></returns>
public static IChannel<T, T> ToChannel<T>(this IEnumerable<T> coll)
{
return Channel.Create(coll.GetEnumerator());
var chan = Channel.Create(coll.GetEnumerator());
chan.Close();
return chan;
}
/// <summary>
/// Takes all the values from chan, once the channel closes returns a List of the values taken.
/// </summary>
/// <typeparam name="TOut"></typeparam>
/// <typeparam name="TIn"></typeparam>
/// <param name="chan"></param>
/// <returns></returns>
public static async Task<List<TOut>> TakeAll<TOut, TIn>(this IChannel<TIn, TOut> chan)
{
List<TOut> acc = new List<TOut>();
@ -42,5 +52,120 @@ namespace Wabbajack.Common.CSP
}
return acc;
}
/// <summary>
/// Pipes values from `from` into `to`
/// </summary>
/// <typeparam name="TIn"></typeparam>
/// <typeparam name="TMid"></typeparam>
/// <typeparam name="TOut"></typeparam>
/// <param name="from">source channel</param>
/// <param name="to">destination channel</param>
/// <param name="closeOnFinished">Tf true, will close the other channel when one channel closes</param>
/// <returns></returns>
public static async Task Pipe<TIn, TMid, TOut>(this IChannel<TIn, TMid> from, IChannel<TMid, TOut> to, bool closeOnFinished = true)
{
while (true)
{
var (isFromOpen, val) = await from.Take();
if (isFromOpen)
{
var isToOpen = await to.Put(val);
if (isToOpen) continue;
if (closeOnFinished)
@from.Close();
break;
}
if (closeOnFinished)
to.Close();
break;
}
}
/*
private static void PipelineInner<TInSrc, TOutSrc, TInDest, TOutDest>(int n,
IChannel<TInSrc, TOutSrc> from,
Func<TOutSrc, Task<TInDest>> fn,
IChannel<TInDest, TOutDest> to,
bool closeOnFinished)
{
var jobs = Channel.Create<TOutSrc>(n);
var results = Channel.Create<TInDest>(n);
{
bool Process(TOutSrc val, )
{
if ()
}
}
}*/
/// <summary>
/// Creates a pipeline that takes items from `from` transforms them with the pipeline given by `transform` and puts
/// the resulting values onto `to`. The pipeline may create 0 or more items for every input item and they will be
/// spooled onto `to` in a undefined order. `n` determines how many parallel tasks will be running at once. Each of
/// these tasks maintains its own transformation pipeline, so `transform` will be called once for every `n`. Completing
/// a `transform` pipeline has no effect.
/// </summary>
/// <typeparam name="TInSrc"></typeparam>
/// <typeparam name="TOutSrc"></typeparam>
/// <typeparam name="TInDest"></typeparam>
/// <typeparam name="TOutDest"></typeparam>
/// <param name="from"></param>
/// <param name="parallelism"></param>
/// <param name="to"></param>
/// <param name="transform"></param>
/// <param name="propagateClose"></param>
/// <returns></returns>
public static async Task UnorderedPipeline<TInSrc, TOutSrc, TInDest, TOutDest>(
this IChannel<TInSrc, TOutSrc> from,
int parallelism,
IChannel<TInDest, TOutDest> to,
Func<IObservable<TOutSrc>, IObservable<TInDest>> transform,
bool propagateClose = true)
{
async Task Pump()
{
var pipeline = new Subject<TOutSrc>();
var buffer = new List<TInDest>();
var dest = transform(pipeline);
dest.Subscribe(itm => buffer.Add(itm));
while (true)
{
var (is_open, tval) = await from.Take();
if (is_open)
{
pipeline.OnNext(tval);
foreach (var pval in buffer)
{
var is_put_open = await to.Put(pval);
if (is_put_open) continue;
if (propagateClose) @from.Close();
return;
}
buffer.Clear();
}
else
{
pipeline.OnCompleted();
if (buffer.Count > 0)
{
foreach (var pval in buffer)
if (!await to.Put(pval))
break;
}
if (propagateClose) to.Close();
break;
}
}
}
await Task.WhenAll(Enumerable.Range(0, parallelism)
.Select(idx => Task.Run(Pump)));
}
}
}

View File

@ -2,6 +2,8 @@
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Reactive.Subjects;
using System.Security.Cryptography;
using System.Text;
using System.Threading.Tasks;
@ -47,5 +49,11 @@ namespace Wabbajack.Common.CSP
},
b => {}, buffer);
}
public static IChannel<TIn, TOut> Create<TIn, TOut>(int buffer_size, Func<IObservable<TIn>, IObservable<TOut>> transform)
{
var buf = new RxBuffer<TIn, TOut>(buffer_size, transform);
return new ManyToManyChannel<TIn, TOut>(null, RxBuffer<TIn,TOut>.TransformAdd, RxBuffer<TIn, TOut>.Finalize, buf);
}
}
}

View File

@ -34,7 +34,7 @@ namespace Wabbajack.Common.CSP
void Close();
(AsyncResult, bool) Put(TIn val, Handler<Action<bool>> handler);
(AsyncResult, TOut) Take(Handler<Action<TOut>> handler);
(AsyncResult, TOut) Take(Handler<Action<bool, TOut>> handler);
ValueTask<(bool, TOut)> Take(bool onCaller = true);
ValueTask<bool> Put(TIn val, bool onCaller = true);
}

View File

@ -23,7 +23,7 @@ namespace Wabbajack.Common.CSP
{
public const int MAX_QUEUE_SIZE = 1024;
private RingBuffer<Handler<Action<TOut>>> _takes = new RingBuffer<Handler<Action<TOut>>>(8);
private RingBuffer<Handler<Action<bool, TOut>>> _takes = new RingBuffer<Handler<Action<bool, TOut>>>(8);
private RingBuffer<(Handler<Action<bool>>, TIn)> _puts = new RingBuffer<(Handler<Action<bool>>, TIn)>(8);
private IBuffer<TOut> _buf;
private Func<IBuffer<TOut>, TIn, bool> _add;
@ -47,7 +47,7 @@ namespace Wabbajack.Common.CSP
_converter = converter;
}
private static bool IsActiveTake(Handler<Action<TOut>> handler)
private static bool IsActiveTake(Handler<Action<bool, TOut>> handler)
{
return handler.IsActive;
}
@ -108,7 +108,7 @@ namespace Wabbajack.Common.CSP
if (put_cb2 != null && take_cb != null)
{
Monitor.Exit(this);
Task.Run(() => take_cb(_converter(val)));
Task.Run(() => take_cb(true, _converter(val)));
return (AsyncResult.Completed, true);
}
@ -140,7 +140,7 @@ namespace Wabbajack.Common.CSP
return (AsyncResult.Enqueued, true);
}
public override (AsyncResult, TOut) Take(Handler<Action<TOut>> handler)
public override (AsyncResult, TOut) Take(Handler<Action<bool, TOut>> handler)
{
Monitor.Enter(this);
Cleanup();
@ -221,9 +221,22 @@ namespace Wabbajack.Common.CSP
_isClosed = true;
if (_buf != null && _puts.IsEmpty)
_finalize(_buf);
var cbs = GetTakersForBuffer();
while (!_takes.IsEmpty)
{
var take_cb = LockIfActiveCommit(_takes.Pop());
if (take_cb != null)
cbs.Add(() => take_cb(false, default));
}
Monitor.Exit(this);
foreach (var cb in cbs)
Task.Run(cb);
}
private (Action<TOut>, Action<bool>, TIn, bool) FindMatchingPut(Handler<Action<TOut>> handler)
private (Action<bool, TOut>, Action<bool>, TIn, bool) FindMatchingPut(Handler<Action<bool, TOut>> handler)
{
while (!_puts.IsEmpty)
{
@ -326,7 +339,7 @@ namespace Wabbajack.Common.CSP
if (take_cp != null)
{
var val = _buf.Remove();
ret.Add(() => take_cp(val));
ret.Add(() => take_cp(true, val));
}
}

View File

@ -1,5 +1,6 @@
using System;
using System.CodeDom.Compiler;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@ -9,7 +10,7 @@ using System.Windows.Forms;
namespace Wabbajack.Common.CSP
{
public struct RingBuffer<T>
public struct RingBuffer<T> : IEnumerable<T>
{
private int _size;
private int _length;
@ -104,5 +105,16 @@ namespace Wabbajack.Common.CSP
}
}
}
public IEnumerator<T> GetEnumerator()
{
while (!IsEmpty)
yield return Pop();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}

View File

@ -0,0 +1,47 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Reactive.Subjects;
using System.Text;
using System.Threading.Tasks;
namespace Wabbajack.Common.CSP
{
public class RxBuffer<TIn, TOut> : FixedSizeBuffer<TOut>
{
private Subject<TIn> _inputSubject;
private IObservable<TOut> _outputObservable;
private bool _completed;
public RxBuffer(int size, Func<IObservable<TIn>, IObservable<TOut>> transform) : base(size)
{
_inputSubject = new Subject<TIn>();
_outputObservable = transform(_inputSubject);
_outputObservable.Subscribe(itm => base.Add(itm), () => {
_completed = true;
});
}
public bool TransformAdd(TIn val)
{
_inputSubject.OnNext(val);
return _completed;
}
public static bool TransformAdd(IBuffer<TOut> buf, TIn itm)
{
return ((RxBuffer<TIn, TOut>) buf).TransformAdd(itm);
}
public void Finalize()
{
_inputSubject.OnCompleted();
}
public static void Finalize(IBuffer<TOut> buf)
{
((RxBuffer<TIn, TOut>)buf).Finalize();
}
}
}

View File

@ -7,7 +7,7 @@ using System.Threading.Tasks;
namespace Wabbajack.Common.CSP
{
public class TakeTaskHandler<T> : Handler<Action<T>>
public class TakeTaskHandler<T> : Handler<Action<bool, T>>
{
private readonly bool _blockable;
private TaskCompletionSource<(bool, T)> _tcs;
@ -32,14 +32,14 @@ namespace Wabbajack.Common.CSP
public bool IsBlockable => _blockable;
public uint LockId => 0;
public Task<(bool, T)> Task => TaskCompletionSource.Task;
public Action<T> Commit()
public Action<bool, T> Commit()
{
return Handle;
}
private void Handle(T a)
private void Handle(bool is_open, T a)
{
TaskCompletionSource.SetResult((true, a));
TaskCompletionSource.SetResult((is_open, a));
}
}
}

View File

@ -101,6 +101,7 @@
<Compile Include="CSP\ManyToManyChannel.cs" />
<Compile Include="CSP\PutTaskHandler.cs" />
<Compile Include="CSP\RingBuffer.cs" />
<Compile Include="CSP\RxBuffer.cs" />
<Compile Include="CSP\TakeTaskHandler.cs" />
<Compile Include="DynamicIniData.cs" />
<Compile Include="Error States\ErrorResponse.cs" />

View File

@ -1,5 +1,6 @@
using System;
using System.Linq;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Wabbajack.Common.CSP;
@ -9,6 +10,11 @@ namespace Wabbajack.Test
[TestClass]
public class CSPTests
{
/// <summary>
/// Test that we can put a value onto a channel without a buffer, and that the put is released once the
/// take finalizes
/// </summary>
/// <returns></returns>
[TestMethod]
public async Task TestTakePutBlocking()
{
@ -21,6 +27,11 @@ namespace Wabbajack.Test
Assert.IsTrue(await ptask);
}
/// <summary>
/// If we create a channel with a fixed buffer size, we can enqueue that number of items without blocking
/// We can then take those items later on.
/// </summary>
/// <returns></returns>
[TestMethod]
public async Task TestTakePutBuffered()
{
@ -36,6 +47,10 @@ namespace Wabbajack.Test
}
}
/// <summary>
/// We can convert a IEnumerable into a channel by inlining the enumerable into the channel's buffer.
/// </summary>
/// <returns></returns>
[TestMethod]
public async Task TestToChannel()
{
@ -49,6 +64,12 @@ namespace Wabbajack.Test
}
}
/// <summary>
/// TakeAll will continue to take from a channel as long as the channel is open. Once the channel closes
/// TakeAll returns a list of the items taken.
/// </summary>
/// <returns></returns>
[TestMethod]
public async Task TestTakeAll()
{
@ -56,5 +77,41 @@ namespace Wabbajack.Test
CollectionAssert.AreEqual(Enumerable.Range(0, 10).ToList(), results);
}
/// <summary>
/// We can add Rx transforms as transforms inside a channel. This allows for cheap conversion and calcuation
/// to be performed in a channel without incuring the dispatch overhead of swapping values between threads.
/// These calculations happen inside the channel's lock, however, so be sure to keep these operations relatively
/// cheap.
/// </summary>
/// <returns></returns>
[TestMethod]
public async Task RxTransformInChannel()
{
var chan = Channel.Create<int, int>(1, o => o.Select(v => v + 1));
var finished = Enumerable.Range(0, 10).OntoChannel(chan);
foreach (var itm in Enumerable.Range(0, 10))
{
var (is_open, val) = await chan.Take();
Assert.AreEqual(itm + 1, val);
Assert.IsTrue(is_open);
}
await finished;
}
[TestMethod]
public async Task UnorderedPipeline()
{
var o = Channel.Create<string>(3);
var finished = Enumerable.Range(0, 3)
.ToChannel()
.UnorderedPipeline(1, o, obs => obs.Select(itm => itm.ToString()));
var results = (await o.TakeAll()).OrderBy(e => e).ToList();
var expected = Enumerable.Range(0, 3).Select(i => i.ToString()).ToList();
CollectionAssert.AreEqual(expected, results);
}
}
}