Skip to content

Commit d0dd1d6

Browse files
committed
Improved Hibernate example
1 parent 5379069 commit d0dd1d6

File tree

6 files changed

+138
-48
lines changed

6 files changed

+138
-48
lines changed

.github/workflows/build.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,7 @@ jobs:
2525
make
2626
sudo make install
2727
- run: psql -d pgvector_java_test -c "CREATE EXTENSION vector"
28+
# Hibernate 6.4 requires Java 11+
29+
- if: ${{ matrix.java == 8 }}
30+
run: rm src/test/java/com/pgvector/HibernateTest.java
2831
- run: mvn -B -ntp test

README.md

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ For other build tools, see [this page](https://central.sonatype.com/artifact/com
2828

2929
And follow the instructions for your database library:
3030

31-
- Java - [JDBC](#jdbc-java), [Spring JDBC](#spring-jdbc)
31+
- Java - [JDBC](#jdbc-java), [Spring JDBC](#spring-jdbc), [Hibernate](#hibernate)
3232
- Kotlin - [JDBC](#jdbc-kotlin)
3333
- Groovy - [JDBC](#jdbc-groovy), [Groovy SQL](#groovy-sql)
3434
- Scala - [JDBC](#jdbc-scala), [Slick](#slick)
@@ -142,6 +142,64 @@ Use `vector_ip_ops` for inner product and `vector_cosine_ops` for cosine distanc
142142

143143
See a [full example](src/test/java/com/pgvector/SpringJDBCTest.java)
144144

145+
## Hibernate
146+
147+
Hibernate 6.4+ has a [vector module](https://docs.jboss.org/hibernate/orm/6.4/userguide/html_single/Hibernate_User_Guide.html#vector-module) (use this instead of `com.pgvector.pgvector`).
148+
149+
For Maven, add to `pom.xml` under `<dependencies>`:
150+
151+
```xml
152+
<dependency>
153+
<groupId>org.hibernate.orm</groupId>
154+
<artifactId>hibernate-vector</artifactId>
155+
<version>6.4.0.Final</version>
156+
</dependency>
157+
```
158+
159+
Define an entity
160+
161+
```java
162+
import jakarta.persistence.*;
163+
import org.hibernate.annotations.Array;
164+
import org.hibernate.annotations.JdbcTypeCode;
165+
import org.hibernate.type.SqlTypes;
166+
167+
@Entity
168+
class Item {
169+
@Id
170+
@GeneratedValue
171+
private Long id;
172+
173+
@Column
174+
@JdbcTypeCode(SqlTypes.VECTOR)
175+
@Array(length = 3) // dimensions
176+
private float[] embedding;
177+
178+
public void setEmbedding(float[] embedding) {
179+
this.embedding = embedding;
180+
}
181+
}
182+
```
183+
184+
Insert a vector
185+
186+
```java
187+
Item item = new Item();
188+
item.setEmbedding(new float[] {1, 1, 1});
189+
entityManager.persist(item);
190+
```
191+
192+
Get the nearest neighbors
193+
194+
```java
195+
List<Item> items = entityManager
196+
.createQuery("FROM Item ORDER BY l2_distance(embedding, :embedding) LIMIT 5", Item.class)
197+
.setParameter("embedding", new float[] {1, 1, 1})
198+
.getResultList();
199+
```
200+
201+
See a [full example](src/test/java/com/pgvector/HibernateTest.java)
202+
145203
## JDBC (Kotlin)
146204

147205
Import the `PGvector` class

pom.xml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@
9696
<scope>test</scope>
9797
</dependency>
9898
<dependency>
99-
<groupId>org.hibernate</groupId>
100-
<artifactId>hibernate-core</artifactId>
101-
<version>5.6.15.Final</version>
99+
<groupId>org.hibernate.orm</groupId>
100+
<artifactId>hibernate-vector</artifactId>
101+
<version>6.4.0.Final</version>
102102
<scope>test</scope>
103103
</dependency>
104104
</dependencies>
Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,73 @@
11
package com.pgvector;
22

3-
import java.math.BigInteger;
4-
import java.sql.*;
3+
import jakarta.persistence.*;
4+
import java.sql.SQLException;
55
import java.util.ArrayList;
66
import java.util.List;
7-
import com.pgvector.PGvector;
8-
import org.hibernate.boot.registry.StandardServiceRegistry;
9-
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
10-
import org.hibernate.boot.MetadataSources;
11-
import org.hibernate.Session;
12-
import org.hibernate.SessionFactory;
7+
import org.hibernate.annotations.Array;
8+
import org.hibernate.annotations.JdbcTypeCode;
9+
import org.hibernate.type.SqlTypes;
1310
import org.junit.jupiter.api.Test;
1411

1512
import static org.junit.jupiter.api.Assertions.*;
1613

14+
@Entity
15+
@Table(name = "hibernate_items")
16+
class Item {
17+
@Id
18+
@GeneratedValue
19+
private Long id;
20+
21+
@Column
22+
@JdbcTypeCode(SqlTypes.VECTOR)
23+
@Array(length = 3)
24+
private float[] embedding;
25+
26+
public Long getId() {
27+
return id;
28+
}
29+
30+
public float[] getEmbedding() {
31+
return embedding;
32+
}
33+
34+
public void setEmbedding(float[] embedding) {
35+
this.embedding = embedding;
36+
}
37+
}
38+
1739
public class HibernateTest {
1840
@Test
1941
void example() throws SQLException {
2042
// disable logging
2143
System.setProperty("org.jboss.logging.provider", "slf4j");
2244

23-
StandardServiceRegistry registry = new StandardServiceRegistryBuilder().build();
24-
SessionFactory sessionFactory = new MetadataSources(registry).buildMetadata().buildSessionFactory();
25-
Session session = sessionFactory.openSession();
26-
session.beginTransaction();
27-
28-
session.createNativeQuery("CREATE EXTENSION IF NOT EXISTS vector").executeUpdate();
29-
session.createNativeQuery("DROP TABLE IF EXISTS hibernate_items").executeUpdate();
30-
session.createNativeQuery("CREATE TABLE hibernate_items (id bigserial PRIMARY KEY, embedding vector(3))").executeUpdate();
31-
32-
session.createNativeQuery("INSERT INTO hibernate_items (embedding) VALUES (CAST(? AS vector)), (CAST(? AS vector)), (CAST(? AS vector)), (NULL)")
33-
.setParameter(1, (new PGvector(new float[] {1, 1, 1})).getValue())
34-
.setParameter(2, (new PGvector(new float[] {2, 2, 2})).getValue())
35-
.setParameter(3, (new PGvector(new float[] {1, 1, 2})).getValue())
36-
.executeUpdate();
37-
38-
@SuppressWarnings("unchecked")
39-
List<Object[]> items = session
40-
.createNativeQuery("SELECT id, CAST(embedding AS text) FROM hibernate_items ORDER BY embedding <-> CAST(? AS vector) LIMIT 5")
41-
.setParameter(1, (new PGvector(new float[] {1, 1, 1})).getValue())
42-
.list();
43-
List<Long> ids = new ArrayList<>();
44-
List<PGvector> embeddings = new ArrayList<>();
45-
for (Object[] item : items) {
46-
ids.add(Long.valueOf(((BigInteger) item[0]).longValue()));
47-
embeddings.add(item[1] == null ? null : new PGvector((String) item[1]));
48-
}
49-
assertArrayEquals(new Long[] {1L, 3L, 2L, 4L}, ids.toArray());
50-
assertArrayEquals(new float[] {1, 1, 1}, embeddings.get(0).toArray());
51-
assertArrayEquals(new float[] {1, 1, 2}, embeddings.get(1).toArray());
52-
assertArrayEquals(new float[] {2, 2, 2}, embeddings.get(2).toArray());
53-
assertNull(embeddings.get(3));
54-
55-
session.getTransaction().commit();
56-
session.close();
45+
EntityManagerFactory entityManagerFactory = Persistence.createEntityManagerFactory("default");
46+
EntityManager entityManager = entityManagerFactory.createEntityManager();
47+
48+
entityManager.getTransaction().begin();
49+
50+
Item item1 = new Item();
51+
item1.setEmbedding(new float[] {1, 1, 1});
52+
entityManager.persist(item1);
53+
54+
Item item2 = new Item();
55+
item2.setEmbedding(new float[] {2, 2, 2});
56+
entityManager.persist(item2);
57+
58+
Item item3 = new Item();
59+
item3.setEmbedding(new float[] {1, 1, 2});
60+
entityManager.persist(item3);
61+
62+
List<Item> items = entityManager
63+
.createQuery("FROM Item ORDER BY l2_distance(embedding, :embedding) LIMIT 5", Item.class)
64+
.setParameter("embedding", new float[] {1, 1, 1})
65+
.getResultList();
66+
assertArrayEquals(new Long[] {1L, 3L, 2L}, items.stream().map(v -> v.getId()).toArray());
67+
assertArrayEquals(new float[] {1, 1, 1}, items.get(0).getEmbedding());
68+
assertArrayEquals(new float[] {1, 1, 2}, items.get(1).getEmbedding());
69+
assertArrayEquals(new float[] {2, 2, 2}, items.get(2).getEmbedding());
70+
71+
entityManager.getTransaction().commit();
5772
}
5873
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<persistence xmlns="http://xmlns.jcp.org/xml/ns/persistence"
2+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
3+
xsi:schemaLocation="http://xmlns.jcp.org/xml/ns/persistence
4+
http://xmlns.jcp.org/xml/ns/persistence/persistence_2_1.xsd"
5+
version="2.1">
6+
<persistence-unit name="default">
7+
<properties>
8+
<property name="jakarta.persistence.jdbc.url"
9+
value="jdbc:postgresql://localhost:5432/pgvector_java_test" />
10+
<property name="hibernate.show_sql"
11+
value="false" />
12+
<property name="hibernate.hbm2ddl.auto"
13+
value="create" />
14+
</properties>
15+
</persistence-unit>
16+
</persistence>

src/test/resources/hibernate.properties

Lines changed: 0 additions & 2 deletions
This file was deleted.

0 commit comments

Comments
 (0)