Skip to content

Commit 8ebd293

Browse files
committed
DPL: allow keeping track of outgoing messages
This reworks how sending messages works, by keeping track of the specific routes used to send messages, rather than simply keeping track of the FairMQ channel name. This should allow us to asses wether or not a given route has messages associated to it and therefore wether or not: * Mandatory inputs are there * Sporadic inputs are not As a bonus, we get rid of the convoluted and slow matching by string, we possibly avoid a copy when the message is created using the wrong transport and we isolate even further message sending in FairMQDeviceProxy. This also still behaves as before, so that messages to be sent via the same channel are still coalesced in a single multipart message.
1 parent 7709de1 commit 8ebd293

28 files changed

Lines changed: 347 additions & 226 deletions

Framework/Core/include/Framework/ArrowContext.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#define O2_FRAMEWORK_ARROWCONTEXT_H_
1313

1414
#include "Framework/FairMQDeviceProxy.h"
15+
#include "Framework/RoutingIndices.h"
1516
#include <cassert>
1617
#include <functional>
1718
#include <memory>
@@ -31,7 +32,7 @@ class FairMQResizableBuffer;
3132
class ArrowContext
3233
{
3334
public:
34-
ArrowContext(FairMQDeviceProxy proxy)
35+
ArrowContext(FairMQDeviceProxy& proxy)
3536
: mProxy{proxy}
3637
{
3738
}
@@ -43,20 +44,20 @@ class ArrowContext
4344
std::shared_ptr<FairMQResizableBuffer> buffer;
4445
/// The function to call to finalise the builder into the message
4546
std::function<void(std::shared_ptr<FairMQResizableBuffer>)> finalize;
46-
std::string channel;
47+
RouteIndex routeIndex;
4748
};
4849

4950
using Messages = std::vector<MessageRef>;
5051

5152
void addBuffer(std::unique_ptr<FairMQMessage> header,
5253
std::shared_ptr<FairMQResizableBuffer> buffer,
5354
std::function<void(std::shared_ptr<FairMQResizableBuffer>)> finalize,
54-
const std::string& channel)
55+
RouteIndex routeIndex)
5556
{
5657
mMessages.push_back(std::move(MessageRef{std::move(header),
5758
std::move(buffer),
5859
std::move(finalize),
59-
channel}));
60+
routeIndex}));
6061
}
6162

6263
Messages::iterator begin()
@@ -128,7 +129,7 @@ class ArrowContext
128129
}
129130

130131
private:
131-
FairMQDeviceProxy mProxy;
132+
FairMQDeviceProxy& mProxy;
132133
Messages mMessages;
133134
size_t mBytesSent = 0;
134135
size_t mBytesDestroyed = 0;

Framework/Core/include/Framework/CommonMessageBackends.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ namespace o2::framework
1919

