Skip to content

Commit dffa8c5

Browse files
umar456pradeep
authored andcommitted
Further improve the box muller function by fixing rounding issues
* Fix rounding of some operations by using fused operations like sincospi and fma instead of a multiply add. * Convert half constants to hex values and use __ushort_as_half to avoid redundant conversions from float * Pass integers instead of pointers in the OpenCL backend of the rng functions
1 parent e206e0b commit dffa8c5

8 files changed

Lines changed: 858 additions & 617 deletions

File tree

src/backend/cpu/kernel/random_engine.hpp

Lines changed: 109 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <types.hpp>
2020

2121
#include <algorithm>
22+
#include <cmath>
2223
#include <cstring>
2324

2425
using std::array;
@@ -31,89 +32,146 @@ static const double PI_VAL =
3132
3.1415926535897932384626433832795028841971693993751058209749445923078164;
3233

3334
// Conversion to half adapted from Random123
34-
#define USHORTMAX 0xffff
35-
#define HALF_FACTOR ((1.0f) / (USHORTMAX + (1.0f)))
35+
#define HALF_FACTOR ((1.0f) / (std::numeric_limits<ushort>::max() + (1.0f)))
3636
#define HALF_HALF_FACTOR ((0.5f) * HALF_FACTOR)
3737

38-
// Conversion to floats adapted from Random123
39-
#define UINTMAX 0xffffffff
40-
#define FLT_FACTOR ((1.0f) / (UINTMAX + (1.0f)))
41-
#define HALF_FLT_FACTOR ((0.5f) * FLT_FACTOR)
38+
// Conversion to half adapted from Random123
39+
#define SIGNED_HALF_FACTOR \
40+
((1.0f) / (std::numeric_limits<short>::max() + (1.0f)))
41+
#define SIGNED_HALF_HALF_FACTOR ((0.5f) * SIGNED_HALF_FACTOR)
4242

43-
#define UINTLMAX 0xffffffffffffffff
44-
#define DBL_FACTOR ((1.0) / (UINTLMAX + (1.0)))
43+
#define DBL_FACTOR \
44+
((1.0) / (std::numeric_limits<unsigned long long>::max() + (1.0)))
4545
#define HALF_DBL_FACTOR ((0.5) * DBL_FACTOR)
4646

