Skip to content

Commit 0823eee

Browse files
f3schmartenole
authored andcommitted
TRD: PID: fixups
This commit renames the factory function to better reflect its purpose. Additionally, z-row merging and charge correction have been added to the codebase to improve functionality. Pytorch policy and LQND policy have been added as new policies. The README has also been added to provide additional explanation of the code. As part of this update, the pid policy map and pidvalue alias have been removed. A print overload has been added for the policy enum to improve readability. Various minor fixes have also been made to improve overall code quality. Also included are various changes to the class layout and the ccdb object for LUTs. Signed-off-by: Felix Schlepper <f3sch.git@outlook.com>
1 parent a7e5416 commit 0823eee

17 files changed

Lines changed: 614 additions & 210 deletions

File tree

DataFormats/Detectors/TRD/include/DataFormatsTRD/PID.h

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <array>
1919
#include <unordered_map>
2020
#include <string>
21+
#include <iostream>
2122

2223
namespace o2
2324
{
@@ -28,46 +29,71 @@ namespace trd
2829
enum class PIDPolicy : unsigned int {
2930
// Classical Algorithms
3031
LQ1D = 0, ///< 1-Dimensional Likelihood model
32+
LQ2D, ///< 2-Dimensional Likelihood model
3133
LQ3D, ///< 3-Dimensional Likelihood model
3234

35+
#ifdef TRDPID_WITH_ONNX
3336
// ML models
3437
XGB, ///< XGBOOST
38+
PY, ///< Pytorch
39+
#endif
3540

3641
// Do not add anything after this!
3742
NMODELS, ///< Count of all models
38-
Test, ///< Load object for testing
3943
Dummy, ///< Dummy object outputting -1.f
4044
DEFAULT = Dummy, ///< The default option
4145
};
4246

47+
inline std::ostream& operator<<(std::ostream& os, const PIDPolicy& policy)
48+
{
49+
std::string name;
50+
switch (policy) {
51+
case PIDPolicy::LQ1D:
52+
name = "LQ1D";
53+
break;
54+
case PIDPolicy::LQ2D:
55+
name = "LQ2D";
56+
break;
57+
case PIDPolicy::LQ3D:
58+
name = "LQ3D";
59+
break;
60+
#ifdef TRDPID_WITH_ONNX
61+
case PIDPolicy::XGB:
62+
name = "XGBoost";
63+
break;
64+
case PIDPolicy::PY:
65+
name = "PyTorch";
66+
break;
67+
#endif
68+
case PIDPolicy::Dummy:
69+
name = "Dummy";
70+
break;
71+
default:
72+
name = "Default";
73+
}
74+
os << name;
75+
return os;
76+
}
77+
4378
/// Transform PID policy from string to enum.
4479
static const std::unordered_map<std::string, PIDPolicy> PIDPolicyString{
4580
// Classical Algorithms
4681
{"LQ1D", PIDPolicy::LQ1D},
82+
{"LQ2D", PIDPolicy::LQ2D},
4783
{"LQ3D", PIDPolicy::LQ3D},
4884

85+
#ifdef TRDPID_WITH_ONNX
4986
// ML models
5087
{"XGB", PIDPolicy::XGB},
88+
{"PY", PIDPolicy::PY},
89+
#endif
5190

5291
// General
53-
{"TEST", PIDPolicy::Test},
5492
{"DUMMY", PIDPolicy::Dummy},
5593
// Default
5694
{"default", PIDPolicy::DEFAULT},
5795
};
5896

59-
/// Transform PID policy from string to enum.
60-
static const char* PIDPolicyEnum[] = {
61-
"LQ1D",
62-
"LQ3D",
63-
"XGBoost",
64-
"NMODELS",
65-
"Test",
66-
"Dummy",
67-
"default(=TODO)"};
68-
69-
using PIDValue = float;
70-
7197
} // namespace trd
7298
} // namespace o2
7399

