Skip to content

Commit efd1996

Browse files
MichaelLettrichshahor02
authored andcommitted
[rANS] Fix Frequency Table Overflow
Solves an overflow issue that is encountered when building a dictionary from many timeframes. In certain cases the `uint32_t` data type used to count the frequencies of symbols can overflow. The result is skewed distributions resulting in bad compression (usual case) or unrecoverable data loss due to wrong encoding (unlikely but possible). Changes: * Use `uint64_t` for sum when building cumulative frequencies during renorming process. * Introduce overflow checks when adding two frequency tables. * Fault tolerance in `CTFWriterSpec`, in case addition of frequency tables fails.
1 parent 9a27fc9 commit efd1996

4 files changed

Lines changed: 47 additions & 16 deletions

File tree

Detectors/CTF/workflow/src/CTFWriterSpec.cxx

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "DataFormatsCTP/CTF.h"
4444
#include "rANS/rans.h"
4545
#include <vector>
46+
#include <stdexcept>
4647
#include <array>
4748
#include <TStopwatch.h>
4849
#include <vector>
@@ -118,8 +119,8 @@ class CTFWriterSpec : public o2::framework::Task
118119
bool mCreateRunEnvDir = true;
119120
bool mStoreMetaFile = false;
120121
int mVerbosity = 0;
121-
int mSaveDictAfter = 0; // if positive and mWriteCTF==true, save dictionary after each mSaveDictAfter TFs processed
122-
int mFlagMinDet = 1; // append list of detectors to LHC period if their number is <= mFlagMinDet
122+
int mSaveDictAfter = 0; // if positive and mWriteCTF==true, save dictionary after each mSaveDictAfter TFs processed
123+
int mFlagMinDet = 1; // append list of detectors to LHC period if their number is <= mFlagMinDet
123124
uint32_t mPrevDictTimeStamp = 0; // timestamp of the previously stored dictionary
124125
uint32_t mDictTimeStamp = 0; // timestamp of the currently stored dictionary
125126
uint64_t mRun = 0;
@@ -160,6 +161,7 @@ class CTFWriterSpec : public o2::framework::Task
160161
// The metadata of the block (min,max) will be used for the consistency check at the decoding
161162
std::array<std::vector<FTrans>, DetID::nDetectors> mFreqsAccumulation;
162163
std::array<std::vector<o2::ctf::Metadata>, DetID::nDetectors> mFreqsMetaData;
164+
std::array<std::bitset<64>, DetID::nDetectors> mIsSaturatedFrequencyTable;
163165
std::array<std::shared_ptr<void>, DetID::nDetectors> mHeaders;
164166
TStopwatch mTimer;
165167

@@ -172,6 +174,7 @@ const std::string CTFWriterSpec::TMPFileEnding{".part"};
172174
CTFWriterSpec::CTFWriterSpec(DetID::mask_t dm, uint64_t r, const std::string& outType, int verbosity)
173175
: mDets(dm), mRun(r), mOutputType(outType), mVerbosity(verbosity)
174176
{
177+
std::for_each(mIsSaturatedFrequencyTable.begin(), mIsSaturatedFrequencyTable.end(), [](auto& bitset) { bitset.reset(); });
175178
mTimer.Stop();
176179
mTimer.Reset();
177180
}
@@ -262,14 +265,27 @@ size_t CTFWriterSpec::processDet(o2::framework::ProcessingContext& pc, DetID det
262265
hb.det = det;
263266
}
264267
for (int ib = 0; ib < C::getNBlocks(); ib++) {
265-
const auto& bl = ctfImage.getBlock(ib);
266-
if (bl.getNDict()) {
267-
auto& freq = mFreqsAccumulation[det][ib];
268-
auto& mdSave = mFreqsMetaData[det][ib];
269-
const auto& md = ctfImage.getMetadata(ib);
270-
freq.addFrequencies(bl.getDict(), bl.getDict() + bl.getNDict(), md.min);
271-
auto newProbBits = uint8_t(o2::rans::computeRenormingPrecision(freq));
272-
mdSave = o2::ctf::Metadata{0, 0, md.messageWordSize, md.coderType, md.streamSize, newProbBits, md.opt, freq.getMinSymbol(), freq.getMaxSymbol(), (int)freq.size(), 0, 0};
268+
if (!mIsSaturatedFrequencyTable[det][ib]) {
269+
const auto& bl = ctfImage.getBlock(ib);
270+
if (bl.getNDict()) {
271+
auto freq = mFreqsAccumulation[det][ib];
272+
auto& mdSave = mFreqsMetaData[det][ib];
273+
const auto& md = ctfImage.getMetadata(ib);
274+
if ([&, this]() {
275+
try {
276+
freq.addFrequencies(bl.getDict(), bl.getDict() + bl.getNDict(), md.min);
277+
} catch (const std::overflow_error& e) {
278+
LOGP(warning, "unable to frequency table for {}, block {} due to overflow", det.getName(), ib);
279+
mIsSaturatedFrequencyTable[det][ib] = true;
280+
return false;
281+
}
282+
return true;
283+
}()) {
284+
auto newProbBits = static_cast<uint8_t>(o2::rans::computeRenormingPrecision(freq));
285+
mdSave = o2::ctf::Metadata{0, 0, md.messageWordSize, md.coderType, md.streamSize, newProbBits, md.opt, freq.getMinSymbol(), freq.getMaxSymbol(), static_cast<int32_t>(freq.size()), 0, 0};
286+
mFreqsAccumulation[det][ib] = std::move(freq);
287+
}
288+
}
273289
}
274290
}
275291
}

