1+ using System ;
2+ using System . Threading ;
3+
4+ namespace TensorFlowNET . UnitTest
5+ {
6+ public delegate void MultiThreadedTestDelegate ( int threadid ) ;
7+
8+ /// <summary>
9+ /// Creates a synchronized eco-system of running code.
10+ /// </summary>
11+ public class MultiThreadedUnitTestExecuter : IDisposable
12+ {
13+ public int ThreadCount { get ; }
14+ public Thread [ ] Threads { get ; }
15+ private readonly SemaphoreSlim barrier_threadstarted ;
16+ private readonly ManualResetEventSlim barrier_corestart ;
17+ private readonly SemaphoreSlim done_barrier2 ;
18+
19+ public Action < MultiThreadedUnitTestExecuter > PostRun { get ; set ; }
20+
21+ #region Static
22+
23+ public static void Run ( int threadCount , MultiThreadedTestDelegate workload )
24+ {
25+ if ( workload == null ) throw new ArgumentNullException ( nameof ( workload ) ) ;
26+ if ( threadCount <= 0 ) throw new ArgumentOutOfRangeException ( nameof ( threadCount ) ) ;
27+ new MultiThreadedUnitTestExecuter ( threadCount ) . Run ( workload ) ;
28+ }
29+
30+ public static void Run ( int threadCount , params MultiThreadedTestDelegate [ ] workloads )
31+ {
32+ if ( workloads == null ) throw new ArgumentNullException ( nameof ( workloads ) ) ;
33+ if ( workloads . Length == 0 ) throw new ArgumentException ( "Value cannot be an empty collection." , nameof ( workloads ) ) ;
34+ if ( threadCount <= 0 ) throw new ArgumentOutOfRangeException ( nameof ( threadCount ) ) ;
35+ new MultiThreadedUnitTestExecuter ( threadCount ) . Run ( workloads ) ;
36+ }
37+
38+ public static void Run ( int threadCount , MultiThreadedTestDelegate workload , Action < MultiThreadedUnitTestExecuter > postRun )
39+ {
40+ if ( workload == null ) throw new ArgumentNullException ( nameof ( workload ) ) ;
41+ if ( postRun == null ) throw new ArgumentNullException ( nameof ( postRun ) ) ;
42+ if ( threadCount <= 0 ) throw new ArgumentOutOfRangeException ( nameof ( threadCount ) ) ;
43+ new MultiThreadedUnitTestExecuter ( threadCount ) { PostRun = postRun } . Run ( workload ) ;
44+ }
45+
46+ #endregion
47+
48+
49+ /// <summary>Initializes a new instance of the <see cref="T:System.Object"></see> class.</summary>
50+ public MultiThreadedUnitTestExecuter ( int threadCount )
51+ {
52+ if ( threadCount <= 0 )
53+ throw new ArgumentOutOfRangeException ( nameof ( threadCount ) ) ;
54+ ThreadCount = threadCount ;
55+ Threads = new Thread [ ThreadCount ] ;
56+ done_barrier2 = new SemaphoreSlim ( 0 , threadCount ) ;
57+ barrier_corestart = new ManualResetEventSlim ( ) ;
58+ barrier_threadstarted = new SemaphoreSlim ( 0 , threadCount ) ;
59+ }
60+
61+ public void Run ( params MultiThreadedTestDelegate [ ] workloads )
62+ {
63+ if ( workloads == null )
64+ throw new ArgumentNullException ( nameof ( workloads ) ) ;
65+ if ( workloads . Length != 1 && workloads . Length % ThreadCount != 0 )
66+ throw new InvalidOperationException ( $ "Run method must accept either 1 workload or n-threads workloads. Got { workloads . Length } workloads.") ;
67+
68+ if ( ThreadCount == 1 )
69+ {
70+ workloads [ 0 ] ( 0 ) ;
71+ return ;
72+ }
73+
74+ //thread core
75+ void ThreadCore ( MultiThreadedTestDelegate core , int threadid )
76+ {
77+ barrier_threadstarted . Release ( 1 ) ;
78+ barrier_corestart . Wait ( ) ;
79+
80+ //workload
81+ core ( threadid ) ;
82+
83+ done_barrier2 . Release ( 1 ) ;
84+ }
85+
86+ //initialize all threads
87+ if ( workloads . Length == 1 )
88+ {
89+ var workload = workloads [ 0 ] ;
90+ for ( int i = 0 ; i < ThreadCount ; i ++ )
91+ {
92+ var i_local = i ;
93+ Threads [ i ] = new Thread ( ( ) => ThreadCore ( workload , i_local ) ) ;
94+ }
95+ } else
96+ {
97+ for ( int i = 0 ; i < ThreadCount ; i ++ )
98+ {
99+ var i_local = i ;
100+ var workload = workloads [ i_local % workloads . Length ] ;
101+ Threads [ i ] = new Thread ( ( ) => ThreadCore ( workload , i_local ) ) ;
102+ }
103+ }
104+
105+ //run all threads
106+ for ( int i = 0 ; i < ThreadCount ; i ++ ) Threads [ i ] . Start ( ) ;
107+ //wait for threads to be started and ready
108+ for ( int i = 0 ; i < ThreadCount ; i ++ ) barrier_threadstarted . Wait ( ) ;
109+
110+ //signal threads to start
111+ barrier_corestart . Set ( ) ;
112+
113+ //wait for threads to finish
114+ for ( int i = 0 ; i < ThreadCount ; i ++ ) done_barrier2 . Wait ( ) ;
115+
116+ //checks after ended
117+ PostRun ? . Invoke ( this ) ;
118+ }
119+
120+ public void Dispose ( )
121+ {
122+ barrier_threadstarted . Dispose ( ) ;
123+ barrier_corestart . Dispose ( ) ;
124+ done_barrier2 . Dispose ( ) ;
125+ }
126+ }
127+ }
0 commit comments