2020
/// A few ServiceSpecs data sending backends
2121
struct CommonMessageBackends {
22+
static ServiceSpec fairMQDeviceProxy();
2223
static ServiceSpec fairMQBackendSpec();
2324
static ServiceSpec stringBackendSpec();
2425
static ServiceSpec rawBufferBackendSpec();

Framework/Core/include/Framework/DataAllocator.h

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "Framework/CheckTypes.h"
2828
#include "Framework/ServiceRegistry.h"
2929
#include "Framework/RuntimeError.h"
30+
#include "Framework/RouteState.h"
3031

3132
#include "Headers/DataHeader.h"
3233
#include <TClass.h>
@@ -106,35 +107,35 @@ class DataAllocator
106107
// plain buffer as polymorphic spectator std::vector, which does not run constructors / destructors
107108
using ValueType = typename T::value_type;
108109
auto& timingInfo = mRegistry->get<TimingInfo>();
109-
std::string const& channel = matchDataHeader(spec, timingInfo.timeslice);
110+
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
110111
auto& context = mRegistry->get<MessageContext>();
111112

112113
// Note: initial payload size is 0 and will be set by the context before sending
113-
FairMQMessagePtr headerMessage = headerMessageFromOutput(spec, channel, o2::header::gSerializationMethodNone, 0);
114+
FairMQMessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
114115
return context.add<MessageContext::VectorObject<ValueType, MessageContext::ContainerRefObject<std::vector<ValueType, o2::pmr::NoConstructAllocator<ValueType>>>>>(
115-
std::move(headerMessage), channel, 0, std::forward<Args>(args)...)
116+
std::move(headerMessage), routeIndex, 0, std::forward<Args>(args)...)
116117
.get();
117118
} else if constexpr (is_specialization_v<T, std::vector> && has_messageable_value_type<T>::value) {
118119
// this catches all std::vector objects with messageable value type before checking if is also
119120
// has a root dictionary, so non-serialized transmission is preferred
120121
using ValueType = typename T::value_type;
121122
auto& timingInfo = mRegistry->get<TimingInfo>();
122-
std::string const& channel = matchDataHeader(spec, timingInfo.timeslice);
123+
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
123124
auto& context = mRegistry->get<MessageContext>();
124125

125126
// Note: initial payload size is 0 and will be set by the context before sending
126-
FairMQMessagePtr headerMessage = headerMessageFromOutput(spec, channel, o2::header::gSerializationMethodNone, 0);
127-
return context.add<MessageContext::VectorObject<ValueType>>(std::move(headerMessage), channel, 0, std::forward<Args>(args)...).get();
127+
FairMQMessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
128+
return context.add<MessageContext::VectorObject<ValueType>>(std::move(headerMessage), routeIndex, 0, std::forward<Args>(args)...).get();
128129
} else if constexpr (has_root_dictionary<T>::value == true && is_messageable<T>::value == false) {
129130
// Extended support for types implementing the Root ClassDef interface, both TObject
130131
// derived types and others
131132
auto& timingInfo = mRegistry->get<TimingInfo>();
132-
std::string const& channel = matchDataHeader(spec, timingInfo.timeslice);
133+
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
133134
auto& context = mRegistry->get<MessageContext>();
134135

135136
// Note: initial payload size is 0 and will be set by the context before sending
136-
FairMQMessagePtr headerMessage = headerMessageFromOutput(spec, channel, o2::header::gSerializationMethodROOT, 0);
137-
return context.add<MessageContext::RootSerializedObject<T>>(std::move(headerMessage), channel, std::forward<Args>(args)...).get();
137+
FairMQMessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodROOT, 0);
138+
return context.add<MessageContext::RootSerializedObject<T>>(std::move(headerMessage), routeIndex, std::forward<Args>(args)...).get();
138139
} else if constexpr (std::is_base_of_v<std::string, T>) {
139140
auto* s = new std::string(args...);
140141
adopt(spec, s);
@@ -172,11 +173,11 @@ class DataAllocator
172173
auto [nElements] = std::make_tuple(args...);
173174
auto size = nElements * sizeof(T);
174175
auto& timingInfo = mRegistry->get<TimingInfo>();
175-
std::string const& channel = matchDataHeader(spec, timingInfo.timeslice);
176+
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
176177
auto& context = mRegistry->get<MessageContext>();
177178

178-
FairMQMessagePtr headerMessage = headerMessageFromOutput(spec, channel, o2::header::gSerializationMethodNone, size);
179-
return context.add<MessageContext::SpanObject<T>>(std::move(headerMessage), channel, 0, nElements).get();
179+
FairMQMessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, size);
180+
return context.add<MessageContext::SpanObject<T>>(std::move(headerMessage), routeIndex, 0, nElements).get();
180181
}
181182
} else if constexpr (std::is_same_v<FirstArg, std::shared_ptr<arrow::Schema>>) {
182183
if constexpr (std::is_base_of_v<arrow::ipc::RecordBatchWriter, T>) {
@@ -239,10 +240,10 @@ class DataAllocator
239240

240241
char* payload = reinterpret_cast<char*>(ptr);
241242
auto& timingInfo = mRegistry->get<TimingInfo>();
242-
std::string const& channel = matchDataHeader(spec, timingInfo.timeslice);
243+
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
243244
// the correct payload size is set later when sending the
244245
// RawBufferContext, see DataProcessor::doSend
245-
auto header = headerMessageFromOutput(spec, channel, o2::header::gSerializationMethodNone, 0);
246+
auto header = headerMessageFromOutput(spec, routeIndex, o2::header::gSerializationMethodNone, 0);
246247

247248
auto lambdaSerialize = [voidPtr = payload]() {
248249
return o2::utils::BoostSerialize<type>(*(reinterpret_cast<type*>(voidPtr)));
@@ -253,7 +254,7 @@ class DataAllocator
253254
delete tmpPtr;
254255
};
255256

256-
mRegistry->get<RawBufferContext>().addRawBuffer(std::move(header), std::move(payload), std::move(channel), std::move(lambdaSerialize), std::move(lambdaDestructor));
257+
mRegistry->get<RawBufferContext>().addRawBuffer(std::move(header), std::move(payload), routeIndex, std::move(lambdaSerialize), std::move(lambdaDestructor));
257258
}
258259

259260
/// Send a snapshot of an object, depending on the object type it is serialized before.
@@ -278,12 +279,13 @@ class DataAllocator
278279
template <typename T>
279280
void snapshot(const Output& spec, T const& object)
280281
{
281-
auto proxy = mRegistry->get<MessageContext>().proxy();
282+
auto& proxy = mRegistry->get<MessageContext>().proxy();
282283
FairMQMessagePtr payloadMessage;
283284
auto serializationType = o2::header::gSerializationMethodNone;
285+
RouteIndex routeIndex = matchDataHeader(spec, mRegistry->get<TimingInfo>().timeslice);
284286
if constexpr (is_messageable<T>::value == true) {
285287
// Serialize a snapshot of a trivially copyable, non-polymorphic object,
286-
payloadMessage = proxy.createMessage(sizeof(T));
288+
payloadMessage = proxy.createMessage(routeIndex, sizeof(T));
287289
memcpy(payloadMessage->GetData(), &object, sizeof(T));
288290

289291
serializationType = o2::header::gSerializationMethodNone;
@@ -296,7 +298,7 @@ class DataAllocator
296298
// reference object
297299
constexpr auto elementSizeInBytes = sizeof(ElementType);
298300
auto sizeInBytes = elementSizeInBytes * object.size();
299-
payloadMessage = proxy.createMessage(sizeInBytes);
301+
payloadMessage = proxy.createMessage(routeIndex, sizeInBytes);
300302

301303
if constexpr (std::is_pointer<typename T::value_type>::value == false) {
302304
// vector of elements
@@ -324,7 +326,7 @@ class DataAllocator
324326
}
325327
} else if constexpr (has_root_dictionary<T>::value == true || is_specialization_v<T, ROOTSerialized> == true) {
326328
// Serialize a snapshot of an object with root dictionary
327-
payloadMessage = proxy.createMessage();
329+
payloadMessage = proxy.createMessage(routeIndex);
328330
if constexpr (is_specialization_v<T, ROOTSerialized> == true) {
329331
// Explicitely ROOT serialize a snapshot of object.
330332
// An object wrapped into type `ROOTSerialized` is explicitely marked to be ROOT serialized
@@ -403,9 +405,9 @@ class DataAllocator
403405
o2::pmr::FairMQMemoryResource* getMemoryResource(const Output& spec)
404406
{
405407
auto& timingInfo = mRegistry->get<TimingInfo>();
406-
std::string const& channel = matchDataHeader(spec, timingInfo.timeslice);
407-
auto& context = mRegistry->get<MessageContext>();
408-
return *context.proxy().getTransport(channel);
408+
auto& proxy = mRegistry->get<FairMQDeviceProxy>();
409+
RouteIndex routeIndex = matchDataHeader(spec, timingInfo.timeslice);
410+
return *proxy.getTransport(routeIndex);
409411
}
410412

411413
//make a stl (pmr) vector
@@ -486,9 +488,9 @@ class DataAllocator
486488
AllowedOutputRoutes mAllowedOutputRoutes;
487489
ServiceRegistry* mRegistry;
488490

489-
std::string const& matchDataHeader(const Output& spec, size_t timeframeId);
491+
RouteIndex matchDataHeader(const Output& spec, size_t timeframeId);
490492
FairMQMessagePtr headerMessageFromOutput(Output const& spec, //
491-
std::string const& channel, //
493+
RouteIndex index, //
492494
o2::header::SerializationMethod serializationMethod, //
493495
size_t payloadSize); //
494496

@@ -504,12 +506,12 @@ DataAllocator::CacheId DataAllocator::adoptContainer(const Output& spec, Contain
504506
// Find a matching channel, extract the message for it form the container
505507
// and put it in the queue to be sent at the end of the processing
506508
auto& timingInfo = mRegistry->get<TimingInfo>();
507-
std::string const& channel = matchDataHeader(spec, timingInfo.timeslice);
509+
auto routeIndex = matchDataHeader(spec, timingInfo.timeslice);
508510

509511
auto& context = mRegistry->get<MessageContext>();
510-
FairMQMessagePtr payloadMessage = o2::pmr::getMessage(std::forward<ContainerT>(container), *context.proxy().getTransport(channel));
511-
512-
FairMQMessagePtr headerMessage = headerMessageFromOutput(spec, channel, //
512+
auto* transport = mRegistry->get<FairMQDeviceProxy>().getTransport(routeIndex);
513+
FairMQMessagePtr payloadMessage = o2::pmr::getMessage(std::forward<ContainerT>(container), *transport);
514+
FairMQMessagePtr headerMessage = headerMessageFromOutput(spec, routeIndex, //
513515
method, //
514516
payloadMessage->GetSize() //
515517
);
@@ -522,7 +524,7 @@ DataAllocator::CacheId DataAllocator::adoptContainer(const Output& spec, Contain
522524
cacheId.value = context.addToCache(payloadMessage);
523525
}
524526

525-
context.add<MessageContext::TrivialObject>(std::move(headerMessage), std::move(payloadMessage), channel);
527+
context.add<MessageContext::TrivialObject>(std::move(headerMessage), std::move(payloadMessage), routeIndex);
526528
return cacheId;
527529
}
528530

Framework/Core/include/Framework/DataSender.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#ifndef O2_FRAMEWORK_DATASENDER_H_
1212
#define O2_FRAMEWORK_DATASENDER_H_
1313

14+
#include "Framework/RoutingIndices.h"
1415
#include "Framework/SendingPolicy.h"
1516
#include "Framework/Tracing.h"
1617
#include "Framework/OutputSpec.h"
@@ -32,11 +33,11 @@ class DataSender
3233
public:
3334
DataSender(ServiceRegistry& registry,
3435
SendingPolicy const& policy);
35-
void send(FairMQParts&, std::string const& s);
36-
std::unique_ptr<FairMQMessage> create();
36+
void send(FairMQParts&, ChannelIndex index);
37+
std::unique_ptr<FairMQMessage> create(RouteIndex index);
3738

3839
private:
39-
void* mContext;
40+
FairMQDeviceProxy& mProxy;
4041
ServiceRegistry& mRegistry;
4142
DeviceSpec const& mSpec;
4243
std::vector<OutputSpec> mOutputs;

Framework/Core/include/Framework/DispatchControl.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,23 @@
88
// In applying this license CERN does not waive the privileges and immunities
99
// granted to it by virtue of its status as an Intergovernmental Organization
1010
// or submit itself to any jurisdiction.
11-
#ifndef FRAMEWORK_DISPATCHCONTROL_H
12-
#define FRAMEWORK_DISPATCHCONTROL_H
11+
#ifndef O2_FRAMEWORK_DISPATCHCONTROL_H_
12+
#define O2_FRAMEWORK_DISPATCHCONTROL_H_
1313

1414
#include "Framework/DispatchPolicy.h"
15+
#include "Framework/OutputRoute.h"
16+
#include "Framework/RoutingIndices.h"
1517
#include <functional>
1618
#include <string>
1719

1820
#include <fairmq/FwdDecls.h>
1921

20-
namespace o2
21-
{
22-
namespace header
22+
namespace o2::header
2323
{
2424
struct DataHeader;
2525
}
2626

27-
namespace framework
27+
namespace o2::framework
2828
{
2929
/// @struct DispatchControl
3030
/// @brief Control for the message dispatching within message context.
@@ -33,14 +33,13 @@ namespace framework
3333
/// is used to decide when to sent the scheduled messages via the actual dispatch
3434
/// callback.
3535
struct DispatchControl {
36-
using DispatchCallback = std::function<void(FairMQParts&& message, std::string const&, int)>;
36+
using DispatchCallback = std::function<void(FairMQParts&& message, ChannelIndex index, int)>;
3737
using DispatchTrigger = std::function<bool(o2::header::DataHeader const&)>;
3838
// dispatcher callback
3939
DispatchCallback dispatch;
4040
// matcher to trigger sending of scheduled messages
4141
DispatchTrigger trigger;
4242
};
4343

44-
} // namespace framework
45-
} // namespace o2
46-
#endif // FRAMEWORK_DISPATCHCONTROL_H
44+
} // namespace o2::framework
45+
#endif // O2_FRAMEWORK_DISPATCHCONTROL_H_

Framework/Core/include/Framework/FairMQDeviceProxy.h

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,46 +8,45 @@
88
// In applying this license CERN does not waive the privileges and immunities
99
// granted to it by virtue of its status as an Intergovernmental Organization
1010
// or submit itself to any jurisdiction.
11-
#ifndef FRAMEWORK_FAIRMQDEVICEPROXY_H
12-
#define FRAMEWORK_FAIRMQDEVICEPROXY_H
11+
#ifndef O2_FRAMEWORK_FAIRMQDEVICEPROXY_H_
12+
#define O2_FRAMEWORK_FAIRMQDEVICEPROXY_H_
1313

1414
#include <memory>
1515

16+
#include "Framework/RoutingIndices.h"
17+
#include "Framework/RouteState.h"
18+
#include "Framework/OutputRoute.h"
1619
#include <fairmq/FwdDecls.h>
20+
#include <vector>
1721

18-
namespace o2
19-
{
20-
namespace framework
22+
namespace o2::framework
2123
{
2224
/// Helper class to hide FairMQDevice headers in the DataAllocator header.
2325
/// This is done because FairMQDevice brings in a bunch of boost.mpl /
2426
/// boost.fusion stuff, slowing down compilation times enourmously.
2527
class FairMQDeviceProxy
2628
{
2729
public:
28-
FairMQDeviceProxy(FairMQDevice* device)
29-
: mDevice{device}
30-
{
31-
}
32-
33-
/// To be used in DataAllocator.cxx to avoid reimplenting any device
34-
/// API.
35-
FairMQDevice* getDevice()
36-
{
37-
return mDevice;
38-
}
39-
40-
/// Looks like what we really need in the headers is just the transport.
41-
FairMQTransportFactory* getTransport();
42-
FairMQTransportFactory* getTransport(const std::string& channel, int index = 0);
43-
std::unique_ptr<FairMQMessage> createMessage() const;
44-
std::unique_ptr<FairMQMessage> createMessage(const size_t size) const;
30+
FairMQDeviceProxy() = default;
31+
FairMQDeviceProxy(FairMQDeviceProxy const&) = delete;
32+
void bindRoutes(std::vector<OutputRoute> const& routes, FairMQDevice& device);
33+
34+
/// Retrieve the transport associated to a given route.
35+
fair::mq::TransportFactory* getTransport(RouteIndex routeIndex) const;
36+
/// ChannelIndex from a RouteIndex
37+
ChannelIndex getChannelIndex(RouteIndex routeIndex) const;
38+
/// Retrieve the channel associated to a given route.
39+
fair::mq::Channel* getChannel(ChannelIndex channelIndex) const;
40+
41+
std::unique_ptr<FairMQMessage> createMessage(RouteIndex routeIndex) const;
42+
std::unique_ptr<FairMQMessage> createMessage(RouteIndex routeIndex, const size_t size) const;
43+
size_t getNumChannels() const { return mChannels.size(); }
4544

4645
private:
47-
FairMQDevice* mDevice;
46+
std::vector<RouteState> mRoutes;
47+
std::vector<fair::mq::Channel*> mChannels;
4848
};
4949

50-
} // namespace framework
51-
} // namespace o2
50+
} // namespace o2::framework
5251

53-
#endif // FRAMEWORK_FAIRMQDEVICEPROXY_H
52+
#endif // O2_FRAMEWORK_FAIRMQDEVICEPROXY_H_

0 commit comments

Comments
 (0)