Skip to content

Commit da1e3a5

Browse files
committed
Added support for sparsevec type
1 parent a86a6c1 commit da1e3a5

File tree

4 files changed

+324
-1
lines changed

4 files changed

+324
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## 0.1.5 (unreleased)
22

3-
- Added support for `halfvec` type
3+
- Added support for `halfvec` and `sparsevec` types
44

55
## 0.1.4 (2023-12-08)
66

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
package com.pgvector;
2+
3+
import java.io.Serializable;
4+
import java.sql.Connection;
5+
import java.sql.SQLException;
6+
import java.util.Arrays;
7+
import java.util.List;
8+
import java.util.Objects;
9+
import org.postgresql.PGConnection;
10+
import org.postgresql.util.ByteConverter;
11+
import org.postgresql.util.PGBinaryObject;
12+
import org.postgresql.util.PGobject;
13+
14+
/**
15+
* PGsparsevec class
16+
*/
17+
public class PGsparsevec extends PGobject implements PGBinaryObject, Serializable, Cloneable {
18+
private int dimensions;
19+
private int[] indices;
20+
private float[] values;
21+
22+
/**
23+
* Constructor
24+
*/
25+
public PGsparsevec() {
26+
type = "sparsevec";
27+
}
28+
29+
/**
30+
* Constructor
31+
*
32+
* @param v float array
33+
*/
34+
public PGsparsevec(float[] v) {
35+
this();
36+
37+
int nnz = 0;
38+
for (int i = 0; i < v.length; i++) {
39+
if (v[i] != 0) {
40+
nnz++;
41+
}
42+
}
43+
44+
dimensions = v.length;
45+
indices = new int[nnz];
46+
values = new float[nnz];
47+
48+
int j = 0;
49+
for (int i = 0; i < v.length; i++) {
50+
if (v[i] != 0) {
51+
indices[j] = i;
52+
values[j] = v[i];
53+
j++;
54+
}
55+
}
56+
}
57+
58+
/**
59+
* Constructor
60+
*
61+
* @param <T> number
62+
* @param v list of numbers
63+
*/
64+
public <T extends Number> PGsparsevec(List<T> v) {
65+
this();
66+
if (Objects.isNull(v)) {
67+
indices = null;
68+
} else {
69+
int nnz = 0;
70+
for (T f : v) {
71+
if (f.floatValue() != 0) {
72+
nnz++;
73+
}
74+
}
75+
76+
dimensions = v.size();
77+
indices = new int[nnz];
78+
values = new float[nnz];
79+
80+
int i = 0;
81+
int j = 0;
82+
for (T f : v) {
83+
float fv = f.floatValue();
84+
if (fv != 0) {
85+
indices[j] = i;
86+
values[j] = fv;
87+
j++;
88+
}
89+
i++;
90+
}
91+
92+
}
93+
}
94+
95+
/**
96+
* Constructor
97+
*
98+
* @param s text representation of a sparse vector
99+
* @throws SQLException exception
100+
*/
101+
public PGsparsevec(String s) throws SQLException {
102+
this();
103+
setValue(s);
104+
}
105+
106+
/**
107+
* Sets the value from a text representation of a sparse vector
108+
*/
109+
public void setValue(String s) throws SQLException {
110+
if (s == null) {
111+
indices = null;
112+
} else {
113+
String[] sp = s.split("/", 2);
114+
String[] elements = sp[0].substring(1, sp[0].length() - 1).split(",");
115+
116+
dimensions = Integer.parseInt(sp[1]);
117+
indices = new int[elements.length];
118+
values = new float[elements.length];
119+
120+
for (int i = 0; i < elements.length; i++)
121+
{
122+
String[] ep = elements[i].split(":", 2);
123+
indices[i] = Integer.parseInt(ep[0]) - 1;
124+
values[i] = Float.parseFloat(ep[1]);
125+
}
126+
}
127+
}
128+
129+
/**
130+
* Returns the text representation of a sparse vector
131+
*/
132+
public String getValue() {
133+
if (indices == null) {
134+
return null;
135+
} else {
136+
StringBuilder sb = new StringBuilder(13 + 27 * indices.length);
137+
sb.append('{');
138+
139+
for (int i = 0; i < indices.length; i++) {
140+
if (i > 0) {
141+
sb.append(',');
142+
}
143+
sb.append(indices[i] + 1);
144+
sb.append(':');
145+
sb.append(values[i]);
146+
}
147+
148+
sb.append('}');
149+
sb.append('/');
150+
sb.append(dimensions);
151+
return sb.toString();
152+
}
153+
}
154+
155+
/**
156+
* Returns the number of bytes for the binary representation
157+
*/
158+
public int lengthInBytes() {
159+
return indices == null ? 0 : 12 + indices.length * 4 + values.length * 4;
160+
}
161+
162+
/**
163+
* Sets the value from a binary representation of a sparse vector
164+
*/
165+
public void setByteValue(byte[] value, int offset) throws SQLException {
166+
dimensions = ByteConverter.int4(value, offset);
167+
int nnz = ByteConverter.int4(value, offset + 4);
168+
169+
int unused = ByteConverter.int4(value, offset + 8);
170+
if (unused != 0) {
171+
throw new SQLException("expected unused to be 0");
172+
}
173+
174+
indices = new int[nnz];
175+
for (int i = 0; i < nnz; i++) {
176+
indices[i] = ByteConverter.int4(value, offset + 12 + i * 4);
177+
}
178+
179+
values = new float[nnz];
180+
for (int i = 0; i < nnz; i++) {
181+
values[i] = ByteConverter.float4(value, offset + 12 + nnz * 4 + i * 4);
182+
}
183+
}
184+
185+
/**
186+
* Writes the binary representation of a sparse vector
187+
*/
188+
public void toBytes(byte[] bytes, int offset) {
189+
if (indices == null) {
190+
return;
191+
}
192+
193+
// server will error on overflow due to unconsumed buffer
194+
// could set to Integer.MAX_VALUE for friendlier error message
195+
ByteConverter.int4(bytes, offset, dimensions);
196+
ByteConverter.int4(bytes, offset + 4, indices.length);
197+
ByteConverter.int4(bytes, offset + 8, 0);
198+
for (int i = 0; i < indices.length; i++) {
199+
ByteConverter.int4(bytes, offset + 12 + i * 4, indices[i]);
200+
}
201+
for (int i = 0; i < values.length; i++) {
202+
ByteConverter.float4(bytes, offset + 12 + indices.length * 4 + i * 4, values[i]);
203+
}
204+
}
205+
206+
/**
207+
* Returns an array
208+
*
209+
* @return an array
210+
*/
211+
public float[] toArray() {
212+
if (indices == null) {
213+
return null;
214+
}
215+
216+
float[] vec = new float[dimensions];
217+
for (int i = 0; i < indices.length; i++) {
218+
vec[indices[i]] = values[i];
219+
}
220+
return vec;
221+
}
222+
223+
/**
224+
* Registers the sparsevec type
225+
*
226+
* @param conn connection
227+
* @throws SQLException exception
228+
*/
229+
public static void addSparsevecType(Connection conn) throws SQLException {
230+
conn.unwrap(PGConnection.class).addDataType("sparsevec", PGsparsevec.class);
231+
}
232+
}