Detectors/TRD/pid/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ if(ONNXRuntime_FOUND)
3131
HEADERS include/TRDPID/PIDBase.h
3232
include/TRDPID/PIDParameters.h
3333
include/TRDPID/ML.h
34+
include/TRDPID/LQND.h
3435
include/TRDPID/Dummy.h)
3536
else()
3637
o2_target_root_dictionary(TRDPID
3738
HEADERS include/TRDPID/PIDBase.h
3839
include/TRDPID/PIDParameters.h
3940
include/TRDPID/Dummy.h
41+
include/TRDPID/LQND.h
4042
LINKDEF src/TRDPIDNoMLLinkDef.h)
4143
endif()
4244

Detectors/TRD/pid/README.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Particle Identification with TRD
2+
## Usage
3+
Activate PID during tracking with the '--with-pid' flag.
4+
5+
o2-trd-global-tracking --with-pid --policy ML
6+
7+
Specify a which algorithm (called policy) should be use.
8+
Implemented are the following:
9+
10+
- LQ1D
11+
- LQ2D
12+
- LQ3D
13+
- ML (every model, which is exported to the ONNX format):
14+
- XGB (XGBoost model)
15+
- NN (Pytorch model)
16+
- Dummy (returns only -1)
17+
- Test (one of the above)
18+
- Default (one of the above, gets picked if '--policy' is unspecified)
19+
20+
## Implementation details
21+
### Tracking workflow
22+
Every TRDTrack gets a PID value set (mSignal), which then gets propergated to the AO2D writer.
23+
24+
### Basic Interface
25+
The base interface for PID is defined in [here](include/TRDPID/PIDBase.h).
26+
The 'init' function is such that each policy can specify what if anything it needs from the CCDB.
27+
For the 'process' each policy defines how a TRDTrack gets assigned a PID value.
28+
Additionally, the base class implements how to get the _corrected charges_ from the tracklets.
29+
_Corrected charges_ means z-row merged and calibrated charges.
30+
31+
### Classical Likelihood
32+
The classical LQND policies ([here](include/TRDPID/LQND.h)) require an array of lookup tables (LUTs) from the ccdb.
33+
$N$ stands for the dimension.
34+
Then the electron likelihood for layer $i$ is defined as this:
35+
36+
$$L_i(e|Q_i)=\frac{P(Q_i|e)}{P(Q_i|e)+P(Q_i|\pi)}$$
37+
38+
From the charge $Q_i$ the LUTs give the corresponding $L_i$.
39+
The _combined electron likelihood_ is obtained by this formula:
40+
41+
$$L(e|Q)=\frac{\prod_i L_i(e|Q_i)}{\prod_i L_i(e|Q_i) + \prod_i L_i(\pi|Q_i)}$$
42+
43+
where $L_i(\pi|Q_i)=1-L_i(e|Q_i)$.
44+
45+
46+
Extension to higher dimensions is easy each tracklet has charges $Q_j$ which cover the integral of the pulse height spectrum in different slice ($j\in [0,1,2]$).
47+
In our case $Q0$ covers the pulse height peak, $Q1$ the Transition Radiation peak and $Q2$ the plateau.
48+
For each charge $j$ a LUT is available which gives the likelihood $L^e_j$.
49+
For each layer $i$ the likelihood is then:
50+
51+
$$L_i(e|Q)=\frac{\prod_j L_{i,j}(e|Q_j)}{\prod_j L_{i,j}(e|Q_j) + \prod_j L_{i,j}(\pi|Q_j)}$$
52+
53+
The combined electron likelihood is then:
54+
55+
$$L(e|Q)=\frac{\prod_{i,j} L_{i,j}(e|Q_j)}{\prod_{i,j} L_{i,j}(e|Q_j) + \prod_{i,j} L_{i,j}(\pi|Q_j)}$$
56+
57+
58+
### Machine Learning
59+
The ML policies ([here](include/TRDPID/ML.h)) are uploaded to the CCDB in the ONNX file format (most python machine learning libraries support this standardized format).
60+
In O2 we leverage the ONNXRuntime to use these formats and calculate a PID value.
61+
The models can thus be trained in python which is very convenient.
62+
The code should take care of most of the annoying stuff.
63+
Policies just have to specify how to get the electron likelihood from the ONNXRuntime output (each python library varies in that somewhat).
64+
The 'prepareModelInput' prepares the TRDTracks as input.

