Skip to content

Commit 6fc0bd5

Browse files
committed
TopoQueue improvements.
1 parent fa3e4dc commit 6fc0bd5

1 file changed

Lines changed: 53 additions & 21 deletions

File tree

include/react/common/TopoQueue.h

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,29 @@
1717

1818
/***************************************/ REACT_IMPL_BEGIN /**************************************/
1919

20+
template <typename T>
21+
struct NodeLevelHelper
22+
{
23+
int operator()(const T& x) const { return x.Level; }
24+
};
25+
26+
template <typename T>
27+
struct NodeLevelHelper<T*>
28+
{
29+
int operator()(const T* x) const { return x->Level; }
30+
};
31+
2032
///////////////////////////////////////////////////////////////////////////////////////////////////
21-
/// TopoQueue
33+
/// Sequential TopoQueue
2234
///////////////////////////////////////////////////////////////////////////////////////////////////
2335
template <typename T>
2436
class TopoQueue
2537
{
2638
public:
27-
using DataT = std::vector<T*>;
39+
using DataT = std::vector<T>;
40+
using LevelFunctorT = NodeLevelHelper<T>;
2841

29-
void Push(T* node)
42+
void Push(const T& node)
3043
{
3144
data_.push_back(node);
3245
}
@@ -37,8 +50,11 @@ class TopoQueue
3750

3851
minLevel_ = INT_MAX;
3952
for (const auto& e : data_)
40-
if (minLevel_ > e->Level)
41-
minLevel_ = e->Level;
53+
{
54+
auto l = LevelFunctorT{}(e);
55+
if (minLevel_ > l)
56+
minLevel_ = l;
57+
}
4258

4359
auto p = std::partition(data_.begin(), data_.end(), CompFunctor{ minLevel_ });
4460

@@ -54,7 +70,7 @@ class TopoQueue
5470
struct CompFunctor
5571
{
5672
CompFunctor(int level) : Level{ level } {}
57-
bool operator()(T* x) { return x->Level != Level; }
73+
bool operator()(const T& x) { return LevelFunctorT{}(x) != Level; }
5874
const int Level;
5975
};
6076

@@ -66,6 +82,18 @@ class TopoQueue
6682
///////////////////////////////////////////////////////////////////////////////////////////////////
6783
/// WeightedRange
6884
///////////////////////////////////////////////////////////////////////////////////////////////////
85+
template <typename T>
86+
struct NodeWeightHelper
87+
{
88+
int operator()(const T& x) const { return x.Weight; }
89+
};
90+
91+
template <typename T>
92+
struct NodeWeightHelper<T*>
93+
{
94+
int operator()(const T* x) const { return x->Weight; }
95+
};
96+
6997
template
7098
<
7199
typename TValue,
@@ -136,17 +164,20 @@ template <typename T, uint grain_size>
136164
class ConcurrentTopoQueue
137165
{
138166
public:
139-
using NodesT = std::vector<T*>;
140-
using RangeT = WeightedRange<typename NodesT::const_iterator, grain_size>;
167+
using DataT = std::vector<T>;
168+
using RangeT = WeightedRange<typename DataT::const_iterator, grain_size>;
169+
using LevelFunctorT = NodeLevelHelper<T>;
170+
using WeightFunctorT = NodeWeightHelper<T>;
141171

142-
void Push(T* node)
172+
void Push(const T& node)
143173
{
144174
auto& t = collectBuffer_.local();
145175

146176
t.Data.push_back(node);
147-
t.Weight += node->Weight;
148-
if (t.MinLevel > node->Level)
149-
t.MinLevel = node->Level;
177+
t.Weight += WeightFunctorT{}(node);
178+
auto l = LevelFunctorT{}(node);
179+
if (t.MinLevel > l)
180+
t.MinLevel = l;
150181
}
151182

152183
bool FetchNext()
@@ -178,11 +209,12 @@ class ConcurrentTopoQueue
178209
buf.MinLevel = INT_MAX;
179210
int oldWeight = buf.Weight;
180211
buf.Weight = 0;
181-
for (const T* x : v)
212+
for (const T& x : v)
182213
{
183-
buf.Weight += x->Weight;
184-
if (buf.MinLevel > x->Level)
185-
buf.MinLevel = x->Level;
214+
buf.Weight += WeightFunctorT{}(x);
215+
auto l = LevelFunctorT{}(x);
216+
if (buf.MinLevel > l)
217+
buf.MinLevel = l;
186218
}
187219

188220
// Add diff to nodes_ weight
@@ -204,19 +236,19 @@ class ConcurrentTopoQueue
204236
struct CompFunctor
205237
{
206238
CompFunctor(int level) : Level{ level } {}
207-
bool operator()(T* other) { return other->Level != Level; }
239+
bool operator()(const T& x) { return LevelFunctorT{}(x) != Level; }
208240
const int Level;
209241
};
210242

211243
struct ThreadLocalBuffer
212244
{
213-
std::vector<T*> Data;
214-
int MinLevel = INT_MAX;
215-
uint Weight = 0;
245+
DataT Data;
246+
int MinLevel = INT_MAX;
247+
uint Weight = 0;
216248
};
217249

218250
int minLevel_ = INT_MAX;
219-
std::vector<T*> nodes_;
251+
DataT nodes_;
220252
RangeT range_;
221253

222254
tbb::enumerable_thread_specific<ThreadLocalBuffer> collectBuffer_;

0 commit comments

Comments
 (0)