|
1 | 1 | package com.pgvector; |
2 | 2 |
|
3 | | -import java.math.BigInteger; |
4 | | -import java.sql.*; |
| 3 | +import jakarta.persistence.*; |
| 4 | +import java.sql.SQLException; |
5 | 5 | import java.util.ArrayList; |
6 | 6 | import java.util.List; |
7 | | -import com.pgvector.PGvector; |
8 | | -import org.hibernate.boot.registry.StandardServiceRegistry; |
9 | | -import org.hibernate.boot.registry.StandardServiceRegistryBuilder; |
10 | | -import org.hibernate.boot.MetadataSources; |
11 | | -import org.hibernate.Session; |
12 | | -import org.hibernate.SessionFactory; |
| 7 | +import org.hibernate.annotations.Array; |
| 8 | +import org.hibernate.annotations.JdbcTypeCode; |
| 9 | +import org.hibernate.type.SqlTypes; |
13 | 10 | import org.junit.jupiter.api.Test; |
14 | 11 |
|
15 | 12 | import static org.junit.jupiter.api.Assertions.*; |
16 | 13 |
|
| 14 | +@Entity |
| 15 | +@Table(name = "hibernate_items") |
| 16 | +class Item { |
| 17 | + @Id |
| 18 | + @GeneratedValue |
| 19 | + private Long id; |
| 20 | + |
| 21 | + @Column |
| 22 | + @JdbcTypeCode(SqlTypes.VECTOR) |
| 23 | + @Array(length = 3) |
| 24 | + private float[] embedding; |
| 25 | + |
| 26 | + public Long getId() { |
| 27 | + return id; |
| 28 | + } |
| 29 | + |
| 30 | + public float[] getEmbedding() { |
| 31 | + return embedding; |
| 32 | + } |
| 33 | + |
| 34 | + public void setEmbedding(float[] embedding) { |
| 35 | + this.embedding = embedding; |
| 36 | + } |
| 37 | +} |
| 38 | + |
17 | 39 | public class HibernateTest { |
18 | 40 | @Test |
19 | 41 | void example() throws SQLException { |
20 | 42 | // disable logging |
21 | 43 | System.setProperty("org.jboss.logging.provider", "slf4j"); |
22 | 44 |
|
23 | | - StandardServiceRegistry registry = new StandardServiceRegistryBuilder().build(); |
24 | | - SessionFactory sessionFactory = new MetadataSources(registry).buildMetadata().buildSessionFactory(); |
25 | | - Session session = sessionFactory.openSession(); |
26 | | - session.beginTransaction(); |
27 | | - |
28 | | - session.createNativeQuery("CREATE EXTENSION IF NOT EXISTS vector").executeUpdate(); |
29 | | - session.createNativeQuery("DROP TABLE IF EXISTS hibernate_items").executeUpdate(); |
30 | | - session.createNativeQuery("CREATE TABLE hibernate_items (id bigserial PRIMARY KEY, embedding vector(3))").executeUpdate(); |
31 | | - |
32 | | - session.createNativeQuery("INSERT INTO hibernate_items (embedding) VALUES (CAST(? AS vector)), (CAST(? AS vector)), (CAST(? AS vector)), (NULL)") |
33 | | - .setParameter(1, (new PGvector(new float[] {1, 1, 1})).getValue()) |
34 | | - .setParameter(2, (new PGvector(new float[] {2, 2, 2})).getValue()) |
35 | | - .setParameter(3, (new PGvector(new float[] {1, 1, 2})).getValue()) |
36 | | - .executeUpdate(); |
37 | | - |
38 | | - @SuppressWarnings("unchecked") |
39 | | - List<Object[]> items = session |
40 | | - .createNativeQuery("SELECT id, CAST(embedding AS text) FROM hibernate_items ORDER BY embedding <-> CAST(? AS vector) LIMIT 5") |
41 | | - .setParameter(1, (new PGvector(new float[] {1, 1, 1})).getValue()) |
42 | | - .list(); |
43 | | - List<Long> ids = new ArrayList<>(); |
44 | | - List<PGvector> embeddings = new ArrayList<>(); |
45 | | - for (Object[] item : items) { |
46 | | - ids.add(Long.valueOf(((BigInteger) item[0]).longValue())); |
47 | | - embeddings.add(item[1] == null ? null : new PGvector((String) item[1])); |
48 | | - } |
49 | | - assertArrayEquals(new Long[] {1L, 3L, 2L, 4L}, ids.toArray()); |
50 | | - assertArrayEquals(new float[] {1, 1, 1}, embeddings.get(0).toArray()); |
51 | | - assertArrayEquals(new float[] {1, 1, 2}, embeddings.get(1).toArray()); |
52 | | - assertArrayEquals(new float[] {2, 2, 2}, embeddings.get(2).toArray()); |
53 | | - assertNull(embeddings.get(3)); |
54 | | - |
55 | | - session.getTransaction().commit(); |
56 | | - session.close(); |
| 45 | + EntityManagerFactory entityManagerFactory = Persistence.createEntityManagerFactory("default"); |
| 46 | + EntityManager entityManager = entityManagerFactory.createEntityManager(); |
| 47 | + |
| 48 | + entityManager.getTransaction().begin(); |
| 49 | + |
| 50 | + Item item1 = new Item(); |
| 51 | + item1.setEmbedding(new float[] {1, 1, 1}); |
| 52 | + entityManager.persist(item1); |
| 53 | + |
| 54 | + Item item2 = new Item(); |
| 55 | + item2.setEmbedding(new float[] {2, 2, 2}); |
| 56 | + entityManager.persist(item2); |
| 57 | + |
| 58 | + Item item3 = new Item(); |
| 59 | + item3.setEmbedding(new float[] {1, 1, 2}); |
| 60 | + entityManager.persist(item3); |
| 61 | + |
| 62 | + List<Item> items = entityManager |
| 63 | + .createQuery("FROM Item ORDER BY l2_distance(embedding, :embedding) LIMIT 5", Item.class) |
| 64 | + .setParameter("embedding", new float[] {1, 1, 1}) |
| 65 | + .getResultList(); |
| 66 | + assertArrayEquals(new Long[] {1L, 3L, 2L}, items.stream().map(v -> v.getId()).toArray()); |
| 67 | + assertArrayEquals(new float[] {1, 1, 1}, items.get(0).getEmbedding()); |
| 68 | + assertArrayEquals(new float[] {1, 1, 2}, items.get(1).getEmbedding()); |
| 69 | + assertArrayEquals(new float[] {2, 2, 2}, items.get(2).getEmbedding()); |
| 70 | + |
| 71 | + entityManager.getTransaction().commit(); |
57 | 72 | } |
58 | 73 | } |
0 commit comments