Skip to content

Commit 70a90e5

Browse files
committed
TensorflowNET.UnitTest: Added MultiThreadedUnitTestExecuter
1 parent a039e64 commit 70a90e5

1 file changed

Lines changed: 127 additions & 0 deletions

File tree

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
using System;
2+
using System.Threading;
3+
4+
namespace TensorFlowNET.UnitTest
5+
{
6+
public delegate void MultiThreadedTestDelegate(int threadid);
7+
8+
/// <summary>
9+
/// Creates a synchronized eco-system of running code.
10+
/// </summary>
11+
public class MultiThreadedUnitTestExecuter : IDisposable
12+
{
13+
public int ThreadCount { get; }
14+
public Thread[] Threads { get; }
15+
private readonly SemaphoreSlim barrier_threadstarted;
16+
private readonly ManualResetEventSlim barrier_corestart;
17+
private readonly SemaphoreSlim done_barrier2;
18+
19+
public Action<MultiThreadedUnitTestExecuter> PostRun { get; set; }
20+
21+
#region Static
22+
23+
public static void Run(int threadCount, MultiThreadedTestDelegate workload)
24+
{
25+
if (workload == null) throw new ArgumentNullException(nameof(workload));
26+
if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
27+
new MultiThreadedUnitTestExecuter(threadCount).Run(workload);
28+
}
29+
30+
public static void Run(int threadCount, params MultiThreadedTestDelegate[] workloads)
31+
{
32+
if (workloads == null) throw new ArgumentNullException(nameof(workloads));
33+
if (workloads.Length == 0) throw new ArgumentException("Value cannot be an empty collection.", nameof(workloads));
34+
if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
35+
new MultiThreadedUnitTestExecuter(threadCount).Run(workloads);
36+
}
37+
38+
public static void Run(int threadCount, MultiThreadedTestDelegate workload, Action<MultiThreadedUnitTestExecuter> postRun)
39+
{
40+
if (workload == null) throw new ArgumentNullException(nameof(workload));
41+
if (postRun == null) throw new ArgumentNullException(nameof(postRun));
42+
if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
43+
new MultiThreadedUnitTestExecuter(threadCount) {PostRun = postRun}.Run(workload);
44+
}
45+
46+
#endregion
47+
48+
49+
/// <summary>Initializes a new instance of the <see cref="T:System.Object"></see> class.</summary>
50+
public MultiThreadedUnitTestExecuter(int threadCount)
51+
{
52+
if (threadCount <= 0)
53+
throw new ArgumentOutOfRangeException(nameof(threadCount));
54+
ThreadCount = threadCount;
55+
Threads = new Thread[ThreadCount];
56+
done_barrier2 = new SemaphoreSlim(0, threadCount);
57+
barrier_corestart = new ManualResetEventSlim();
58+
barrier_threadstarted = new SemaphoreSlim(0, threadCount);
59+
}
60+
61+
public void Run(params MultiThreadedTestDelegate[] workloads)
62+
{
63+
if (workloads == null)
64+
throw new ArgumentNullException(nameof(workloads));
65+
if (workloads.Length != 1 && workloads.Length % ThreadCount != 0)
66+
throw new InvalidOperationException($"Run method must accept either 1 workload or n-threads workloads. Got {workloads.Length} workloads.");
67+
68+
if (ThreadCount == 1)
69+
{
70+
workloads[0](0);
71+
return;
72+
}
73+
74+
//thread core
75+
void ThreadCore(MultiThreadedTestDelegate core, int threadid)
76+
{
77+
barrier_threadstarted.Release(1);
78+
barrier_corestart.Wait();
79+
80+
//workload
81+
core(threadid);
82+
83+
done_barrier2.Release(1);
84+
}
85+
86+
//initialize all threads
87+
if (workloads.Length == 1)
88+
{
89+
var workload = workloads[0];
90+
for (int i = 0; i < ThreadCount; i++)
91+
{
92+
var i_local = i;
93+
Threads[i] = new Thread(() => ThreadCore(workload, i_local));
94+
}
95+
} else
96+
{
97+
for (int i = 0; i < ThreadCount; i++)
98+
{
99+
var i_local = i;
100+
var workload = workloads[i_local % workloads.Length];
101+
Threads[i] = new Thread(() => ThreadCore(workload, i_local));
102+
}
103+
}
104+
105+
//run all threads
106+
for (int i = 0; i < ThreadCount; i++) Threads[i].Start();
107+
//wait for threads to be started and ready
108+
for (int i = 0; i < ThreadCount; i++) barrier_threadstarted.Wait();
109+
110+
//signal threads to start
111+
barrier_corestart.Set();
112+
113+
//wait for threads to finish
114+
for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait();
115+
116+
//checks after ended
117+
PostRun?.Invoke(this);
118+
}
119+
120+
public void Dispose()
121+
{
122+
barrier_threadstarted.Dispose();
123+
barrier_corestart.Dispose();
124+
done_barrier2.Dispose();
125+
}
126+
}
127+
}

0 commit comments

Comments
 (0)