Skip to content

Commit c2be9ba

Browse files
dfa1claude
andcommitted
perf(pco): hoist BATCH_N scratch arrays out of page loops
Allocate 4×long[256]/int[256] once per decode call instead of per page. Thread scratch params through decodeClassicPage, decodeConv1Page, decodeIntMultPage, decodeLookbackPage, and PcoTansDecoder.decodePage. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent e9f0e0d commit c2be9ba

3 files changed

Lines changed: 41 additions & 27 deletions

File tree

core/src/main/java/io/github/dfa1/vortex/encoding/PcoEncoding.java

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ static Array decode(DecodeContext ctx) {
125125
int bufIdx = 0;
126126
long rawByteOffset = 0L;
127127

128+
long[] batchLowers1 = new long[PcoTansDecoder.BATCH_N];
129+
int[] batchOffsetBits1 = new int[PcoTansDecoder.BATCH_N];
130+
long[] batchLowers2 = new long[PcoTansDecoder.BATCH_N];
131+
int[] batchOffsetBits2 = new int[PcoTansDecoder.BATCH_N];
132+
128133
for (int c = 0; c < nChunks; c++) {
129134
EncodingProtos.PcoChunkInfo chunkInfo = meta.getChunks(c);
130135
MemorySegment chunkMetaBuf = ctx.buffer(bufIdx++);
@@ -153,7 +158,8 @@ static Array decode(DecodeContext ctx) {
153158
chunkMeta.conv1Quantization(), chunkMeta.conv1Bias(),
154159
chunkMeta.conv1Weights(),
155160
dtypeSize, pageBuf, pageN,
156-
rawLatents, rawByteOffset);
161+
rawLatents, rawByteOffset,
162+
batchLowers1, batchOffsetBits1);
157163
}
158164
} else if (deltaVariant == 2) {
159165
// Lookback delta: currently only Classic mode supported.
@@ -177,7 +183,9 @@ static Array decode(DecodeContext ctx) {
177183
primaryTans, chunkMeta.ansSizeLog(),
178184
stateN, windowN, mid, mask,
179185
dtypeSize, pageBuf, pageN,
180-
rawLatents, rawByteOffset, ctx.arena());
186+
rawLatents, rawByteOffset, ctx.arena(),
187+
batchLowers1, batchOffsetBits1,
188+
batchLowers2, batchOffsetBits2);
181189
}
182190
} else if (mode == 0 || mode == 4) {
183191
// Single-latent var: Classic or Dict.
@@ -188,7 +196,8 @@ static Array decode(DecodeContext ctx) {
188196
MemorySegment pageBuf = ctx.buffer(bufIdx++);
189197
rawByteOffset = decodeClassicPage(tans, chunkMeta.ansSizeLog(),
190198
chunkMeta.deltaOrder(), primaryDtypeSize,
191-
pageBuf, pageN, rawLatents, rawByteOffset);
199+
pageBuf, pageN, rawLatents, rawByteOffset,
200+
batchLowers1, batchOffsetBits1);
192201
}
193202
if (mode == 4) {
194203
combineDict(chunkMeta.dict(), chunkN, rawLatents, chunkStartOffset);
@@ -212,7 +221,9 @@ static Array decode(DecodeContext ctx) {
212221
secondaryTans, secondaryAnsSizeLog, secondaryDeltaOrder,
213222
dtypeSize, pageBuf, pageN,
214223
rawLatents, rawByteOffset,
215-
rawAdjs, adjByteOffset);
224+
rawAdjs, adjByteOffset,
225+
batchLowers1, batchOffsetBits1,
226+
batchLowers2, batchOffsetBits2);
216227
rawByteOffset += (long) pageN * Long.BYTES;
217228
adjByteOffset += (long) pageN * Long.BYTES;
218229
}
@@ -268,7 +279,8 @@ private static Array decodeChild(DecodeContext parent, int idx, DType dtype, lon
268279
/// Decode one Classic-mode page into rawLatents and return the updated byte offset.
269280
private static long decodeClassicPage(PcoTansDecoder tans, int ansSizeLog, int deltaOrder,
270281
int primaryDtypeSize, MemorySegment pageBuf, int pageN,
271-
MemorySegment rawLatents, long rawByteOffset) {
282+
MemorySegment rawLatents, long rawByteOffset,
283+
long[] batchLowers, int[] batchOffsetBits) {
272284
LeBitReader pageReader = new LeBitReader(pageBuf);
273285

274286
long[] moments = new long[deltaOrder];
@@ -283,7 +295,8 @@ private static long decodeClassicPage(PcoTansDecoder tans, int ansSizeLog, int d
283295
pageReader.alignToByte();
284296

285297
int decodedN = pageN - deltaOrder;
286-
tans.decodePage(pageReader, stateIdxs, decodedN, rawLatents, rawByteOffset);
298+
tans.decodePage(pageReader, stateIdxs, decodedN, rawLatents, rawByteOffset,
299+
batchLowers, batchOffsetBits);
287300

288301
if (deltaOrder > 0) {
289302
applyConsecutiveDelta(rawLatents, rawByteOffset, pageN, moments, primaryDtypeSize);
@@ -299,7 +312,9 @@ private static void decodeIntMultPage(
299312
PcoTansDecoder secondaryTans, int secondaryAnsSizeLog, int secondaryDeltaOrder,
300313
int dtypeSize, MemorySegment pageBuf, int pageN,
301314
MemorySegment rawMults, long multsOffset,
302-
MemorySegment rawAdjs, long adjsOffset) {
315+
MemorySegment rawAdjs, long adjsOffset,
316+
long[] batchLowersP, int[] batchOffsetBitsP,
317+
long[] batchLowersS, int[] batchOffsetBitsS) {
303318
LeBitReader pageReader = new LeBitReader(pageBuf);
304319

305320
long[] primaryMoments = new long[deltaOrder];
@@ -322,11 +337,6 @@ private static void decodeIntMultPage(
322337

323338
pageReader.alignToByte();
324339

325-
long[] batchLowersP = new long[PcoTansDecoder.BATCH_N];
326-
int[] batchOffsetBitsP = new int[PcoTansDecoder.BATCH_N];
327-
long[] batchLowersS = new long[PcoTansDecoder.BATCH_N];
328-
int[] batchOffsetBitsS = new int[PcoTansDecoder.BATCH_N];
329-
330340
int nRemaining = pageN;
331341
long primaryPos = multsOffset;
332342
long secondaryPos = adjsOffset;
@@ -364,7 +374,9 @@ private static long decodeLookbackPage(
364374
int stateN, int windowN, long mid, long mask,
365375
int dtypeSize, MemorySegment pageBuf, int pageN,
366376
MemorySegment rawLatents, long latentsOffset,
367-
SegmentAllocator arena) {
377+
SegmentAllocator arena,
378+
long[] batchLowersD, int[] batchOffsetBitsD,
379+
long[] batchLowersP, int[] batchOffsetBitsP) {
368380
if (pageN < stateN) {
369381
throw new VortexException(EncodingId.VORTEX_PCO,
370382
"pco corrupt lookback page: stateN " + stateN + " exceeds pageN " + pageN);
@@ -396,11 +408,6 @@ private static long decodeLookbackPage(
396408
MemorySegment rawLookbacks = arena.allocate((long) decodeN * Long.BYTES);
397409
MemorySegment rawResiduals = arena.allocate((long) decodeN * Long.BYTES);
398410

399-
long[] batchLowersD = new long[PcoTansDecoder.BATCH_N];
400-
int[] batchOffsetBitsD = new int[PcoTansDecoder.BATCH_N];
401-
long[] batchLowersP = new long[PcoTansDecoder.BATCH_N];
402-
int[] batchOffsetBitsP = new int[PcoTansDecoder.BATCH_N];
403-
404411
int remaining = decodeN;
405412
long dPos = 0L;
406413
long pPos = 0L;
@@ -459,7 +466,8 @@ private static long decodeConv1Page(
459466
PcoTansDecoder tans, int ansSizeLog,
460467
int order, int quantization, long bias, long[] weights,
461468
int dtypeSize, MemorySegment pageBuf, int pageN,
462-
MemorySegment rawLatents, long latentsOffset) {
469+
MemorySegment rawLatents, long latentsOffset,
470+
long[] batchLowers, int[] batchOffsetBits) {
463471
LeBitReader pageReader = new LeBitReader(pageBuf);
464472

465473
long[] state = new long[order];
@@ -483,7 +491,8 @@ private static long decodeConv1Page(
483491

484492
// Decode residuals directly into rawLatents[latentsOffset + order*8..].
485493
tans.decodePage(pageReader, stateIdxs, decodeN, rawLatents,
486-
latentsOffset + (long) order * Long.BYTES);
494+
latentsOffset + (long) order * Long.BYTES,
495+
batchLowers, batchOffsetBits);
487496

488497
// Toggle-center decoded residuals in-place.
489498
for (int i = order; i < pageN; i++) {

core/src/main/java/io/github/dfa1/vortex/encoding/PcoTansDecoder.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,11 @@ static int[] spreadStateSymbols(int ansSizeLog, int[] weights, int tableSize) {
9595
/// Caller must have already read 4 initial ANS state indices and called
9696
/// {@link LeBitReader#alignToByte()} before this call.
9797
/// {@code ansStateIdxs} is modified in place and not valid after return.
98+
/// {@code batchLowers} and {@code batchOffsetBits} are caller-provided scratch arrays of
99+
/// length ≥ {@link #BATCH_N}; they are fully overwritten before use.
98100
void decodePage(LeBitReader reader, int[] ansStateIdxs, int n,
99-
MemorySegment out, long outByteOffset) {
100-
long[] batchLowers = new long[BATCH_N];
101-
int[] batchOffsetBits = new int[BATCH_N];
101+
MemorySegment out, long outByteOffset,
102+
long[] batchLowers, int[] batchOffsetBits) {
102103
int remaining = n;
103104
long pos = outByteOffset;
104105
while (remaining > 0) {

core/src/test/java/io/github/dfa1/vortex/encoding/PcoTansDecoderTest.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ void singleBin_zeroBits_constantOutput() {
7474
MemorySegment out = Arena.ofAuto().allocate((long) n * Long.BYTES);
7575

7676
// When
77-
sut.decodePage(reader, stateIdxs, n, out, 0L);
77+
sut.decodePage(reader, stateIdxs, n, out, 0L,
78+
new long[PcoTansDecoder.BATCH_N], new int[PcoTansDecoder.BATCH_N]);
7879

7980
// Then — all raw latent values = 42
8081
for (int i = 0; i < n; i++) {
@@ -99,7 +100,8 @@ void singleBin_oneBitOffset_decodesOffset() {
99100
MemorySegment out = Arena.ofAuto().allocate((long) n * Long.BYTES);
100101

101102
// When
102-
sut.decodePage(reader, stateIdxs, n, out, 0L);
103+
sut.decodePage(reader, stateIdxs, n, out, 0L,
104+
new long[PcoTansDecoder.BATCH_N], new int[PcoTansDecoder.BATCH_N]);
103105

104106
// Then — offsets all zero → all values = lower + 0 = 10
105107
for (int i = 0; i < n; i++) {
@@ -122,7 +124,8 @@ void degenerateBins_zeroBins_outputZero() {
122124
MemorySegment out = Arena.ofAuto().allocate((long) n * Long.BYTES);
123125

124126
// When
125-
sut.decodePage(reader, stateIdxs, n, out, 0L);
127+
sut.decodePage(reader, stateIdxs, n, out, 0L,
128+
new long[PcoTansDecoder.BATCH_N], new int[PcoTansDecoder.BATCH_N]);
126129

127130
// Then — all zero
128131
for (int i = 0; i < n; i++) {
@@ -178,7 +181,8 @@ void moreThanOneBatch_decodesCorrectly() {
178181
MemorySegment out = Arena.ofAuto().allocate((long) n * Long.BYTES);
179182

180183
// When
181-
sut.decodePage(reader, stateIdxs, n, out, 0L);
184+
sut.decodePage(reader, stateIdxs, n, out, 0L,
185+
new long[PcoTansDecoder.BATCH_N], new int[PcoTansDecoder.BATCH_N]);
182186

183187
// Then — all values = 7
184188
for (int i = 0; i < n; i++) {

0 commit comments

Comments
 (0)