|
1 | 1 | import numpy as np |
2 | | -from pgvector.psycopg import register_vector, register_vector_async |
| 2 | +from pgvector.psycopg import register_vector, register_vector_async, HalfVec, SparseVec |
3 | 3 | import psycopg |
4 | 4 | import pytest |
5 | 5 |
|
@@ -27,7 +27,7 @@ def test_works(self): |
27 | 27 |
|
28 | 28 | def test_binary_format(self): |
29 | 29 | embedding = np.array([1.5, 2, 3]) |
30 | | - res = conn.execute('SELECT %b::vector', (embedding,)).fetchone()[0] |
| 30 | + res = conn.execute('SELECT %b::vector', (embedding,), binary=True).fetchone()[0] |
31 | 31 | assert np.array_equal(res, embedding) |
32 | 32 |
|
33 | 33 | def test_text_format(self): |
@@ -71,6 +71,45 @@ def test_binary_copy_set_types(self): |
71 | 71 | copy.set_types(['int8', 'vector']) |
72 | 72 | copy.write_row([1, embedding]) |
73 | 73 |
|
| 74 | + def test_halfvec(self): |
| 75 | + conn.execute('DROP TABLE IF EXISTS half_items') |
| 76 | + conn.execute('CREATE TABLE half_items (id bigserial PRIMARY KEY, embedding halfvec(3))') |
| 77 | + |
| 78 | + embedding = HalfVec([1.5, 2, 3]) |
| 79 | + conn.execute('INSERT INTO half_items (embedding) VALUES (%s)', (embedding,)) |
| 80 | + |
| 81 | + res = conn.execute('SELECT * FROM half_items ORDER BY id').fetchall() |
| 82 | + |
| 83 | + def test_halfvec_binary_format(self): |
| 84 | + embedding = HalfVec([1.5, 2, 3]) |
| 85 | + res = conn.execute('SELECT %b::halfvec', (embedding,), binary=True).fetchone()[0] |
| 86 | + assert res.to_list() == [1.5, 2, 3] |
| 87 | + |
| 88 | + def test_halfvec_text_format(self): |
| 89 | + embedding = HalfVec([1.5, 2, 3]) |
| 90 | + res = conn.execute('SELECT %t::halfvec', (embedding,)).fetchone()[0] |
| 91 | + assert res.to_list() == [1.5, 2, 3] |
| 92 | + |
| 93 | + def test_sparsevec(self): |
| 94 | + conn.execute('DROP TABLE IF EXISTS sparse_items') |
| 95 | + conn.execute('CREATE TABLE sparse_items (id bigserial PRIMARY KEY, embedding sparsevec(6))') |
| 96 | + |
| 97 | + embedding = SparseVec.from_dense([0, 1.5, 0, 2, 0, 3]) |
| 98 | + conn.execute('INSERT INTO sparse_items (embedding) VALUES (%s)', (embedding,)) |
| 99 | + |
| 100 | + res = conn.execute('SELECT * FROM sparse_items ORDER BY id').fetchall() |
| 101 | + assert res[0][1].to_dense() == [0, 1.5, 0, 2, 0, 3] |
| 102 | + |
| 103 | + def test_sparsevec_binary_format(self): |
| 104 | + embedding = SparseVec.from_dense([1.5, 2, 3]) |
| 105 | + res = conn.execute('SELECT %b::sparsevec', (embedding,), binary=True).fetchone()[0] |
| 106 | + assert res.to_dense() == [1.5, 2, 3] |
| 107 | + |
| 108 | + def test_sparsevec_text_format(self): |
| 109 | + embedding = SparseVec.from_dense([1.5, 2, 3]) |
| 110 | + res = conn.execute('SELECT %t::sparsevec', (embedding,)).fetchone()[0] |
| 111 | + assert res.to_dense() == [1.5, 2, 3] |
| 112 | + |
74 | 113 | def test_bit(self): |
75 | 114 | res = conn.execute('SELECT %s::bit(3)', ('101',)).fetchone()[0] |
76 | 115 | assert res == '101' |
|
0 commit comments