src/test/java/com/pgvector/JDBCJavaTest.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,52 @@ void halfvecExample(boolean readBinary) throws SQLException {
116116
assertArrayEquals(new float[] {2, 2, 2}, embeddings.get(2).toArray());
117117
assertNull(embeddings.get(3));
118118
}
119+
120+
@Test
121+
void testSparsevecReadText() throws SQLException {
122+
sparsevecExample(false);
123+
}
124+
125+
@Test
126+
void testSparsevecReadBinary() throws SQLException {
127+
sparsevecExample(true);
128+
}
129+
130+
void sparsevecExample(boolean readBinary) throws SQLException {
131+
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_java_test");
132+
if (readBinary) {
133+
conn.unwrap(PGConnection.class).setPrepareThreshold(-1);
134+
}
135+
136+
Statement setupStmt = conn.createStatement();
137+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
138+
setupStmt.executeUpdate("DROP TABLE IF EXISTS jdbc_items");
139+
140+
PGsparsevec.addSparsevecType(conn);
141+
142+
Statement createStmt = conn.createStatement();
143+
createStmt.executeUpdate("CREATE TABLE jdbc_items (id bigserial PRIMARY KEY, embedding sparsevec(3))");
144+
145+
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO jdbc_items (embedding) VALUES (?), (?), (?), (?)");
146+
insertStmt.setObject(1, new PGsparsevec(new float[] {1, 1, 1}));
147+
insertStmt.setObject(2, new PGsparsevec(new float[] {2, 2, 2}));
148+
insertStmt.setObject(3, new PGsparsevec(new float[] {1, 1, 2}));
149+
insertStmt.setObject(4, null);
150+
insertStmt.executeUpdate();
151+
152+
PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM jdbc_items ORDER BY embedding <-> ? LIMIT 5");
153+
neighborStmt.setObject(1, new PGsparsevec(new float[] {1, 1, 1}));
154+
ResultSet rs = neighborStmt.executeQuery();
155+
List<Long> ids = new ArrayList<>();
156+
List<PGsparsevec> embeddings = new ArrayList<>();
157+
while (rs.next()) {
158+
ids.add(rs.getLong("id"));
159+
embeddings.add((PGsparsevec) rs.getObject("embedding"));
160+
}
161+
assertArrayEquals(new Long[] {1L, 3L, 2L, 4L}, ids.toArray());
162+
assertArrayEquals(new float[] {1, 1, 1}, embeddings.get(0).toArray());
163+
assertArrayEquals(new float[] {1, 1, 2}, embeddings.get(1).toArray());
164+
assertArrayEquals(new float[] {2, 2, 2}, embeddings.get(2).toArray());
165+
assertNull(embeddings.get(3));
166+
}
119167
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package com.pgvector;
2+
3+
import java.sql.SQLException;
4+
import java.util.Arrays;
5+
import com.pgvector.PGsparsevec;
6+
import org.junit.jupiter.api.Test;
7+
8+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
11+
public class PGsparsevecTest {
12+
@Test
13+
void testArrayConstructor() {
14+
PGsparsevec vec = new PGsparsevec(new float[] {1, 0, 2, 0, 3, 0});
15+
assertArrayEquals(new float[] {1, 0, 2, 0, 3, 0}, vec.toArray());
16+
}
17+
18+
@Test
19+
void testStringConstructor() throws SQLException {
20+
PGsparsevec vec = new PGsparsevec("{1:1,3:2,5:3}/6");
21+
assertArrayEquals(new float[] {1, 0, 2, 0, 3, 0}, vec.toArray());
22+
}
23+
24+
@Test
25+
void testFloatListConstructor() {
26+
Float[] a = new Float[] {Float.valueOf(1), Float.valueOf(2), Float.valueOf(3)};
27+
PGsparsevec vec = new PGsparsevec(Arrays.asList(a));
28+
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
29+
}
30+
31+
@Test
32+
void testDoubleListConstructor() {
33+
Double[] a = new Double[] {Double.valueOf(1), Double.valueOf(2), Double.valueOf(3)};
34+
PGsparsevec vec = new PGsparsevec(Arrays.asList(a));
35+
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
36+
}
37+
38+
@Test
39+
void testGetValue() {
40+
PGsparsevec vec = new PGsparsevec(new float[] {1, 0, 2, 0, 3, 0});
41+
assertEquals("{1:1.0,3:2.0,5:3.0}/6", vec.getValue());
42+
}
43+
}

0 commit comments

Comments
 (0)