Skip to content

Commit 54b8f98

Browse files
committed
Added support for halfvec type
1 parent 45d8b38 commit 54b8f98

File tree

4 files changed

+192
-0
lines changed

4 files changed

+192
-0
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.1.5 (unreleased)
2+
3+
- Added support for `halfvec` type
4+
15
## 0.1.4 (2023-12-08)
26

37
- Added `List` constructor
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package com.pgvector;
2+
3+
import java.io.Serializable;
4+
import java.sql.Connection;
5+
import java.sql.SQLException;
6+
import java.util.Arrays;
7+
import java.util.List;
8+
import java.util.Objects;
9+
import org.postgresql.PGConnection;
10+
import org.postgresql.util.PGobject;
11+
12+
/**
13+
* PGhalfvec class
14+
*/
15+
public class PGhalfvec extends PGobject implements Serializable, Cloneable {
16+
private float[] vec;
17+
18+
/**
19+
* Constructor
20+
*/
21+
public PGhalfvec() {
22+
type = "halfvec";
23+
}
24+
25+
/**
26+
* Constructor
27+
*
28+
* @param v float array
29+
*/
30+
public PGhalfvec(float[] v) {
31+
this();
32+
vec = v;
33+
}
34+
35+
/**
36+
* Constructor
37+
*
38+
* @param <T> number
39+
* @param v list of numbers
40+
*/
41+
public <T extends Number> PGhalfvec(List<T> v) {
42+
this();
43+
if (Objects.isNull(v)) {
44+
vec = null;
45+
} else {
46+
vec = new float[v.size()];
47+
int i = 0;
48+
for (T f : v) {
49+
vec[i++] = f.floatValue();
50+
}
51+
}
52+
}
53+
54+
/**
55+
* Constructor
56+
*
57+
* @param s text representation of a half vector
58+
* @throws SQLException exception
59+
*/
60+
public PGhalfvec(String s) throws SQLException {
61+
this();
62+
setValue(s);
63+
}
64+
65+
/**
66+
* Sets the value from a text representation of a half vector
67+
*/
68+
public void setValue(String s) throws SQLException {
69+
if (s == null) {
70+
vec = null;
71+
} else {
72+
String[] sp = s.substring(1, s.length() - 1).split(",");
73+
vec = new float[sp.length];
74+
for (int i = 0; i < sp.length; i++) {
75+
vec[i] = Float.parseFloat(sp[i]);
76+
}
77+
}
78+
}
79+
80+
/**
81+
* Returns the text representation of a half vector
82+
*/
83+
public String getValue() {
84+
if (vec == null) {
85+
return null;
86+
} else {
87+
return Arrays.toString(vec).replace(" ", "");
88+
}
89+
}
90+
91+
/**
92+
* Returns an array
93+
*
94+
* @return an array
95+
*/
96+
public float[] toArray() {
97+
return vec;
98+
}
99+
100+
/**
101+
* Registers the halfvec type
102+
*
103+
* @param conn connection
104+
* @throws SQLException exception
105+
*/
106+
public static void addHalfvecType(Connection conn) throws SQLException {
107+
conn.unwrap(PGConnection.class).addDataType("halfvec", PGhalfvec.class);
108+
}
109+
}

src/test/java/com/pgvector/JDBCJavaTest.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,40 @@ void example(boolean readBinary) throws SQLException {
6868

6969
conn.close();
7070
}
71+
72+
@Test
73+
void testHalfvec() throws SQLException {
74+
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_java_test");
75+
76+
Statement setupStmt = conn.createStatement();
77+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
78+
setupStmt.executeUpdate("DROP TABLE IF EXISTS jdbc_items");
79+
80+
PGhalfvec.addHalfvecType(conn);
81+
82+
Statement createStmt = conn.createStatement();
83+
createStmt.executeUpdate("CREATE TABLE jdbc_items (id bigserial PRIMARY KEY, embedding halfvec(3))");
84+
85+
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO jdbc_items (embedding) VALUES (?), (?), (?), (?)");
86+
insertStmt.setObject(1, new PGhalfvec(new float[] {1, 1, 1}));
87+
insertStmt.setObject(2, new PGhalfvec(new float[] {2, 2, 2}));
88+
insertStmt.setObject(3, new PGhalfvec(new float[] {1, 1, 2}));
89+
insertStmt.setObject(4, null);
90+
insertStmt.executeUpdate();
91+
92+
PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM jdbc_items ORDER BY embedding <-> ? LIMIT 5");
93+
neighborStmt.setObject(1, new PGhalfvec(new float[] {1, 1, 1}));
94+
ResultSet rs = neighborStmt.executeQuery();
95+
List<Long> ids = new ArrayList<>();
96+
List<PGhalfvec> embeddings = new ArrayList<>();
97+
while (rs.next()) {
98+
ids.add(rs.getLong("id"));
99+
embeddings.add((PGhalfvec) rs.getObject("embedding"));
100+
}
101+
assertArrayEquals(new Long[] {1L, 3L, 2L, 4L}, ids.toArray());
102+
assertArrayEquals(new float[] {1, 1, 1}, embeddings.get(0).toArray());
103+
assertArrayEquals(new float[] {1, 1, 2}, embeddings.get(1).toArray());
104+
assertArrayEquals(new float[] {2, 2, 2}, embeddings.get(2).toArray());
105+
assertNull(embeddings.get(3));
106+
}
71107
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package com.pgvector;
2+
3+
import java.sql.SQLException;
4+
import java.util.Arrays;
5+
import com.pgvector.PGhalfvec;
6+
import org.junit.jupiter.api.Test;
7+
8+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
11+
public class PGhalfvecTest {
12+
@Test
13+
void testArrayConstructor() {
14+
PGhalfvec vec = new PGhalfvec(new float[] {1, 2, 3});
15+
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
16+
}
17+
18+
@Test
19+
void testStringConstructor() throws SQLException {
20+
PGhalfvec vec = new PGhalfvec("[1,2,3]");
21+
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
22+
}
23+
24+
@Test
25+
void testFloatListConstructor() {
26+
Float[] a = new Float[] {Float.valueOf(1), Float.valueOf(2), Float.valueOf(3)};
27+
PGhalfvec vec = new PGhalfvec(Arrays.asList(a));
28+
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
29+
}
30+
31+
@Test
32+
void testDoubleListConstructor() {
33+
Double[] a = new Double[] {Double.valueOf(1), Double.valueOf(2), Double.valueOf(3)};
34+
PGhalfvec vec = new PGhalfvec(Arrays.asList(a));
35+
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
36+
}
37+
38+
@Test
39+
void testGetValue() {
40+
PGhalfvec vec = new PGhalfvec(new float[] {1, 2, 3});
41+
assertEquals("[1.0,2.0,3.0]", vec.getValue());
42+
}
43+
}

0 commit comments

Comments
 (0)