47+
// Conversion to floats adapted from Random123
48+
#define SIGNED_DBL_FACTOR \
49+
((1.0) / (std::numeric_limits<long long>::max() + (1.0)))
50+
#define SIGNED_HALF_DBL_FACTOR ((0.5) * SIGNED_DBL_FACTOR)
51+
4752
template<typename T>
48-
T transform(uint *val, int index) {
49-
T *oval = (T *)val;
50-
return oval[index];
53+
T transform(uint *val, uint index);
54+
55+
template<>
56+
uintl transform<uintl>(uint *val, uint index) {
57+
uint index2 = index << 1;
58+
uintl v = ((static_cast<uintl>(val[index2]) << 32) |
59+
(static_cast<uintl>(val[index2 + 1])));
60+
return v;
61+
}
62+
63+
// Generates rationals in [0, 1)
64+
float getFloat01(uint *val, uint index) {
65+
// Conversion to floats adapted from Random123
66+
constexpr float factor =
67+
((1.0f) /
68+
(static_cast<float>(std::numeric_limits<unsigned int>::max()) +
69+
(1.0f)));
70+
constexpr float half_factor = ((0.5f) * factor);
71+
return fmaf(val[index], factor, half_factor);
72+
}
73+
74+
// Generates rationals in (-1, 1]
75+
static float getFloatNegative11(uint *val, uint index) {
76+
// Conversion to floats adapted from Random123
77+
constexpr float factor =
78+
((1.0) /
79+
(static_cast<double>(std::numeric_limits<int>::max()) + (1.0)));
80+
constexpr float half_factor = ((0.5f) * factor);
81+
82+
return fmaf(static_cast<float>(val[index]), factor, half_factor);
83+
}
84+
85+
// Generates rationals in [0, 1)
86+
common::half getHalf01(uint *val, uint index) {
87+
float v = val[index >> 1U] >> (16U * (index & 1U)) & 0x0000ffff;
88+
return static_cast<common::half>(fmaf(v, HALF_FACTOR, HALF_HALF_FACTOR));
89+
}
90+
91+
// Generates rationals in (-1, 1]
92+
static common::half getHalfNegative11(uint *val, uint index) {
93+
float v = val[index >> 1U] >> (16U * (index & 1U)) & 0x0000ffff;
94+
return static_cast<common::half>(
95+
fmaf(v, SIGNED_HALF_FACTOR, SIGNED_HALF_HALF_FACTOR));
96+
}
97+
98+
// Generates rationals in [0, 1)
99+
double getDouble01(uint *val, uint index) {
100+
uintl v = transform<uintl>(val, index);
101+
constexpr double factor =
102+
((1.0) / (std::numeric_limits<unsigned long long>::max() +
103+
static_cast<long double>(1.0l)));
104+
constexpr double half_factor((0.5) * factor);
105+
return fma(v, factor, half_factor);
51106
}
52107

53108
template<>
54-
char transform<char>(uint *val, int index) {
109+
char transform<char>(uint *val, uint index) {
55110
char v = val[index >> 2] >> (8 << (index & 3));
56111
v = (v & 0x1) ? 1 : 0;
57112
return v;
58113
}
59114

60115
template<>
61-
uchar transform<uchar>(uint *val, int index) {
116+
uchar transform<uchar>(uint *val, uint index) {
62117
uchar v = val[index >> 2] >> (index << 3);
63118
return v;
64119
}
65120

66121
template<>
67-
ushort transform<ushort>(uint *val, int index) {
122+
ushort transform<ushort>(uint *val, uint index) {
68123
ushort v = val[index >> 1U] >> (16U * (index & 1U)) & 0x0000ffff;
69124
return v;
70125
}
71126

72127
template<>
73-
short transform<short>(uint *val, int index) {
128+
short transform<short>(uint *val, uint index) {
74129
return transform<ushort>(val, index);
75130
}
76131

77132
template<>
78-
uint transform<uint>(uint *val, int index) {
133+
uint transform<uint>(uint *val, uint index) {
79134
return val[index];
80135
}
81136

82137
template<>
83-
int transform<int>(uint *val, int index) {
138+
int transform<int>(uint *val, uint index) {
84139
return transform<uint>(val, index);
85140
}
86141

87142
template<>
88-
uintl transform<uintl>(uint *val, int index) {
89-
uintl v = (((uintl)val[index << 1]) << 32) | ((uintl)val[(index << 1) + 1]);
143+
intl transform<intl>(uint *val, uint index) {
144+
uintl v = transform<uintl>(val, index);
145+
intl out;
146+
memcpy(&out, &v, sizeof(intl));
90147
return v;
91148
}
92149

93150
template<>
94-
intl transform<intl>(uint *val, int index) {
95-
return transform<uintl>(val, index);
151+
float transform<float>(uint *val, uint index) {
152+
return 1.f - getFloat01(val, index);
96153
}
97154

98-
// Generates rationals in [0, 1)
99155
template<>
100-
float transform<float>(uint *val, int index) {
101-
return 1.f - (val[index] * FLT_FACTOR + HALF_FLT_FACTOR);
156+
double transform<double>(uint *val, uint index) {
157+
return 1. - getDouble01(val, index);
102158
}
103159

104-
// Generates rationals in [0, 1)
105160
template<>
106-
common::half transform<common::half>(uint *val, int index) {
161+
common::half transform<common::half>(uint *val, uint index) {
107162
float v = val[index >> 1U] >> (16U * (index & 1U)) & 0x0000ffff;
108163
return static_cast<common::half>(1.f -
109-
(v * HALF_FACTOR + HALF_HALF_FACTOR));
164+
fmaf(v, HALF_FACTOR, HALF_HALF_FACTOR));
110165
}
111166

112-
// Generates rationals in [0, 1)
113-
template<>
114-
double transform<double>(uint *val, int index) {
115-
uintl v = transform<uintl>(val, index);
116-
return 1.0 - (v * DBL_FACTOR + HALF_DBL_FACTOR);
167+
// Generates rationals in [-1, 1)
168+
double getDoubleNegative11(uint *val, uint index) {
169+
intl v = transform<intl>(val, index);
170+
// Conversion to doubles adapted from Random123
171+
constexpr double signed_factor =
172+
((1.0l) / (std::numeric_limits<long long>::max() + (1.0l)));
173+
constexpr double half_factor = ((0.5) * signed_factor);
174+
return fma(v, signed_factor, half_factor);
117175
}
118176

119177
#define MAX_RESET_CTR_VAL 64
@@ -201,34 +259,35 @@ void boxMullerTransform(data_t<T> *const out1, data_t<T> *const out2,
201259
* The log of a real value x where 0 < x < 1 is negative.
202260
*/
203261
using Tc = compute_t<T>;
204-
Tc r = sqrt((Tc)(-2.0) * log((Tc)(1.0) - static_cast<Tc>(r1)));
205-
Tc theta = 2 * (Tc)PI_VAL * ((Tc)(1.0) - static_cast<Tc>(r2));
206-
*out1 = r * sin(theta);
207-
*out2 = r * cos(theta);
262+
Tc r = sqrt((Tc)(-2.0) * log(static_cast<Tc>(r2)));
263+
Tc theta = PI_VAL * (static_cast<Tc>(r1));
264+
265+
*out1 = r * sin(theta);
266+
*out2 = r * cos(theta);
208267
}
209268

210269
void boxMullerTransform(uint val[4], double *temp) {
211-
boxMullerTransform<double>(&temp[0], &temp[1], transform<double>(val, 0),
212-
transform<double>(val, 1));
270+
boxMullerTransform<double>(&temp[0], &temp[1], getDoubleNegative11(val, 0),
271+
getDouble01(val, 1));
213272
}
214273

215274
void boxMullerTransform(uint val[4], float *temp) {
216-
boxMullerTransform<float>(&temp[0], &temp[1], transform<float>(val, 0),
217-
transform<float>(val, 1));
218-
boxMullerTransform<float>(&temp[2], &temp[3], transform<float>(val, 2),
219-
transform<float>(val, 3));
275+
boxMullerTransform<float>(&temp[0], &temp[1], getFloatNegative11(val, 0),
276+
getFloat01(val, 1));
277+
boxMullerTransform<float>(&temp[2], &temp[3], getFloatNegative11(val, 2),
278+
getFloat01(val, 3));
220279
}
221280

222281
void boxMullerTransform(uint val[4], common::half *temp) {
223282
using common::half;
224-
boxMullerTransform<half>(&temp[0], &temp[1], transform<half>(val, 0),
225-
transform<half>(val, 1));
226-
boxMullerTransform<half>(&temp[2], &temp[3], transform<half>(val, 2),
227-
transform<half>(val, 3));
228-
boxMullerTransform<half>(&temp[4], &temp[5], transform<half>(val, 4),
229-
transform<half>(val, 5));
230-
boxMullerTransform<half>(&temp[6], &temp[7], transform<half>(val, 6),
231-
transform<half>(val, 7));
283+
boxMullerTransform<half>(&temp[0], &temp[1], getHalfNegative11(val, 0),
284+
getHalf01(val, 1));
285+
boxMullerTransform<half>(&temp[2], &temp[3], getHalfNegative11(val, 2),
286+
getHalf01(val, 3));
287+
boxMullerTransform<half>(&temp[4], &temp[5], getHalfNegative11(val, 4),
288+
getHalf01(val, 5));
289+
boxMullerTransform<half>(&temp[6], &temp[7], getHalfNegative11(val, 6),
290+
getHalf01(val, 7));
232291
}
233292

234293
template<typename T>

0 commit comments

Comments
 (0)