-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathcommon.h
More file actions
316 lines (275 loc) · 10.7 KB
/
common.h
File metadata and controls
316 lines (275 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
/* Copyright 2022 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef ARRAY_RECORD_CPP_COMMON_H_
#define ARRAY_RECORD_CPP_COMMON_H_
#include "absl/base/attributes.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
namespace array_record {
////////////////////////////////////////////////////////////////////////////////
// Canonical Errors (with formatting!)
////////////////////////////////////////////////////////////////////////////////
template <typename... Args>
ABSL_MUST_USE_RESULT absl::Status FailedPreconditionError(
const absl::FormatSpec<Args...>& fmt, const Args&... args) {
return absl::FailedPreconditionError(absl::StrFormat(fmt, args...));
}
template <typename... Args>
ABSL_MUST_USE_RESULT absl::Status InternalError(
const absl::FormatSpec<Args...>& fmt, const Args&... args) {
return absl::InternalError(absl::StrFormat(fmt, args...));
}
template <typename... Args>
ABSL_MUST_USE_RESULT absl::Status InvalidArgumentError(
const absl::FormatSpec<Args...>& fmt, const Args&... args) {
return absl::InvalidArgumentError(absl::StrFormat(fmt, args...));
}
template <typename... Args>
ABSL_MUST_USE_RESULT absl::Status NotFoundError(
const absl::FormatSpec<Args...>& fmt, const Args&... args) {
return absl::NotFoundError(absl::StrFormat(fmt, args...));
}
template <typename... Args>
ABSL_MUST_USE_RESULT absl::Status OutOfRangeError(
const absl::FormatSpec<Args...>& fmt, const Args&... args) {
return absl::OutOfRangeError(absl::StrFormat(fmt, args...));
}
template <typename... Args>
ABSL_MUST_USE_RESULT absl::Status UnavailableError(
const absl::FormatSpec<Args...>& fmt, const Args&... args) {
return absl::UnavailableError(absl::StrFormat(fmt, args...));
}
template <typename... Args>
ABSL_MUST_USE_RESULT absl::Status UnimplementedError(
const absl::FormatSpec<Args...>& fmt, const Args&... args) {
return absl::UnimplementedError(absl::StrFormat(fmt, args...));
}
template <typename... Args>
ABSL_MUST_USE_RESULT absl::Status UnknownError(
const absl::FormatSpec<Args...>& fmt, const Args&... args) {
return absl::UnknownError(absl::StrFormat(fmt, args...));
}
// TODO(fchern): Align with what XLA do.
template <typename Int, typename DenomInt>
constexpr Int DivRoundUp(Int num, DenomInt denom) {
// Note: we want DivRoundUp(my_uint64, 17) to just work, so we cast the denom
// to the numerator's type. The result of division always fits in the
// numerator's type, so this is very safe.
return (num + static_cast<Int>(denom) - static_cast<Int>(1)) /
static_cast<Int>(denom);
}
////////////////////////////////////////////////////////////////////////////////
// Class Decorators
////////////////////////////////////////////////////////////////////////////////
#define DECLARE_COPYABLE_CLASS(ClassName) \
ClassName(ClassName&&) = default; \
ClassName& operator=(ClassName&&) = default; \
ClassName(const ClassName&) = default; \
ClassName& operator=(const ClassName&) = default
#define DECLARE_MOVE_ONLY_CLASS(ClassName) \
ClassName(ClassName&&) = default; \
ClassName& operator=(ClassName&&) = default; \
ClassName(const ClassName&) = delete; \
ClassName& operator=(const ClassName&) = delete
#define DECLARE_IMMOBILE_CLASS(ClassName) \
ClassName(ClassName&&) = delete; \
ClassName& operator=(ClassName&&) = delete; \
ClassName(const ClassName&) = delete; \
ClassName& operator=(const ClassName&) = delete
////////////////////////////////////////////////////////////////////////////////
// Seq / SeqWithStride / IndicesOf
////////////////////////////////////////////////////////////////////////////////
//
// Seq facilitates iterating over [begin, end) index ranges.
//
// * Avoids 3X stutter of the 'idx' variable, facilitating use of more
// descriptive variable names like 'datapoint_idx', 'centroid_idx', etc.
//
// * Unifies the syntax between ParallelFor and vanilla for-loops.
//
// * Reverse iteration is much easier to read and less error prone.
//
// * Strided iteration becomes harder to miss when skimming code.
//
// * Reduction in boilerplate '=', '<', '+=' symbols makes it easier to
// skim-read code with lots of small for-loops interleaed with operator heavy
// logic (ie, most of ScaM).
//
// * Zero runtime overhead.
//
//
// Equivalent for-loops (basic iteration):
//
// for (size_t idx : Seq(collection.size()) { ... }
// for (size_t idx : Seq(0, collection.size()) { ... }
// for (size_t idx = 0; idx < collection.size(); idx++) { ... }
//
//
// In particular, reverse iteration becomes much simpler and more readable:
//
// for (size_t idx : ReverseSeq(collection.size())) { ... }
// for (ssize_t idx = collection.size() - 1; idx >= 0; idx--) { ... }
//
//
// Strided iteration works too:
//
// for (size_t idx : SeqWithStride<8>(filenames.size())) { ... }
// for (size_t idx = 0; idx < filenames.size(); idx += 8) { ... }
//
//
// Iteration count without using a variable:
//
// for (auto _ : Seq(16)) { ... }
//
//
// Clarifies the ParallelFor syntax:
//
// ParallelFor<1>(Seq(dataset.size()), &pool, [&](size_t datapoint_idx) {
// ...
// });
//
template <ssize_t kStride = 1>
class SeqWithStride {
public:
static constexpr size_t Stride() { return kStride; }
// Constructor for iterating [0, end).
inline explicit SeqWithStride(size_t end) : begin_(0), end_(end) {}
// Constructor for iterating [begin, end).
inline SeqWithStride(size_t begin, size_t end) : begin_(begin), end_(end) {
static_assert(kStride != 0);
}
// SizeT is an internal detail that helps suppress 'unused variable' compiler
// errors. It's implicitly convertible to size_t, but by virtue of having a
// destructor, the compiler doesn't complain about unused SizeT variables.
//
// These are equivalent:
//
// for (auto _ : Seq(10)) // Suppresses 'unused variable' error.
// for (SizeT _ : Seq(10)) // Suppresses 'unused variable' error.
//
// Prefer the 'auto' variant. Don't use SizeT directly.
//
class SizeT {
public:
// Implicit SizeT <=> SizeT conversions.
inline SizeT(size_t val) : val_(val) {} // NOLINT
inline operator size_t() const { return val_; } // NOLINT
// Defining a destructor suppresses 'unused variable' errors for the
// following pattern: for (auto _ : Seq(kNumIters)) { ... }
inline ~SizeT() {}
private:
size_t val_;
};
// Iterator implements the "!=", "++" and "*" operators required to support
// the C++ for-each syntax. Not intended for direct use.
class Iterator {
public:
// Constructor.
inline explicit Iterator(size_t idx) : idx_(idx) {}
// The '*' operator.
inline SizeT operator*() const { return idx_; }
// The '++' operator.
inline Iterator& operator++() {
idx_ += kStride;
return *this;
}
// The '!=' operator.
inline bool operator!=(Iterator end) const {
// Note: The comparison below is "<", not "!=", in order to generate the
// correct behavior when (end - begin) is not a multiple of kStride; note
// that the Iterator class only exists to support the C++ for-each syntax,
// and is *not* intended for direct use.
//
// Consider the case where (end - begin) is not a multple of kStride:
//
// for (size_t j : SeqWithStride<5>(9)) { ... }
// for (size_t j = 0; j < 9; j += 5) { ... } // '==' wouldn't work.
//
if constexpr (kStride > 0) {
return idx_ < end.idx_;
}
// The reverse-iteration case:
//
// for (size_t j : ReverseSeq(sz)) { ... }
// for (ssize_t j = sz-1; j >= 0; j -= 5) { ... }
//
return static_cast<ssize_t>(idx_) >= static_cast<ssize_t>(end.idx_);
}
private:
size_t idx_;
};
using iterator = Iterator;
inline Iterator begin() const { return Iterator(begin_); }
inline Iterator end() const { return Iterator(end_); }
private:
size_t begin_;
size_t end_;
};
// Seq iterates [0, end)
inline auto Seq(size_t end) { return SeqWithStride<1>(0, end); }
// Seq iterates [begin, end).
inline auto Seq(size_t begin, size_t end) {
return SeqWithStride<1>(begin, end);
}
// IndicesOf provides the following equivalence class:
//
// for (size_t j : IndicesOf(container)) { ... }
// for (size_t j : Seq(container.size()) { ... }
//
template <typename Container>
SeqWithStride<1> IndicesOf(const Container& container) {
return Seq(container.size());
}
////////////////////////////////////////////////////////////////////////////////
// Enumerate
////////////////////////////////////////////////////////////////////////////////
template <typename T, typename IdxType = size_t,
typename TIter = decltype(std::begin(std::declval<T>())),
typename = decltype(std::end(std::declval<T>()))>
constexpr auto Enumerate(T&& iterable) {
class IteratorWithIndex {
public:
IteratorWithIndex(IdxType idx, TIter it) : idx_(idx), it_(it) {}
bool operator!=(const IteratorWithIndex& other) const {
return it_ != other.it_;
}
void operator++() { idx_++, it_++; }
auto operator*() const { return std::tie(idx_, *it_); }
private:
IdxType idx_;
TIter it_;
};
struct iterator_wrapper {
T iterable;
auto begin() { return IteratorWithIndex{0, std::begin(iterable)}; }
auto end() { return IteratorWithIndex{0, std::end(iterable)}; }
};
return iterator_wrapper{std::forward<T>(iterable)};
}
////////////////////////////////////////////////////////////////////////////////
// Profiling Helpers
////////////////////////////////////////////////////////////////////////////////
#define AR_ENDO_MARKER(...)
#define AR_ENDO_MARKER_TIMEOUT(...)
#define AR_ENDO_TASK(...)
#define AR_ENDO_JOB(...)
#define AR_ENDO_TASK_TIMEOUT(...)
#define AR_ENDO_JOB_TIMEOUT(...)
#define AR_ENDO_SCOPE(...)
#define AR_ENDO_SCOPE_TIMEOUT(...)
#define AR_ENDO_EVENT(...)
#define AR_ENDO_ERROR(...)
#define AR_ENDO_UNITS(...)
#define AR_ENDO_THREAD_NAME(...)
#define AR_ENDO_GROUP(...)
} // namespace array_record
#endif // ARRAY_RECORD_CPP_COMMON_H_