Add a bunch of unit tests fixed a nasty race condition in the put/take handlers

This commit is contained in:
Timothy Baldridge 2019-11-10 15:15:52 -07:00
parent e9c2ababec
commit f66427c2ea
8 changed files with 301 additions and 18 deletions

View File

@ -29,7 +29,7 @@ namespace Wabbajack.Common.CSP
private Func<IBuffer<TOut>, TIn, bool> _add;
private Action<IBuffer<TOut>> _finalize;
private Func<TIn, TOut> _converter;
bool _isClosed = false;
volatile bool _isClosed = false;
public ManyToManyChannel(Func<TIn, TOut> converter)
{
@ -323,7 +323,7 @@ namespace Wabbajack.Common.CSP
var put_cb = LockIfActiveCommit(handler);
if (put_cb != null)
{
Task.Run(() => put_cb(true));
Task.Run(() => put_cb(false));
}
}
_puts.Cleanup(x => false);
@ -347,12 +347,6 @@ namespace Wabbajack.Common.CSP
return ret;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static T IfActiveCommit<T>(Handler<T> handler)
{
return handler.IsActive ? handler.Commit() : default;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static T LockIfActiveCommit<T>(Handler<T> handler)
{

View File

@ -9,7 +9,7 @@ namespace Wabbajack.Common.CSP
class PutTaskHandler<T> : Handler<Action<bool>>
{
private readonly bool _blockable;
private TaskCompletionSource<bool> _tcs;
private TaskCompletionSource<bool> _tcs = new TaskCompletionSource<bool>();
public PutTaskHandler(bool blockable = true)
{

View File

@ -39,8 +39,7 @@ namespace Wabbajack.Common.CSP
public T Peek()
{
if (_length == 0) return default;
return _arr[_tail];
return _length == 0 ? default : _arr[_tail];
}
public void Unshift(T x)

View File

@ -8,17 +8,19 @@ using System.Threading.Tasks;
namespace Wabbajack.Common.CSP
{
public class RxBuffer<TIn, TOut> : FixedSizeBuffer<TOut>
public class RxBuffer<TIn, TOut> : LinkedList<TOut>, IBuffer<TOut>
{
private Subject<TIn> _inputSubject;
private IObservable<TOut> _outputObservable;
private bool _completed;
private int _maxSize;
public RxBuffer(int size, Func<IObservable<TIn>, IObservable<TOut>> transform) : base(size)
public RxBuffer(int size, Func<IObservable<TIn>, IObservable<TOut>> transform) : base()
{
_maxSize = size;
_inputSubject = new Subject<TIn>();
_outputObservable = transform(_inputSubject);
_outputObservable.Subscribe(itm => base.Add(itm), () => {
_outputObservable.Subscribe(itm => AddFirst(itm), () => {
_completed = true;
});
}
@ -39,9 +41,28 @@ namespace Wabbajack.Common.CSP
_inputSubject.OnCompleted();
}
public void Dispose()
{
throw new NotImplementedException();
}
public static void Finalize(IBuffer<TOut> buf)
{
((RxBuffer<TIn, TOut>)buf).Finalize();
}
public bool IsFull => Count >= _maxSize;
public bool IsEmpty => Count == 0;
public TOut Remove()
{
var ret = Last.Value;
RemoveLast();
return ret;
}
public void Add(TOut itm)
{
}
}
}

View File

@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace Wabbajack.Common.CSP
@ -22,7 +23,11 @@ namespace Wabbajack.Common.CSP
get
{
if (_tcs == null)
_tcs = new TaskCompletionSource<(bool, T)>();
{
var new_tcs = new TaskCompletionSource<(bool, T)>();
Interlocked.CompareExchange(ref _tcs, new_tcs, null);
}
return _tcs;
}
}

View File

@ -2,14 +2,29 @@
using System.Linq;
using System.Reactive.Linq;
using System.Threading.Tasks;
using Alphaleonis.Win32.Filesystem;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Wabbajack.Common.CSP;
namespace Wabbajack.Test
namespace Wabbajack.Test.CSP
{
[TestClass]
public class CSPTests
{
public TestContext TestContext { get; set; }
public void Log(string msg)
{
TestContext.WriteLine(msg);
}
[TestInitialize]
public void Startup()
{
}
/// <summary>
/// Test that we can put a value onto a channel without a buffer, and that the put is released once the
/// take finalizes
@ -137,7 +152,9 @@ namespace Wabbajack.Test
var putter = Task.Run(async () =>
{
for (var i = 0; i < 1000; i++)
await chan.Put(i);
{
var result = await chan.Put(i);
}
});
var taker = Task.Run(async () =>
@ -168,7 +185,9 @@ namespace Wabbajack.Test
var putter = Task.Run(async () =>
{
for (var i = 0; i < 1000; i++)
{
await chan.Put(i);
}
});
var taker = Task.Run(async () =>

View File

@ -0,0 +1,243 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reactive.Linq;
using System.Security.Policy;
using System.Text;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Wabbajack.Common.CSP;
namespace Wabbajack.Test.CSP
{
[TestClass]
public class ChannelTests
{
[TestMethod]
public async Task PutThenTakeNoBuffer()
{
var chan = Channel.Create<int>();
var putter = chan.Put(42);
var taker = chan.Take();
Assert.IsTrue(await putter);
Assert.AreEqual((true, 42), await taker);
}
[TestMethod]
public async Task TakeThenPushNoBuffer()
{
var chan = Channel.Create<int>();
var taker = chan.Take();
var putter = chan.Put(42);
Assert.IsTrue(await putter);
Assert.AreEqual((true, 42), await taker);
}
[TestMethod]
public async Task TakeFromBufferAfterPut()
{
var chan = Channel.Create<int>(1);
var putter = chan.Put(42);
var taker = chan.Take();
Assert.IsTrue(await putter);
Assert.AreEqual((true, 42), await taker);
}
[TestMethod]
public async Task TakeFromBufferBeforePut()
{
var chan = Channel.Create<int>(1);
var taker = chan.Take();
var putter = chan.Put(42);
Assert.IsTrue(await putter);
Assert.AreEqual((true, 42), await taker);
}
[TestMethod]
public async Task TakesAreReleasedAfterClose()
{
var chan = Channel.Create<int>();
var taker = chan.Take();
chan.Close();
Assert.AreEqual((false, 0), await taker);
}
[TestMethod]
public async Task ExpandingTransformsReleaseMultipleTakes()
{
var chan = Channel.Create<int, int>(1, i => i.SelectMany(len => Enumerable.Range(0, len)));
var take1 = chan.Take();
var take2 = chan.Take();
await chan.Put(2);
Assert.AreEqual((true, 0), await take1);
Assert.AreEqual((true, 1), await take2);
}
[TestMethod]
public async Task TransformsCanCloseChannel()
{
var chan = Channel.Create<int, int>(1, i => i.Take(1));
var take1 = chan.Take();
var take2 = chan.Take();
await chan.Put(1);
await chan.Put(2);
Assert.IsTrue(chan.IsClosed);
Assert.AreEqual((true, 1), await take1);
Assert.AreEqual((false, 0), await take2);
}
[TestMethod]
public async Task TransformsCanCloseDuringExpand()
{
var chan = Channel.Create<int, int>(1, i => i.SelectMany(len => Enumerable.Range(1, len)).Take(1));
var take1 = chan.Take();
var take2 = chan.Take();
await chan.Put(2);
Assert.IsTrue(chan.IsClosed);
Assert.AreEqual((true, 1), await take1);
Assert.AreEqual((false, 0), await take2);
}
[TestMethod]
public async Task TransformsCanFilterTakeFirst()
{
var chan = Channel.Create<int, int>(1, i => i.Where(x => x == 2));
var take1 = chan.Take();
var take2 = chan.Take();
await chan.Put(1);
await chan.Put(2);
chan.Close();
Assert.IsTrue(chan.IsClosed);
Assert.AreEqual((true, 2), await take1);
Assert.AreEqual((false, 0), await take2);
}
[TestMethod]
public async Task TransformsCanReturnNothingTakeFirst()
{
var chan = Channel.Create<int, int>(1, i => i.Take(0));
var take1 = chan.Take();
var take2 = chan.Take();
await chan.Put(1);
Assert.IsTrue(chan.IsClosed);
Assert.AreEqual((false, 0), await take1);
Assert.AreEqual((false, 0), await take2);
}
[TestMethod]
public async Task TransformsCanFilterTakeAfter()
{
var chan = Channel.Create<int, int>(1, i => i.Where(x => x == 2));
await chan.Put(1);
await chan.Put(2);
var take1 = chan.Take();
var take2 = chan.Take();
chan.Close();
Assert.IsTrue(chan.IsClosed);
Assert.AreEqual((true, 2), await take1);
Assert.AreEqual((false, 0), await take2);
}
[TestMethod]
public async Task TransformsCanReturnNothingTakeAfter()
{
var chan = Channel.Create<int, int>(1, i => i.Take(0));
await chan.Put(1);
var take1 = chan.Take();
var take2 = chan.Take();
Assert.IsTrue(chan.IsClosed);
Assert.AreEqual((false, 0), await take1);
Assert.AreEqual((false, 0), await take2);
}
[TestMethod]
public void TooManyTakesCausesException()
{
var chan = Channel.Create<int>();
Assert.ThrowsException<ManyToManyChannel<int, int>.TooManyHanldersException>(() =>
{
for (var x = 0; x < ManyToManyChannel<int, int>.MAX_QUEUE_SIZE + 1; x++)
chan.Take();
});
}
[TestMethod]
public void TooManyPutsCausesException()
{
var chan = Channel.Create<int>();
Assert.ThrowsException<ManyToManyChannel<int, int>.TooManyHanldersException>(() =>
{
for (var x = 0; x < ManyToManyChannel<int, int>.MAX_QUEUE_SIZE + 1; x++)
chan.Put(x);
});
}
[TestMethod]
public async Task BlockingPutsGoThroughTransform()
{
var chan = Channel.Create<int, int>(1, i => i.Take(2));
var put1 = chan.Put(1);
var put2 = chan.Put(2);
var put3 = chan.Put(3);
var put4 = chan.Put(4);
var take1 = chan.Take();
var take2 = chan.Take();
var take3 = chan.Take();
Assert.AreEqual((true, 1), await take1);
Assert.AreEqual((true, 2), await take2);
Assert.AreEqual((false, 0), await take3);
Assert.IsTrue(await put1);
Assert.IsTrue(await put2);
Assert.IsFalse(await put3);
Assert.IsFalse(await put4);
Assert.IsTrue(chan.IsClosed);
}
}
}

View File

@ -93,7 +93,8 @@
</ItemGroup>
<ItemGroup>
<Compile Include="ACompilerTest.cs" />
<Compile Include="CSPTests.cs" />
<Compile Include="CSP\ChannelTests.cs" />
<Compile Include="CSP\CSPTests.cs" />
<Compile Include="DownloaderTests.cs" />
<Compile Include="EndToEndTests.cs" />
<Compile Include="Extensions.cs" />
@ -147,6 +148,7 @@
<Version>4.2.0</Version>
</PackageReference>
</ItemGroup>
<ItemGroup />
<Import Project="$(VSToolsPath)\TeamTest\Microsoft.TestTools.targets" Condition="Exists('$(VSToolsPath)\TeamTest\Microsoft.TestTools.targets')" />
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
</Project>