Skip to content

Commit 00956d3

Browse files
committed
512-bit vectors in utf8 validator
1 parent 1f9074d commit 00956d3

3 files changed

Lines changed: 36 additions & 57 deletions

File tree

src/main/java/org/simdjson/StructuralIndexer.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import jdk.incubator.vector.ByteVector;
44
import jdk.incubator.vector.VectorSpecies;
5-
import java.lang.invoke.MethodType;
65

76
import static jdk.incubator.vector.VectorOperators.UNSIGNED_LE;
87

src/main/java/org/simdjson/Utf8Validator.java

Lines changed: 33 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44

55
import java.util.Arrays;
66

7-
public class Utf8Validator {
8-
private static final VectorSpecies<Byte> VECTOR_SPECIES = ByteVector.SPECIES_256;
7+
class Utf8Validator {
8+
9+
private static final VectorSpecies<Byte> VECTOR_SPECIES = StructuralIndexer.SPECIES;
910
private static final ByteVector INCOMPLETE_CHECK = getIncompleteCheck();
10-
private static final VectorShuffle<Integer> SHIFT_FOUR_BYTES_FORWARD = VectorShuffle.iota(IntVector.SPECIES_256,
11-
IntVector.SPECIES_256.elementSize() - 1, 1, true);
12-
private static final ByteVector LOW_NIBBLE_MASK = ByteVector.broadcast(VECTOR_SPECIES, 0b0000_1111);
11+
private static final byte LOW_NIBBLE_MASK = 0x0f;
1312
private static final ByteVector ALL_ASCII_MASK = ByteVector.broadcast(VECTOR_SPECIES, (byte) 0b1000_0000);
1413

1514
/**
@@ -19,9 +18,9 @@ public class Utf8Validator {
1918
* @throws JsonParsingException if the input is not valid UTF8
2019
*/
2120
static void validate(byte[] inputBytes) {
22-
long previousIncomplete = 0;
23-
long errors = 0;
24-
int previousFourUtf8Bytes = 0;
21+
boolean previousIncomplete = false;
22+
boolean errors = false;
23+
ByteVector prevChunk = ByteVector.zero(VECTOR_SPECIES);
2524

2625
int idx = 0;
2726
for (; idx < VECTOR_SPECIES.loopBound(inputBytes.length); idx += VECTOR_SPECIES.vectorByteSize()) {
@@ -32,14 +31,12 @@ static void validate(byte[] inputBytes) {
3231
} else {
3332
previousIncomplete = isIncomplete(utf8Vector);
3433

35-
var fourBytesPrevious = fourBytesPreviousSlice(utf8Vector, previousFourUtf8Bytes);
36-
37-
ByteVector firstCheck = firstTwoByteSequenceCheck(utf8Vector.reinterpretAsInts(), fourBytesPrevious);
38-
ByteVector secondCheck = lastTwoByteSequenceCheck(utf8Vector.reinterpretAsInts(), fourBytesPrevious, firstCheck);
34+
ByteVector firstCheck = firstTwoByteSequenceCheck(utf8Vector, prevChunk);
35+
ByteVector secondCheck = lastTwoByteSequenceCheck(utf8Vector, prevChunk, firstCheck);
3936

40-
errors |= secondCheck.compare(VectorOperators.NE, 0).toLong();
37+
errors |= secondCheck.compare(VectorOperators.NE, 0).anyTrue();
4138
}
42-
previousFourUtf8Bytes = utf8Vector.reinterpretAsInts().lane(IntVector.SPECIES_256.length() - 1);
39+
prevChunk = utf8Vector;
4340
}
4441

4542
// if the input file doesn't align with the vector width, pad the missing bytes with zero
@@ -48,73 +45,56 @@ static void validate(byte[] inputBytes) {
4845
if (!isAscii(lastVectorChunk)) {
4946
previousIncomplete = isIncomplete(lastVectorChunk);
5047

51-
var fourBytesPrevious = fourBytesPreviousSlice(lastVectorChunk, previousFourUtf8Bytes);
48+
ByteVector firstCheck = firstTwoByteSequenceCheck(lastVectorChunk, prevChunk);
49+
ByteVector secondCheck = lastTwoByteSequenceCheck(lastVectorChunk, prevChunk, firstCheck);
5250

53-
ByteVector firstCheck = firstTwoByteSequenceCheck(lastVectorChunk.reinterpretAsInts(), fourBytesPrevious);
54-
ByteVector secondCheck = lastTwoByteSequenceCheck(lastVectorChunk.reinterpretAsInts(), fourBytesPrevious, firstCheck);
55-
56-
errors |= secondCheck.compare(VectorOperators.NE, 0).toLong();
51+
errors |= secondCheck.compare(VectorOperators.NE, 0).anyTrue();
5752
}
5853

59-
if ((errors | previousIncomplete) != 0) {
54+
if (errors | previousIncomplete) {
6055
throw new JsonParsingException("Invalid UTF8");
6156
}
6257
}
6358

64-
/* Shuffles the input forward by four bytes to make space for the previous four bytes.
65-
The previous three bytes are required for validation, pulling in the last integer will give the previous four bytes.
66-
The switch to integer vectors is to allow for integer shifting instead of the more expensive shuffle / slice operations */
67-
private static IntVector fourBytesPreviousSlice(ByteVector vectorChunk, int previousFourUtf8Bytes) {
68-
return vectorChunk.reinterpretAsInts()
69-
.rearrange(SHIFT_FOUR_BYTES_FORWARD)
70-
.withLane(0, previousFourUtf8Bytes);
71-
}
72-
73-
// works similar to previousUtf8Vector.slice(VECTOR_SPECIES.length() - numOfBytesToInclude, utf8Vector) but without the performance cost
74-
private static ByteVector previousVectorSlice(IntVector utf8Vector, IntVector fourBytesPrevious, int numOfPreviousBytes) {
75-
return utf8Vector
76-
.lanewise(VectorOperators.LSHL, Byte.SIZE * numOfPreviousBytes)
77-
.or(fourBytesPrevious.lanewise(VectorOperators.LSHR, Byte.SIZE * (4 - numOfPreviousBytes)))
78-
.reinterpretAsBytes();
79-
}
80-
81-
private static ByteVector firstTwoByteSequenceCheck(IntVector utf8Vector, IntVector fourBytesPrevious) {
59+
private static ByteVector firstTwoByteSequenceCheck(ByteVector utf8Vector, ByteVector prevChunk) {
8260
// shift the current input forward by 1 byte to include 1 byte from the previous input
83-
var oneBytePrevious = previousVectorSlice(utf8Vector, fourBytesPrevious, 1);
61+
var oneBytePrevious = concatenate(utf8Vector, prevChunk, 1);
8462

8563
// high nibbles of the current input (e.g. 0xC3 >> 4 = 0xC)
86-
ByteVector byte2HighNibbles = utf8Vector.lanewise(VectorOperators.LSHR, 4)
87-
.reinterpretAsBytes().and(LOW_NIBBLE_MASK);
64+
ByteVector byte2HighNibbles = utf8Vector.lanewise(VectorOperators.LSHR, 4);
8865

8966
// high nibbles of the shifted input
90-
ByteVector byte1HighNibbles = oneBytePrevious.reinterpretAsInts().lanewise(VectorOperators.LSHR, 4)
91-
.reinterpretAsBytes().and(LOW_NIBBLE_MASK);
67+
ByteVector byte1HighNibbles = oneBytePrevious.lanewise(VectorOperators.LSHR, 4);
9268

9369
// low nibbles of the shifted input (e.g. 0xC3 & 0xF = 0x3)
9470
ByteVector byte1LowNibbles = oneBytePrevious.and(LOW_NIBBLE_MASK);
95-
96-
ByteVector byte1HighState = byte1HighNibbles.selectFrom(LookupTable.byte1High);
97-
ByteVector byte1LowState = byte1LowNibbles.selectFrom(LookupTable.byte1Low);
98-
ByteVector byte2HighState = byte2HighNibbles.selectFrom(LookupTable.byte2High);
99-
71+
ByteVector byte1HighState = byte2HighNibbles.selectFrom(LookupTable.byte2High);
72+
ByteVector byte1LowState = byte1HighNibbles.selectFrom(LookupTable.byte1High);
73+
ByteVector byte2HighState = byte1LowNibbles.selectFrom(LookupTable.byte1Low);
10074
return byte1HighState.and(byte1LowState).and(byte2HighState);
10175
}
10276

10377
// All remaining checks are invalid 3–4 byte sequences, which either have too many continuations bytes or not enough
104-
private static ByteVector lastTwoByteSequenceCheck(IntVector utf8Vector, IntVector fourBytesPrevious, ByteVector firstCheck) {
78+
private static ByteVector lastTwoByteSequenceCheck(ByteVector utf8Vector, ByteVector prevChunk, ByteVector firstCheck) {
10579
// the minimum 3byte lead - 1110_0000 is always greater than the max 2byte lead - 110_11111
106-
ByteVector twoBytesPrevious = previousVectorSlice(utf8Vector, fourBytesPrevious, 2);
80+
ByteVector twoBytesPrevious = concatenate(utf8Vector, prevChunk, 2);
81+
10782
VectorMask<Byte> is3ByteLead = twoBytesPrevious.compare(VectorOperators.UNSIGNED_GT, (byte) 0b110_11111);
10883

10984
// the minimum 4byte lead - 1111_0000 is always greater than the max 3byte lead - 1110_1111
110-
ByteVector threeBytesPrevious = previousVectorSlice(utf8Vector, fourBytesPrevious, 3);
85+
ByteVector threeBytesPrevious = concatenate(utf8Vector, prevChunk, 3);
86+
11187
VectorMask<Byte> is4ByteLead = threeBytesPrevious.compare(VectorOperators.UNSIGNED_GT, (byte) 0b1110_1111);
11288

11389
// the firstCheck vector contains 0x80 values on continuation byte indexes
11490
// the 3/4 byte lead bytes should match up with these indexes and zero them out
11591
return firstCheck.add((byte) 0x80, is3ByteLead.or(is4ByteLead));
11692
}
11793

94+
private static ByteVector concatenate(ByteVector curr, ByteVector prev, int byteCountFromPrev) {
95+
return prev.slice(VECTOR_SPECIES.length() - byteCountFromPrev, curr);
96+
}
97+
11898
/* checks that the previous vector isn't in an incomplete state.
11999
Previous vector is in an incomplete state if the last byte is smaller than 0xC0,
120100
or the second last byte is smaller than 0xE0, or the third last byte is smaller than 0xF0.*/
@@ -128,8 +108,8 @@ private static ByteVector getIncompleteCheck() {
128108
return ByteVector.fromArray(VECTOR_SPECIES, eofArray, 0);
129109
}
130110

131-
private static long isIncomplete(ByteVector utf8Vector) {
132-
return utf8Vector.compare(VectorOperators.UNSIGNED_GE, INCOMPLETE_CHECK).toLong();
111+
private static boolean isIncomplete(ByteVector utf8Vector) {
112+
return utf8Vector.compare(VectorOperators.UNSIGNED_GE, INCOMPLETE_CHECK).anyTrue();
133113
}
134114

135115
// ASCII will never exceed 01111_1111

src/test/java/org/simdjson/Utf8ValidatorTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
package org.simdjson;
22

3-
import jdk.incubator.vector.ByteVector;
43
import jdk.incubator.vector.VectorSpecies;
54
import org.junit.jupiter.api.Test;
65
import org.junit.jupiter.params.ParameterizedTest;
76
import org.junit.jupiter.params.provider.ValueSource;
87

98
import java.io.IOException;
109
import java.util.Arrays;
11-
import java.util.Objects;
1210

13-
import static org.assertj.core.api.Assertions.*;
11+
import static org.assertj.core.api.Assertions.assertThatCode;
12+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
1413

1514
class Utf8ValidatorTest {
15+
1616
private static final VectorSpecies<Byte> VECTOR_SPECIES = StructuralIndexer.SPECIES;
1717

1818

0 commit comments

Comments
 (0)