Skip to content

Commit e26f206

Browse files
committed
FEAT: Adding setSeed and getSeed for all backends
1 parent 4dfbf29 commit e26f206

File tree

11 files changed

+157
-18
lines changed

11 files changed

+157
-18
lines changed

include/af/data.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,31 @@ namespace af
107107
@}
108108
*/
109109

110+
/**
111+
\defgroup data_func_setseed setSeed
112+
Set the seed for the random number generator
113+
114+
115+
\param[in] seed is a 64 bit unsigned integer
116+
117+
\ingroup data_mat
118+
\ingroup arrayfire_func
119+
*/
120+
AFAPI void setSeed(const uintl seed);
121+
122+
/**
123+
\defgroup data_func_getseed getSeed
124+
Get the seed for the random number generator
125+
126+
127+
\return seed which is a 64 bit unsigned integer
128+
129+
\ingroup data_mat
130+
\ingroup arrayfire_func
131+
*/
132+
AFAPI uintl getSeed();
133+
134+
110135
/**
111136
\defgroup data_func_identity identity
112137
Create an identity array
@@ -413,6 +438,31 @@ extern "C" {
413438
@}
414439
*/
415440

441+
/**
442+
\defgroup data_func_setseed setSeed
443+
Set the seed for the random number generator
444+
445+
446+
\param[in] seed is a 64 bit unsigned integer
447+
448+
\ingroup data_mat
449+
\ingroup arrayfire_func
450+
*/
451+
AFAPI af_err af_set_seed(const uintl seed);
452+
453+
/**
454+
\defgroup data_func_getseed getSeed
455+
Get the seed for the random number generator
456+
457+
458+
\param[out] seed which is a 64 bit unsigned integer
459+
460+
\ingroup data_mat
461+
\ingroup arrayfire_func
462+
*/
463+
AFAPI af_err af_get_seed(uintl *seed);
464+
465+
416466
/**
417467
\defgroup data_func_identity identity
418468
Create an identity array

src/api/c/data.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,22 @@ af_err af_randn(af_array *out, const unsigned ndims, const dim_type * const dims
310310
return AF_SUCCESS;
311311
}
312312

313+
af_err af_set_seed(const uintl seed)
314+
{
315+
try {
316+
setSeed(seed);
317+
} CATCHALL;
318+
return AF_SUCCESS;
319+
}
320+
321+
af_err af_get_seed(uintl *seed)
322+
{
323+
try {
324+
*seed = getSeed();
325+
} CATCHALL;
326+
return AF_SUCCESS;
327+
}
328+
313329
af_err af_identity(af_array *out, const unsigned ndims, const dim_type * const dims, const af_dtype type)
314330
{
315331
try {
@@ -656,4 +672,3 @@ af_err af_write_array(af_array arr, const void *data, const size_t bytes, af_sou
656672
CATCHALL
657673
return AF_SUCCESS;
658674
}
659-

src/api/cpp/data.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,18 @@ namespace af
174174
return randn(dim4(d0, d1, d2, d3), ty);
175175
}
176176

177+
void setSeed(const uintl seed)
178+
{
179+
AF_THROW(af_set_seed(seed));
180+
}
181+
182+
uintl getSeed()
183+
{
184+
uintl seed = 0;
185+
AF_THROW(af_get_seed(&seed));
186+
return seed;
187+
}
188+
177189
array range(const dim4 &dims, const int seq_dim, af::dtype ty)
178190
{
179191
af_array out;

src/backend/cpu/random.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,15 @@ nrand(GenType &generator)
6969
}
7070

7171
static default_random_engine generator;
72+
static unsigned long long gen_seed = 0;
73+
static bool is_first = true;
7274

7375
template<typename T>
7476
Array<T> randn(const af::dim4 &dims)
7577
{
76-
static auto gen = nrand<T>(generator);
78+
if (is_first) setSeed(gen_seed);
79+
80+
auto gen = nrand<T>(generator);
7781

7882
Array<T> outArray = createEmptyArray<T>(dims);
7983
T *outPtr = outArray.get();
@@ -84,7 +88,9 @@ Array<T> randn(const af::dim4 &dims)
8488
template<typename T>
8589
Array<T> randu(const af::dim4 &dims)
8690
{
87-
static auto gen = urand<T>(generator);
91+
if (is_first) setSeed(gen_seed);
92+
93+
auto gen = urand<T>(generator);
8894

8995
Array<T> outArray = createEmptyArray<T>(dims);
9096
T *outPtr = outArray.get();
@@ -117,7 +123,7 @@ INSTANTIATE_NORMAL(cdouble)
117123
template<>
118124
Array<char> randu(const af::dim4 &dims)
119125
{
120-
static auto gen = urand<float>(generator);
126+
auto gen = urand<float>(generator);
121127

122128
Array<char> outArray = createEmptyArray<char>(dims);
123129
char *outPtr = outArray.get();
@@ -127,4 +133,16 @@ Array<char> randu(const af::dim4 &dims)
127133
return outArray;
128134
}
129135

136+
void setSeed(const uintl seed)
137+
{
138+
generator.seed(seed);
139+
is_first = false;
140+
gen_seed = seed;
141+
}
142+
143+
uintl getSeed()
144+
{
145+
return gen_seed;
146+
}
147+
130148
}

src/backend/cpu/random.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,7 @@ namespace cpu
1717

1818
template<typename T>
1919
Array<T> randn(const af::dim4 &dims);
20+
21+
void setSeed(const uintl seed);
22+
uintl getSeed();
2023
}

src/backend/cuda/kernel/random.hpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ namespace kernel
2020

2121
static const int THREADS = 256;
2222
static const int BLOCKS = 64;
23-
static unsigned long long uniform_seed = 0;
24-
static unsigned long long normal_seed = 0;
23+
static unsigned long long seed = 0;
2524
static curandState_t *states[DeviceManager::MAX_DEVICES];
25+
static bool is_first = true;
2626

2727
template<typename T>
2828
__device__
@@ -128,6 +128,19 @@ namespace kernel
128128
states[id] = state;
129129
}
130130

131+
void setup_states()
132+
{
133+
int device = getActiveDeviceId();
134+
135+
if (is_first) {
136+
CUDA_CHECK(cudaMalloc(&states[device], BLOCKS * THREADS * sizeof(curandState_t)));
137+
is_first = false;
138+
}
139+
140+
setup_kernel<<<BLOCKS, THREADS>>>(states[device], seed);
141+
POST_LAUNCH_CHECK();
142+
}
143+
131144
template<typename T>
132145
void randu(T *out, size_t elements)
133146
{
@@ -136,17 +149,7 @@ namespace kernel
136149
int threads = THREADS;
137150
int blocks = divup(elements, THREADS);
138151
if (blocks > BLOCKS) blocks = BLOCKS;
139-
140-
if (!states[device]) {
141-
CUDA_CHECK(cudaMalloc(&states[device], BLOCKS * THREADS * sizeof(curandState_t)));
142-
143-
setup_kernel<<<BLOCKS, THREADS>>>(states[device], uniform_seed);
144-
145-
POST_LAUNCH_CHECK();
146-
}
147-
148152
uniform_kernel<<<blocks, threads>>>(out, states[device], elements);
149-
150153
POST_LAUNCH_CHECK();
151154
}
152155

@@ -162,7 +165,7 @@ namespace kernel
162165
if (!states[device]) {
163166
CUDA_CHECK(cudaMalloc(&states[device], BLOCKS * THREADS * sizeof(curandState_t)));
164167

165-
setup_kernel<<<BLOCKS, THREADS>>>(states[device], uniform_seed);
168+
setup_kernel<<<BLOCKS, THREADS>>>(states[device], seed);
166169

167170
POST_LAUNCH_CHECK();
168171
}

src/backend/cuda/random.cu

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace cuda
1919
template<typename T>
2020
Array<T> randu(const af::dim4 &dims)
2121
{
22+
if (kernel::is_first) kernel::setup_states();
2223
Array<T> out = createEmptyArray<T>(dims);
2324
kernel::randu(out.get(), out.elements());
2425
return out;
@@ -27,6 +28,7 @@ namespace cuda
2728
template<typename T>
2829
Array<T> randn(const af::dim4 &dims)
2930
{
31+
if (kernel::is_first) kernel::setup_states();
3032
Array<T> out = createEmptyArray<T>(dims);
3133
kernel::randn(out.get(), out.elements());
3234
return out;
@@ -46,4 +48,17 @@ namespace cuda
4648
template Array<cfloat> randn<cfloat> (const af::dim4 &dims);
4749
template Array<cdouble> randn<cdouble> (const af::dim4 &dims);
4850

51+
52+
void setSeed(const uintl seed)
53+
{
54+
kernel::seed = seed;
55+
kernel::setup_states();
56+
}
57+
58+
uintl getSeed()
59+
{
60+
return kernel::seed;
61+
}
62+
63+
4964
}

src/backend/cuda/random.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,7 @@ namespace cuda
1717

1818
template<typename T>
1919
Array<T> randn(const af::dim4 &dims);
20+
21+
void setSeed(const uintl seed);
22+
uintl getSeed();
2023
}

src/backend/opencl/kernel/random.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ namespace opencl
3737
{
3838
static const uint REPEAT = 32;
3939
static const uint THREADS = 256;
40-
static uint random_seed[2];
40+
41+
static uint random_seed[2] = {0, 0};
4142
static unsigned counter;
4243

4344
template<typename T, bool isRandu>

src/backend/opencl/random.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,20 @@ namespace opencl
7070
COMPLEX_RANDOM(randn, cfloat, float, false)
7171
COMPLEX_RANDOM(randn, cdouble, double, false)
7272

73+
74+
void setSeed(const uintl seed)
75+
{
76+
uintl hi = (seed & 0xffffffff00000000) >> 32;
77+
uintl lo = (seed & 0x00000000ffffffff);
78+
kernel::random_seed[0] = (unsigned)hi;
79+
kernel::random_seed[1] = (unsigned)lo;
80+
kernel::counter = 0;
81+
}
82+
83+
uintl getSeed()
84+
{
85+
uintl hi = kernel::random_seed[0];
86+
uintl lo = kernel::random_seed[1];
87+
return hi << 32 | lo;
88+
}
7389
}

0 commit comments

Comments
 (0)