@@ -125,6 +125,11 @@ static Array decode(DecodeContext ctx) {
125125 int bufIdx = 0 ;
126126 long rawByteOffset = 0L ;
127127
128+ long [] batchLowers1 = new long [PcoTansDecoder .BATCH_N ];
129+ int [] batchOffsetBits1 = new int [PcoTansDecoder .BATCH_N ];
130+ long [] batchLowers2 = new long [PcoTansDecoder .BATCH_N ];
131+ int [] batchOffsetBits2 = new int [PcoTansDecoder .BATCH_N ];
132+
128133 for (int c = 0 ; c < nChunks ; c ++) {
129134 EncodingProtos .PcoChunkInfo chunkInfo = meta .getChunks (c );
130135 MemorySegment chunkMetaBuf = ctx .buffer (bufIdx ++);
@@ -153,7 +158,8 @@ static Array decode(DecodeContext ctx) {
153158 chunkMeta .conv1Quantization (), chunkMeta .conv1Bias (),
154159 chunkMeta .conv1Weights (),
155160 dtypeSize , pageBuf , pageN ,
156- rawLatents , rawByteOffset );
161+ rawLatents , rawByteOffset ,
162+ batchLowers1 , batchOffsetBits1 );
157163 }
158164 } else if (deltaVariant == 2 ) {
159165 // Lookback delta: currently only Classic mode supported.
@@ -177,7 +183,9 @@ static Array decode(DecodeContext ctx) {
177183 primaryTans , chunkMeta .ansSizeLog (),
178184 stateN , windowN , mid , mask ,
179185 dtypeSize , pageBuf , pageN ,
180- rawLatents , rawByteOffset , ctx .arena ());
186+ rawLatents , rawByteOffset , ctx .arena (),
187+ batchLowers1 , batchOffsetBits1 ,
188+ batchLowers2 , batchOffsetBits2 );
181189 }
182190 } else if (mode == 0 || mode == 4 ) {
183191 // Single-latent var: Classic or Dict.
@@ -188,7 +196,8 @@ static Array decode(DecodeContext ctx) {
188196 MemorySegment pageBuf = ctx .buffer (bufIdx ++);
189197 rawByteOffset = decodeClassicPage (tans , chunkMeta .ansSizeLog (),
190198 chunkMeta .deltaOrder (), primaryDtypeSize ,
191- pageBuf , pageN , rawLatents , rawByteOffset );
199+ pageBuf , pageN , rawLatents , rawByteOffset ,
200+ batchLowers1 , batchOffsetBits1 );
192201 }
193202 if (mode == 4 ) {
194203 combineDict (chunkMeta .dict (), chunkN , rawLatents , chunkStartOffset );
@@ -212,7 +221,9 @@ static Array decode(DecodeContext ctx) {
212221 secondaryTans , secondaryAnsSizeLog , secondaryDeltaOrder ,
213222 dtypeSize , pageBuf , pageN ,
214223 rawLatents , rawByteOffset ,
215- rawAdjs , adjByteOffset );
224+ rawAdjs , adjByteOffset ,
225+ batchLowers1 , batchOffsetBits1 ,
226+ batchLowers2 , batchOffsetBits2 );
216227 rawByteOffset += (long ) pageN * Long .BYTES ;
217228 adjByteOffset += (long ) pageN * Long .BYTES ;
218229 }
@@ -268,7 +279,8 @@ private static Array decodeChild(DecodeContext parent, int idx, DType dtype, lon
268279 /// Decode one Classic-mode page into rawLatents and return the updated byte offset.
269280 private static long decodeClassicPage (PcoTansDecoder tans , int ansSizeLog , int deltaOrder ,
270281 int primaryDtypeSize , MemorySegment pageBuf , int pageN ,
271- MemorySegment rawLatents , long rawByteOffset ) {
282+ MemorySegment rawLatents , long rawByteOffset ,
283+ long [] batchLowers , int [] batchOffsetBits ) {
272284 LeBitReader pageReader = new LeBitReader (pageBuf );
273285
274286 long [] moments = new long [deltaOrder ];
@@ -283,7 +295,8 @@ private static long decodeClassicPage(PcoTansDecoder tans, int ansSizeLog, int d
283295 pageReader .alignToByte ();
284296
285297 int decodedN = pageN - deltaOrder ;
286- tans .decodePage (pageReader , stateIdxs , decodedN , rawLatents , rawByteOffset );
298+ tans .decodePage (pageReader , stateIdxs , decodedN , rawLatents , rawByteOffset ,
299+ batchLowers , batchOffsetBits );
287300
288301 if (deltaOrder > 0 ) {
289302 applyConsecutiveDelta (rawLatents , rawByteOffset , pageN , moments , primaryDtypeSize );
@@ -299,7 +312,9 @@ private static void decodeIntMultPage(
299312 PcoTansDecoder secondaryTans , int secondaryAnsSizeLog , int secondaryDeltaOrder ,
300313 int dtypeSize , MemorySegment pageBuf , int pageN ,
301314 MemorySegment rawMults , long multsOffset ,
302- MemorySegment rawAdjs , long adjsOffset ) {
315+ MemorySegment rawAdjs , long adjsOffset ,
316+ long [] batchLowersP , int [] batchOffsetBitsP ,
317+ long [] batchLowersS , int [] batchOffsetBitsS ) {
303318 LeBitReader pageReader = new LeBitReader (pageBuf );
304319
305320 long [] primaryMoments = new long [deltaOrder ];
@@ -322,11 +337,6 @@ private static void decodeIntMultPage(
322337
323338 pageReader .alignToByte ();
324339
325- long [] batchLowersP = new long [PcoTansDecoder .BATCH_N ];
326- int [] batchOffsetBitsP = new int [PcoTansDecoder .BATCH_N ];
327- long [] batchLowersS = new long [PcoTansDecoder .BATCH_N ];
328- int [] batchOffsetBitsS = new int [PcoTansDecoder .BATCH_N ];
329-
330340 int nRemaining = pageN ;
331341 long primaryPos = multsOffset ;
332342 long secondaryPos = adjsOffset ;
@@ -364,7 +374,9 @@ private static long decodeLookbackPage(
364374 int stateN , int windowN , long mid , long mask ,
365375 int dtypeSize , MemorySegment pageBuf , int pageN ,
366376 MemorySegment rawLatents , long latentsOffset ,
367- SegmentAllocator arena ) {
377+ SegmentAllocator arena ,
378+ long [] batchLowersD , int [] batchOffsetBitsD ,
379+ long [] batchLowersP , int [] batchOffsetBitsP ) {
368380 if (pageN < stateN ) {
369381 throw new VortexException (EncodingId .VORTEX_PCO ,
370382 "pco corrupt lookback page: stateN " + stateN + " exceeds pageN " + pageN );
@@ -396,11 +408,6 @@ private static long decodeLookbackPage(
396408 MemorySegment rawLookbacks = arena .allocate ((long ) decodeN * Long .BYTES );
397409 MemorySegment rawResiduals = arena .allocate ((long ) decodeN * Long .BYTES );
398410
399- long [] batchLowersD = new long [PcoTansDecoder .BATCH_N ];
400- int [] batchOffsetBitsD = new int [PcoTansDecoder .BATCH_N ];
401- long [] batchLowersP = new long [PcoTansDecoder .BATCH_N ];
402- int [] batchOffsetBitsP = new int [PcoTansDecoder .BATCH_N ];
403-
404411 int remaining = decodeN ;
405412 long dPos = 0L ;
406413 long pPos = 0L ;
@@ -459,7 +466,8 @@ private static long decodeConv1Page(
459466 PcoTansDecoder tans , int ansSizeLog ,
460467 int order , int quantization , long bias , long [] weights ,
461468 int dtypeSize , MemorySegment pageBuf , int pageN ,
462- MemorySegment rawLatents , long latentsOffset ) {
469+ MemorySegment rawLatents , long latentsOffset ,
470+ long [] batchLowers , int [] batchOffsetBits ) {
463471 LeBitReader pageReader = new LeBitReader (pageBuf );
464472
465473 long [] state = new long [order ];
@@ -483,7 +491,8 @@ private static long decodeConv1Page(
483491
484492 // Decode residuals directly into rawLatents[latentsOffset + order*8..].
485493 tans .decodePage (pageReader , stateIdxs , decodeN , rawLatents ,
486- latentsOffset + (long ) order * Long .BYTES );
494+ latentsOffset + (long ) order * Long .BYTES ,
495+ batchLowers , batchOffsetBits );
487496
488497 // Toggle-center decoded residuals in-place.
489498 for (int i = order ; i < pageN ; i ++) {
0 commit comments