Utilities/rANS/include/rANS/FrequencyTable.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ FrequencyTable& FrequencyTable::addFrequencies(Freq_IT begin, Freq_IT end, symbo
239239
histogramOverlap = utils::intersection(newHistogram, addedHistogram);
240240
assert(!histogramOverlap.empty());
241241
assert(histogramOverlap.size() == addedHistogram.size());
242-
std::transform(addedHistogram.begin(), addedHistogram.end(), histogramOverlap.begin(), histogramOverlap.begin(), [this](const count_t& a, const count_t& b) { return this->frequencyCountingDecorator(a) + b; });
242+
std::transform(addedHistogram.begin(), addedHistogram.end(), histogramOverlap.begin(), histogramOverlap.begin(), [this](const count_t& a, const count_t& b) { return internal::safeadd(this->frequencyCountingDecorator(a), b); });
243243

244244
this->mFrequencyTable = std::move(newFreequencyTable);
245245
this->mOffset = newHistogram.getOffset();
@@ -259,7 +259,7 @@ FrequencyTable& FrequencyTable::addFrequencies(Freq_IT begin, Freq_IT end, symbo
259259
if (!overlapAdded.empty()) {
260260
assert(overlapAdded.getMin() == overlapThis.getMin());
261261
assert(overlapAdded.size() == overlapThis.size());
262-
std::transform(overlapAdded.begin(), overlapAdded.end(), overlapThis.begin(), overlapThis.begin(), [this](const count_t& a, const count_t& b) { return this->frequencyCountingDecorator(a) + b; });
262+
std::transform(overlapAdded.begin(), overlapAdded.end(), overlapThis.begin(), overlapThis.begin(), [this](const count_t& a, const count_t& b) { return internal::safeadd(this->frequencyCountingDecorator(a), b); });
263263
}
264264

265265
// right incompressible tail

Utilities/rANS/include/rANS/internal/helper.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
#include <sstream>
2626
#include <vector>
2727

28+
#define rans_likely(x) __builtin_expect((x), 1)
29+
#define rans_unlikely(x) __builtin_expect((x), 0)
30+
2831
namespace o2
2932
{
3033
namespace rans
@@ -142,6 +145,15 @@ class JSONArrayLogger
142145
bool mReverse{false};
143146
};
144147

148+
inline uint32_t safeadd(uint32_t a, uint32_t b)
149+
{
150+
uint32_t result;
151+
if (rans_unlikely(__builtin_uadd_overflow(a, b, &result))) {
152+
throw std::overflow_error("arithmetic overflow during addition");
153+
}
154+
return result;
155+
}
156+
145157
template <typename T, typename IT>
146158
inline constexpr bool isCompatibleIter_v = std::is_convertible_v<typename std::iterator_traits<IT>::value_type, T>;
147159
template <typename IT>

Utilities/rANS/src/RenormedFrequencyTable.cxx

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ RenormedFrequencyTable renorm(FrequencyTable frequencyTable, size_t newPrecision
3737
newPrecision = computeRenormingPrecision(frequencyTable);
3838
}
3939

40-
count_t nSamples = frequencyTable.getNumSamples();
40+
size_t nSamples = frequencyTable.getNumSamples();
4141
count_t nIncompressible = frequencyTable.getIncompressibleSymbolFrequency();
4242
count_t nUsedAlphabetSymbols = frequencyTable.getNUsedAlphabetSymbols();
4343
const symbol_t offset = frequencyTable.getMinSymbol();
@@ -52,11 +52,14 @@ RenormedFrequencyTable renorm(FrequencyTable frequencyTable, size_t newPrecision
5252
histogram_t frequencies = std::move(frequencyTable).release();
5353
frequencies.push_back(nIncompressible);
5454

55-
histogram_t cumulativeFrequencies(frequencies.size() + 1);
55+
std::vector<uint64_t> cumulativeFrequencies(frequencies.size() + 1);
5656
cumulativeFrequencies[0] = 0;
57-
std::inclusive_scan(frequencies.begin(), frequencies.end(), ++cumulativeFrequencies.begin());
57+
std::inclusive_scan(frequencies.begin(), frequencies.end(), ++cumulativeFrequencies.begin(), std::plus<>(), 0ull);
5858

59-
auto getFrequency = [&cumulativeFrequencies](count_t i) { return cumulativeFrequencies[i + 1] - cumulativeFrequencies[i]; };
59+
auto getFrequency = [&cumulativeFrequencies](count_t i) {
60+
assert(cumulativeFrequencies[i + 1] >= cumulativeFrequencies[i]);
61+
return cumulativeFrequencies[i + 1] - cumulativeFrequencies[i];
62+
};
6063

6164
const auto sortIdx = [&]() {
6265
std::vector<size_t> indices;

0 commit comments

Comments
 (0)