Detectors/TRD/pid/include/TRDPID/Dummy.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ class Dummy final : public PIDBase
3434
using PIDBase::PIDBase;
3535

3636
public:
37-
~Dummy() final = default;
37+
~Dummy() = default;
3838

3939
/// Do absolutely nothing.
4040
void init(o2::framework::ProcessingContext& pc) final{};
4141

4242
/// Everything below 0.f indicates nothing available.
43-
PIDValue process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPC) final
43+
float process(const TrackTRD& trk, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final
4444
{
4545
return -1.f;
4646
};
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
2+
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
3+
// All rights not expressly granted are reserved.
4+
//
5+
// This software is distributed under the terms of the GNU General Public
6+
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
7+
//
8+
// In applying this license CERN does not waive the privileges and immunities
9+
// granted to it by virtue of its status as an Intergovernmental Organization
10+
// or submit itself to any jurisdiction.
11+
12+
/// \file LQND.h
13+
/// \brief This file provides the interface for loglikehood policies
14+
/// \author Felix Schlepper
15+
16+
#ifndef O2_TRD_LQND_H
17+
#define O2_TRD_LQND_H
18+
19+
#include "TGraph.h"
20+
#include "TRDPID/PIDBase.h"
21+
#include "DataFormatsTRD/PID.h"
22+
#include "DataFormatsTRD/Constants.h"
23+
#include "DataFormatsTRD/HelperMethods.h"
24+
#include "Framework/ProcessingContext.h"
25+
#include "Framework/InputRecord.h"
26+
#include "DataFormatsTRD/CalibratedTracklet.h"
27+
#include "DetectorsBase/Propagator.h"
28+
#include "Framework/Logger.h"
29+
#include "ReconstructionDataFormats/TrackParametrization.h"
30+
31+
#include <memory>
32+
#include <vector>
33+
#include <array>
34+
#include <string>
35+
#include <numeric>
36+
37+
namespace o2
38+
{
39+
namespace trd
40+
{
41+
namespace detail
42+
{
43+
/// Lookup Table class for ccdb upload
44+
template <int nDim>
45+
class LUT
46+
{
47+
public:
48+
LUT() = default;
49+
LUT(std::vector<float> p, std::vector<TGraph> l) : mIntervalsP(p), mLUTs(l) {}
50+
51+
//
52+
const TGraph& get(float p, bool isNegative, int iDim = 0) const
53+
{
54+
auto upper = std::upper_bound(mIntervalsP.begin(), mIntervalsP.end(), p);
55+
if (upper == mIntervalsP.end()) {
56+
// outside of momentum intervals, should not happen
57+
return mLUTs[0];
58+
}
59+
auto index = std::distance(mIntervalsP.begin(), upper);
60+
index += (isNegative) ? 0 : mIntervalsP.size() * nDim;
61+
return mLUTs[index + iDim];
62+
}
63+
64+
private:
65+
std::vector<float> mIntervalsP; ///< half-open interval upper bounds starting at 0, e.g., for {1.0,2.0,...} is (-inf,1.0], (1.0,2.0], (2.0, ...)
66+
std::vector<TGraph> mLUTs; ///< corresponding likelihood lookup tables
67+
68+
ClassDefNV(LUT, 1);
69+
};
70+
} // namespace detail
71+
72+
/// This is the ML Base class which defines the interface all machine learning
73+
/// models.
74+
template <int nDim>
75+
class LQND : public PIDBase
76+
{
77+
static_assert(nDim == 1 || nDim == 2 || nDim == 3, "Likelihood only for 1/2/3 dimension");
78+
using PIDBase::PIDBase;
79+
80+
public:
81+
~LQND() = default;
82+
83+
void init(o2::framework::ProcessingContext& pc) final
84+
{
85+
// retrieve lookup tables (LUTs) from ccdb
86+
mLUTs = *(pc.inputs().get<detail::LUT<nDim>*>(Form("lq%ddlut", nDim)));
87+
}
88+
89+
float process(const TrackTRD& trkIn, const o2::globaltracking::RecoContainer& input, bool isTPCTRD) const final
90+
{
91+
const auto& trkSeed = isTPCTRD ? input.getTPCTracks()[trkIn.getRefGlobalTrackId()].getParamOut() : input.getTPCITSTracks()[trkIn.getRefGlobalTrackId()].getParamOut(); // seeding track
92+
auto trk = trkSeed;
93+
94+
const auto isNegative = std::signbit(trkSeed.getSign()); // positive and negative charged particles are treated differently since ExB effects the charge distributions
95+
const auto& trackletsRaw = input.getTRDTracklets();
96+
float lei0{1.f}, lei1{1.f}, lei2{1.f}, lpi0{1.f}, lpi1{1.f}, lpi2{1.f}; // likelihood per layer
97+
for (int iLayer = 0; iLayer < constants::NLAYER; ++iLayer) {
98+
int trkltId = trkIn.getTrackletIndex(iLayer);
99+
if (trkltId < 0) { // no tracklet attached
100+
continue;
101+
}
102+
const auto xCalib = input.getTRDCalibratedTracklets()[trkIn.getTrackletIndex(iLayer)].getX();
103+
auto bz = o2::base::Propagator::Instance()->getNominalBz();
104+
const auto tgl = trk.getTgl();
105+
const auto snp = trk.getSnpAt(o2::math_utils::sector2Angle(HelperMethods::getSector(input.getTRDTracklets()[trkIn.getTrackletIndex(iLayer)].getDetector())), xCalib, bz);
106+
const auto& trklt = trackletsRaw[trkltId];
107+
const auto [q0, q1, q2] = getCharges(trklt, iLayer, trkIn, input, snp, tgl); // correct charges
108+
if constexpr (nDim == 1) {
109+
auto lut = mLUTs.get(trk.getP(), isNegative);
110+
auto ll1{1.f};
111+
ll1 = lut.Eval(q0 + q1 + q2);
112+
lei0 *= ll1;
113+
lpi0 *= (1.f - ll1);
114+
} else if (nDim == 2) {
115+
auto lut1 = mLUTs.get(trk.getP(), isNegative, 0);
116+
auto lut2 = mLUTs.get(trk.getP(), isNegative, 1);
117+
auto ll1{1.f};
118+
auto ll2{1.f};
119+
ll1 = lut1.Eval(q0 + q2);
120+
ll2 = lut2.Eval(q1);
121+
lei0 *= ll1;
122+
lei1 *= ll2;
123+
lpi0 *= (1.f - ll1);
124+
lpi1 *= (1.f - ll2);
125+
} else {
126+
auto lut1 = mLUTs.get(trk.getP(), isNegative, 0);
127+
auto lut2 = mLUTs.get(trk.getP(), isNegative, 1);
128+
auto lut3 = mLUTs.get(trk.getP(), isNegative, 2);
129+
auto ll1{1.f};
130+
auto ll2{1.f};
131+
auto ll3{1.f};
132+
ll1 = lut1.Eval(q0);
133+
ll2 = lut2.Eval(q1);
134+
ll3 = lut3.Eval(q2);
135+
lei0 *= ll1;
136+
lei1 *= ll2;
137+
lei2 *= ll3;
138+
lpi0 *= (1.f - ll1);
139+
lpi1 *= (1.f - ll2);
140+
lpi2 *= (1.f - ll3);
141+
}
142+
}
143+
144+
return (lei0 * lei1 * lei2) / (lei0 * lei1 * lei2 + lpi0 * lpi1 * lpi2); // combined likelihood
145+
}
146+
147+
private:
148+
detail::LUT<nDim> mLUTs; ///< likelihood lookup tables
149+
150+
ClassDefNV(LQND, 1);
151+
};
152+
153+
using LQ1D = LQND<1>;
154+
using LQ2D = LQND<2>;
155+
using LQ3D = LQND<3>;
156+
157+
} // namespace trd
158+
} // namespace o2
159+
160+
#endif

0 commit comments

Comments
 (0)