@@ -20,9 +20,9 @@ namespace kernel
2020
2121 static const int THREADS = 256 ;
2222 static const int BLOCKS = 64 ;
23- static unsigned long long uniform_seed = 0 ;
24- static unsigned long long normal_seed = 0 ;
23+ static unsigned long long seed = 0 ;
2524 static curandState_t *states[DeviceManager::MAX_DEVICES];
25+ static bool is_first = true ;
2626
2727 template <typename T>
2828 __device__
@@ -128,6 +128,19 @@ namespace kernel
128128 states[id] = state;
129129 }
130130
131+ void setup_states ()
132+ {
133+ int device = getActiveDeviceId ();
134+
135+ if (is_first) {
136+ CUDA_CHECK (cudaMalloc (&states[device], BLOCKS * THREADS * sizeof (curandState_t)));
137+ is_first = false ;
138+ }
139+
140+ setup_kernel<<<BLOCKS, THREADS>>>(states[device], seed);
141+ POST_LAUNCH_CHECK ();
142+ }
143+
131144 template <typename T>
132145 void randu (T *out, size_t elements)
133146 {
@@ -136,17 +149,7 @@ namespace kernel
136149 int threads = THREADS;
137150 int blocks = divup (elements, THREADS);
138151 if (blocks > BLOCKS) blocks = BLOCKS;
139-
140- if (!states[device]) {
141- CUDA_CHECK (cudaMalloc (&states[device], BLOCKS * THREADS * sizeof (curandState_t)));
142-
143- setup_kernel<<<BLOCKS, THREADS>>>(states[device], uniform_seed);
144-
145- POST_LAUNCH_CHECK ();
146- }
147-
148152 uniform_kernel<<<blocks, threads>>>(out, states[device], elements);
149-
150153 POST_LAUNCH_CHECK ();
151154 }
152155
@@ -162,7 +165,7 @@ namespace kernel
162165 if (!states[device]) {
163166 CUDA_CHECK (cudaMalloc (&states[device], BLOCKS * THREADS * sizeof (curandState_t)));
164167
165- setup_kernel<<<BLOCKS, THREADS>>>(states[device], uniform_seed );
168+ setup_kernel<<<BLOCKS, THREADS>>>(states[device], seed );
166169
167170 POST_LAUNCH_CHECK ();
168171 }
0 commit comments